├── assets ├── schemes.png ├── teaser.webp └── Pipeline.webp ├── .gitignore ├── diffusion ├── __pycache__ │ ├── respace.cpython-312.pyc │ ├── __init__.cpython-312.pyc │ ├── diffusion_utils.cpython-312.pyc │ └── gaussian_diffusion.cpython-312.pyc ├── __init__.py ├── diffusion_utils.py ├── respace.py └── timestep_sampler.py ├── preparation ├── change_fps.sh ├── extract_audio.py ├── process_audio.py ├── extract_frames.py └── audio_processor.py ├── cache.sh ├── datasets ├── __init__.py ├── frames_dataset.py └── video_transforms.py ├── configs ├── sample.yaml └── train.yaml ├── models ├── vae │ ├── __init__.py │ └── vae.py ├── __init__.py ├── audio_proj.py ├── utils.py ├── basic_modules.py └── model_cat.py ├── README.md ├── sample_pl.py ├── sample.py ├── sample_long.py ├── sample_long_pl.py ├── train_pl.py └── utils.py /assets/schemes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhang-haojie/LetsTalk/HEAD/assets/schemes.png -------------------------------------------------------------------------------- /assets/teaser.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhang-haojie/LetsTalk/HEAD/assets/teaser.webp -------------------------------------------------------------------------------- /assets/Pipeline.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhang-haojie/LetsTalk/HEAD/assets/Pipeline.webp -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *$py.class 4 | 5 | pretrained* 6 | results* 7 | sample_videos* -------------------------------------------------------------------------------- /diffusion/__pycache__/respace.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhang-haojie/LetsTalk/HEAD/diffusion/__pycache__/respace.cpython-312.pyc -------------------------------------------------------------------------------- /diffusion/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhang-haojie/LetsTalk/HEAD/diffusion/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /diffusion/__pycache__/diffusion_utils.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhang-haojie/LetsTalk/HEAD/diffusion/__pycache__/diffusion_utils.cpython-312.pyc -------------------------------------------------------------------------------- /diffusion/__pycache__/gaussian_diffusion.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhang-haojie/LetsTalk/HEAD/diffusion/__pycache__/gaussian_diffusion.cpython-312.pyc -------------------------------------------------------------------------------- /preparation/change_fps.sh: -------------------------------------------------------------------------------- 1 | 2 | SOURCE_FOLDER="/path/to/dataset/videos" 3 | OUTPUT_FOLDER="/path/to/dataset/videos_25fps" 4 | TARGET_FRAMERATE=25 5 | 6 | mkdir -p "$OUTPUT_FOLDER" 7 | 8 | export SOURCE_FOLDER OUTPUT_FOLDER TARGET_FRAMERATE 9 | 10 | find "$SOURCE_FOLDER" -name '*.mp4' | parallel ffmpeg -i {} -r "$TARGET_FRAMERATE" "$OUTPUT_FOLDER/{/.}.mp4" 11 | 12 | echo "Frame rate conversion completed for all videos." 13 | -------------------------------------------------------------------------------- /cache.sh: -------------------------------------------------------------------------------- 1 | pip install pytorch-lightning -i https://mirrors.tencent.com/tencent_pypi/simple 2 | pip install omegaconf -i https://mirrors.tencent.com/tencent_pypi/simple 3 | pip install diffusers -i https://mirrors.tencent.com/tencent_pypi/simple 4 | pip install tensorboard -i https://mirrors.tencent.com/tencent_pypi/simple 5 | pip install moviepy==1.0.3 -i https://mirrors.tencent.com/tencent_pypi/simple 6 | 7 | 8 | python3 datasets/images2latent.py -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from datasets import video_transforms 3 | from .frames_dataset import VideoFramesDataset, VideoLatentDataset, FramesLatentDataset, VideoDubbingDataset 4 | 5 | 6 | def get_dataset(args): 7 | temporal_sample = video_transforms.TemporalRandomCrop((args.num_frames + args.initial_frames) * args.frame_interval) 8 | 9 | transform = transforms.Compose([ 10 | video_transforms.ToTensorVideo(), # TCHW 11 | video_transforms.UCFCenterCropVideo(args.image_size), 12 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) 13 | ]) 14 | 15 | if args.dataset == "frame": 16 | return VideoFramesDataset(args, transform=transform, temporal_sample=temporal_sample) 17 | elif "frame" in args.dataset: 18 | return FramesLatentDataset(args, temporal_sample=temporal_sample) 19 | elif args.dataset == "video": 20 | return VideoDubbingDataset(args, transform=transform) 21 | return VideoLatentDataset(args, temporal_sample=temporal_sample) 22 | -------------------------------------------------------------------------------- /configs/sample.yaml: -------------------------------------------------------------------------------- 1 | # dataset 2 | dataset: frame 3 | 4 | data_dir: /path/to/dataset/videos # s 5 | audio_dir: /path/to/dataset/audio 6 | pretrained_model_path: /path/to/checkpoints/latte 7 | 8 | # path: 9 | pretrained: 10 | wav2vec: ./pretrained/wav2vec/wav2vec2-base-960h 11 | audio_separator: ./pretrained/audio_separator/Kim_Vocal_2.onnx 12 | save_video_path: ./sample_videos 13 | 14 | # model config: 15 | model: VDT-B/2 16 | num_frames: 16 17 | initial_frames: 2 18 | clip_frames: 16 19 | frame_interval: 1 20 | image_size: 256 21 | in_channels: 4 22 | temp_comp_rate: 1 23 | fixed_spatial: False 24 | attention_bias: True 25 | learn_sigma: True 26 | extras: 1 # [1, 2] 1 unconditional generation, 2 class-conditional generation 27 | num_classes: 28 | use_seque: False 29 | audio_margin: 2 30 | audio_dim: 768 31 | audio_token: 32 32 | 33 | precision: bf16 # GPU Axx: bf16-mixed 34 | gradient_checkpointing: False 35 | 36 | # sample config: 37 | seed: 38 | sample_method: ddpm 39 | num_sampling_steps: 250 40 | cfg_scale: 1.0 41 | sample_rate: 16000 42 | fps: 25 43 | num_workers: 8 44 | per_proc_batch_size: 1 45 | -------------------------------------------------------------------------------- /configs/train.yaml: -------------------------------------------------------------------------------- 1 | # dataset 2 | dataset: frame 3 | 4 | data_dir: /path/to/dataset/videos # s 5 | audio_dir: /path/to/dataset/audio 6 | pretrained_model_path: /path/to/checkpoints/latte 7 | 8 | # save and load 9 | results_dir: ./results 10 | pretrained: 11 | 12 | # model config: 13 | model: VDT-B/2 14 | num_frames: 16 15 | initial_frames: 0 16 | clip_frames: 16 17 | frame_interval: 1 18 | fixed_spatial: False 19 | image_size: 256 20 | in_channels: 4 21 | temp_comp_rate: 1 22 | num_sampling_steps: 250 23 | learn_sigma: True # important 24 | extras: 1 # [1, 2] 1 unconditional generation, 2 class-conditional generation 25 | use_seque: False 26 | audio_margin: 2 27 | audio_dim: 768 28 | audio_token: 32 29 | 30 | precision: bf16 # GPU Axx: bf16-mixed 31 | 32 | # train config: 33 | save_ceph: True # important 34 | learning_rate: 1e-4 35 | ckpt_every: 100000 36 | clip_max_norm: 0.1 37 | start_clip_iter: 20000 38 | local_batch_size: 1 # important 39 | max_train_steps: 300000 40 | global_seed: 3407 41 | num_workers: 8 42 | log_every: 50 43 | lr_warmup_steps: 0 44 | resume_from_checkpoint: 45 | gradient_accumulation_steps: 1 # TODO 46 | num_classes: 47 | 48 | # low VRAM and speed up training 49 | use_compile: False 50 | gradient_checkpointing: True 51 | enable_xformers_memory_efficient_attention: False -------------------------------------------------------------------------------- /preparation/extract_audio.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import argparse 4 | import subprocess 5 | from tqdm import tqdm 6 | 7 | def main(video_root, audio_root): 8 | error_files = [] 9 | video_files = [] 10 | for root, _, files in os.walk(video_root): 11 | for file in files: 12 | if file.endswith('.mp4'): 13 | video_files.append(os.path.join(root, file)) 14 | 15 | for video_path in tqdm(video_files): 16 | relative_path = os.path.relpath(video_path, video_root) 17 | wav_path = os.path.join(audio_root, os.path.splitext(relative_path)[0] + '.wav') 18 | 19 | if not os.path.exists(os.path.dirname(wav_path)): 20 | os.makedirs(os.path.dirname(wav_path)) 21 | 22 | ffmpeg_command = [ 23 | 'ffmpeg', '-y', 24 | '-i', video_path, 25 | '-vn', '-acodec', 26 | "pcm_s16le", '-ar', '16000', '-ac', '2', 27 | wav_path 28 | ] 29 | 30 | try: 31 | subprocess.run(ffmpeg_command, check=True) 32 | except subprocess.CalledProcessError as e: 33 | print(f"Error extracting audio from video {video_path}: {e}") 34 | error_files.append(video_path) 35 | continue 36 | 37 | for e_file in error_files: 38 | print(f"Error extracting audio from video {e_file}") 39 | 40 | if __name__ == '__main__': 41 | video_root = "path/to/your/input/folder" 42 | audio_root = "path/to/your/output/folder" 43 | os.makedirs(audio_root, exist_ok=True) 44 | 45 | main(video_root, audio_root) 46 | -------------------------------------------------------------------------------- /diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from . import gaussian_diffusion as gd 7 | from .respace import SpacedDiffusion, space_timesteps 8 | 9 | 10 | def create_diffusion( 11 | timestep_respacing, 12 | noise_schedule="linear", 13 | use_kl=False, 14 | sigma_small=False, 15 | predict_xstart=False, 16 | learn_sigma=True, 17 | # learn_sigma=False, 18 | rescale_learned_sigmas=False, 19 | diffusion_steps=1000 20 | ): 21 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) 22 | if use_kl: 23 | loss_type = gd.LossType.RESCALED_KL 24 | elif rescale_learned_sigmas: 25 | loss_type = gd.LossType.RESCALED_MSE 26 | else: 27 | loss_type = gd.LossType.MSE 28 | if timestep_respacing is None or timestep_respacing == "": 29 | timestep_respacing = [diffusion_steps] 30 | return SpacedDiffusion( 31 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), 32 | betas=betas, 33 | model_mean_type=( 34 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 35 | ), 36 | model_var_type=( 37 | ( 38 | gd.ModelVarType.FIXED_LARGE 39 | if not sigma_small 40 | else gd.ModelVarType.FIXED_SMALL 41 | ) 42 | if not learn_sigma 43 | else gd.ModelVarType.LEARNED_RANGE 44 | ), 45 | loss_type=loss_type 46 | # rescale_timesteps=rescale_timesteps, 47 | ) 48 | -------------------------------------------------------------------------------- /preparation/process_audio.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from pathlib import Path 4 | from tqdm import tqdm 5 | from audio_processor import AudioProcessor 6 | 7 | 8 | def get_audio_paths(source_dir: Path): 9 | """Get .wav files from the source directory.""" 10 | return sorted([item for item in source_dir.iterdir() if item.is_file() and item.suffix == ".wav"]) 11 | 12 | def process_audio(audio_path: Path, output_dir: Path, audio_processor: AudioProcessor): 13 | """Process a single audio file and save its embedding.""" 14 | audio_emb, _ = audio_processor.preprocess(audio_path) 15 | torch.save(audio_emb, os.path.join(output_dir, f"{audio_path.stem}.pt")) 16 | 17 | def process_all_audios(input_audio_list, output_dir): 18 | """Process all audio files in the list.""" 19 | wav2vec_model_path = "pretrained/wav2vec/wav2vec2-base-960h" 20 | audio_separator_model_file = "pretrained/audio_separator/Kim_Vocal_2.onnx" 21 | audio_processor = AudioProcessor( 22 | 16000, 23 | 25, 24 | wav2vec_model_path, 25 | os.path.dirname(audio_separator_model_file), 26 | os.path.basename(audio_separator_model_file), 27 | os.path.join(output_dir, "vocals"), 28 | # only_last_features=True 29 | ) 30 | error_files = [] 31 | for audio_path in tqdm(input_audio_list, desc="Processing audios"): 32 | try: 33 | process_audio(audio_path, output_dir, audio_processor) 34 | except: 35 | error_files.append(audio_path) 36 | print("Error file:") 37 | for error_file in error_files: 38 | print(error_file) 39 | 40 | 41 | if __name__ == "__main__": 42 | input_dir = Path("path/to/your/output/folder") # Set your input directory 43 | output_dir = os.path.join(input_dir.parent, "audio_emb") 44 | os.makedirs(output_dir, exist_ok=True) 45 | 46 | audio_paths = get_audio_paths(input_dir) 47 | process_all_audios(audio_paths, output_dir) 48 | -------------------------------------------------------------------------------- /models/vae/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from pathlib import Path 5 | from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D 6 | 7 | VAE_PATH = {"884-16c-hy": "./ckpts/hunyuan-video-t2v-720p/vae"} 8 | 9 | def load_vae(vae_type: str="884-16c-hy", 10 | vae_precision: str=None, 11 | sample_size: tuple=None, 12 | vae_path: str=None, 13 | ): 14 | """the fucntion to load the 3D VAE model 15 | 16 | Args: 17 | vae_type (str): the type of the 3D VAE model. Defaults to "884-16c-hy". 18 | vae_precision (str, optional): the precision to load vae. Defaults to None. 19 | sample_size (tuple, optional): the tiling size. Defaults to None. 20 | vae_path (str, optional): the path to vae. Defaults to None. 21 | logger (_type_, optional): logger. Defaults to None. 22 | device (_type_, optional): device to load vae. Defaults to None. 23 | """ 24 | if vae_path is None: 25 | vae_path = VAE_PATH[vae_type] 26 | 27 | config = AutoencoderKLCausal3D.load_config(vae_path) 28 | if sample_size: 29 | vae = AutoencoderKLCausal3D.from_config(config, sample_size=sample_size) 30 | else: 31 | vae = AutoencoderKLCausal3D.from_config(config) 32 | 33 | vae_ckpt = Path(vae_path) / "pytorch_model.pt" 34 | assert vae_ckpt.exists(), f"VAE checkpoint not found: {vae_ckpt}" 35 | 36 | ckpt = torch.load(vae_ckpt, map_location=vae.device) 37 | if "state_dict" in ckpt: 38 | ckpt = ckpt["state_dict"] 39 | if any(k.startswith("vae.") for k in ckpt.keys()): 40 | ckpt = {k.replace("vae.", ""): v for k, v in ckpt.items() if k.startswith("vae.")} 41 | vae.load_state_dict(ckpt) 42 | 43 | spatial_compression_ratio = vae.config.spatial_compression_ratio 44 | time_compression_ratio = vae.config.time_compression_ratio 45 | 46 | vae.requires_grad_(False) 47 | vae.eval() 48 | 49 | return vae, vae_path, spatial_compression_ratio, time_compression_ratio 50 | 51 | 52 | if __name__ == "__main__": 53 | vae_name = "884-16c-hy" 54 | vae_precision = "vae-precision" 55 | vae, _, s_ratio, t_ratio = load_vae( 56 | vae_name, 57 | ) 58 | vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio} 59 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .model_cat import VDTcat_models 2 | from .model_cat_long import VDTcatlong_models 3 | from .lite_cat_long import Litecatlong_models 4 | 5 | 6 | def customized_lr_scheduler(optimizer, warmup_steps=5000): # 5000 from u-vit 7 | from torch.optim.lr_scheduler import LambdaLR 8 | def fn(step): 9 | if warmup_steps > 0: 10 | return min(step / warmup_steps, 1) 11 | else: 12 | return 1 13 | return LambdaLR(optimizer, fn) 14 | 15 | 16 | def get_lr_scheduler(optimizer, name, **kwargs): 17 | if name == 'warmup': 18 | return customized_lr_scheduler(optimizer, **kwargs) 19 | elif name == 'cosine': 20 | from torch.optim.lr_scheduler import CosineAnnealingLR 21 | return CosineAnnealingLR(optimizer, **kwargs) 22 | else: 23 | raise NotImplementedError(name) 24 | 25 | 26 | def get_models(args): 27 | if 'Litecatlong' in args.model: 28 | return Litecatlong_models[args.model]( 29 | input_size=args.latent_size, 30 | context_dim=args.audio_dim, 31 | in_channels=args.in_channels, 32 | num_classes=args.num_classes, 33 | num_frames=args.clip_frames, 34 | initial_frames=args.initial_frames, 35 | learn_sigma=args.learn_sigma, 36 | extras=args.extras, 37 | temp_comp_rate=args.temp_comp_rate, 38 | gradient_checkpointing=args.gradient_checkpointing 39 | ) 40 | elif 'VDTcatlong' in args.model: 41 | return VDTcatlong_models[args.model]( 42 | input_size=args.latent_size, 43 | context_dim=args.audio_dim, 44 | in_channels=args.in_channels, 45 | num_classes=args.num_classes, 46 | num_frames=args.clip_frames, 47 | initial_frames=args.initial_frames, 48 | learn_sigma=args.learn_sigma, 49 | extras=args.extras, 50 | temp_comp_rate=args.temp_comp_rate, 51 | gradient_checkpointing=args.gradient_checkpointing 52 | ) 53 | elif 'VDTcat' in args.model: 54 | return VDTcat_models[args.model]( 55 | input_size=args.latent_size, 56 | context_dim=args.audio_dim, 57 | in_channels=args.in_channels, 58 | num_classes=args.num_classes, 59 | num_frames=args.clip_frames, 60 | learn_sigma=args.learn_sigma, 61 | extras=args.extras, 62 | temp_comp_rate=args.temp_comp_rate, 63 | gradient_checkpointing=args.gradient_checkpointing 64 | ) 65 | else: 66 | raise '{} Model Not Supported!'.format(args.model) 67 | -------------------------------------------------------------------------------- /preparation/extract_frames.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from pathlib import Path 3 | from concurrent.futures import ThreadPoolExecutor, as_completed 4 | 5 | 6 | def convert_video_to_images(video_path: Path, output_dir: Path): 7 | """ 8 | Convert a video file into a sequence of images at 25 fps and save them in the output directory. 9 | 10 | Args: 11 | video_path (Path): The path to the input video file. 12 | output_dir (Path): The directory where the extracted images will be saved. 13 | 14 | Returns: 15 | None 16 | """ 17 | # Ensure the output directory exists 18 | output_dir.mkdir(parents=True, exist_ok=True) 19 | 20 | # ffmpeg command to convert video to images at 25 fps 21 | ffmpeg_command = [ 22 | 'ffmpeg', 23 | '-i', str(video_path), 24 | '-vf', 'fps=25', 25 | str(output_dir / '%04d.jpg') # Save images as 0001.png, 0002.png, etc. 26 | ] 27 | 28 | try: 29 | print(f"Running command: {' '.join(ffmpeg_command)}") 30 | subprocess.run(ffmpeg_command, check=True) 31 | except subprocess.CalledProcessError as e: 32 | print(f"Error converting video {video_path} to images: {e}") 33 | raise 34 | 35 | 36 | def process_videos_in_folder(folder_path: Path, output_root: Path, max_workers=4): 37 | """ 38 | Traverse through all mp4 files in a folder and convert each video to images in parallel using threads. 39 | 40 | Args: 41 | folder_path (Path): The directory containing mp4 files. 42 | output_root (Path): The root directory where extracted frames will be saved. 43 | max_workers (int): Maximum number of threads for parallel processing. 44 | 45 | Returns: 46 | None 47 | """ 48 | # Gather all mp4 files 49 | video_files = list(folder_path.glob('*.mp4')) 50 | 51 | # Use ThreadPoolExecutor to process videos in parallel 52 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 53 | futures = [] 54 | for video_file in video_files: 55 | # Create a directory named after the video file (without extension) 56 | output_dir = output_root / video_file.stem 57 | print(f"Submitting video: {video_file.name} for processing") 58 | 59 | # Submit the task to the thread pool 60 | futures.append(executor.submit(convert_video_to_images, video_file, output_dir)) 61 | 62 | # Wait for all futures to complete 63 | for future in as_completed(futures): 64 | try: 65 | future.result() # This will raise any exception that occurred during the task execution 66 | except Exception as e: 67 | print(f"Error during processing: {e}") 68 | 69 | 70 | if __name__ == "__main__": 71 | # Define the input folder containing .mp4 videos and the output root directory 72 | input_folder = Path("path/to/your/input/folder") # Replace with your input folder path 73 | output_folder = Path("path/to/your/output/folder") # Replace with your desired output folder path 74 | 75 | # Start processing with multi-threading (adjust the number of threads by changing max_workers) 76 | process_videos_in_folder(input_folder, output_folder, max_workers=4) 77 | -------------------------------------------------------------------------------- /diffusion/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import torch as th 7 | import numpy as np 8 | 9 | 10 | def normal_kl(mean1, logvar1, mean2, logvar2): 11 | """ 12 | Compute the KL divergence between two gaussians. 13 | Shapes are automatically broadcasted, so batches can be compared to 14 | scalars, among other use cases. 15 | """ 16 | tensor = None 17 | for obj in (mean1, logvar1, mean2, logvar2): 18 | if isinstance(obj, th.Tensor): 19 | tensor = obj 20 | break 21 | assert tensor is not None, "at least one argument must be a Tensor" 22 | 23 | # Force variances to be Tensors. Broadcasting helps convert scalars to 24 | # Tensors, but it does not work for th.exp(). 25 | logvar1, logvar2 = [ 26 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 27 | for x in (logvar1, logvar2) 28 | ] 29 | 30 | return 0.5 * ( 31 | -1.0 32 | + logvar2 33 | - logvar1 34 | + th.exp(logvar1 - logvar2) 35 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 36 | ) 37 | 38 | 39 | def approx_standard_normal_cdf(x): 40 | """ 41 | A fast approximation of the cumulative distribution function of the 42 | standard normal. 43 | """ 44 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 45 | 46 | 47 | def continuous_gaussian_log_likelihood(x, *, means, log_scales): 48 | """ 49 | Compute the log-likelihood of a continuous Gaussian distribution. 50 | :param x: the targets 51 | :param means: the Gaussian mean Tensor. 52 | :param log_scales: the Gaussian log stddev Tensor. 53 | :return: a tensor like x of log probabilities (in nats). 54 | """ 55 | centered_x = x - means 56 | inv_stdv = th.exp(-log_scales) 57 | normalized_x = centered_x * inv_stdv 58 | log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x) 59 | return log_probs 60 | 61 | 62 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 63 | """ 64 | Compute the log-likelihood of a Gaussian distribution discretizing to a 65 | given image. 66 | :param x: the target images. It is assumed that this was uint8 values, 67 | rescaled to the range [-1, 1]. 68 | :param means: the Gaussian mean Tensor. 69 | :param log_scales: the Gaussian log stddev Tensor. 70 | :return: a tensor like x of log probabilities (in nats). 71 | """ 72 | assert x.shape == means.shape == log_scales.shape 73 | centered_x = x - means 74 | inv_stdv = th.exp(-log_scales) 75 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 76 | cdf_plus = approx_standard_normal_cdf(plus_in) 77 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 78 | cdf_min = approx_standard_normal_cdf(min_in) 79 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 80 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 81 | cdf_delta = cdf_plus - cdf_min 82 | log_probs = th.where( 83 | x < -0.999, 84 | log_cdf_plus, 85 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 86 | ) 87 | assert log_probs.shape == x.shape 88 | return log_probs 89 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |

Efficient Long-duration Talking Video Synthesis with Linear Diffusion Transformer under Multimodal Guidance

4 | 5 | ####

[Haojie Zhang](https://zhang-haojie.github.io/), [Zhihao Liang](https://lzhnb.github.io/), Ruibo Fu, Bingyan Liu, Zhengqi Wen,

6 | ####

Xuefei Liu, Jianhua Tao, Yaling Liang

7 | 8 | 9 | 10 | 11 |
12 | 13 | 14 | ## 🚀 Introduction 15 | **TL;DR:** We propose LetsTalk, a diffusion transformer for audio-driven portrait animation. By leveraging DC-VAE and linear attention, LetsTalk enables efficient multimodal fusion and consistent portrait generation, while memory bank and noise-regularized training further improve the quality and stability of long-duration videos. 16 | 17 |
18 | image 19 |
20 | 21 | **Abstract:** Long-duration talking video synthesis faces enduring challenges in achieving high video quality, portrait and temporal consistency, and computational efficiency. As video length increases, issues such as visual degradation, identity inconsistency, temporal incoherence, and error accumulation become increasingly problematic, severely affecting the realism and reliability of the results. 22 | To address these challenges, we present LetsTalk, a diffusion transformer framework equipped with multimodal guidance and a novel memory bank mechanism, explicitly maintaining contextual continuity and enabling robust, high-quality, and efficient generation of long-duration talking videos. In particular, LetsTalk introduces a noise-regularized memory bank to alleviate error accumulation and sampling artifacts during extended video generation. To further improve efficiency and spatiotemporal consistency, LetsTalk employs a deep compression autoencoder and a spatiotemporal-aware transformer with linear attention for effective multimodal fusion. We systematically analyze three fusion schemes and show that combining deep (Symbiotic Fusion) for portrait features and shallow (Direct Fusion) for audio achieves superior visual realism and precise speech-driven motion, while preserving diversity of movements. Extensive experiments demonstrate that LetsTalk establishes new state-of-the-art in generation quality, producing temporally coherent and realistic talking videos with enhanced diversity and liveliness, and maintains remarkable efficiency with 8x fewer parameters than previous approaches. 23 | 24 | 25 | ## 🎁 Overview 26 | 27 |
28 | image 29 |
30 | 31 | Overview of our LetsTalk framework for robust long-duration talking head video generation. Our system combines a deep compression autoencoder to reduce spatial redundancy while preserving temporal features, and transformer blocks with intertwined temporal and spatial attention to effectively capture both intra-frame details and long-range dependencies. 32 | Portrait and audio embeddings are extracted; Symbiotic Fusion integrates the portrait embedding, and Direct Fusion incorporates the audio embedding, providing effective multimodal guidance for video synthesis. Portrait embeddings are repeated along the temporal axis for consistent conditioning across frames. 33 | To further support long-sequence generation, a memory bank module is introduced to maintain temporal consistency, while a dedicated noise-regularized training strategy helps align the memory bank usage between training and inference stages, ensuring stable and high-fidelity generation. 34 | 35 |
36 | image 37 |
38 | 39 | Illustration of three multimodal fusion schemes, our transformer backbone is formed by the left-side blocks. 40 | 41 | (a) **Direct Fusion**. Directly feeding condition into each block's cross-attention module; 42 | 43 | (b) **Siamese Fusion**. Maintaining a similar transformer and feeding the condition into it, extracting the corresponding features to guide the features in the backbone; 44 | 45 | (c) **Symbiotic Fusion**. Concatenating modality with the input at the beginning, then feeding it into the backbone, achieving fusion via the inherent self-attention mechanisms. 46 | 47 | 50 | 51 | 60 | 61 | 62 | 63 | ## 🎫 Citation 64 | If you find this project useful in your research, please consider the citation: 65 | 66 | ```BibTeX 67 | @article{zhang2024efficient, 68 | title={Efficient Long-duration Talking Video Synthesis with Linear Diffusion Transformer under Multimodal Guidance}, 69 | author={Zhang, Haojie and Liang, Zhihao and Fu, Ruibo and Liu, Bingyan and Wen, Zhengqi and Liu, Xuefei and Tao, Jianhua and Liang, Yaling}, 70 | journal={arXiv preprint arXiv:2411.16748}, 71 | year={2024} 72 | } 73 | ``` 74 | -------------------------------------------------------------------------------- /models/audio_proj.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module provides the implementation of an Audio Projection Model, which is designed for 3 | audio processing tasks. The model takes audio embeddings as input and outputs context tokens 4 | that can be used for various downstream applications, such as audio analysis or synthesis. 5 | 6 | The AudioProjModel class is based on the ModelMixin class from the diffusers library, which 7 | provides a foundation for building custom models. This implementation includes multiple linear 8 | layers with ReLU activation functions and a LayerNorm for normalization. 9 | 10 | Key Features: 11 | - Audio embedding input with flexible sequence length and block structure. 12 | - Multiple linear layers for feature transformation. 13 | - ReLU activation for non-linear transformation. 14 | - LayerNorm for stabilizing and speeding up training. 15 | - Rearrangement of input embeddings to match the model's expected input shape. 16 | - Customizable number of blocks, channels, and context tokens for adaptability. 17 | 18 | The module is structured to be easily integrated into larger systems or used as a standalone 19 | component for audio feature extraction and processing. 20 | 21 | Classes: 22 | - AudioProjModel: A class representing the audio projection model with configurable parameters. 23 | 24 | Functions: 25 | - (none) 26 | 27 | Dependencies: 28 | - torch: For tensor operations and neural network components. 29 | - diffusers: For the ModelMixin base class. 30 | - einops: For tensor rearrangement operations. 31 | 32 | """ 33 | 34 | import torch 35 | from diffusers import ModelMixin 36 | from einops import rearrange 37 | from torch import nn 38 | 39 | 40 | class AudioProjModel(ModelMixin): 41 | """Audio Projection Model 42 | 43 | This class defines an audio projection model that takes audio embeddings as input 44 | and produces context tokens as output. The model is based on the ModelMixin class 45 | and consists of multiple linear layers and activation functions. It can be used 46 | for various audio processing tasks. 47 | 48 | Attributes: 49 | seq_len (int): The length of the audio sequence. 50 | blocks (int): The number of blocks in the audio projection model. 51 | channels (int): The number of channels in the audio projection model. 52 | intermediate_dim (int): The intermediate dimension of the model. 53 | context_tokens (int): The number of context tokens in the output. 54 | output_dim (int): The output dimension of the context tokens. 55 | 56 | Methods: 57 | __init__(self, seq_len=5, blocks=12, channels=768, intermediate_dim=512, context_tokens=32, output_dim=768): 58 | Initializes the AudioProjModel with the given parameters. 59 | forward(self, audio_embeds): 60 | Defines the forward pass for the AudioProjModel. 61 | Parameters: 62 | audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels). 63 | Returns: 64 | context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim). 65 | 66 | """ 67 | 68 | def __init__( 69 | self, 70 | seq_len=5, 71 | blocks=12, # add a new parameter blocks 72 | channels=768, # add a new parameter channels 73 | intermediate_dim=512, 74 | output_dim=768, 75 | context_tokens=32, 76 | ): 77 | super().__init__() 78 | 79 | self.seq_len = seq_len 80 | self.blocks = blocks 81 | self.channels = channels 82 | self.input_dim = ( 83 | seq_len * blocks * channels 84 | ) # update input_dim to be the product of blocks and channels. 85 | self.intermediate_dim = intermediate_dim 86 | self.context_tokens = context_tokens 87 | self.output_dim = output_dim 88 | 89 | # define multiple linear layers 90 | self.proj1 = nn.Linear(self.input_dim, intermediate_dim) 91 | self.proj2 = nn.Linear(intermediate_dim, intermediate_dim) 92 | self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim) 93 | 94 | self.norm = nn.LayerNorm(output_dim) 95 | 96 | def forward(self, audio_embeds): 97 | """ 98 | Defines the forward pass for the AudioProjModel. 99 | 100 | Parameters: 101 | audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels). 102 | 103 | Returns: 104 | context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim). 105 | """ 106 | # merge 107 | video_length = audio_embeds.shape[1] 108 | audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c") 109 | batch_size, window_size, blocks, channels = audio_embeds.shape 110 | audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels) 111 | 112 | audio_embeds = torch.relu(self.proj1(audio_embeds)) 113 | audio_embeds = torch.relu(self.proj2(audio_embeds)) 114 | 115 | context_tokens = self.proj3(audio_embeds).reshape( 116 | batch_size, self.context_tokens, self.output_dim 117 | ) 118 | 119 | context_tokens = self.norm(context_tokens) 120 | context_tokens = rearrange( 121 | context_tokens, "(bz f) m c -> bz f m c", f=video_length 122 | ) 123 | 124 | return context_tokens 125 | -------------------------------------------------------------------------------- /diffusion/respace.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | import torch 6 | import numpy as np 7 | import torch as th 8 | 9 | from .gaussian_diffusion import GaussianDiffusion 10 | 11 | 12 | def space_timesteps(num_timesteps, section_counts): 13 | """ 14 | Create a list of timesteps to use from an original diffusion process, 15 | given the number of timesteps we want to take from equally-sized portions 16 | of the original process. 17 | For example, if there's 300 timesteps and the section counts are [10,15,20] 18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 19 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 20 | If the stride is a string starting with "ddim", then the fixed striding 21 | from the DDIM paper is used, and only one section is allowed. 22 | :param num_timesteps: the number of diffusion steps in the original 23 | process to divide up. 24 | :param section_counts: either a list of numbers, or a string containing 25 | comma-separated numbers, indicating the step count 26 | per section. As a special case, use "ddimN" where N 27 | is a number of steps to use the striding from the 28 | DDIM paper. 29 | :return: a set of diffusion steps from the original process to use. 30 | """ 31 | if isinstance(section_counts, str): 32 | if section_counts.startswith("ddim"): 33 | desired_count = int(section_counts[len("ddim") :]) 34 | for i in range(1, num_timesteps): 35 | if len(range(0, num_timesteps, i)) == desired_count: 36 | return set(range(0, num_timesteps, i)) 37 | raise ValueError( 38 | f"cannot create exactly {num_timesteps} steps with an integer stride" 39 | ) 40 | section_counts = [int(x) for x in section_counts.split(",")] 41 | size_per = num_timesteps // len(section_counts) 42 | extra = num_timesteps % len(section_counts) 43 | start_idx = 0 44 | all_steps = [] 45 | for i, section_count in enumerate(section_counts): 46 | size = size_per + (1 if i < extra else 0) 47 | if size < section_count: 48 | raise ValueError( 49 | f"cannot divide section of {size} steps into {section_count}" 50 | ) 51 | if section_count <= 1: 52 | frac_stride = 1 53 | else: 54 | frac_stride = (size - 1) / (section_count - 1) 55 | cur_idx = 0.0 56 | taken_steps = [] 57 | for _ in range(section_count): 58 | taken_steps.append(start_idx + round(cur_idx)) 59 | cur_idx += frac_stride 60 | all_steps += taken_steps 61 | start_idx += size 62 | return set(all_steps) 63 | 64 | 65 | class SpacedDiffusion(GaussianDiffusion): 66 | """ 67 | A diffusion process which can skip steps in a base diffusion process. 68 | :param use_timesteps: a collection (sequence or set) of timesteps from the 69 | original diffusion process to retain. 70 | :param kwargs: the kwargs to create the base diffusion process. 71 | """ 72 | 73 | def __init__(self, use_timesteps, **kwargs): 74 | self.use_timesteps = set(use_timesteps) 75 | self.timestep_map = [] 76 | self.original_num_steps = len(kwargs["betas"]) 77 | 78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 79 | last_alpha_cumprod = 1.0 80 | new_betas = [] 81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 82 | if i in self.use_timesteps: 83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 84 | last_alpha_cumprod = alpha_cumprod 85 | self.timestep_map.append(i) 86 | kwargs["betas"] = np.array(new_betas) 87 | super().__init__(**kwargs) 88 | 89 | def p_mean_variance( 90 | self, model, *args, **kwargs 91 | ): # pylint: disable=signature-differs 92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 93 | 94 | # @torch.compile 95 | def training_losses( 96 | self, model, *args, **kwargs 97 | ): # pylint: disable=signature-differs 98 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 99 | 100 | def condition_mean(self, cond_fn, *args, **kwargs): 101 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 102 | 103 | def condition_score(self, cond_fn, *args, **kwargs): 104 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 105 | 106 | def _wrap_model(self, model): 107 | if isinstance(model, _WrappedModel): 108 | return model 109 | return _WrappedModel( 110 | model, self.timestep_map, self.original_num_steps 111 | ) 112 | 113 | def _scale_timesteps(self, t): 114 | # Scaling is done by the wrapped model. 115 | return t 116 | 117 | 118 | class _WrappedModel: 119 | def __init__(self, model, timestep_map, original_num_steps): 120 | self.model = model 121 | self.timestep_map = timestep_map 122 | # self.rescale_timesteps = rescale_timesteps 123 | self.original_num_steps = original_num_steps 124 | 125 | def __call__(self, x, ts, **kwargs): 126 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 127 | new_ts = map_tensor[ts] 128 | # if self.rescale_timesteps: 129 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 130 | return self.model(x, new_ts, **kwargs) 131 | -------------------------------------------------------------------------------- /sample_pl.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import imageio 4 | import logging 5 | import argparse 6 | from einops import rearrange, repeat 7 | from pytorch_lightning import LightningModule, Trainer 8 | from torch.utils.data import DataLoader 9 | from models import get_models 10 | from models.audio_proj import AudioProjModel 11 | from diffusion import create_diffusion 12 | from diffusers.models import AutoencoderKL 13 | from omegaconf import OmegaConf 14 | from utils import find_model, cleanup 15 | from datasets import get_dataset 16 | 17 | 18 | class LatteSamplingModule(LightningModule): 19 | def __init__(self, args, logger: logging.Logger): 20 | super(LatteSamplingModule, self).__init__() 21 | self.args = args 22 | self.logging = logger 23 | self.model = get_models(args).to(self.device) 24 | self.audioproj = AudioProjModel( 25 | seq_len=5, 26 | blocks=12, 27 | channels=args.audio_dim, 28 | intermediate_dim=512, 29 | output_dim=args.audio_dim, 30 | context_tokens=args.audio_token, 31 | ) 32 | 33 | state_dict, audioproj_dict = find_model(args.pretrained) 34 | self.model.load_state_dict(state_dict) 35 | self.logging.info(f"Loaded model checkpoint from {args.pretrained}") 36 | self.audioproj.load_state_dict(audioproj_dict) 37 | 38 | self.diffusion = create_diffusion(str(args.num_sampling_steps)) 39 | self.vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae").to(self.device) 40 | 41 | self.model.eval() 42 | self.vae.eval() 43 | self.audioproj.eval() 44 | 45 | def sample(self, z, model_kwargs, sample_method='ddpm'): 46 | if sample_method == 'ddim': 47 | return self.diffusion.ddim_sample_loop( 48 | self.model.forward_with_cfg if self.args.cfg_scale > 1.0 else self.model.forward, 49 | z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=self.device 50 | ) 51 | elif sample_method == 'ddpm': 52 | return self.diffusion.p_sample_loop( 53 | self.model.forward_with_cfg if self.args.cfg_scale > 1.0 else self.model.forward, 54 | z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=self.device 55 | ) 56 | else: 57 | raise ValueError(f"Unknown sample method: {sample_method}") 58 | 59 | @torch.no_grad() 60 | def validation_step(self, batch, batch_idx): 61 | video_names = batch["video_name"] 62 | if "latent" in self.args.dataset: 63 | ref_latents = batch["ref_latent"] 64 | else: 65 | ref_image = batch["ref_image"] 66 | ref_latents = self.vae.encode(ref_image).latent_dist.sample().mul_(0.18215) 67 | 68 | ref_latents = repeat(ref_latents, "b c h w -> b f c h w", f=self.args.clip_frames) 69 | model_kwargs = dict(y=None, cond=ref_latents) 70 | 71 | local_batch_size = ref_latents.size(0) 72 | z = torch.randn(local_batch_size, self.args.clip_frames, self.args.in_channels, self.args.latent_size, self.args.latent_size, device=self.device) 73 | if "audio" in batch: 74 | audio_emb = batch["audio"] 75 | audio_emb = self.audioproj(audio_emb) 76 | model_kwargs.update(audio_embed=audio_emb) 77 | 78 | if self.args.cfg_scale > 1.0: 79 | z = torch.cat([z, z], 0) 80 | model_kwargs.update(cfg_scale=self.args.cfg_scale) 81 | 82 | samples = self.sample(z, model_kwargs, sample_method=self.args.sample_method) 83 | 84 | samples = rearrange(samples, 'b f c h w -> (b f) c h w') 85 | samples = self.vae.decode(samples / 0.18215).sample 86 | samples = rearrange(samples, '(b f) c h w -> b f c h w', b=local_batch_size) 87 | 88 | for sample, video_name in zip(samples, video_names): 89 | video_ = ((sample * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1).contiguous() 90 | video_save_path = os.path.join(self.args.save_video_path, f"{video_name}.mp4") 91 | imageio.mimwrite(video_save_path, video_, fps=self.args.fps, quality=9) 92 | self.logging.info(f"Saved video at {video_save_path}") 93 | 94 | return video_save_path 95 | 96 | 97 | def main(args): 98 | # Setup logger 99 | logging.basicConfig(level=logging.INFO, format="[%(asctime)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S") 100 | logger = logging.getLogger(__name__) 101 | 102 | # Determine if the current process is the main process (rank 0) 103 | is_main_process = (int(os.environ.get("LOCAL_RANK", 0)) == 0) 104 | if is_main_process: 105 | os.makedirs(args.save_video_path, exist_ok=True) 106 | print(f"Saving .mp4 samples at {args.save_video_path}") 107 | 108 | # Create dataset and dataloader 109 | dataset = get_dataset(args) 110 | val_loader = DataLoader( 111 | dataset, 112 | batch_size=args.per_proc_batch_size, # Batch size 1 for sampling 113 | shuffle=False, 114 | num_workers=args.num_workers, 115 | pin_memory=True, 116 | ) 117 | logger.info(f"Validation set contains {len(dataset)} samples") 118 | 119 | sample_size = args.image_size // 8 120 | args.latent_size = sample_size 121 | 122 | # Initialize the sampling module 123 | pl_module = LatteSamplingModule(args, logger) 124 | 125 | # Trainer 126 | trainer = Trainer( 127 | accelerator="gpu", 128 | devices=[0], # Specify GPU ids 129 | strategy="auto", 130 | logger=False, 131 | precision=args.precision if args.precision else "32-true", 132 | ) 133 | 134 | # Run validation to generate samples 135 | trainer.validate(pl_module, dataloaders=val_loader) 136 | 137 | logger.info("Sampling completed!") 138 | 139 | if __name__ == "__main__": 140 | parser = argparse.ArgumentParser() 141 | parser.add_argument("--config", type=str, default="./configs/sample.yaml") 142 | args = parser.parse_args() 143 | omega_conf = OmegaConf.load(args.config) 144 | main(omega_conf) -------------------------------------------------------------------------------- /diffusion/timestep_sampler.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from abc import ABC, abstractmethod 7 | 8 | import numpy as np 9 | import torch as th 10 | import torch.distributed as dist 11 | 12 | 13 | def create_named_schedule_sampler(name, diffusion): 14 | """ 15 | Create a ScheduleSampler from a library of pre-defined samplers. 16 | :param name: the name of the sampler. 17 | :param diffusion: the diffusion object to sample for. 18 | """ 19 | if name == "uniform": 20 | return UniformSampler(diffusion) 21 | elif name == "loss-second-moment": 22 | return LossSecondMomentResampler(diffusion) 23 | else: 24 | raise NotImplementedError(f"unknown schedule sampler: {name}") 25 | 26 | 27 | class ScheduleSampler(ABC): 28 | """ 29 | A distribution over timesteps in the diffusion process, intended to reduce 30 | variance of the objective. 31 | By default, samplers perform unbiased importance sampling, in which the 32 | objective's mean is unchanged. 33 | However, subclasses may override sample() to change how the resampled 34 | terms are reweighted, allowing for actual changes in the objective. 35 | """ 36 | 37 | @abstractmethod 38 | def weights(self): 39 | """ 40 | Get a numpy array of weights, one per diffusion step. 41 | The weights needn't be normalized, but must be positive. 42 | """ 43 | 44 | def sample(self, batch_size, device): 45 | """ 46 | Importance-sample timesteps for a batch. 47 | :param batch_size: the number of timesteps. 48 | :param device: the torch device to save to. 49 | :return: a tuple (timesteps, weights): 50 | - timesteps: a tensor of timestep indices. 51 | - weights: a tensor of weights to scale the resulting losses. 52 | """ 53 | w = self.weights() 54 | p = w / np.sum(w) 55 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 56 | indices = th.from_numpy(indices_np).long().to(device) 57 | weights_np = 1 / (len(p) * p[indices_np]) 58 | weights = th.from_numpy(weights_np).float().to(device) 59 | return indices, weights 60 | 61 | 62 | class UniformSampler(ScheduleSampler): 63 | def __init__(self, diffusion): 64 | self.diffusion = diffusion 65 | self._weights = np.ones([diffusion.num_timesteps]) 66 | 67 | def weights(self): 68 | return self._weights 69 | 70 | 71 | class LossAwareSampler(ScheduleSampler): 72 | def update_with_local_losses(self, local_ts, local_losses): 73 | """ 74 | Update the reweighting using losses from a model. 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | :param local_ts: an integer Tensor of timesteps. 80 | :param local_losses: a 1D Tensor of losses. 81 | """ 82 | batch_sizes = [ 83 | th.tensor([0], dtype=th.int32, device=local_ts.device) 84 | for _ in range(dist.get_world_size()) 85 | ] 86 | dist.all_gather( 87 | batch_sizes, 88 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 89 | ) 90 | 91 | # Pad all_gather batches to be the maximum batch size. 92 | batch_sizes = [x.item() for x in batch_sizes] 93 | max_bs = max(batch_sizes) 94 | 95 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 96 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 97 | dist.all_gather(timestep_batches, local_ts) 98 | dist.all_gather(loss_batches, local_losses) 99 | timesteps = [ 100 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 101 | ] 102 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 103 | self.update_with_all_losses(timesteps, losses) 104 | 105 | @abstractmethod 106 | def update_with_all_losses(self, ts, losses): 107 | """ 108 | Update the reweighting using losses from a model. 109 | Sub-classes should override this method to update the reweighting 110 | using losses from the model. 111 | This method directly updates the reweighting without synchronizing 112 | between workers. It is called by update_with_local_losses from all 113 | ranks with identical arguments. Thus, it should have deterministic 114 | behavior to maintain state across workers. 115 | :param ts: a list of int timesteps. 116 | :param losses: a list of float losses, one per timestep. 117 | """ 118 | 119 | 120 | class LossSecondMomentResampler(LossAwareSampler): 121 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 122 | self.diffusion = diffusion 123 | self.history_per_term = history_per_term 124 | self.uniform_prob = uniform_prob 125 | self._loss_history = np.zeros( 126 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 127 | ) 128 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 129 | 130 | def weights(self): 131 | if not self._warmed_up(): 132 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 133 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 134 | weights /= np.sum(weights) 135 | weights *= 1 - self.uniform_prob 136 | weights += self.uniform_prob / len(weights) 137 | return weights 138 | 139 | def update_with_all_losses(self, ts, losses): 140 | for t, loss in zip(ts, losses): 141 | if self._loss_counts[t] == self.history_per_term: 142 | # Shift out the oldest loss term. 143 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 144 | self._loss_history[t, -1] = loss 145 | else: 146 | self._loss_history[t, self._loss_counts[t]] = loss 147 | self._loss_counts[t] += 1 148 | 149 | def _warmed_up(self): 150 | return (self._loss_counts == self.history_per_term).all() 151 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import torch 4 | import argparse 5 | import numpy as np 6 | from tqdm import tqdm 7 | from PIL import Image 8 | from torchvision import transforms 9 | from pathlib import Path 10 | from omegaconf import OmegaConf 11 | from diffusion import create_diffusion 12 | from diffusers.models import AutoencoderKL 13 | from einops import rearrange, repeat 14 | from models import get_models 15 | from models.audio_proj import AudioProjModel 16 | from preparation.audio_processor import AudioProcessor 17 | from utils import find_model, combine_video_audio, tensor_to_video 18 | 19 | 20 | def process_audio_emb(audio_emb): 21 | """ 22 | Process the audio embedding to concatenate with other tensors. 23 | 24 | Parameters: 25 | audio_emb (torch.Tensor): The audio embedding tensor to process. 26 | 27 | Returns: 28 | concatenated_tensors (List[torch.Tensor]): The concatenated tensor list. 29 | """ 30 | concatenated_tensors = [] 31 | 32 | for i in range(audio_emb.shape[0]): 33 | vectors_to_concat = [ 34 | audio_emb[max(min(i + j, audio_emb.shape[0]-1), 0)]for j in range(-2, 3)] 35 | concatenated_tensors.append(torch.stack(vectors_to_concat, dim=0)) 36 | audio_emb = torch.stack(concatenated_tensors, dim=0) 37 | return audio_emb 38 | 39 | 40 | @torch.no_grad() 41 | def main(args): 42 | torch.backends.cuda.matmul.allow_tf32 = True 43 | assert torch.cuda.is_available(), "Sampling requires at least one GPU." 44 | torch.set_grad_enabled(False) 45 | 46 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 47 | 48 | # Load model 49 | latent_size = args.image_size // 8 50 | args.latent_size = latent_size 51 | model = get_models(args).to(device) 52 | state_dict, audioproj_dict = find_model(args.pretrained) 53 | model.load_state_dict(state_dict) 54 | model.eval() 55 | 56 | audioproj = AudioProjModel( 57 | seq_len=5, 58 | blocks=12, 59 | channels=args.audio_dim, 60 | intermediate_dim=512, 61 | output_dim=args.audio_dim, 62 | context_tokens=args.audio_token, 63 | ).to(device) 64 | audioproj.load_state_dict(audioproj_dict) 65 | audioproj.eval() 66 | 67 | sample_rate = args.sample_rate 68 | assert sample_rate == 16000, "audio sample rate must be 16000" 69 | fps = args.fps 70 | wav2vec_model_path = args.wav2vec 71 | audio_separator_model_file = args.audio_separator 72 | 73 | diffusion = create_diffusion(str(args.num_sampling_steps)) 74 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="sd-vae-ft-ema").to(device) 75 | transform = transforms.Compose( 76 | [ 77 | transforms.ToTensor(), 78 | transforms.Resize(args.image_size), 79 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) 80 | ] 81 | ) 82 | 83 | vae.requires_grad_(False) 84 | audioproj.requires_grad_(False) 85 | model.requires_grad_(False) 86 | 87 | # Prepare output directory 88 | os.makedirs(args.save_video_path, exist_ok=True) 89 | 90 | # Iterate through reference image folder 91 | ref_image_folder = args.data_dir 92 | audio_folder = args.audio_dir 93 | ref_image_paths = glob.glob(os.path.join(ref_image_folder, "*.jpg")) 94 | 95 | # Add progress bar for reference image processing 96 | for ref_image_path in ref_image_paths: 97 | audio_path = os.path.join(audio_folder, f"{Path(ref_image_path).stem}.wav") 98 | 99 | if not os.path.exists(audio_path): 100 | print(f"Warning: Audio file not found for {audio_path}") 101 | continue 102 | 103 | ref_name = Path(ref_image_path).stem 104 | video_save_path = os.path.join(args.save_video_path, f"{ref_name}.mp4") 105 | if os.path.exists(video_save_path): 106 | continue 107 | 108 | clip_length = args.clip_frames 109 | with AudioProcessor( 110 | sample_rate, 111 | fps, 112 | wav2vec_model_path, 113 | os.path.dirname(audio_separator_model_file), 114 | os.path.basename(audio_separator_model_file), 115 | os.path.join(args.save_video_path, "audio_preprocess") 116 | ) as audio_processor: 117 | audio_emb, audio_length = audio_processor.preprocess(audio_path, clip_length) 118 | 119 | audio_emb = process_audio_emb(audio_emb) 120 | # Load reference image 121 | ref_image_pil = Image.open(ref_image_path).convert("RGB") 122 | ref_image_np = np.array(ref_image_pil) 123 | ref_image_tensor = transform(ref_image_np).unsqueeze(0).to(device) 124 | ref_latents = vae.encode(ref_image_tensor).latent_dist.sample().mul_(0.18215) 125 | ref_latents = repeat(ref_latents, "b c h w -> b f c h w", f=clip_length) 126 | 127 | times = audio_emb.shape[0] // clip_length 128 | 129 | concat_samples = [] 130 | for t in tqdm(range(times), desc="Processing"): 131 | audio_tensor = audio_emb[ 132 | t * clip_length: min((t + 1) * clip_length, audio_emb.shape[0]) 133 | ] 134 | 135 | audio_tensor = audio_tensor.unsqueeze(0) 136 | audio_tensor = audio_tensor.to( 137 | device=audioproj.device, dtype=audioproj.dtype) 138 | audio_tensor = audioproj(audio_tensor) 139 | 140 | z = torch.randn(1, args.clip_frames, args.in_channels, latent_size, latent_size, device=device) 141 | model_kwargs = dict(y=None, cond=ref_latents, audio_embed=audio_tensor) 142 | if args.cfg_scale > 1.0: 143 | z = torch.cat([z, z], 0) 144 | model_kwargs.update(cfg_scale=args.cfg_scale) 145 | 146 | # Sample images: 147 | if args.sample_method == 'ddim': 148 | samples = diffusion.ddim_sample_loop( 149 | model.forward, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=False, device=device 150 | ) 151 | elif args.sample_method == 'ddpm': 152 | samples = diffusion.p_sample_loop( 153 | model.forward, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=False, device=device 154 | ) 155 | 156 | # Decode and save samples 157 | samples = rearrange(samples, 'b f c h w -> (b f) c h w') 158 | samples = vae.decode(samples / 0.18215).sample.cpu() 159 | concat_samples.append(samples) 160 | 161 | tensor_result = torch.cat(concat_samples) 162 | tensor_result = tensor_result[:audio_length] 163 | tensor_result = ((tensor_result * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1).contiguous() 164 | tensor_to_video(tensor_result, video_save_path, audio_path) 165 | print("Sampling completed!") 166 | 167 | 168 | if __name__ == "__main__": 169 | parser = argparse.ArgumentParser() 170 | parser.add_argument("--config", type=str, default="./configs/sample.yaml") 171 | args = parser.parse_args() 172 | omega_conf = OmegaConf.load(args.config) 173 | main(omega_conf) 174 | -------------------------------------------------------------------------------- /sample_long.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import torch 4 | import argparse 5 | import numpy as np 6 | from tqdm import tqdm 7 | from PIL import Image 8 | from torchvision import transforms 9 | from pathlib import Path 10 | from omegaconf import OmegaConf 11 | from diffusion import create_diffusion 12 | from diffusers.models import AutoencoderKL 13 | from einops import rearrange, repeat 14 | from models import get_models 15 | from models.audio_proj import AudioProjModel 16 | from preparation.audio_processor import AudioProcessor 17 | from utils import find_model, combine_video_audio, tensor_to_video 18 | 19 | 20 | def process_audio_emb(audio_emb): 21 | """ 22 | Process the audio embedding to concatenate with other tensors. 23 | 24 | Parameters: 25 | audio_emb (torch.Tensor): The audio embedding tensor to process. 26 | 27 | Returns: 28 | concatenated_tensors (List[torch.Tensor]): The concatenated tensor list. 29 | """ 30 | concatenated_tensors = [] 31 | 32 | for i in range(audio_emb.shape[0]): 33 | vectors_to_concat = [ 34 | audio_emb[max(min(i + j, audio_emb.shape[0]-1), 0)]for j in range(-2, 3)] 35 | concatenated_tensors.append(torch.stack(vectors_to_concat, dim=0)) 36 | audio_emb = torch.stack(concatenated_tensors, dim=0) 37 | return audio_emb 38 | 39 | 40 | @torch.no_grad() 41 | def main(args): 42 | torch.backends.cuda.matmul.allow_tf32 = True 43 | assert torch.cuda.is_available(), "Sampling requires at least one GPU." 44 | torch.set_grad_enabled(False) 45 | 46 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 47 | 48 | # Load model 49 | latent_size = args.image_size // 8 50 | args.latent_size = latent_size 51 | model = get_models(args).to(device) 52 | state_dict, audioproj_dict = find_model(args.pretrained) 53 | model.load_state_dict(state_dict) 54 | model.eval() 55 | 56 | audioproj = AudioProjModel( 57 | seq_len=5, 58 | blocks=12, 59 | channels=args.audio_dim, 60 | intermediate_dim=512, 61 | output_dim=args.audio_dim, 62 | context_tokens=args.audio_token, 63 | ).to(device) 64 | audioproj.load_state_dict(audioproj_dict) 65 | audioproj.eval() 66 | 67 | sample_rate = args.sample_rate 68 | assert sample_rate == 16000, "audio sample rate must be 16000" 69 | fps = args.fps 70 | wav2vec_model_path = args.wav2vec 71 | audio_separator_model_file = args.audio_separator 72 | 73 | diffusion = create_diffusion(str(args.num_sampling_steps)) 74 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="sd-vae-ft-ema").to(device) 75 | transform = transforms.Compose( 76 | [ 77 | transforms.ToTensor(), 78 | transforms.Resize(args.image_size), 79 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) 80 | ] 81 | ) 82 | 83 | vae.requires_grad_(False) 84 | audioproj.requires_grad_(False) 85 | model.requires_grad_(False) 86 | 87 | # Prepare output directory 88 | os.makedirs(args.save_video_path, exist_ok=True) 89 | 90 | # Iterate through reference image folder 91 | ref_image_folder = args.data_dir 92 | audio_folder = args.audio_dir 93 | ref_image_paths = glob.glob(os.path.join(ref_image_folder, "*.jpg")) 94 | print(f"===== Process folder: {args.data_dir} =====") 95 | 96 | tensor_result = [] 97 | # Add progress bar for reference image processing 98 | for ref_image_path in ref_image_paths: 99 | audio_path = os.path.join(audio_folder, f"{Path(ref_image_path).stem}.wav") 100 | 101 | if not os.path.exists(audio_path): 102 | print(f"Warning: Audio file not found for {audio_path}") 103 | continue 104 | 105 | ref_name = Path(ref_image_path).stem 106 | video_save_path = os.path.join(args.save_video_path, f"{ref_name}.mp4") 107 | if os.path.exists(video_save_path): 108 | print(f"Skip video {video_save_path}") 109 | continue 110 | 111 | clip_length = args.clip_frames 112 | with AudioProcessor( 113 | sample_rate, 114 | fps, 115 | wav2vec_model_path, 116 | os.path.dirname(audio_separator_model_file), 117 | os.path.basename(audio_separator_model_file), 118 | os.path.join(args.save_video_path, "audio_preprocess") 119 | ) as audio_processor: 120 | audio_emb, audio_length = audio_processor.preprocess(audio_path, clip_length) 121 | 122 | audio_emb = process_audio_emb(audio_emb) 123 | # Load reference image 124 | ref_image_pil = Image.open(ref_image_path).convert("RGB") 125 | ref_image_np = np.array(ref_image_pil) 126 | ref_image_tensor = transform(ref_image_np).unsqueeze(0).to(device) 127 | ref_latents = vae.encode(ref_image_tensor).latent_dist.sample().mul_(0.18215) 128 | ref_latents = repeat(ref_latents, "b c h w -> b f c h w", f=clip_length) 129 | 130 | times = audio_emb.shape[0] // clip_length 131 | 132 | concat_samples = [] 133 | for t in tqdm(range(times), desc="Processing"): 134 | audio_tensor = audio_emb[ 135 | t * clip_length: min((t + 1) * clip_length, audio_emb.shape[0]) 136 | ] 137 | 138 | audio_tensor = audio_tensor.unsqueeze(0) 139 | audio_tensor = audio_tensor.to( 140 | device=audioproj.device, dtype=audioproj.dtype) 141 | audio_tensor = audioproj(audio_tensor) 142 | 143 | if len(concat_samples) == 0: 144 | # The first iteration 145 | initial_latents = ref_latents[:, 0:args.initial_frames, ...] 146 | 147 | z = torch.randn(1, args.clip_frames, 4, latent_size, latent_size, device=device) 148 | model_kwargs = dict(y=None, cond=ref_latents, motion=initial_latents, audio_embed=audio_tensor) 149 | if args.cfg_scale > 1.0: 150 | z = torch.cat([z, z], 0) 151 | model_kwargs.update(cfg_scale=args.cfg_scale) 152 | 153 | # Sample images: 154 | if args.sample_method == 'ddim': 155 | samples = diffusion.ddim_sample_loop( 156 | model.forward, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=False, device=device 157 | ) 158 | elif args.sample_method == 'ddpm': 159 | samples = diffusion.p_sample_loop( 160 | model.forward, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=False, device=device 161 | ) 162 | 163 | initial_latents = samples[:, 0 - args.initial_frames:] 164 | 165 | # Decode and save samples 166 | samples = rearrange(samples, 'b f c h w -> (b f) c h w') 167 | samples = vae.decode(samples / 0.18215).sample.cpu() 168 | concat_samples.append(samples) 169 | torch.cuda.empty_cache() 170 | 171 | tensor_result = torch.cat(concat_samples) 172 | tensor_result = tensor_result[:audio_length] 173 | tensor_result = ((tensor_result * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1).contiguous() 174 | tensor_to_video(tensor_result, video_save_path, audio_path) 175 | print("Sampling completed!") 176 | 177 | 178 | if __name__ == "__main__": 179 | parser = argparse.ArgumentParser() 180 | parser.add_argument("--config", type=str, default="./configs/sample.yaml") 181 | parser.add_argument("--data_dir", type=str, default=None) 182 | args = parser.parse_args() 183 | omega_conf = OmegaConf.load(args.config) 184 | if args.data_dir is not None: 185 | omega_conf.data_dir = args.data_dir 186 | main(omega_conf) 187 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import os 12 | import math 13 | import torch 14 | 15 | import numpy as np 16 | import torch.nn as nn 17 | 18 | from einops import repeat 19 | 20 | 21 | ################################################################################# 22 | # Unet Utils # 23 | ################################################################################# 24 | 25 | def checkpoint(func, inputs, params, flag): 26 | """ 27 | Evaluate a function without caching intermediate activations, allowing for 28 | reduced memory at the expense of extra compute in the backward pass. 29 | :param func: the function to evaluate. 30 | :param inputs: the argument sequence to pass to `func`. 31 | :param params: a sequence of parameters `func` depends on but does not 32 | explicitly take as arguments. 33 | :param flag: if False, disable gradient checkpointing. 34 | """ 35 | if flag: 36 | args = tuple(inputs) + tuple(params) 37 | return CheckpointFunction.apply(func, len(inputs), *args) 38 | else: 39 | return func(*inputs) 40 | 41 | 42 | class CheckpointFunction(torch.autograd.Function): 43 | @staticmethod 44 | def forward(ctx, run_function, length, *args): 45 | ctx.run_function = run_function 46 | ctx.input_tensors = list(args[:length]) 47 | ctx.input_params = list(args[length:]) 48 | 49 | with torch.no_grad(): 50 | output_tensors = ctx.run_function(*ctx.input_tensors) 51 | return output_tensors 52 | 53 | @staticmethod 54 | def backward(ctx, *output_grads): 55 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 56 | with torch.enable_grad(): 57 | # Fixes a bug where the first op in run_function modifies the 58 | # Tensor storage in place, which is not allowed for detach()'d 59 | # Tensors. 60 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 61 | output_tensors = ctx.run_function(*shallow_copies) 62 | input_grads = torch.autograd.grad( 63 | output_tensors, 64 | ctx.input_tensors + ctx.input_params, 65 | output_grads, 66 | allow_unused=True, 67 | ) 68 | del ctx.input_tensors 69 | del ctx.input_params 70 | del output_tensors 71 | return (None, None) + input_grads 72 | 73 | 74 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 75 | """ 76 | Create sinusoidal timestep embeddings. 77 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 78 | These may be fractional. 79 | :param dim: the dimension of the output. 80 | :param max_period: controls the minimum frequency of the embeddings. 81 | :return: an [N x dim] Tensor of positional embeddings. 82 | """ 83 | if not repeat_only: 84 | half = dim // 2 85 | freqs = torch.exp( 86 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 87 | ).to(device=timesteps.device) 88 | args = timesteps[:, None].float() * freqs[None] 89 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 90 | if dim % 2: 91 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 92 | else: 93 | embedding = repeat(timesteps, 'b -> b d', d=dim).contiguous() 94 | return embedding 95 | 96 | 97 | def zero_module(module): 98 | """ 99 | Zero out the parameters of a module and return it. 100 | """ 101 | for p in module.parameters(): 102 | p.detach().zero_() 103 | return module 104 | 105 | 106 | def scale_module(module, scale): 107 | """ 108 | Scale the parameters of a module and return it. 109 | """ 110 | for p in module.parameters(): 111 | p.detach().mul_(scale) 112 | return module 113 | 114 | 115 | def mean_flat(tensor): 116 | """ 117 | Take the mean over all non-batch dimensions. 118 | """ 119 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 120 | 121 | 122 | def normalization(channels): 123 | """ 124 | Make a standard normalization layer. 125 | :param channels: number of input channels. 126 | :return: an nn.Module for normalization. 127 | """ 128 | return GroupNorm32(32, channels) 129 | 130 | 131 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 132 | class SiLU(nn.Module): 133 | def forward(self, x): 134 | return x * torch.sigmoid(x) 135 | 136 | 137 | class GroupNorm32(nn.GroupNorm): 138 | def forward(self, x): 139 | return super().forward(x.float()).type(x.dtype) 140 | 141 | def conv_nd(dims, *args, **kwargs): 142 | """ 143 | Create a 1D, 2D, or 3D convolution module. 144 | """ 145 | if dims == 1: 146 | return nn.Conv1d(*args, **kwargs) 147 | elif dims == 2: 148 | return nn.Conv2d(*args, **kwargs) 149 | elif dims == 3: 150 | return nn.Conv3d(*args, **kwargs) 151 | raise ValueError(f"unsupported dimensions: {dims}") 152 | 153 | 154 | def linear(*args, **kwargs): 155 | """ 156 | Create a linear module. 157 | """ 158 | return nn.Linear(*args, **kwargs) 159 | 160 | 161 | def avg_pool_nd(dims, *args, **kwargs): 162 | """ 163 | Create a 1D, 2D, or 3D average pooling module. 164 | """ 165 | if dims == 1: 166 | return nn.AvgPool1d(*args, **kwargs) 167 | elif dims == 2: 168 | return nn.AvgPool2d(*args, **kwargs) 169 | elif dims == 3: 170 | return nn.AvgPool3d(*args, **kwargs) 171 | raise ValueError(f"unsupported dimensions: {dims}") 172 | 173 | 174 | # class HybridConditioner(nn.Module): 175 | 176 | # def __init__(self, c_concat_config, c_crossattn_config): 177 | # super().__init__() 178 | # self.concat_conditioner = instantiate_from_config(c_concat_config) 179 | # self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 180 | 181 | # def forward(self, c_concat, c_crossattn): 182 | # c_concat = self.concat_conditioner(c_concat) 183 | # c_crossattn = self.crossattn_conditioner(c_crossattn) 184 | # return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} 185 | 186 | 187 | def noise_like(shape, device, repeat=False): 188 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 189 | noise = lambda: torch.randn(shape, device=device) 190 | return repeat_noise() if repeat else noise() 191 | 192 | def count_flops_attn(model, _x, y): 193 | """ 194 | A counter for the `thop` package to count the operations in an 195 | attention operation. 196 | Meant to be used like: 197 | macs, params = thop.profile( 198 | model, 199 | inputs=(inputs, timestamps), 200 | custom_ops={QKVAttention: QKVAttention.count_flops}, 201 | ) 202 | """ 203 | b, c, *spatial = y[0].shape 204 | num_spatial = int(np.prod(spatial)) 205 | # We perform two matmuls with the same number of ops. 206 | # The first computes the weight matrix, the second computes 207 | # the combination of the value vectors. 208 | matmul_ops = 2 * b * (num_spatial ** 2) * c 209 | model.total_ops += torch.DoubleTensor([matmul_ops]) 210 | 211 | def count_params(model, verbose=False): 212 | total_params = sum(p.numel() for p in model.parameters()) 213 | if verbose: 214 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 215 | return total_params -------------------------------------------------------------------------------- /sample_long_pl.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import imageio 4 | import logging 5 | import argparse 6 | from tqdm import tqdm 7 | from einops import rearrange, repeat 8 | from diffusion import create_diffusion 9 | from diffusers.models import AutoencoderKL, AutoencoderDC 10 | from omegaconf import OmegaConf 11 | from pytorch_lightning import LightningModule, Trainer 12 | from torch.utils.data import DataLoader 13 | 14 | from models import get_models 15 | from models.audio_proj import AudioProjModel 16 | from utils import find_model 17 | from datasets import get_dataset 18 | 19 | 20 | class LatteSamplingModule(LightningModule): 21 | def __init__(self, args, logger: logging.Logger): 22 | super(LatteSamplingModule, self).__init__() 23 | self.args = args 24 | self.logging = logger 25 | self.model = get_models(args).to(self.device) 26 | self.audioproj = AudioProjModel( 27 | seq_len=5, 28 | blocks=12, 29 | channels=args.audio_dim, 30 | intermediate_dim=512, 31 | output_dim=args.audio_dim, 32 | context_tokens=args.audio_token, 33 | ) 34 | 35 | state_dict, audioproj_dict = find_model(args.pretrained) 36 | self.model.load_state_dict(state_dict) 37 | self.logging.info(f"Loaded model checkpoint from {args.pretrained}") 38 | self.audioproj.load_state_dict(audioproj_dict) 39 | 40 | self.diffusion = create_diffusion(str(args.num_sampling_steps)) 41 | if args.in_channels == 32: 42 | self.vae = AutoencoderDC.from_pretrained(args.pretrained_model_path, subfolder="vae") 43 | else: 44 | self.vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae") 45 | 46 | self.model.eval() 47 | self.vae.eval() 48 | self.audioproj.eval() 49 | self.global_counter = 0 50 | 51 | def sample(self, z, model_kwargs, sample_method='ddpm'): 52 | if sample_method == 'ddim': 53 | return self.diffusion.ddim_sample_loop( 54 | self.model.forward_with_cfg if self.args.cfg_scale > 1.0 else self.model.forward, 55 | z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=self.device 56 | ) 57 | elif sample_method == 'ddpm': 58 | return self.diffusion.p_sample_loop( 59 | self.model.forward_with_cfg if self.args.cfg_scale > 1.0 else self.model.forward, 60 | z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=self.device 61 | ) 62 | else: 63 | raise ValueError(f"Unknown sample method: {sample_method}") 64 | 65 | @torch.no_grad() 66 | def validation_step(self, batch, batch_idx): 67 | video_names = batch["video_name"] 68 | for video_name in video_names: 69 | self.logging.info(f"Processing video {video_name}.") 70 | 71 | audio_emb = batch["audio"] 72 | local_batch_size = audio_emb.size(0) 73 | 74 | if "latent" in self.args.dataset: 75 | ref_latents = batch["ref_latent"] * self.vae.config.scaling_factor 76 | motion_latents = batch["motions"] * self.vae.config.scaling_factor if self.args.initial_frames != 0 else None 77 | else: 78 | ref_image = batch["ref_image"] 79 | if self.args.in_channels == 32: 80 | ref_latents = self.vae.encode(ref_image)[0] * self.vae.config.scaling_factor 81 | else: 82 | ref_latents = self.vae.encode(ref_image).latent_dist.sample() * self.vae.scaling_factor 83 | 84 | motions = batch["motions"] 85 | motions = rearrange(motions, "b f c h w -> (b f) c h w").contiguous() 86 | if self.args.in_channels == 32: 87 | motion_latents = self.vae.encode(motions)[0] * self.vae.config.scaling_factor 88 | else: 89 | motion_latents = self.vae.encode(motions).latent_dist.sample() * self.vae.scaling_factor 90 | motion_latents = rearrange(motion_latents, "(b f) c h w -> b f c h w", b=local_batch_size).contiguous() 91 | 92 | clip_length = self.args.clip_frames 93 | times = audio_emb.size(1) // clip_length 94 | concat_samples = [] 95 | ref_latents = repeat(ref_latents, "b c h w -> b f c h w", f=clip_length) 96 | for t in tqdm(range(times), desc="Processing"): 97 | audio_tensor = audio_emb[:, t * clip_length: (t + 1) * clip_length] 98 | audio_tensor = self.audioproj(audio_tensor) 99 | 100 | if t == 0: 101 | # The first iteration 102 | initial_latents = motion_latents 103 | 104 | z = torch.randn(local_batch_size, clip_length, self.args.in_channels, self.args.latent_size, self.args.latent_size, device=self.device) 105 | model_kwargs = dict(y=None, cond=ref_latents, motion=initial_latents, audio_embed=audio_tensor) 106 | if self.args.cfg_scale > 1.0: 107 | z = torch.cat([z, z], 0) 108 | model_kwargs.update(cfg_scale=self.args.cfg_scale) 109 | 110 | samples = self.sample(z, model_kwargs, sample_method=self.args.sample_method) 111 | initial_latents = samples[:, 0 - self.args.initial_frames:] 112 | 113 | samples = rearrange(samples, 'b f c h w -> (b f) c h w') 114 | if self.args.in_channels == 32: 115 | samples = self.vae.decode(samples / self.vae.config.scaling_factor, return_dict=False)[0] 116 | else: 117 | samples = self.vae.decode(samples / 0.18215).sample 118 | samples = rearrange(samples, '(b f) c h w -> b f c h w', b=local_batch_size) 119 | concat_samples.append(samples) 120 | torch.cuda.empty_cache() 121 | 122 | concat_samples = torch.cat(concat_samples, dim=1) 123 | for sample, video_name in zip(concat_samples, video_names): 124 | video_ = ((sample * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1).contiguous() 125 | video_save_path = os.path.join(self.args.save_video_path, f"{video_name}.mp4") 126 | imageio.mimwrite(video_save_path, video_, fps=self.args.fps, quality=9) 127 | self.logging.info(f"Saved video at {video_save_path}") 128 | 129 | return video_save_path 130 | 131 | 132 | def main(args): 133 | # Setup logger 134 | logging.basicConfig(level=logging.INFO, format="[%(asctime)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S") 135 | logger = logging.getLogger(__name__) 136 | 137 | # Determine if the current process is the main process (rank 0) 138 | is_main_process = (int(os.environ.get("LOCAL_RANK", 0)) == 0) 139 | if is_main_process: 140 | os.makedirs(args.save_video_path, exist_ok=True) 141 | print(f"Saving .mp4 samples at {args.save_video_path}") 142 | 143 | # Create dataset and dataloader 144 | dataset = get_dataset(args) 145 | val_loader = DataLoader( 146 | dataset, 147 | batch_size=args.per_proc_batch_size, # Batch size 1 for sampling 148 | shuffle=False, 149 | num_workers=args.num_workers, 150 | pin_memory=True, 151 | ) 152 | logger.info(f"Validation set contains {len(dataset)} samples") 153 | 154 | if args.in_channels == 32: 155 | sample_size = args.image_size // 32 156 | else: 157 | sample_size = args.image_size // 8 158 | args.latent_size = sample_size 159 | 160 | # Initialize the sampling module 161 | pl_module = LatteSamplingModule(args, logger) 162 | 163 | # Trainer 164 | trainer = Trainer( 165 | accelerator="gpu", 166 | devices=[0], # Specify GPU ids 167 | strategy="auto", 168 | logger=False, 169 | precision=args.precision if args.precision else "32-true", 170 | ) 171 | 172 | # Run validation to generate samples 173 | trainer.validate(pl_module, dataloaders=val_loader) 174 | 175 | logger.info("Sampling completed!") 176 | 177 | if __name__ == "__main__": 178 | parser = argparse.ArgumentParser() 179 | parser.add_argument("--config", type=str, default="./configs/sample.yaml") 180 | args = parser.parse_args() 181 | omega_conf = OmegaConf.load(args.config) 182 | main(omega_conf) -------------------------------------------------------------------------------- /datasets/frames_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import time 4 | import torch 5 | import random 6 | import numpy as np 7 | import torchvision.transforms as transforms 8 | from torch.utils.data import Dataset 9 | 10 | 11 | class BaseDataset(Dataset): 12 | def __init__( 13 | self, 14 | configs, 15 | temporal_sample=None, 16 | transform=None, 17 | ): 18 | self.data_dir = configs.data_dir 19 | self.datalists = [d for d in os.listdir(self.data_dir)] 20 | self.frame_interval = configs.frame_interval 21 | self.num_frames = configs.num_frames 22 | self.temporal_sample = temporal_sample 23 | self.transform = transform 24 | 25 | self.audio_dir = configs.audio_dir 26 | self.audio_margin = configs.audio_margin 27 | self.initial_frames = configs.initial_frames 28 | self.slide_window = (configs.in_channels == 4) 29 | 30 | def __len__(self): 31 | return len(self.datalists) 32 | 33 | def get_indices(self, total_frames): 34 | indices = np.arange(total_frames) 35 | # Randomly select frames 36 | start_index, end_index = self.temporal_sample(total_frames - self.audio_margin) 37 | selected_indices = torch.linspace(start_index, end_index - 1, self.num_frames + self.initial_frames, dtype=int) 38 | video_indices = selected_indices[self.initial_frames:] 39 | # start_index = random.randint( 40 | # max(self.initial_frames, self.audio_margin), 41 | # total_frames - self.num_frames - self.audio_margin - 1, 42 | # ) 43 | # selected_indices = indices[start_index - self.initial_frames: start_index + self.num_frames] 44 | # video_indices = torch.from_numpy(selected_indices[self.initial_frames:]) 45 | 46 | # Choose a reference frame from the remaining frames 47 | remaining_indices = np.setdiff1d(indices, selected_indices) 48 | if len(remaining_indices) == 0: 49 | remaining_indices = indices 50 | ref_index = np.random.choice(remaining_indices) 51 | 52 | # Add the reference frame index to the selected_indices 53 | selected_indices_ = np.append(selected_indices, ref_index) 54 | return selected_indices_, video_indices 55 | 56 | def load_audio_emb(self, audio_emb_path, video_indices): 57 | # Extract wav hidden features 58 | audio_emb = torch.load(audio_emb_path) 59 | if self.slide_window: 60 | audio_indices = ( 61 | torch.arange(2 * self.audio_margin + 1) - self.audio_margin 62 | ) # Generates [-2, -1, 0, 1, 2] 63 | center_indices = video_indices.unsqueeze(1) + audio_indices.unsqueeze(0) 64 | try: 65 | audio_tensor = audio_emb[center_indices] 66 | except: 67 | print(audio_emb_path) 68 | print(len(audio_emb)) 69 | else: 70 | audio_tensor = audio_emb[video_indices] 71 | return audio_tensor 72 | 73 | 74 | class VideoFramesDataset(BaseDataset): 75 | def load_images(self, folder_path, selected_indices): 76 | images = [] 77 | for idx in selected_indices: 78 | img_path = os.path.join(folder_path, f"{idx+1:04d}.jpg") 79 | img = cv2.imread(img_path) 80 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 81 | if img is None: 82 | raise FileNotFoundError(f"Image {img_path} not found.") 83 | images.append(img) 84 | return np.array(images) 85 | 86 | def __getitem__(self, index): 87 | sample = self.datalists[index] 88 | video_folder_path = os.path.join(self.data_dir, sample) 89 | audio_emb_path = "{}/{}.pt".format(self.audio_dir, sample) 90 | audio_emb = torch.load(audio_emb_path, weights_only=True) 91 | 92 | # Extract total frames count based on images available 93 | video_frames = len([name for name in os.listdir(video_folder_path) if name.endswith(".jpg")]) 94 | total_frames = min(audio_emb.size(0), video_frames) 95 | 96 | selected_indices, video_indices = self.get_indices(total_frames) 97 | 98 | # Load the selected frames including the reference frame 99 | all_frames = self.load_images(video_folder_path, selected_indices) 100 | all_frames = torch.from_numpy(all_frames).permute(0, 3, 1, 2).contiguous() 101 | 102 | # Apply transformation if it exists 103 | if self.transform: 104 | all_frames = self.transform(all_frames) 105 | 106 | # Separate the reference frame, motions, and image_window 107 | ref_image = all_frames[-1] # The last frame is the reference frame 108 | video_frames = all_frames[:-1] # All frames except the last one 109 | 110 | motions = video_frames[:self.initial_frames] 111 | image_window = video_frames[self.initial_frames:] 112 | 113 | audio_indices = ( 114 | torch.arange(2 * self.audio_margin + 1) - self.audio_margin 115 | ) # Generates [-2, -1, 0, 1, 2] 116 | center_indices = video_indices.unsqueeze(1) + audio_indices.unsqueeze(0) 117 | audio_tensor = audio_emb[center_indices].squeeze(1) 118 | return {"video": image_window, "audio": audio_tensor, "ref_image": ref_image, "motions": motions, "video_name": sample} 119 | 120 | 121 | class FramesLatentDataset(BaseDataset): 122 | def load_latents(self, folder_path, selected_indices): 123 | latents = [] 124 | for idx in selected_indices: 125 | latent_path = os.path.join(folder_path, f"{idx+1:04d}.npz") # Change .pt to .npz 126 | latent_data = np.load(latent_path)['latent'] # Load latent data from .npz file 127 | latent = torch.tensor(latent_data) # Convert numpy array to PyTorch tensor 128 | latents.append(latent) 129 | return torch.stack(latents) # Return stacked latents as a tensor 130 | 131 | def __getitem__(self, index): 132 | sample = self.datalists[index] 133 | video_folder_path = os.path.join(self.data_dir, sample) 134 | audio_emb_path = "{}/{}.pt".format(self.audio_dir, sample) 135 | audio_emb = torch.load(audio_emb_path, weights_only=True) 136 | 137 | # Extract total frames count based on available .npz files 138 | video_frames = len([name for name in os.listdir(video_folder_path) if name.endswith(".npz")]) 139 | total_frames = min(audio_emb.size(0), video_frames) 140 | 141 | selected_indices, video_indices = self.get_indices(total_frames) 142 | 143 | # Load the selected latents including the reference frame 144 | all_latents = self.load_latents(video_folder_path, selected_indices) 145 | 146 | # Separate the reference frame, motions, and image_window 147 | ref_latent = all_latents[-1] # The last frame is the reference frame 148 | video_latents = all_latents[:-1] # All frames except the last one 149 | 150 | motion_latents = video_latents[:self.initial_frames] 151 | latent_window = video_latents[self.initial_frames:] 152 | 153 | audio_indices = ( 154 | torch.arange(2 * self.audio_margin + 1) - self.audio_margin 155 | ) # Generates [-2, -1, 0, 1, 2] 156 | center_indices = video_indices.unsqueeze(1) + audio_indices.unsqueeze(0) 157 | audio_tensor = audio_emb[center_indices].squeeze(1) 158 | return {"video": latent_window, "audio": audio_tensor, "ref_latent": ref_latent, "motions": motion_latents, "video_name": sample} 159 | 160 | 161 | class VideoLatentDataset(BaseDataset): 162 | def load_latents(self, latent_path): 163 | latents = torch.load(latent_path) # Load the entire latent tensor 164 | return latents 165 | 166 | def __getitem__(self, index): 167 | sample = self.datalists[index] 168 | video_latent_path = os.path.join(self.data_dir, sample) 169 | audio_emb_path = "{}/{}.pt".format(self.audio_dir, sample) 170 | audio_emb = torch.load(audio_emb_path, weights_only=True) 171 | 172 | # Load latents and calculate total frames 173 | latents = self.load_latents(video_latent_path) # Latent shape: [num_frames, channels, height, width] 174 | video_frames = latents.shape[0] 175 | total_frames = min(audio_emb.size(0), video_frames) 176 | 177 | selected_indices, video_indices = self.get_indices(total_frames) 178 | 179 | # Select latent frames 180 | all_latents = latents[selected_indices] 181 | 182 | # Separate the reference frame, motions, and image_window 183 | ref_latent = all_latents[-1] # The last frame is the reference frame 184 | video_latents = all_latents[:-1] # All frames except the last one 185 | 186 | motion_latents = video_latents[:self.initial_frames] 187 | latent_window = video_latents[self.initial_frames:] 188 | 189 | audio_indices = ( 190 | torch.arange(2 * self.audio_margin + 1) - self.audio_margin 191 | ) # Generates [-2, -1, 0, 1, 2] 192 | center_indices = video_indices.unsqueeze(1) + audio_indices.unsqueeze(0) 193 | audio_tensor = audio_emb[center_indices].squeeze(1) 194 | return {"video": latent_window, "audio": audio_tensor, "ref_latent": ref_latent, "motions": motion_latents, "video_name": sample} 195 | 196 | 197 | class VideoDubbingDataset(BaseDataset): 198 | def __init__( 199 | self, 200 | configs, 201 | transform=None, 202 | ): 203 | self.data_dir = configs.data_dir 204 | self.datalists = [d for d in os.listdir(self.data_dir)] 205 | self.frame_interval = configs.frame_interval 206 | self.num_frames = configs.num_frames 207 | self.transform = transform 208 | self.audio_dir = configs.audio_dir 209 | self.audio_margin = configs.audio_margin 210 | 211 | def load_images(self, folder_path, selected_indices): 212 | images = [] 213 | for idx in selected_indices: 214 | img_path = os.path.join(folder_path, f"{idx+1:04d}.jpg") 215 | img = cv2.imread(img_path) 216 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 217 | if img is None: 218 | raise FileNotFoundError(f"Image {img_path} not found.") 219 | images.append(img) 220 | return np.array(images) 221 | 222 | def get_two_clip_indices(self, total_frames): 223 | # Calculate the total length needed for two clips and the gap between them 224 | total_needed = 2 * self.num_frames + self.num_frames # 2 clips + gap 225 | 226 | # We need additional margin for audio indices at both ends 227 | total_needed_with_margin = total_needed + 2 * self.audio_margin 228 | 229 | # Check if video is long enough 230 | if total_frames < total_needed_with_margin: 231 | raise ValueError(f"Video has only {total_frames} frames, but need at least {total_needed_with_margin} frames (including audio margin)") 232 | 233 | # Randomly select a starting point that allows both clips to fit with audio margins 234 | max_start = total_frames - total_needed_with_margin 235 | start_frame = random.randint(self.audio_margin, max_start + self.audio_margin) 236 | 237 | # First clip indices (accounting for audio margin) 238 | first_clip_indices = range(start_frame, start_frame + self.num_frames) 239 | 240 | # Second clip starts after first clip + gap 241 | second_clip_start = start_frame + self.num_frames + self.num_frames 242 | second_clip_indices = range(second_clip_start, second_clip_start + self.num_frames) 243 | 244 | # Combine all indices needed for loading 245 | all_indices = list(first_clip_indices) + list(second_clip_indices) 246 | 247 | # Return the combined indices and the indices for each clip 248 | return all_indices, first_clip_indices, second_clip_indices 249 | 250 | def __getitem__(self, index): 251 | sample = self.datalists[index] 252 | video_folder_path = os.path.join(self.data_dir, sample) 253 | 254 | audio_emb_path = "{}/{}.pt".format(self.audio_dir, sample) 255 | audio_emb = torch.load(audio_emb_path, weights_only=True) 256 | 257 | # Extract total frames count based on images available 258 | video_frames = len([name for name in os.listdir(video_folder_path) if name.endswith(".jpg")]) 259 | total_frames = min(audio_emb.size(0), video_frames) 260 | 261 | # Get indices for two clips 262 | all_indices, first_clip_indices, second_clip_indices = self.get_two_clip_indices(total_frames) 263 | 264 | # Load all needed frames 265 | all_frames = self.load_images(video_folder_path, all_indices) 266 | all_frames = torch.from_numpy(all_frames).permute(0, 3, 1, 2).contiguous() 267 | 268 | # Split into two clips 269 | first_clip_frames = all_frames[:self.num_frames] 270 | second_clip_frames = all_frames[self.num_frames:] 271 | 272 | # Apply transformation if it exists 273 | if self.transform: 274 | first_clip_frames = self.transform(first_clip_frames) 275 | second_clip_frames = self.transform(second_clip_frames) 276 | 277 | # Process audio for both clips 278 | def get_audio_tensor(video_indices): 279 | audio_indices = ( 280 | torch.arange(2 * self.audio_margin + 1) - self.audio_margin 281 | ) # Generates [-2, -1, 0, 1, 2] 282 | center_indices = torch.tensor(video_indices).unsqueeze(1) + audio_indices.unsqueeze(0) 283 | return audio_emb[center_indices].squeeze(1) 284 | 285 | first_audio_tensor = get_audio_tensor(first_clip_indices) 286 | second_audio_tensor = get_audio_tensor(second_clip_indices) 287 | 288 | return { 289 | "first_video": first_clip_frames, 290 | "first_audio": first_audio_tensor, 291 | "second_video": second_clip_frames, 292 | "second_audio": second_audio_tensor, 293 | "video_name": sample 294 | } -------------------------------------------------------------------------------- /models/vae/vae.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional, Tuple 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | 8 | from diffusers.utils import BaseOutput, is_torch_version 9 | from diffusers.utils.torch_utils import randn_tensor 10 | from diffusers.models.attention_processor import SpatialNorm 11 | from .unet_causal_3d_blocks import ( 12 | CausalConv3d, 13 | UNetMidBlockCausal3D, 14 | get_down_block3d, 15 | get_up_block3d, 16 | ) 17 | 18 | 19 | @dataclass 20 | class DecoderOutput(BaseOutput): 21 | r""" 22 | Output of decoding method. 23 | 24 | Args: 25 | sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): 26 | The decoded output sample from the last layer of the model. 27 | """ 28 | 29 | sample: torch.FloatTensor 30 | 31 | 32 | class EncoderCausal3D(nn.Module): 33 | r""" 34 | The `EncoderCausal3D` layer of a variational autoencoder that encodes its input into a latent representation. 35 | """ 36 | 37 | def __init__( 38 | self, 39 | in_channels: int = 3, 40 | out_channels: int = 3, 41 | down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D",), 42 | block_out_channels: Tuple[int, ...] = (64,), 43 | layers_per_block: int = 2, 44 | norm_num_groups: int = 32, 45 | act_fn: str = "silu", 46 | double_z: bool = True, 47 | mid_block_add_attention=True, 48 | time_compression_ratio: int = 4, 49 | spatial_compression_ratio: int = 8, 50 | ): 51 | super().__init__() 52 | self.layers_per_block = layers_per_block 53 | 54 | self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1) 55 | self.mid_block = None 56 | self.down_blocks = nn.ModuleList([]) 57 | 58 | # down 59 | output_channel = block_out_channels[0] 60 | for i, down_block_type in enumerate(down_block_types): 61 | input_channel = output_channel 62 | output_channel = block_out_channels[i] 63 | is_final_block = i == len(block_out_channels) - 1 64 | num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio)) 65 | num_time_downsample_layers = int(np.log2(time_compression_ratio)) 66 | 67 | if time_compression_ratio == 4: 68 | add_spatial_downsample = bool(i < num_spatial_downsample_layers) 69 | add_time_downsample = bool( 70 | i >= (len(block_out_channels) - 1 - num_time_downsample_layers) 71 | and not is_final_block 72 | ) 73 | else: 74 | raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.") 75 | 76 | downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1) 77 | downsample_stride_T = (2,) if add_time_downsample else (1,) 78 | downsample_stride = tuple(downsample_stride_T + downsample_stride_HW) 79 | down_block = get_down_block3d( 80 | down_block_type, 81 | num_layers=self.layers_per_block, 82 | in_channels=input_channel, 83 | out_channels=output_channel, 84 | add_downsample=bool(add_spatial_downsample or add_time_downsample), 85 | downsample_stride=downsample_stride, 86 | resnet_eps=1e-6, 87 | downsample_padding=0, 88 | resnet_act_fn=act_fn, 89 | resnet_groups=norm_num_groups, 90 | attention_head_dim=output_channel, 91 | temb_channels=None, 92 | ) 93 | self.down_blocks.append(down_block) 94 | 95 | # mid 96 | self.mid_block = UNetMidBlockCausal3D( 97 | in_channels=block_out_channels[-1], 98 | resnet_eps=1e-6, 99 | resnet_act_fn=act_fn, 100 | output_scale_factor=1, 101 | resnet_time_scale_shift="default", 102 | attention_head_dim=block_out_channels[-1], 103 | resnet_groups=norm_num_groups, 104 | temb_channels=None, 105 | add_attention=mid_block_add_attention, 106 | ) 107 | 108 | # out 109 | self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) 110 | self.conv_act = nn.SiLU() 111 | 112 | conv_out_channels = 2 * out_channels if double_z else out_channels 113 | self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3) 114 | 115 | def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor: 116 | r"""The forward method of the `EncoderCausal3D` class.""" 117 | assert len(sample.shape) == 5, "The input tensor should have 5 dimensions" 118 | 119 | sample = self.conv_in(sample) 120 | 121 | # down 122 | for down_block in self.down_blocks: 123 | sample = down_block(sample) 124 | 125 | # middle 126 | sample = self.mid_block(sample) 127 | 128 | # post-process 129 | sample = self.conv_norm_out(sample) 130 | sample = self.conv_act(sample) 131 | sample = self.conv_out(sample) 132 | 133 | return sample 134 | 135 | 136 | class DecoderCausal3D(nn.Module): 137 | r""" 138 | The `DecoderCausal3D` layer of a variational autoencoder that decodes its latent representation into an output sample. 139 | """ 140 | 141 | def __init__( 142 | self, 143 | in_channels: int = 3, 144 | out_channels: int = 3, 145 | up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D",), 146 | block_out_channels: Tuple[int, ...] = (64,), 147 | layers_per_block: int = 2, 148 | norm_num_groups: int = 32, 149 | act_fn: str = "silu", 150 | norm_type: str = "group", # group, spatial 151 | mid_block_add_attention=True, 152 | time_compression_ratio: int = 4, 153 | spatial_compression_ratio: int = 8, 154 | ): 155 | super().__init__() 156 | self.layers_per_block = layers_per_block 157 | 158 | self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1) 159 | self.mid_block = None 160 | self.up_blocks = nn.ModuleList([]) 161 | 162 | temb_channels = in_channels if norm_type == "spatial" else None 163 | 164 | # mid 165 | self.mid_block = UNetMidBlockCausal3D( 166 | in_channels=block_out_channels[-1], 167 | resnet_eps=1e-6, 168 | resnet_act_fn=act_fn, 169 | output_scale_factor=1, 170 | resnet_time_scale_shift="default" if norm_type == "group" else norm_type, 171 | attention_head_dim=block_out_channels[-1], 172 | resnet_groups=norm_num_groups, 173 | temb_channels=temb_channels, 174 | add_attention=mid_block_add_attention, 175 | ) 176 | 177 | # up 178 | reversed_block_out_channels = list(reversed(block_out_channels)) 179 | output_channel = reversed_block_out_channels[0] 180 | for i, up_block_type in enumerate(up_block_types): 181 | prev_output_channel = output_channel 182 | output_channel = reversed_block_out_channels[i] 183 | is_final_block = i == len(block_out_channels) - 1 184 | num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio)) 185 | num_time_upsample_layers = int(np.log2(time_compression_ratio)) 186 | 187 | if time_compression_ratio == 4: 188 | add_spatial_upsample = bool(i < num_spatial_upsample_layers) 189 | add_time_upsample = bool( 190 | i >= len(block_out_channels) - 1 - num_time_upsample_layers 191 | and not is_final_block 192 | ) 193 | else: 194 | raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.") 195 | 196 | upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1) 197 | upsample_scale_factor_T = (2,) if add_time_upsample else (1,) 198 | upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW) 199 | up_block = get_up_block3d( 200 | up_block_type, 201 | num_layers=self.layers_per_block + 1, 202 | in_channels=prev_output_channel, 203 | out_channels=output_channel, 204 | prev_output_channel=None, 205 | add_upsample=bool(add_spatial_upsample or add_time_upsample), 206 | upsample_scale_factor=upsample_scale_factor, 207 | resnet_eps=1e-6, 208 | resnet_act_fn=act_fn, 209 | resnet_groups=norm_num_groups, 210 | attention_head_dim=output_channel, 211 | temb_channels=temb_channels, 212 | resnet_time_scale_shift=norm_type, 213 | ) 214 | self.up_blocks.append(up_block) 215 | prev_output_channel = output_channel 216 | 217 | # out 218 | if norm_type == "spatial": 219 | self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) 220 | else: 221 | self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) 222 | self.conv_act = nn.SiLU() 223 | self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3) 224 | 225 | self.gradient_checkpointing = False 226 | 227 | def forward( 228 | self, 229 | sample: torch.FloatTensor, 230 | latent_embeds: Optional[torch.FloatTensor] = None, 231 | ) -> torch.FloatTensor: 232 | r"""The forward method of the `DecoderCausal3D` class.""" 233 | assert len(sample.shape) == 5, "The input tensor should have 5 dimensions." 234 | 235 | sample = self.conv_in(sample) 236 | 237 | upscale_dtype = next(iter(self.up_blocks.parameters())).dtype 238 | if self.training and self.gradient_checkpointing: 239 | 240 | def create_custom_forward(module): 241 | def custom_forward(*inputs): 242 | return module(*inputs) 243 | 244 | return custom_forward 245 | 246 | if is_torch_version(">=", "1.11.0"): 247 | # middle 248 | sample = torch.utils.checkpoint.checkpoint( 249 | create_custom_forward(self.mid_block), 250 | sample, 251 | latent_embeds, 252 | use_reentrant=False, 253 | ) 254 | sample = sample.to(upscale_dtype) 255 | 256 | # up 257 | for up_block in self.up_blocks: 258 | sample = torch.utils.checkpoint.checkpoint( 259 | create_custom_forward(up_block), 260 | sample, 261 | latent_embeds, 262 | use_reentrant=False, 263 | ) 264 | else: 265 | # middle 266 | sample = torch.utils.checkpoint.checkpoint( 267 | create_custom_forward(self.mid_block), sample, latent_embeds 268 | ) 269 | sample = sample.to(upscale_dtype) 270 | 271 | # up 272 | for up_block in self.up_blocks: 273 | sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds) 274 | else: 275 | # middle 276 | sample = self.mid_block(sample, latent_embeds) 277 | sample = sample.to(upscale_dtype) 278 | 279 | # up 280 | for up_block in self.up_blocks: 281 | sample = up_block(sample, latent_embeds) 282 | 283 | # post-process 284 | if latent_embeds is None: 285 | sample = self.conv_norm_out(sample) 286 | else: 287 | sample = self.conv_norm_out(sample, latent_embeds) 288 | sample = self.conv_act(sample) 289 | sample = self.conv_out(sample) 290 | 291 | return sample 292 | 293 | 294 | class DiagonalGaussianDistribution(object): 295 | def __init__(self, parameters: torch.Tensor, deterministic: bool = False): 296 | if parameters.ndim == 3: 297 | dim = 2 # (B, L, C) 298 | elif parameters.ndim == 5 or parameters.ndim == 4: 299 | dim = 1 # (B, C, T, H ,W) / (B, C, H, W) 300 | else: 301 | raise NotImplementedError 302 | self.parameters = parameters 303 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim) 304 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 305 | self.deterministic = deterministic 306 | self.std = torch.exp(0.5 * self.logvar) 307 | self.var = torch.exp(self.logvar) 308 | if self.deterministic: 309 | self.var = self.std = torch.zeros_like( 310 | self.mean, device=self.parameters.device, dtype=self.parameters.dtype 311 | ) 312 | 313 | def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor: 314 | # make sure sample is on the same device as the parameters and has same dtype 315 | sample = randn_tensor( 316 | self.mean.shape, 317 | generator=generator, 318 | device=self.parameters.device, 319 | dtype=self.parameters.dtype, 320 | ) 321 | x = self.mean + self.std * sample 322 | return x 323 | 324 | def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor: 325 | if self.deterministic: 326 | return torch.Tensor([0.0]) 327 | else: 328 | reduce_dim = list(range(1, self.mean.ndim)) 329 | if other is None: 330 | return 0.5 * torch.sum( 331 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 332 | dim=reduce_dim, 333 | ) 334 | else: 335 | return 0.5 * torch.sum( 336 | torch.pow(self.mean - other.mean, 2) / other.var 337 | + self.var / other.var 338 | - 1.0 339 | - self.logvar 340 | + other.logvar, 341 | dim=reduce_dim, 342 | ) 343 | 344 | def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor: 345 | if self.deterministic: 346 | return torch.Tensor([0.0]) 347 | logtwopi = np.log(2.0 * np.pi) 348 | return 0.5 * torch.sum( 349 | logtwopi + self.logvar + 350 | torch.pow(sample - self.mean, 2) / self.var, 351 | dim=dims, 352 | ) 353 | 354 | def mode(self) -> torch.Tensor: 355 | return self.mean 356 | -------------------------------------------------------------------------------- /train_pl.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import math 4 | import logging 5 | import argparse 6 | from pytorch_lightning.callbacks import Callback 7 | from pytorch_lightning import LightningModule, Trainer 8 | from pytorch_lightning.strategies import DDPStrategy 9 | from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor 10 | from pytorch_lightning.loggers import TensorBoardLogger 11 | from glob import glob 12 | from omegaconf import OmegaConf 13 | from torch.utils.data import DataLoader 14 | from diffusers import AutoencoderDC 15 | from diffusers.models import AutoencoderKL 16 | from diffusers.optimization import get_scheduler 17 | from copy import deepcopy 18 | from einops import rearrange, repeat 19 | 20 | from models import get_models 21 | from models.audio_proj import AudioProjModel 22 | from datasets import get_dataset 23 | from diffusion import create_diffusion 24 | from utils import ( 25 | update_ema, 26 | requires_grad, 27 | clip_grad_norm_, 28 | cleanup, 29 | find_model 30 | ) 31 | 32 | 33 | class LetstalkTrainingModule(LightningModule): 34 | def __init__(self, args, logger: logging.Logger): 35 | super(LetstalkTrainingModule, self).__init__() 36 | self.args = args 37 | self.logging = logger 38 | self.model = get_models(args) 39 | if args.use_compile: 40 | self.model = torch.compile(self.model) 41 | 42 | self.ema = deepcopy(self.model) 43 | self.audioproj = AudioProjModel( 44 | seq_len=5, 45 | blocks=12, 46 | channels=args.audio_dim, 47 | intermediate_dim=512, 48 | output_dim=args.audio_dim, 49 | context_tokens=args.audio_token, 50 | ) 51 | 52 | # Load pretrained model if specified 53 | if args.pretrained is not None: 54 | self._load_pretrained_parameters(args) 55 | self.logging.info(f"Model Parameters: {sum(p.numel() for p in self.model.parameters()):,}") 56 | 57 | self.diffusion = create_diffusion(timestep_respacing="") 58 | if args.in_channels == 32: 59 | self.vae = AutoencoderDC.from_pretrained(args.pretrained_model_path, subfolder="vae") 60 | else: 61 | self.vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae") 62 | 63 | self.opt = torch.optim.AdamW(self.model.parameters(), lr=1e-4, weight_decay=0) 64 | self.lr_scheduler = None 65 | 66 | # Freeze model 67 | self.vae.requires_grad_(False) 68 | requires_grad(self.ema, False) 69 | 70 | update_ema(self.ema, self.model, decay=0) # Ensure EMA is initialized with synced weights 71 | self.model.train() # important! This enables embedding dropout for classifier-free guidance 72 | self.audioproj.train() 73 | self.ema.eval() 74 | 75 | def _load_pretrained_parameters(self, args): 76 | checkpoint = torch.load(args.pretrained, map_location=lambda storage, loc: storage) 77 | if "ema" in checkpoint: # supports checkpoints from train.py 78 | self.logging.info("Using ema ckpt!") 79 | checkpoint = checkpoint["ema"] 80 | 81 | model_dict = self.model.state_dict() 82 | # 1. filter out unnecessary keys 83 | pretrained_dict = {} 84 | for k, v in checkpoint.items(): 85 | if k in model_dict: 86 | pretrained_dict[k] = v 87 | else: 88 | self.logging.info("Ignoring: {}".format(k)) 89 | self.logging.info(f"Successfully Load {len(pretrained_dict) / len(checkpoint.items()) * 100}% original pretrained model weights ") 90 | 91 | # 2. overwrite entries in the existing state dict 92 | model_dict.update(pretrained_dict) 93 | self.model.load_state_dict(model_dict) 94 | self.logging.info(f"Successfully load model at {args.pretrained}!") 95 | 96 | if "audioproj" in checkpoint.keys(): 97 | audioproj_dict = checkpoint["audioproj"] 98 | self.audioproj.load_state_dict(audioproj_dict) 99 | 100 | # def on_load_checkpoint(self, checkpoint): 101 | # file_name = args.pretrained.split("/")[-1].split('.')[0] 102 | # if file_name.isdigit(): 103 | # self.global_step = int(file_name) 104 | 105 | def add_noise_to_image(self, images): 106 | image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(images.size(0),), device=self.device) 107 | image_noise_sigma = torch.exp(image_noise_sigma) 108 | noisy_images = images + torch.randn_like(images) * image_noise_sigma[:, None, None, None] 109 | return noisy_images 110 | 111 | def add_noise( 112 | self, 113 | original_samples: torch.Tensor, 114 | noise: torch.Tensor, 115 | timesteps: torch.IntTensor, 116 | ) -> torch.Tensor: 117 | # Make sure alphas_cumprod and timestep have same device and dtype as original_samples 118 | # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement 119 | # for the subsequent add_noise calls 120 | alphas_cumprod = torch.from_numpy(self.diffusion.alphas_cumprod) 121 | alphas_cumprod = alphas_cumprod.to(device=original_samples.device) 122 | alphas_cumprod = alphas_cumprod.to(dtype=original_samples.dtype) 123 | timesteps = timesteps.to(original_samples.device) 124 | 125 | sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 126 | sqrt_alpha_prod = sqrt_alpha_prod.flatten() 127 | while len(sqrt_alpha_prod.shape) < len(original_samples.shape): 128 | sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) 129 | 130 | sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 131 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() 132 | while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): 133 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) 134 | 135 | noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise 136 | return noisy_samples 137 | 138 | def training_step(self, batch, batch_idx): 139 | if "latent" in self.args.dataset: 140 | latents = batch["video"] * self.vae.config.scaling_factor 141 | ref_latents = batch["ref_latent"] * self.vae.config.scaling_factor 142 | motion_latents = batch["motions"] * self.vae.config.scaling_factor if self.args.initial_frames != 0 else None 143 | else: 144 | x = batch["video"] 145 | ref_image = batch["ref_image"] 146 | 147 | with torch.no_grad(): 148 | b, _, _, _, _ = x.shape 149 | x = rearrange(x, "b f c h w -> (b f) c h w").contiguous() 150 | if self.args.in_channels == 32: 151 | latents = self.vae.encode(x)[0] * self.vae.config.scaling_factor 152 | else: 153 | latents = self.vae.encode(x).latent_dist.sample() * self.vae.scaling_factor 154 | latents = rearrange(latents, "(b f) c h w -> b f c h w", b=b).contiguous() 155 | 156 | ref_image = self.add_noise_to_image(ref_image) 157 | if self.args.in_channels == 32: 158 | ref_latents = self.vae.encode(ref_image)[0] * self.vae.config.scaling_factor 159 | else: 160 | ref_latents = self.vae.encode(ref_image).latent_dist.sample().mul_(0.18215) 161 | 162 | if self.args.initial_frames != 0: 163 | motions = batch["motions"] 164 | motions = rearrange(motions, "b f c h w -> (b f) c h w").contiguous() 165 | motions = self.add_noise_to_image(motions) 166 | if self.args.in_channels == 32: 167 | motion_latents = self.vae.encode(motions)[0] * self.vae.config.scaling_factor 168 | else: 169 | motion_latents = self.vae.encode(motions).latent_dist.sample() * self.vae.scaling_factor 170 | motion_latents = rearrange(motion_latents, "(b f) c h w -> b f c h w", b=b).contiguous() 171 | 172 | ref_latents = repeat(ref_latents, "b c h w -> b f c h w", f=latents.size(1)) 173 | model_kwargs = dict(y=None, cond=ref_latents) 174 | if self.args.initial_frames != 0: 175 | motion_timesteps = torch.randint(0, 50, (latents.shape[0],), device=latents.device).long() 176 | motion_noise = torch.randn_like(motion_latents) 177 | # add motion noise 178 | noisy_motion_latents = self.add_noise( 179 | motion_latents, motion_noise, motion_timesteps 180 | ) 181 | 182 | b, f, c, h, w = noisy_motion_latents.shape 183 | rand_mask = torch.rand(h, w).to(device=noisy_motion_latents.device) 184 | mask = rand_mask > 0.25 185 | mask = mask.unsqueeze(0).unsqueeze(0).unsqueeze(0) 186 | mask = mask.expand(b, f, c, h, w) 187 | noisy_motion_latents = noisy_motion_latents * mask 188 | 189 | model_kwargs.update(motion=noisy_motion_latents) 190 | 191 | if "audio" in batch: 192 | audio_emb = batch["audio"] 193 | audio_emb = self.audioproj(audio_emb) 194 | model_kwargs.update(audio_embed=audio_emb) 195 | 196 | timesteps = torch.randint(0, self.diffusion.num_timesteps, (latents.shape[0],), device=self.device) 197 | loss_dict = self.diffusion.training_losses(self.model, latents, timesteps, model_kwargs) 198 | loss = loss_dict["loss"].mean() 199 | 200 | if self.global_step < self.args.start_clip_iter: 201 | gradient_norm = clip_grad_norm_(self.model.parameters(), self.args.clip_max_norm, clip_grad=False) 202 | else: 203 | gradient_norm = clip_grad_norm_(self.model.parameters(), self.args.clip_max_norm, clip_grad=True) 204 | 205 | self.log("train_loss", loss, prog_bar=True) 206 | self.log("gradient_norm", gradient_norm, prog_bar=True) 207 | self.log("train_step", self.global_step) 208 | 209 | if (self.global_step+1) % self.args.log_every == 0: 210 | self.logging.info( 211 | f"(step={self.global_step+1:07d}/epoch={self.current_epoch:04d}) Train Loss: {loss:.4f}, Gradient Norm: {gradient_norm:.4f}" 212 | ) 213 | return loss 214 | 215 | def on_train_batch_end(self, *args, **kwargs): 216 | update_ema(self.ema, self.model) 217 | 218 | def on_save_checkpoint(self, checkpoint): 219 | super().on_save_checkpoint(checkpoint) 220 | checkpoint_dir = self.trainer.checkpoint_callback.dirpath 221 | epoch = self.trainer.current_epoch 222 | step = self.trainer.global_step 223 | checkpoint = { 224 | "model": self.model.state_dict(), 225 | "ema": self.ema.state_dict(), 226 | "audioproj": self.audioproj.state_dict(), 227 | } 228 | torch.save(checkpoint, f"{checkpoint_dir}/last.ckpt") 229 | if step % self.args.ckpt_every == 0: 230 | torch.save(checkpoint, f"{checkpoint_dir}/epoch{epoch}-step{step}.ckpt") 231 | 232 | def configure_optimizers(self): 233 | self.lr_scheduler = get_scheduler( 234 | name="constant", 235 | optimizer=self.opt, 236 | num_warmup_steps=self.args.lr_warmup_steps * self.args.gradient_accumulation_steps, 237 | num_training_steps=self.args.max_train_steps * self.args.gradient_accumulation_steps, 238 | ) 239 | return [self.opt], [self.lr_scheduler] 240 | 241 | 242 | def create_logger(logging_dir): 243 | logging.basicConfig( 244 | level=logging.INFO, 245 | format="[%(asctime)s] %(message)s", 246 | datefmt="%Y-%m-%d %H:%M:%S", 247 | handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] 248 | ) 249 | logger = logging.getLogger(__name__) 250 | return logger 251 | 252 | 253 | def create_experiment_directory(args): 254 | os.makedirs(args.results_dir, exist_ok=True) # Make results folder (holds all experiment subfolders) 255 | experiment_index = len(glob(os.path.join(args.results_dir, "*"))) 256 | model_string_name = args.model.replace("/", "-") # e.g., Letstalk-L/2 --> Letstalk-L-2 (for naming folders) 257 | num_frame_string = f"F{args.num_frames}S{args.frame_interval}" 258 | experiment_dir = os.path.join( # Create an experiment folder 259 | args.results_dir, 260 | f"{experiment_index:03d}-{model_string_name}-{num_frame_string}-{args.dataset}" 261 | ) 262 | checkpoint_dir = os.path.join(experiment_dir, "checkpoints") # Stores saved model checkpoints 263 | os.makedirs(checkpoint_dir, exist_ok=True) 264 | 265 | return experiment_dir, checkpoint_dir 266 | 267 | 268 | def main(args): 269 | seed = args.global_seed 270 | torch.manual_seed(seed) 271 | 272 | # Determine if the current process is the main process (rank 0) 273 | is_main_process = (int(os.environ.get("LOCAL_RANK", 0)) == 0) 274 | # Setup an experiment folder and logger only if main process 275 | if is_main_process: 276 | experiment_dir, checkpoint_dir = create_experiment_directory(args) 277 | logger = create_logger(experiment_dir) 278 | OmegaConf.save(args, os.path.join(experiment_dir, "config.yaml")) 279 | logger.info(f"Experiment directory created at {experiment_dir}") 280 | else: 281 | experiment_dir = os.getenv("EXPERIMENT_DIR", "default_path") 282 | checkpoint_dir = os.getenv("CHECKPOINT_DIR", "default_path") 283 | logger = logging.getLogger(__name__) 284 | logger.addHandler(logging.NullHandler()) 285 | tb_logger = TensorBoardLogger(experiment_dir, name="letstalk") 286 | 287 | # Create the dataset and dataloader 288 | dataset = get_dataset(args) 289 | loader = DataLoader( 290 | dataset, 291 | batch_size=args.local_batch_size, 292 | shuffle=True, 293 | num_workers=args.num_workers, 294 | pin_memory=True, 295 | drop_last=True 296 | ) 297 | if is_main_process: 298 | logger.info(f"Dataset contains {len(dataset)} videos ({args.data_dir})") 299 | 300 | if args.in_channels == 32: 301 | sample_size = args.image_size // 32 302 | else: 303 | sample_size = args.image_size // 8 304 | args.latent_size = sample_size 305 | 306 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 307 | num_update_steps_per_epoch = math.ceil(len(loader)) // torch.cuda.device_count() 308 | # Afterwards we recalculate our number of training epochs 309 | num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 310 | # In multi GPUs mode, the real batchsize is local_batch_size * GPU numbers 311 | if is_main_process: 312 | logger.info(f"One epoch iteration {num_update_steps_per_epoch} steps") 313 | logger.info(f"Num train epochs: {num_train_epochs}") 314 | 315 | # Initialize the training module 316 | pl_module = LetstalkTrainingModule(args, logger) 317 | 318 | checkpoint_callback = ModelCheckpoint( 319 | dirpath=checkpoint_dir, 320 | filename="{epoch}-{step}-{train_loss:.2f}-{gradient_norm:.2f}", 321 | save_top_k=3, 322 | every_n_train_steps=args.ckpt_every, 323 | monitor="train_step", 324 | mode="max" 325 | ) 326 | 327 | # Trainer 328 | trainer = Trainer( 329 | accelerator="gpu", 330 | # devices=[0], 331 | strategy="auto", 332 | max_epochs=num_train_epochs, 333 | logger=tb_logger, 334 | callbacks=[checkpoint_callback, LearningRateMonitor()], 335 | precision=args.precision if args.precision else "32-true", 336 | ) 337 | 338 | trainer.fit(pl_module, train_dataloaders=loader, ckpt_path=args.resume_from_checkpoint if 339 | args.resume_from_checkpoint else None) 340 | 341 | pl_module.model.eval() 342 | cleanup() 343 | if is_main_process: 344 | logger.info("Done!") 345 | 346 | 347 | if __name__ == "__main__": 348 | parser = argparse.ArgumentParser() 349 | parser.add_argument("--config", type=str, default="./configs/train.yaml") 350 | args = parser.parse_args() 351 | main(OmegaConf.load(args.config)) -------------------------------------------------------------------------------- /preparation/audio_processor.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=C0301 2 | ''' 3 | This module contains the AudioProcessor class and related functions for processing audio data. 4 | It utilizes various libraries and models to perform tasks such as preprocessing, feature extraction, 5 | and audio separation. The class is initialized with configuration parameters and can process 6 | audio files using the provided models. 7 | ''' 8 | import math 9 | import os 10 | import subprocess 11 | 12 | import librosa 13 | import numpy as np 14 | import torch 15 | import torch.nn.functional as F 16 | from audio_separator.separator import Separator 17 | from einops import rearrange 18 | from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model 19 | from transformers.modeling_outputs import BaseModelOutput 20 | 21 | 22 | class Wav2VecModel(Wav2Vec2Model): 23 | """ 24 | Wav2VecModel is a custom model class that extends the Wav2Vec2Model class from the transformers library. 25 | It inherits all the functionality of the Wav2Vec2Model and adds additional methods for feature extraction and encoding. 26 | ... 27 | 28 | Attributes: 29 | base_model (Wav2Vec2Model): The base Wav2Vec2Model object. 30 | 31 | Methods: 32 | forward(input_values, seq_len, attention_mask=None, mask_time_indices=None 33 | , output_attentions=None, output_hidden_states=None, return_dict=None): 34 | Forward pass of the Wav2VecModel. 35 | It takes input_values, seq_len, and other optional parameters as input and returns the output of the base model. 36 | 37 | feature_extract(input_values, seq_len): 38 | Extracts features from the input_values using the base model. 39 | 40 | encode(extract_features, attention_mask=None, mask_time_indices=None, output_attentions=None, output_hidden_states=None, return_dict=None): 41 | Encodes the extracted features using the base model and returns the encoded features. 42 | """ 43 | def forward( 44 | self, 45 | input_values, 46 | seq_len, 47 | attention_mask=None, 48 | mask_time_indices=None, 49 | output_attentions=None, 50 | output_hidden_states=None, 51 | return_dict=None, 52 | ): 53 | """ 54 | Forward pass of the Wav2Vec model. 55 | 56 | Args: 57 | self: The instance of the model. 58 | input_values: The input values (waveform) to the model. 59 | seq_len: The sequence length of the input values. 60 | attention_mask: Attention mask to be used for the model. 61 | mask_time_indices: Mask indices to be used for the model. 62 | output_attentions: If set to True, returns attentions. 63 | output_hidden_states: If set to True, returns hidden states. 64 | return_dict: If set to True, returns a BaseModelOutput instead of a tuple. 65 | 66 | Returns: 67 | The output of the Wav2Vec model. 68 | """ 69 | self.config.output_attentions = True 70 | 71 | output_hidden_states = ( 72 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 73 | ) 74 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 75 | 76 | extract_features = self.feature_extractor(input_values) 77 | extract_features = extract_features.transpose(1, 2) 78 | extract_features = linear_interpolation(extract_features, seq_len=seq_len) 79 | 80 | if attention_mask is not None: 81 | # compute reduced attention_mask corresponding to feature vectors 82 | attention_mask = self._get_feature_vector_attention_mask( 83 | extract_features.shape[1], attention_mask, add_adapter=False 84 | ) 85 | 86 | hidden_states, extract_features = self.feature_projection(extract_features) 87 | hidden_states = self._mask_hidden_states( 88 | hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask 89 | ) 90 | 91 | encoder_outputs = self.encoder( 92 | hidden_states, 93 | attention_mask=attention_mask, 94 | output_attentions=output_attentions, 95 | output_hidden_states=output_hidden_states, 96 | return_dict=return_dict, 97 | ) 98 | 99 | hidden_states = encoder_outputs[0] 100 | 101 | if self.adapter is not None: 102 | hidden_states = self.adapter(hidden_states) 103 | 104 | if not return_dict: 105 | return (hidden_states, ) + encoder_outputs[1:] 106 | return BaseModelOutput( 107 | last_hidden_state=hidden_states, 108 | hidden_states=encoder_outputs.hidden_states, 109 | attentions=encoder_outputs.attentions, 110 | ) 111 | 112 | 113 | def feature_extract( 114 | self, 115 | input_values, 116 | seq_len, 117 | ): 118 | """ 119 | Extracts features from the input values and returns the extracted features. 120 | 121 | Parameters: 122 | input_values (torch.Tensor): The input values to be processed. 123 | seq_len (torch.Tensor): The sequence lengths of the input values. 124 | 125 | Returns: 126 | extracted_features (torch.Tensor): The extracted features from the input values. 127 | """ 128 | extract_features = self.feature_extractor(input_values) 129 | extract_features = extract_features.transpose(1, 2) 130 | extract_features = linear_interpolation(extract_features, seq_len=seq_len) 131 | 132 | return extract_features 133 | 134 | def encode( 135 | self, 136 | extract_features, 137 | attention_mask=None, 138 | mask_time_indices=None, 139 | output_attentions=None, 140 | output_hidden_states=None, 141 | return_dict=None, 142 | ): 143 | """ 144 | Encodes the input features into the output space. 145 | 146 | Args: 147 | extract_features (torch.Tensor): The extracted features from the audio signal. 148 | attention_mask (torch.Tensor, optional): Attention mask to be used for padding. 149 | mask_time_indices (torch.Tensor, optional): Masked indices for the time dimension. 150 | output_attentions (bool, optional): If set to True, returns the attention weights. 151 | output_hidden_states (bool, optional): If set to True, returns all hidden states. 152 | return_dict (bool, optional): If set to True, returns a BaseModelOutput instead of the tuple. 153 | 154 | Returns: 155 | The encoded output features. 156 | """ 157 | self.config.output_attentions = True 158 | 159 | output_hidden_states = ( 160 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 161 | ) 162 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 163 | 164 | if attention_mask is not None: 165 | # compute reduced attention_mask corresponding to feature vectors 166 | attention_mask = self._get_feature_vector_attention_mask( 167 | extract_features.shape[1], attention_mask, add_adapter=False 168 | ) 169 | 170 | hidden_states, extract_features = self.feature_projection(extract_features) 171 | hidden_states = self._mask_hidden_states( 172 | hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask 173 | ) 174 | 175 | encoder_outputs = self.encoder( 176 | hidden_states, 177 | attention_mask=attention_mask, 178 | output_attentions=output_attentions, 179 | output_hidden_states=output_hidden_states, 180 | return_dict=return_dict, 181 | ) 182 | 183 | hidden_states = encoder_outputs[0] 184 | 185 | if self.adapter is not None: 186 | hidden_states = self.adapter(hidden_states) 187 | 188 | if not return_dict: 189 | return (hidden_states, ) + encoder_outputs[1:] 190 | return BaseModelOutput( 191 | last_hidden_state=hidden_states, 192 | hidden_states=encoder_outputs.hidden_states, 193 | attentions=encoder_outputs.attentions, 194 | ) 195 | 196 | 197 | def linear_interpolation(features, seq_len): 198 | """ 199 | Transpose the features to interpolate linearly. 200 | 201 | Args: 202 | features (torch.Tensor): The extracted features to be interpolated. 203 | seq_len (torch.Tensor): The sequence lengths of the features. 204 | 205 | Returns: 206 | torch.Tensor: The interpolated features. 207 | """ 208 | features = features.transpose(1, 2) 209 | output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear') 210 | return output_features.transpose(1, 2) 211 | 212 | 213 | def resample_audio(input_audio_file: str, output_audio_file: str, sample_rate: int): 214 | p = subprocess.Popen([ 215 | "ffmpeg", "-y", "-v", "error", "-i", input_audio_file, "-ar", str(sample_rate), output_audio_file 216 | ]) 217 | ret = p.wait() 218 | assert ret == 0, "Resample audio failed!" 219 | return output_audio_file 220 | 221 | 222 | class AudioProcessor: 223 | """ 224 | AudioProcessor is a class that handles the processing of audio files. 225 | It takes care of preprocessing the audio files, extracting features 226 | using wav2vec models, and separating audio signals if needed. 227 | 228 | :param sample_rate: Sampling rate of the audio file 229 | :param fps: Frames per second for the extracted features 230 | :param wav2vec_model_path: Path to the wav2vec model 231 | :param audio_separator_model_path: Path to the audio separator model 232 | :param audio_separator_model_name: Name of the audio separator model 233 | :param cache_dir: Directory to cache the intermediate results 234 | :param device: Device to run the processing on 235 | :param only_last_features: Whether to only use the last features 236 | """ 237 | def __init__( 238 | self, 239 | sample_rate, 240 | fps, 241 | wav2vec_model_path, 242 | audio_separator_model_path:str=None, 243 | audio_separator_model_name:str=None, 244 | cache_dir:str='', 245 | only_last_features:bool=False, 246 | device="cuda:0", 247 | ) -> None: 248 | self.sample_rate = sample_rate 249 | self.fps = fps 250 | self.device = device 251 | 252 | self.audio_encoder = Wav2VecModel.from_pretrained(wav2vec_model_path, local_files_only=True).to(device=device) 253 | self.audio_encoder.feature_extractor._freeze_parameters() 254 | self.only_last_features = only_last_features 255 | 256 | if audio_separator_model_name is not None: 257 | try: 258 | os.makedirs(cache_dir, exist_ok=True) 259 | except OSError as _: 260 | print("Fail to create the output cache dir.") 261 | self.audio_separator = Separator( 262 | output_dir=cache_dir, 263 | output_single_stem="vocals", 264 | model_file_dir=audio_separator_model_path, 265 | ) 266 | self.audio_separator.load_model(audio_separator_model_name) 267 | assert self.audio_separator.model_instance is not None, "Fail to load audio separate model." 268 | else: 269 | self.audio_separator=None 270 | print("Use audio directly without vocals seperator.") 271 | 272 | 273 | self.wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_model_path, local_files_only=True) 274 | 275 | 276 | def preprocess(self, wav_file: str, clip_length: int=-1): 277 | """ 278 | Preprocess a WAV audio file by separating the vocals from the background and resampling it to a 16 kHz sample rate. 279 | The separated vocal track is then converted into wav2vec2 for further processing or analysis. 280 | 281 | Args: 282 | wav_file (str): The path to the WAV file to be processed. This file should be accessible and in WAV format. 283 | 284 | Raises: 285 | RuntimeError: Raises an exception if the WAV file cannot be processed. This could be due to issues 286 | such as file not found, unsupported file format, or errors during the audio processing steps. 287 | 288 | Returns: 289 | torch.tensor: Returns an audio embedding as a torch.tensor 290 | """ 291 | if self.audio_separator is not None: 292 | # 1. separate vocals 293 | # TODO: process in memory 294 | outputs = self.audio_separator.separate(wav_file) 295 | if len(outputs) <= 0: 296 | raise RuntimeError("Audio separate failed.") 297 | 298 | vocal_audio_file = outputs[0] 299 | vocal_audio_name, _ = os.path.splitext(vocal_audio_file) 300 | vocal_audio_file = os.path.join(self.audio_separator.output_dir, vocal_audio_file) 301 | vocal_audio_file = resample_audio(vocal_audio_file, os.path.join(self.audio_separator.output_dir, f"{vocal_audio_name}-16k.wav"), self.sample_rate) 302 | else: 303 | vocal_audio_file=wav_file 304 | 305 | # 2. extract wav2vec features 306 | speech_array, sampling_rate = librosa.load(vocal_audio_file, sr=self.sample_rate) 307 | audio_feature = np.squeeze(self.wav2vec_feature_extractor(speech_array, sampling_rate=sampling_rate).input_values) 308 | seq_len = math.ceil(len(audio_feature) / self.sample_rate * self.fps) 309 | audio_length = seq_len 310 | 311 | audio_feature = torch.from_numpy(audio_feature).float().to(device=self.device) 312 | 313 | if clip_length>0 and seq_len % clip_length != 0: 314 | audio_feature = torch.nn.functional.pad(audio_feature, (0, (clip_length - seq_len % clip_length) * (self.sample_rate // self.fps)), 'constant', 0.0) 315 | seq_len += clip_length - seq_len % clip_length 316 | audio_feature = audio_feature.unsqueeze(0) 317 | 318 | with torch.no_grad(): 319 | embeddings = self.audio_encoder(audio_feature, seq_len=seq_len, output_hidden_states=True) 320 | assert len(embeddings) > 0, "Fail to extract audio embedding" 321 | if self.only_last_features: 322 | audio_emb = embeddings.last_hidden_state.squeeze() 323 | else: 324 | audio_emb = torch.stack(embeddings.hidden_states[1:], dim=1).squeeze(0) 325 | audio_emb = rearrange(audio_emb, "b s d -> s b d") 326 | 327 | audio_emb = audio_emb.cpu().detach() 328 | 329 | return audio_emb, audio_length 330 | 331 | def get_embedding(self, wav_file: str): 332 | """preprocess wav audio file convert to embeddings 333 | 334 | Args: 335 | wav_file (str): The path to the WAV file to be processed. This file should be accessible and in WAV format. 336 | 337 | Returns: 338 | torch.tensor: Returns an audio embedding as a torch.tensor 339 | """ 340 | speech_array, sampling_rate = librosa.load( 341 | wav_file, sr=self.sample_rate) 342 | assert sampling_rate == 16000, "The audio sample rate must be 16000" 343 | audio_feature = np.squeeze(self.wav2vec_feature_extractor( 344 | speech_array, sampling_rate=sampling_rate).input_values) 345 | seq_len = math.ceil(len(audio_feature) / self.sample_rate * self.fps) 346 | 347 | audio_feature = torch.from_numpy( 348 | audio_feature).float().to(device=self.device) 349 | audio_feature = audio_feature.unsqueeze(0) 350 | 351 | with torch.no_grad(): 352 | embeddings = self.audio_encoder( 353 | audio_feature, seq_len=seq_len, output_hidden_states=True) 354 | assert len(embeddings) > 0, "Fail to extract audio embedding" 355 | 356 | if self.only_last_features: 357 | audio_emb = embeddings.last_hidden_state.squeeze() 358 | else: 359 | audio_emb = torch.stack( 360 | embeddings.hidden_states[1:], dim=1).squeeze(0) 361 | audio_emb = rearrange(audio_emb, "b s d -> s b d") 362 | 363 | audio_emb = audio_emb.cpu().detach() 364 | 365 | return audio_emb 366 | 367 | def close(self): 368 | """ 369 | TODO: to be implemented 370 | """ 371 | return self 372 | 373 | def __enter__(self): 374 | return self 375 | 376 | def __exit__(self, _exc_type, _exc_val, _exc_tb): 377 | self.close() 378 | -------------------------------------------------------------------------------- /models/basic_modules.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import torch.nn as nn 4 | from timm.models.vision_transformer import Mlp 5 | from typing import Union, Tuple 6 | from collections.abc import Iterable 7 | from itertools import repeat 8 | from torch.utils.checkpoint import checkpoint, checkpoint_sequential 9 | 10 | 11 | __all__ = ["build_act", "get_act_name"] 12 | 13 | # register activation function here 14 | # name: module, kwargs with default values 15 | REGISTERED_ACT_DICT: dict[str, tuple[type, dict[str, any]]] = { 16 | "relu": (nn.ReLU, {"inplace": True}), 17 | "relu6": (nn.ReLU6, {"inplace": True}), 18 | "hswish": (nn.Hardswish, {"inplace": True}), 19 | "hsigmoid": (nn.Hardsigmoid, {"inplace": True}), 20 | "swish": (nn.SiLU, {"inplace": True}), 21 | "silu": (nn.SiLU, {"inplace": True}), 22 | "tanh": (nn.Tanh, {}), 23 | "sigmoid": (nn.Sigmoid, {}), 24 | "gelu": (nn.GELU, {"approximate": "tanh"}), 25 | "mish": (nn.Mish, {"inplace": True}), 26 | "identity": (nn.Identity, {}), 27 | } 28 | 29 | 30 | def build_act(name: Union[str, None], **kwargs) -> Union[nn.Module, None]: 31 | if name in REGISTERED_ACT_DICT: 32 | act_cls, default_args = copy.deepcopy(REGISTERED_ACT_DICT[name]) 33 | for key in default_args: 34 | if key in kwargs: 35 | default_args[key] = kwargs[key] 36 | return act_cls(**default_args) 37 | elif name is None or name.lower() == "none": 38 | return None 39 | else: 40 | raise ValueError(f"do not support: {name}") 41 | 42 | 43 | class LayerNorm2d(nn.LayerNorm): 44 | rmsnorm = False 45 | 46 | def forward(self, x: torch.Tensor) -> torch.Tensor: 47 | out = x if LayerNorm2d.rmsnorm else x - torch.mean(x, dim=1, keepdim=True) 48 | out = out / torch.sqrt(torch.square(out).mean(dim=1, keepdim=True) + self.eps) 49 | if self.elementwise_affine: 50 | out = out * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1) 51 | return out 52 | 53 | def extra_repr(self) -> str: 54 | return f"{self.normalized_shape}, eps={self.eps}, elementwise_affine={self.elementwise_affine}, rmsnorm={self.rmsnorm}" 55 | 56 | 57 | # register normalization function here 58 | # name: module, kwargs with default values 59 | REGISTERED_NORMALIZATION_DICT: dict[str, tuple[type, dict[str, any]]] = { 60 | "bn2d": (nn.BatchNorm2d, {"num_features": None, "eps": 1e-5, "momentum": 0.1, "affine": True}), 61 | "syncbn": (nn.SyncBatchNorm, {"num_features": None, "eps": 1e-5, "momentum": 0.1, "affine": True}), 62 | "ln": (nn.LayerNorm, {"normalized_shape": None, "eps": 1e-5, "elementwise_affine": True}), 63 | "ln2d": (LayerNorm2d, {"normalized_shape": None, "eps": 1e-5, "elementwise_affine": True}), 64 | } 65 | 66 | def build_norm(name="bn2d", num_features=None, affine=True, **kwargs) -> Union[nn.Module, None]: 67 | if name in ["ln", "ln2d"]: 68 | kwargs["normalized_shape"] = num_features 69 | kwargs["elementwise_affine"] = affine 70 | else: 71 | kwargs["num_features"] = num_features 72 | kwargs["affine"] = affine 73 | if name in REGISTERED_NORMALIZATION_DICT: 74 | norm_cls, default_args = copy.deepcopy(REGISTERED_NORMALIZATION_DICT[name]) 75 | for key in default_args: 76 | if key in kwargs: 77 | default_args[key] = kwargs[key] 78 | return norm_cls(**default_args) 79 | elif name is None or name.lower() == "none": 80 | return None 81 | else: 82 | raise ValueError("do not support: %s" % name) 83 | 84 | 85 | def val2list(x: list or tuple or any, repeat_time=1) -> list: # type: ignore 86 | """Repeat `val` for `repeat_time` times and return the list or val if list/tuple.""" 87 | if isinstance(x, (list, tuple)): 88 | return list(x) 89 | return [x for _ in range(repeat_time)] 90 | 91 | 92 | def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple: # type: ignore 93 | """Return tuple with min_len by repeating element at idx_repeat.""" 94 | # convert to list first 95 | x = val2list(x) 96 | 97 | # repeat elements if necessary 98 | if len(x) > 0: 99 | x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))] 100 | 101 | return tuple(x) 102 | 103 | 104 | def get_same_padding(kernel_size: Union[int, Tuple[int, ...]]) -> Union[int, Tuple[int, ...]]: 105 | if isinstance(kernel_size, tuple): 106 | return tuple([get_same_padding(ks) for ks in kernel_size]) 107 | else: 108 | assert kernel_size % 2 > 0, f"kernel size {kernel_size} should be odd number" 109 | return kernel_size // 2 110 | 111 | 112 | def apply_rotary_emb( 113 | x: torch.Tensor, 114 | freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], 115 | use_real: bool = True, 116 | use_real_unbind_dim: int = -1, 117 | ) -> Tuple[torch.Tensor, torch.Tensor]: 118 | """ 119 | Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings 120 | to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are 121 | reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting 122 | tensors contain rotary embeddings and are returned as real tensors. 123 | 124 | Args: 125 | x (`torch.Tensor`): 126 | Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply 127 | freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) 128 | 129 | Returns: 130 | Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. 131 | """ 132 | if use_real: 133 | cos, sin = freqs_cis # [S, D] 134 | cos = cos[None, None] 135 | sin = sin[None, None] 136 | cos, sin = cos.to(x.device), sin.to(x.device) 137 | 138 | if use_real_unbind_dim == -1: 139 | # Used for flux, cogvideox, hunyuan-dit 140 | x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] 141 | x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) 142 | elif use_real_unbind_dim == -2: 143 | # Used for Sana 144 | cos = cos.transpose(-1, -2) 145 | sin = sin.transpose(-1, -2) 146 | x_real, x_imag = x.reshape(*x.shape[:-2], -1, 2, x.shape[-1]).unbind(-2) # [B, H, D//2, S] 147 | x_rotated = torch.stack([-x_imag, x_real], dim=-2).flatten(2, 3) 148 | else: 149 | raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") 150 | 151 | out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) 152 | 153 | return out 154 | else: 155 | # used for lumina 156 | x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) 157 | freqs_cis = freqs_cis.unsqueeze(2) 158 | x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) 159 | 160 | return x_out.type_as(x) 161 | 162 | 163 | def get_same_padding(kernel_size: Union[int, tuple[int, ...]]) -> Union[int, tuple[int, ...]]: 164 | if isinstance(kernel_size, tuple): 165 | return tuple([get_same_padding(ks) for ks in kernel_size]) 166 | else: 167 | assert kernel_size % 2 > 0, f"kernel size {kernel_size} should be odd number" 168 | return kernel_size // 2 169 | 170 | 171 | def auto_grad_checkpoint(module, *args, **kwargs): 172 | if getattr(module, "grad_checkpointing", False): 173 | if isinstance(module, Iterable): 174 | gc_step = module[0].grad_checkpointing_step 175 | return checkpoint_sequential(module, gc_step, *args, **kwargs) 176 | else: 177 | return checkpoint(module, *args, **kwargs) 178 | return module(*args, **kwargs) 179 | 180 | 181 | def checkpoint_sequential(functions, step, input, *args, **kwargs): 182 | 183 | # Hack for keyword-only parameter in a python 2.7-compliant way 184 | preserve = kwargs.pop("preserve_rng_state", True) 185 | if kwargs: 186 | raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)) 187 | 188 | def run_function(start, end, functions): 189 | def forward(input): 190 | for j in range(start, end + 1): 191 | input = functions[j](input, *args) 192 | return input 193 | 194 | return forward 195 | 196 | if isinstance(functions, torch.nn.Sequential): 197 | functions = list(functions.children()) 198 | 199 | # the last chunk has to be non-volatile 200 | end = -1 201 | segment = len(functions) // step 202 | for start in range(0, step * (segment - 1), step): 203 | end = start + step - 1 204 | input = checkpoint(run_function(start, end, functions), input, preserve_rng_state=preserve) 205 | return run_function(end + 1, len(functions) - 1, functions)(input) 206 | 207 | 208 | def _ntuple(n): 209 | def parse(x): 210 | if isinstance(x, Iterable) and not isinstance(x, str): 211 | return x 212 | return tuple(repeat(x, n)) 213 | 214 | return parse 215 | 216 | 217 | to_1tuple = _ntuple(1) 218 | to_2tuple = _ntuple(2) 219 | 220 | 221 | class RMSNorm(torch.nn.Module): 222 | def __init__(self, dim: int, scale_factor=1.0, eps: float = 1e-6): 223 | """ 224 | Initialize the RMSNorm normalization layer. 225 | 226 | Args: 227 | dim (int): The dimension of the input tensor. 228 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. 229 | 230 | Attributes: 231 | eps (float): A small value added to the denominator for numerical stability. 232 | weight (nn.Parameter): Learnable scaling parameter. 233 | 234 | """ 235 | super().__init__() 236 | self.eps = eps 237 | self.weight = nn.Parameter(torch.ones(dim) * scale_factor) 238 | 239 | def _norm(self, x): 240 | """ 241 | Apply the RMSNorm normalization to the input tensor. 242 | 243 | Args: 244 | x (torch.Tensor): The input tensor. 245 | 246 | Returns: 247 | torch.Tensor: The normalized tensor. 248 | 249 | """ 250 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 251 | 252 | def forward(self, x): 253 | """ 254 | Forward pass through the RMSNorm layer. 255 | 256 | Args: 257 | x (torch.Tensor): The input tensor. 258 | 259 | Returns: 260 | torch.Tensor: The output tensor after applying RMSNorm. 261 | 262 | """ 263 | return (self.weight * self._norm(x.float())).type_as(x) 264 | 265 | 266 | 267 | class ConvLayer(nn.Module): 268 | def __init__( 269 | self, 270 | in_dim: int, 271 | out_dim: int, 272 | kernel_size=3, 273 | stride=1, 274 | dilation=1, 275 | groups=1, 276 | padding: Union[int, None] = None, 277 | use_bias=False, 278 | dropout=0.0, 279 | norm="bn2d", 280 | act="relu", 281 | ): 282 | super().__init__() 283 | if padding is None: 284 | padding = get_same_padding(kernel_size) 285 | padding *= dilation 286 | 287 | self.in_dim = in_dim 288 | self.out_dim = out_dim 289 | self.kernel_size = kernel_size 290 | self.stride = stride 291 | self.dilation = dilation 292 | self.groups = groups 293 | self.padding = padding 294 | self.use_bias = use_bias 295 | 296 | self.dropout = nn.Dropout2d(dropout, inplace=False) if dropout > 0 else None 297 | self.conv = nn.Conv2d( 298 | in_dim, 299 | out_dim, 300 | kernel_size=(kernel_size, kernel_size), 301 | stride=(stride, stride), 302 | padding=padding, 303 | dilation=(dilation, dilation), 304 | groups=groups, 305 | bias=use_bias, 306 | ) 307 | self.norm = build_norm(norm, num_features=out_dim) 308 | self.act = build_act(act) 309 | 310 | def forward(self, x: torch.Tensor) -> torch.Tensor: 311 | if self.dropout is not None: 312 | x = self.dropout(x) 313 | x = self.conv(x) 314 | if self.norm: 315 | x = self.norm(x) 316 | if self.act: 317 | x = self.act(x) 318 | return x 319 | 320 | 321 | class GLUMBConv(nn.Module): 322 | def __init__( 323 | self, 324 | in_features: int, 325 | hidden_features: int, 326 | out_feature=None, 327 | kernel_size=3, 328 | stride=1, 329 | padding: Union[int, None] = None, 330 | use_bias=False, 331 | norm=(None, None, None), 332 | act=("silu", "silu", None), 333 | dilation=1, 334 | ): 335 | out_feature = out_feature or in_features 336 | super().__init__() 337 | use_bias = val2tuple(use_bias, 3) 338 | norm = val2tuple(norm, 3) 339 | act = val2tuple(act, 3) 340 | 341 | self.glu_act = build_act(act[1], inplace=False) 342 | self.inverted_conv = ConvLayer( 343 | in_features, 344 | hidden_features * 2, 345 | 1, 346 | use_bias=use_bias[0], 347 | norm=norm[0], 348 | act=act[0], 349 | ) 350 | self.depth_conv = ConvLayer( 351 | hidden_features * 2, 352 | hidden_features * 2, 353 | kernel_size, 354 | stride=stride, 355 | groups=hidden_features * 2, 356 | padding=padding, 357 | use_bias=use_bias[1], 358 | norm=norm[1], 359 | act=None, 360 | dilation=dilation, 361 | ) 362 | self.point_conv = ConvLayer( 363 | hidden_features, 364 | out_feature, 365 | 1, 366 | use_bias=use_bias[2], 367 | norm=norm[2], 368 | act=act[2], 369 | ) 370 | # from IPython import embed; embed(header='debug dilate conv') 371 | 372 | def forward(self, x: torch.Tensor, HW=None) -> torch.Tensor: 373 | B, N, C = x.shape 374 | if HW is None: 375 | H = W = int(N**0.5) 376 | else: 377 | H, W = HW 378 | 379 | x = x.reshape(B, H, W, C).permute(0, 3, 1, 2) 380 | x = self.inverted_conv(x) 381 | x = self.depth_conv(x) 382 | 383 | x, gate = torch.chunk(x, 2, dim=1) 384 | gate = self.glu_act(gate) 385 | x = x * gate 386 | 387 | x = self.point_conv(x) 388 | x = x.reshape(B, C, N).permute(0, 2, 1) 389 | 390 | return x 391 | 392 | 393 | class DWMlp(Mlp): 394 | """MLP as used in Vision Transformer, MLP-Mixer and related networks""" 395 | 396 | def __init__( 397 | self, 398 | in_features, 399 | hidden_features=None, 400 | out_features=None, 401 | act_layer=nn.GELU, 402 | bias=True, 403 | drop=0.0, 404 | kernel_size=3, 405 | stride=1, 406 | dilation=1, 407 | padding=None, 408 | ): 409 | super().__init__( 410 | in_features=in_features, 411 | hidden_features=hidden_features, 412 | out_features=out_features, 413 | act_layer=act_layer, 414 | bias=bias, 415 | drop=drop, 416 | ) 417 | hidden_features = hidden_features or in_features 418 | self.hidden_features = hidden_features 419 | if padding is None: 420 | padding = get_same_padding(kernel_size) 421 | padding *= dilation 422 | 423 | self.conv = nn.Conv2d( 424 | hidden_features, 425 | hidden_features, 426 | kernel_size=(kernel_size, kernel_size), 427 | stride=(stride, stride), 428 | padding=padding, 429 | dilation=(dilation, dilation), 430 | groups=hidden_features, 431 | bias=bias, 432 | ) 433 | 434 | def forward(self, x, HW=None): 435 | B, N, C = x.shape 436 | if HW is None: 437 | H = W = int(N**0.5) 438 | else: 439 | H, W = HW 440 | x = self.fc1(x) 441 | x = self.act(x) 442 | x = self.drop1(x) 443 | x = x.reshape(B, H, W, self.hidden_features).permute(0, 3, 1, 2) 444 | x = self.conv(x) 445 | x = x.reshape(B, self.hidden_features, N).permute(0, 2, 1) 446 | x = self.fc2(x) 447 | x = self.drop2(x) 448 | return x 449 | -------------------------------------------------------------------------------- /datasets/video_transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numbers 4 | from PIL import Image 5 | 6 | def _is_tensor_video_clip(clip): 7 | if not torch.is_tensor(clip): 8 | raise TypeError("clip should be Tensor. Got %s" % type(clip)) 9 | 10 | if not clip.ndimension() == 4: 11 | raise ValueError("clip should be 4D. Got %dD" % clip.dim()) 12 | 13 | return True 14 | 15 | 16 | def center_crop_arr(pil_image, image_size): 17 | """ 18 | Center cropping implementation from ADM. 19 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 20 | """ 21 | while min(*pil_image.size) >= 2 * image_size: 22 | pil_image = pil_image.resize( 23 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 24 | ) 25 | 26 | scale = image_size / min(*pil_image.size) 27 | pil_image = pil_image.resize( 28 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 29 | ) 30 | 31 | arr = np.array(pil_image) 32 | crop_y = (arr.shape[0] - image_size) // 2 33 | crop_x = (arr.shape[1] - image_size) // 2 34 | return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) 35 | 36 | 37 | def crop(clip, i, j, h, w): 38 | """ 39 | Args: 40 | clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) 41 | """ 42 | if len(clip.size()) != 4: 43 | raise ValueError("clip should be a 4D tensor") 44 | return clip[..., i : i + h, j : j + w] 45 | 46 | 47 | def resize(clip, target_size, interpolation_mode): 48 | if len(target_size) != 2: 49 | raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") 50 | return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False) 51 | 52 | def resize_scale(clip, target_size, interpolation_mode): 53 | if len(target_size) != 2: 54 | raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") 55 | H, W = clip.size(-2), clip.size(-1) 56 | scale_ = target_size[0] / min(H, W) 57 | new_H = int(round(H * scale_)) 58 | new_W = int(round(W * scale_)) 59 | return torch.nn.functional.interpolate(clip, size=(new_H, new_W), mode=interpolation_mode, align_corners=False) 60 | 61 | 62 | def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"): 63 | """ 64 | Do spatial cropping and resizing to the video clip 65 | Args: 66 | clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) 67 | i (int): i in (i,j) i.e coordinates of the upper left corner. 68 | j (int): j in (i,j) i.e coordinates of the upper left corner. 69 | h (int): Height of the cropped region. 70 | w (int): Width of the cropped region. 71 | size (tuple(int, int)): height and width of resized clip 72 | Returns: 73 | clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W) 74 | """ 75 | if not _is_tensor_video_clip(clip): 76 | raise ValueError("clip should be a 4D torch.tensor") 77 | clip = crop(clip, i, j, h, w) 78 | clip = resize(clip, size, interpolation_mode) 79 | return clip 80 | 81 | 82 | def center_crop(clip, crop_size): 83 | if not _is_tensor_video_clip(clip): 84 | raise ValueError("clip should be a 4D torch.tensor") 85 | h, w = clip.size(-2), clip.size(-1) 86 | th, tw = crop_size 87 | if h < th or w < tw: 88 | raise ValueError(f"height {h} and width {w} must be no smaller than crop_size ({th, tw})") 89 | 90 | i = int(round((h - th) / 2.0)) 91 | j = int(round((w - tw) / 2.0)) 92 | return crop(clip, i, j, th, tw) 93 | 94 | 95 | def center_crop_using_short_edge(clip): 96 | if not _is_tensor_video_clip(clip): 97 | raise ValueError("clip should be a 4D torch.tensor") 98 | h, w = clip.size(-2), clip.size(-1) 99 | if h < w: 100 | th, tw = h, h 101 | i = 0 102 | j = int(round((w - tw) / 2.0)) 103 | else: 104 | th, tw = w, w 105 | i = int(round((h - th) / 2.0)) 106 | j = 0 107 | return crop(clip, i, j, th, tw) 108 | 109 | 110 | def random_shift_crop(clip): 111 | ''' 112 | Slide along the long edge, with the short edge as crop size 113 | ''' 114 | if not _is_tensor_video_clip(clip): 115 | raise ValueError("clip should be a 4D torch.tensor") 116 | h, w = clip.size(-2), clip.size(-1) 117 | 118 | if h <= w: 119 | long_edge = w 120 | short_edge = h 121 | else: 122 | long_edge = h 123 | short_edge =w 124 | 125 | th, tw = short_edge, short_edge 126 | 127 | i = torch.randint(0, h - th + 1, size=(1,)).item() 128 | j = torch.randint(0, w - tw + 1, size=(1,)).item() 129 | return crop(clip, i, j, th, tw) 130 | 131 | 132 | def to_tensor(clip): 133 | """ 134 | Convert tensor data type from uint8 to float, divide value by 255.0 and 135 | permute the dimensions of clip tensor 136 | Args: 137 | clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) 138 | Return: 139 | clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) 140 | """ 141 | _is_tensor_video_clip(clip) 142 | if not clip.dtype == torch.uint8: 143 | raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype)) 144 | # return clip.float().permute(3, 0, 1, 2) / 255.0 145 | return clip.float() / 255.0 146 | 147 | 148 | def normalize(clip, mean, std, inplace=False): 149 | """ 150 | Args: 151 | clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W) 152 | mean (tuple): pixel RGB mean. Size is (3) 153 | std (tuple): pixel standard deviation. Size is (3) 154 | Returns: 155 | normalized clip (torch.tensor): Size is (T, C, H, W) 156 | """ 157 | if not _is_tensor_video_clip(clip): 158 | raise ValueError("clip should be a 4D torch.tensor") 159 | if not inplace: 160 | clip = clip.clone() 161 | mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device) 162 | # print(mean) 163 | std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device) 164 | clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) 165 | return clip 166 | 167 | 168 | def hflip(clip): 169 | """ 170 | Args: 171 | clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W) 172 | Returns: 173 | flipped clip (torch.tensor): Size is (T, C, H, W) 174 | """ 175 | if not _is_tensor_video_clip(clip): 176 | raise ValueError("clip should be a 4D torch.tensor") 177 | return clip.flip(-1) 178 | 179 | 180 | class RandomCropVideo: 181 | def __init__(self, size): 182 | if isinstance(size, numbers.Number): 183 | self.size = (int(size), int(size)) 184 | else: 185 | self.size = size 186 | 187 | def __call__(self, clip): 188 | """ 189 | Args: 190 | clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) 191 | Returns: 192 | torch.tensor: randomly cropped video clip. 193 | size is (T, C, OH, OW) 194 | """ 195 | i, j, h, w = self.get_params(clip) 196 | return crop(clip, i, j, h, w) 197 | 198 | def get_params(self, clip): 199 | h, w = clip.shape[-2:] 200 | th, tw = self.size 201 | 202 | if h < th or w < tw: 203 | raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}") 204 | 205 | if w == tw and h == th: 206 | return 0, 0, h, w 207 | 208 | i = torch.randint(0, h - th + 1, size=(1,)).item() 209 | j = torch.randint(0, w - tw + 1, size=(1,)).item() 210 | 211 | return i, j, th, tw 212 | 213 | def __repr__(self) -> str: 214 | return f"{self.__class__.__name__}(size={self.size})" 215 | 216 | class CenterCropResizeVideo: 217 | ''' 218 | First use the short side for cropping length, 219 | center crop video, then resize to the specified size 220 | ''' 221 | def __init__( 222 | self, 223 | size, 224 | interpolation_mode="bilinear", 225 | ): 226 | if isinstance(size, tuple): 227 | if len(size) != 2: 228 | raise ValueError(f"size should be tuple (height, width), instead got {size}") 229 | self.size = size 230 | else: 231 | self.size = (size, size) 232 | 233 | self.interpolation_mode = interpolation_mode 234 | 235 | 236 | def __call__(self, clip): 237 | """ 238 | Args: 239 | clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) 240 | Returns: 241 | torch.tensor: scale resized / center cropped video clip. 242 | size is (T, C, crop_size, crop_size) 243 | """ 244 | clip_center_crop = center_crop_using_short_edge(clip) 245 | clip_center_crop_resize = resize(clip_center_crop, target_size=self.size, interpolation_mode=self.interpolation_mode) 246 | return clip_center_crop_resize 247 | 248 | def __repr__(self) -> str: 249 | return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" 250 | 251 | class UCFCenterCropVideo: 252 | ''' 253 | First scale to the specified size in equal proportion to the short edge, 254 | then center cropping 255 | ''' 256 | def __init__( 257 | self, 258 | size, 259 | interpolation_mode="bilinear", 260 | ): 261 | if isinstance(size, tuple): 262 | if len(size) != 2: 263 | raise ValueError(f"size should be tuple (height, width), instead got {size}") 264 | self.size = size 265 | else: 266 | self.size = (size, size) 267 | 268 | self.interpolation_mode = interpolation_mode 269 | 270 | 271 | def __call__(self, clip): 272 | """ 273 | Args: 274 | clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) 275 | Returns: 276 | torch.tensor: scale resized / center cropped video clip. 277 | size is (T, C, crop_size, crop_size) 278 | """ 279 | clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode) 280 | clip_center_crop = center_crop(clip_resize, self.size) 281 | return clip_center_crop 282 | 283 | def __repr__(self) -> str: 284 | return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" 285 | 286 | class KineticsRandomCropResizeVideo: 287 | ''' 288 | Slide along the long edge, with the short edge as crop size. And resie to the desired size. 289 | ''' 290 | def __init__( 291 | self, 292 | size, 293 | interpolation_mode="bilinear", 294 | ): 295 | if isinstance(size, tuple): 296 | if len(size) != 2: 297 | raise ValueError(f"size should be tuple (height, width), instead got {size}") 298 | self.size = size 299 | else: 300 | self.size = (size, size) 301 | 302 | self.interpolation_mode = interpolation_mode 303 | 304 | def __call__(self, clip): 305 | clip_random_crop = random_shift_crop(clip) 306 | clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode) 307 | return clip_resize 308 | 309 | 310 | class CenterCropVideo: 311 | def __init__( 312 | self, 313 | size, 314 | interpolation_mode="bilinear", 315 | ): 316 | if isinstance(size, tuple): 317 | if len(size) != 2: 318 | raise ValueError(f"size should be tuple (height, width), instead got {size}") 319 | self.size = size 320 | else: 321 | self.size = (size, size) 322 | 323 | self.interpolation_mode = interpolation_mode 324 | 325 | 326 | def __call__(self, clip): 327 | """ 328 | Args: 329 | clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) 330 | Returns: 331 | torch.tensor: center cropped video clip. 332 | size is (T, C, crop_size, crop_size) 333 | """ 334 | clip_center_crop = center_crop(clip, self.size) 335 | return clip_center_crop 336 | 337 | def __repr__(self) -> str: 338 | return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" 339 | 340 | 341 | class NormalizeVideo: 342 | """ 343 | Normalize the video clip by mean subtraction and division by standard deviation 344 | Args: 345 | mean (3-tuple): pixel RGB mean 346 | std (3-tuple): pixel RGB standard deviation 347 | inplace (boolean): whether do in-place normalization 348 | """ 349 | 350 | def __init__(self, mean, std, inplace=False): 351 | self.mean = mean 352 | self.std = std 353 | self.inplace = inplace 354 | 355 | def __call__(self, clip): 356 | """ 357 | Args: 358 | clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W) 359 | """ 360 | return normalize(clip, self.mean, self.std, self.inplace) 361 | 362 | def __repr__(self) -> str: 363 | return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})" 364 | 365 | 366 | class ToTensorVideo: 367 | """ 368 | Convert tensor data type from uint8 to float, divide value by 255.0 and 369 | permute the dimensions of clip tensor 370 | """ 371 | 372 | def __init__(self): 373 | pass 374 | 375 | def __call__(self, clip): 376 | """ 377 | Args: 378 | clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) 379 | Return: 380 | clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) 381 | """ 382 | return to_tensor(clip) 383 | 384 | def __repr__(self) -> str: 385 | return self.__class__.__name__ 386 | 387 | 388 | class RandomHorizontalFlipVideo: 389 | """ 390 | Flip the video clip along the horizontal direction with a given probability 391 | Args: 392 | p (float): probability of the clip being flipped. Default value is 0.5 393 | """ 394 | 395 | def __init__(self, p=0.5): 396 | self.p = p 397 | 398 | def __call__(self, clip): 399 | """ 400 | Args: 401 | clip (torch.tensor): Size is (T, C, H, W) 402 | Return: 403 | clip (torch.tensor): Size is (T, C, H, W) 404 | """ 405 | if random.random() < self.p: 406 | clip = hflip(clip) 407 | return clip 408 | 409 | def __repr__(self) -> str: 410 | return f"{self.__class__.__name__}(p={self.p})" 411 | 412 | # ------------------------------------------------------------ 413 | # --------------------- Sampling --------------------------- 414 | # ------------------------------------------------------------ 415 | class TemporalRandomCrop(object): 416 | """Temporally crop the given frame indices at a random location. 417 | 418 | Args: 419 | size (int): Desired length of frames will be seen in the model. 420 | """ 421 | 422 | def __init__(self, size): 423 | self.size = size 424 | 425 | def __call__(self, total_frames): 426 | rand_end = max(0, total_frames - self.size - 1) 427 | begin_index = random.randint(0, rand_end) 428 | end_index = min(begin_index + self.size, total_frames) 429 | return begin_index, end_index 430 | 431 | 432 | if __name__ == '__main__': 433 | from torchvision import transforms 434 | import torchvision.io as io 435 | import numpy as np 436 | from torchvision.utils import save_image 437 | import os 438 | 439 | vframes, aframes, info = io.read_video( 440 | filename='./v_Archery_g01_c03.avi', 441 | pts_unit='sec', 442 | output_format='TCHW' 443 | ) 444 | 445 | trans = transforms.Compose([ 446 | ToTensorVideo(), 447 | RandomHorizontalFlipVideo(), 448 | UCFCenterCropVideo(512), 449 | # NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), 450 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) 451 | ]) 452 | 453 | target_video_len = 32 454 | frame_interval = 1 455 | total_frames = len(vframes) 456 | print(total_frames) 457 | 458 | temporal_sample = TemporalRandomCrop(target_video_len * frame_interval) 459 | 460 | 461 | # Sampling video frames 462 | start_frame_ind, end_frame_ind = temporal_sample(total_frames) 463 | # print(start_frame_ind) 464 | # print(end_frame_ind) 465 | assert end_frame_ind - start_frame_ind >= target_video_len 466 | frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, target_video_len, dtype=int) 467 | print(frame_indice) 468 | 469 | select_vframes = vframes[frame_indice] 470 | print(select_vframes.shape) 471 | print(select_vframes.dtype) 472 | 473 | select_vframes_trans = trans(select_vframes) 474 | print(select_vframes_trans.shape) 475 | print(select_vframes_trans.dtype) 476 | 477 | select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * 255).to(dtype=torch.uint8) 478 | print(select_vframes_trans_int.dtype) 479 | print(select_vframes_trans_int.permute(0, 2, 3, 1).shape) 480 | 481 | io.write_video('./test.avi', select_vframes_trans_int.permute(0, 2, 3, 1), fps=8) 482 | 483 | for i in range(target_video_len): 484 | save_image(select_vframes_trans[i], os.path.join('./test000', '%04d.png' % i), normalize=True, value_range=(-1, 1)) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import logging 5 | import random 6 | import subprocess 7 | import numpy as np 8 | import torch.distributed as dist 9 | 10 | from torch import inf 11 | from PIL import Image 12 | from typing import Union, Iterable 13 | from collections import OrderedDict 14 | from torch.utils.tensorboard import SummaryWriter 15 | 16 | from diffusers.utils import is_bs4_available, is_ftfy_available 17 | 18 | import html 19 | import re 20 | import urllib.parse as ul 21 | from moviepy.editor import VideoFileClip, AudioFileClip, VideoClip 22 | 23 | if is_bs4_available(): 24 | from bs4 import BeautifulSoup 25 | 26 | if is_ftfy_available(): 27 | import ftfy 28 | 29 | _tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]] 30 | 31 | 32 | ################################################################################# 33 | # Training Clip Gradients # 34 | ################################################################################# 35 | 36 | def get_grad_norm( 37 | parameters: _tensor_or_tensors, norm_type: float = 2.0) -> torch.Tensor: 38 | r""" 39 | Copy from torch.nn.utils.clip_grad_norm_ 40 | 41 | Clips gradient norm of an iterable of parameters. 42 | 43 | The norm is computed over all gradients together, as if they were 44 | concatenated into a single vector. Gradients are modified in-place. 45 | 46 | Args: 47 | parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a 48 | single Tensor that will have gradients normalized 49 | max_norm (float or int): max norm of the gradients 50 | norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for 51 | infinity norm. 52 | error_if_nonfinite (bool): if True, an error is thrown if the total 53 | norm of the gradients from :attr:`parameters` is ``nan``, 54 | ``inf``, or ``-inf``. Default: False (will switch to True in the future) 55 | 56 | Returns: 57 | Total norm of the parameter gradients (viewed as a single vector). 58 | """ 59 | if isinstance(parameters, torch.Tensor): 60 | parameters = [parameters] 61 | grads = [p.grad for p in parameters if p.grad is not None] 62 | norm_type = float(norm_type) 63 | if len(grads) == 0: 64 | return torch.tensor(0.) 65 | device = grads[0].device 66 | if norm_type == inf: 67 | norms = [g.detach().abs().max().to(device) for g in grads] 68 | total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms)) 69 | else: 70 | total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type) 71 | return total_norm 72 | 73 | def clip_grad_norm_( 74 | parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0, 75 | error_if_nonfinite: bool = False, clip_grad = True) -> torch.Tensor: 76 | r""" 77 | Copy from torch.nn.utils.clip_grad_norm_ 78 | 79 | Clips gradient norm of an iterable of parameters. 80 | 81 | The norm is computed over all gradients together, as if they were 82 | concatenated into a single vector. Gradients are modified in-place. 83 | 84 | Args: 85 | parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a 86 | single Tensor that will have gradients normalized 87 | max_norm (float or int): max norm of the gradients 88 | norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for 89 | infinity norm. 90 | error_if_nonfinite (bool): if True, an error is thrown if the total 91 | norm of the gradients from :attr:`parameters` is ``nan``, 92 | ``inf``, or ``-inf``. Default: False (will switch to True in the future) 93 | 94 | Returns: 95 | Total norm of the parameter gradients (viewed as a single vector). 96 | """ 97 | if isinstance(parameters, torch.Tensor): 98 | parameters = [parameters] 99 | grads = [p.grad for p in parameters if p.grad is not None] 100 | max_norm = float(max_norm) 101 | norm_type = float(norm_type) 102 | if len(grads) == 0: 103 | return torch.tensor(0.) 104 | device = grads[0].device 105 | if norm_type == inf: 106 | norms = [g.detach().abs().max().to(device) for g in grads] 107 | total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms)) 108 | else: 109 | total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type) 110 | 111 | if clip_grad: 112 | if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): 113 | raise RuntimeError( 114 | f'The total norm of order {norm_type} for gradients from ' 115 | '`parameters` is non-finite, so it cannot be clipped. To disable ' 116 | 'this error and scale the gradients by the non-finite norm anyway, ' 117 | 'set `error_if_nonfinite=False`') 118 | clip_coef = max_norm / (total_norm + 1e-6) 119 | # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so 120 | # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization 121 | # when the gradients do not reside in CPU memory. 122 | clip_coef_clamped = torch.clamp(clip_coef, max=1.0) 123 | for g in grads: 124 | g.detach().mul_(clip_coef_clamped.to(g.device)) 125 | # gradient_cliped = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type) 126 | return total_norm 127 | 128 | def get_experiment_dir(root_dir, args): 129 | # if args.pretrained is not None and 'Latte-XL-2-256x256.pt' not in args.pretrained: 130 | # root_dir += '-WOPRE' 131 | if args.use_compile: 132 | root_dir += '-Compile' # speedup by torch compile 133 | if args.fixed_spatial: 134 | root_dir += '-FixedSpa' 135 | if args.enable_xformers_memory_efficient_attention: 136 | root_dir += '-Xfor' 137 | if args.gradient_checkpointing: 138 | root_dir += '-Gc' 139 | if args.mixed_precision: 140 | root_dir += '-Amp' 141 | if args.image_size == 512: 142 | root_dir += '-512' 143 | return root_dir 144 | 145 | ################################################################################# 146 | # Training Logger # 147 | ################################################################################# 148 | 149 | def create_logger(logging_dir): 150 | """ 151 | Create a logger that writes to a log file and stdout. 152 | """ 153 | if dist.get_rank() == 0: # real logger 154 | logging.basicConfig( 155 | level=logging.INFO, 156 | # format='[\033[34m%(asctime)s\033[0m] %(message)s', 157 | format='[%(asctime)s] %(message)s', 158 | datefmt='%Y-%m-%d %H:%M:%S', 159 | handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] 160 | ) 161 | logger = logging.getLogger(__name__) 162 | 163 | else: # dummy logger (does nothing) 164 | logger = logging.getLogger(__name__) 165 | logger.addHandler(logging.NullHandler()) 166 | return logger 167 | 168 | 169 | def create_tensorboard(tensorboard_dir): 170 | """ 171 | Create a tensorboard that saves losses. 172 | """ 173 | if dist.get_rank() == 0: # real tensorboard 174 | # tensorboard 175 | writer = SummaryWriter(tensorboard_dir) 176 | 177 | return writer 178 | 179 | def write_tensorboard(writer, *args): 180 | ''' 181 | write the loss information to a tensorboard file. 182 | Only for pytorch DDP mode. 183 | ''' 184 | if dist.get_rank() == 0: # real tensorboard 185 | writer.add_scalar(args[0], args[1], args[2]) 186 | 187 | ################################################################################# 188 | # EMA Update/ DDP Training Utils # 189 | ################################################################################# 190 | 191 | @torch.no_grad() 192 | def update_ema(ema_model, model, decay=0.9999): 193 | """ 194 | Step the EMA model towards the current model. 195 | """ 196 | ema_params = OrderedDict(ema_model.named_parameters()) 197 | model_params = OrderedDict(model.named_parameters()) 198 | 199 | for name, param in model_params.items(): 200 | # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed 201 | ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) 202 | 203 | def requires_grad(model, flag=True): 204 | """ 205 | Set requires_grad flag for all parameters in a model. 206 | """ 207 | for p in model.parameters(): 208 | p.requires_grad = flag 209 | 210 | def cleanup(): 211 | """ 212 | End DDP training. 213 | """ 214 | dist.destroy_process_group() 215 | 216 | 217 | def setup_distributed(backend="nccl", port=None): 218 | """Initialize distributed training environment. 219 | support both slurm and torch.distributed.launch 220 | see torch.distributed.init_process_group() for more details 221 | """ 222 | num_gpus = torch.cuda.device_count() 223 | 224 | if "SLURM_JOB_ID" in os.environ: 225 | rank = int(os.environ["SLURM_PROCID"]) 226 | world_size = int(os.environ["SLURM_NTASKS"]) 227 | node_list = os.environ["SLURM_NODELIST"] 228 | addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1") 229 | # specify master port 230 | if port is not None: 231 | os.environ["MASTER_PORT"] = str(port) 232 | elif "MASTER_PORT" not in os.environ: 233 | # os.environ["MASTER_PORT"] = "29566" 234 | os.environ["MASTER_PORT"] = str(29567 + num_gpus) 235 | if "MASTER_ADDR" not in os.environ: 236 | os.environ["MASTER_ADDR"] = addr 237 | os.environ["WORLD_SIZE"] = str(world_size) 238 | os.environ["LOCAL_RANK"] = str(rank % num_gpus) 239 | os.environ["RANK"] = str(rank) 240 | else: 241 | rank = int(os.environ["RANK"]) 242 | world_size = int(os.environ["WORLD_SIZE"]) 243 | 244 | # torch.cuda.set_device(rank % num_gpus) 245 | 246 | dist.init_process_group( 247 | backend=backend, 248 | world_size=world_size, 249 | rank=rank, 250 | ) 251 | 252 | ################################################################################# 253 | # Testing Utils # 254 | ################################################################################# 255 | 256 | def save_video_grid(video, nrow=None): 257 | b, t, h, w, c = video.shape 258 | 259 | if nrow is None: 260 | nrow = math.ceil(math.sqrt(b)) 261 | ncol = math.ceil(b / nrow) 262 | padding = 1 263 | video_grid = torch.zeros((t, (padding + h) * nrow + padding, 264 | (padding + w) * ncol + padding, c), dtype=torch.uint8) 265 | 266 | for i in range(b): 267 | r = i // ncol 268 | c = i % ncol 269 | start_r = (padding + h) * r 270 | start_c = (padding + w) * c 271 | video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i] 272 | 273 | return video_grid 274 | 275 | def find_model(model_name, use_ema=True): 276 | """ 277 | Finds a pre-trained Latte model, downloading it if necessary. Alternatively, loads a model from a local path. 278 | """ 279 | assert os.path.isfile(model_name), f'Could not find Latte checkpoint at {model_name}' 280 | checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage) 281 | 282 | # Check if 'audioproj' exists in the checkpoint before returning it 283 | audioproj_dict = checkpoint.get("audioproj", None) 284 | 285 | if use_ema: # supports checkpoints from train.py 286 | print('Using Ema!') 287 | checkpoint = checkpoint["ema"] 288 | else: 289 | print('Using model!') 290 | checkpoint = checkpoint['model'] 291 | 292 | # Return the model and audioproj_dict if it exists, otherwise return the model alone 293 | return (checkpoint, audioproj_dict) if audioproj_dict is not None else (checkpoint,) 294 | 295 | 296 | ################################################################################# 297 | # MMCV Utils # 298 | ################################################################################# 299 | 300 | 301 | def collect_env(): 302 | # Copyright (c) OpenMMLab. All rights reserved. 303 | from mmcv.utils import collect_env as collect_base_env 304 | from mmcv.utils import get_git_hash 305 | """Collect the information of the running environments.""" 306 | 307 | env_info = collect_base_env() 308 | env_info['MMClassification'] = get_git_hash()[:7] 309 | 310 | for name, val in env_info.items(): 311 | print(f'{name}: {val}') 312 | 313 | print(torch.cuda.get_arch_list()) 314 | print(torch.version.cuda) 315 | 316 | 317 | ################################################################################# 318 | # Pixart-alpha Utils # 319 | ################################################################################# 320 | 321 | bad_punct_regex = re.compile( 322 | r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" 323 | ) 324 | 325 | def text_preprocessing(text, clean_caption=False): 326 | if clean_caption and not is_bs4_available(): 327 | clean_caption = False 328 | 329 | if clean_caption and not is_ftfy_available(): 330 | clean_caption = False 331 | 332 | if not isinstance(text, (tuple, list)): 333 | text = [text] 334 | 335 | def process(text: str): 336 | if clean_caption: 337 | text = clean_caption(text) 338 | text = clean_caption(text) 339 | else: 340 | text = text.lower().strip() 341 | return text 342 | 343 | return [process(t) for t in text] 344 | 345 | # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption 346 | def clean_caption(caption): 347 | caption = str(caption) 348 | caption = ul.unquote_plus(caption) 349 | caption = caption.strip().lower() 350 | caption = re.sub("", "person", caption) 351 | # urls: 352 | caption = re.sub( 353 | r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa 354 | "", 355 | caption, 356 | ) # regex for urls 357 | caption = re.sub( 358 | r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa 359 | "", 360 | caption, 361 | ) # regex for urls 362 | # html: 363 | caption = BeautifulSoup(caption, features="html.parser").text 364 | 365 | # @ 366 | caption = re.sub(r"@[\w\d]+\b", "", caption) 367 | 368 | # 31C0—31EF CJK Strokes 369 | # 31F0—31FF Katakana Phonetic Extensions 370 | # 3200—32FF Enclosed CJK Letters and Months 371 | # 3300—33FF CJK Compatibility 372 | # 3400—4DBF CJK Unified Ideographs Extension A 373 | # 4DC0—4DFF Yijing Hexagram Symbols 374 | # 4E00—9FFF CJK Unified Ideographs 375 | caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) 376 | caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) 377 | caption = re.sub(r"[\u3200-\u32ff]+", "", caption) 378 | caption = re.sub(r"[\u3300-\u33ff]+", "", caption) 379 | caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) 380 | caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) 381 | caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) 382 | ####################################################### 383 | 384 | # все виды тире / all types of dash --> "-" 385 | caption = re.sub( 386 | r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa 387 | "-", 388 | caption, 389 | ) 390 | 391 | # кавычки к одному стандарту 392 | caption = re.sub(r"[`´«»“”¨]", '"', caption) 393 | caption = re.sub(r"[‘’]", "'", caption) 394 | 395 | # " 396 | caption = re.sub(r""?", "", caption) 397 | # & 398 | caption = re.sub(r"&", "", caption) 399 | 400 | # ip adresses: 401 | caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) 402 | 403 | # article ids: 404 | caption = re.sub(r"\d:\d\d\s+$", "", caption) 405 | 406 | # \n 407 | caption = re.sub(r"\\n", " ", caption) 408 | 409 | # "#123" 410 | caption = re.sub(r"#\d{1,3}\b", "", caption) 411 | # "#12345.." 412 | caption = re.sub(r"#\d{5,}\b", "", caption) 413 | # "123456.." 414 | caption = re.sub(r"\b\d{6,}\b", "", caption) 415 | # filenames: 416 | caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) 417 | 418 | # 419 | caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" 420 | caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" 421 | 422 | caption = re.sub(bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT 423 | caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " 424 | 425 | # this-is-my-cute-cat / this_is_my_cute_cat 426 | regex2 = re.compile(r"(?:\-|\_)") 427 | if len(re.findall(regex2, caption)) > 3: 428 | caption = re.sub(regex2, " ", caption) 429 | 430 | caption = ftfy.fix_text(caption) 431 | caption = html.unescape(html.unescape(caption)) 432 | 433 | caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 434 | caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc 435 | caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 436 | 437 | caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) 438 | caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) 439 | caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) 440 | caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) 441 | caption = re.sub(r"\bpage\s+\d+\b", "", caption) 442 | 443 | caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... 444 | 445 | caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) 446 | 447 | caption = re.sub(r"\b\s+\:\s+", r": ", caption) 448 | caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) 449 | caption = re.sub(r"\s+", " ", caption) 450 | 451 | caption.strip() 452 | 453 | caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) 454 | caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) 455 | caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) 456 | caption = re.sub(r"^\.\S+$", "", caption) 457 | 458 | return caption.strip() 459 | 460 | 461 | def combine_video_audio(video_path, audio_path, output_path): 462 | """ 463 | Combine a video file and an audio file into one video with audio. 464 | 465 | Parameters: 466 | video_path (str): Path to the input video file (MP4). 467 | audio_path (str): Path to the input audio file (WAV). 468 | output_path (str): Path to save the output video file (MP4). 469 | """ 470 | video_clip = VideoFileClip(video_path) 471 | audio_clip = AudioFileClip(audio_path) 472 | 473 | video_with_audio = video_clip.set_audio(audio_clip) 474 | 475 | video_with_audio.write_videofile(output_path, codec="libx264", audio_codec="aac") 476 | 477 | video_clip.close() 478 | audio_clip.close() 479 | 480 | 481 | def tensor_to_video(tensor, output_video_file, audio_source, fps=25): 482 | """ 483 | Converts a Tensor with shape [c, f, h, w] into a video and adds an audio track from the specified audio file. 484 | 485 | Args: 486 | tensor (Tensor): The Tensor to be converted, shaped [c, f, h, w]. 487 | output_video_file (str): The file path where the output video will be saved. 488 | audio_source (str): The path to the audio file (WAV file) that contains the audio track to be added. 489 | fps (int): The frame rate of the output video. Default is 25 fps. 490 | """ 491 | if isinstance(tensor, torch.Tensor): 492 | tensor = np.array(tensor).astype(np.uint8) 493 | 494 | def make_frame(t): 495 | # get index 496 | frame_index = min(int(t * fps), tensor.shape[0] - 1) 497 | return tensor[frame_index] 498 | new_video_clip = VideoClip(make_frame, duration=tensor.shape[0] / fps) 499 | audio_clip = AudioFileClip(audio_source).subclip(0, tensor.shape[0] / fps) 500 | new_video_clip = new_video_clip.set_audio(audio_clip) 501 | new_video_clip.write_videofile(output_video_file, fps=fps, codec="libx264", audio_codec='aac') 502 | 503 | 504 | def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): 505 | """ 506 | Get a pre-defined beta schedule for the given name. 507 | The beta schedule library consists of beta schedules which remain similar 508 | in the limit of num_diffusion_timesteps. 509 | Beta schedules may be added, but should not be removed or changed once 510 | they are committed to maintain backwards compatibility. 511 | """ 512 | if schedule_name == "linear": 513 | # Linear schedule from Ho et al, extended to work for any number of 514 | # diffusion steps. 515 | scale = 1000 / num_diffusion_timesteps 516 | return get_beta_schedule( 517 | "linear", 518 | beta_start=scale * 0.0001, 519 | beta_end=scale * 0.02, 520 | num_diffusion_timesteps=num_diffusion_timesteps, 521 | ) 522 | elif schedule_name == "squaredcos_cap_v2": 523 | return betas_for_alpha_bar( 524 | num_diffusion_timesteps, 525 | lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, 526 | ) 527 | else: 528 | raise NotImplementedError(f"unknown beta schedule: {schedule_name}") 529 | 530 | 531 | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): 532 | """ 533 | This is the deprecated API for creating beta schedules. 534 | See get_named_beta_schedule() for the new library of schedules. 535 | """ 536 | if beta_schedule == "quad": 537 | betas = ( 538 | np.linspace( 539 | beta_start ** 0.5, 540 | beta_end ** 0.5, 541 | num_diffusion_timesteps, 542 | dtype=np.float64, 543 | ) 544 | ** 2 545 | ) 546 | elif beta_schedule == "linear": 547 | betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) 548 | elif beta_schedule == "warmup10": 549 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) 550 | elif beta_schedule == "warmup50": 551 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) 552 | elif beta_schedule == "const": 553 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) 554 | elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 555 | betas = 1.0 / np.linspace( 556 | num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 557 | ) 558 | else: 559 | raise NotImplementedError(beta_schedule) 560 | assert betas.shape == (num_diffusion_timesteps,) 561 | return betas 562 | 563 | 564 | def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): 565 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) 566 | warmup_time = int(num_diffusion_timesteps * warmup_frac) 567 | betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) 568 | return betas 569 | 570 | 571 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 572 | """ 573 | Create a beta schedule that discretizes the given alpha_t_bar function, 574 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 575 | :param num_diffusion_timesteps: the number of betas to produce. 576 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 577 | produces the cumulative product of (1-beta) up to that 578 | part of the diffusion process. 579 | :param max_beta: the maximum beta to use; use values lower than 1 to 580 | prevent singularities. 581 | """ 582 | betas = [] 583 | for i in range(num_diffusion_timesteps): 584 | t1 = i / num_diffusion_timesteps 585 | t2 = (i + 1) / num_diffusion_timesteps 586 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 587 | return np.array(betas) 588 | -------------------------------------------------------------------------------- /models/model_cat.py: -------------------------------------------------------------------------------- 1 | # All rights reserved. 2 | 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # -------------------------------------------------------- 6 | # References: 7 | # GLIDE: https://github.com/openai/glide-text2im 8 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py 9 | # -------------------------------------------------------- 10 | import math 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import numpy as np 15 | from torch.utils.checkpoint import checkpoint 16 | from einops import rearrange, repeat 17 | from timm.models.vision_transformer import Mlp, PatchEmbed 18 | 19 | # the xformers lib allows less memory, faster training and inference 20 | try: 21 | import xformers 22 | import xformers.ops 23 | except: 24 | XFORMERS_IS_AVAILBLE = False 25 | 26 | # from timm.models.layers.helpers import to_2tuple 27 | # from timm.models.layers.trace_utils import _assert 28 | 29 | def modulate(x, shift, scale, T): 30 | N, M = x.shape[-2], x.shape[-1] 31 | B = scale.shape[0] 32 | x = rearrange(x, '(b t) n m-> b (t n) m',b=B,t=T,n=N,m=M) 33 | x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) 34 | x = rearrange(x, 'b (t n) m-> (b t) n m',b=B,t=T,n=N,m=M) 35 | return x 36 | 37 | ################################################################################# 38 | # Attention Layers from TIMM # 39 | ################################################################################# 40 | 41 | class Attention(nn.Module): 42 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_lora=False, attention_mode='math'): 43 | super().__init__() 44 | assert dim % num_heads == 0, 'dim should be divisible by num_heads' 45 | self.num_heads = num_heads 46 | head_dim = dim // num_heads 47 | self.scale = head_dim ** -0.5 48 | self.attention_mode = attention_mode 49 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 50 | self.attn_drop = nn.Dropout(attn_drop) 51 | self.proj = nn.Linear(dim, dim) 52 | self.proj_drop = nn.Dropout(proj_drop) 53 | 54 | def forward(self, x): 55 | B, N, C = x.shape 56 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() 57 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 58 | 59 | if self.attention_mode == 'xformers': # cause loss nan while using with amp 60 | # https://github.com/facebookresearch/xformers/blob/e8bd8f932c2f48e3a3171d06749eecbbf1de420c/xformers/ops/fmha/__init__.py#L135 61 | q_xf = q.transpose(1,2).contiguous() 62 | k_xf = k.transpose(1,2).contiguous() 63 | v_xf = v.transpose(1,2).contiguous() 64 | x = xformers.ops.memory_efficient_attention(q_xf, k_xf, v_xf).reshape(B, N, C) 65 | 66 | elif self.attention_mode == 'flash': 67 | # cause loss nan while using with amp 68 | # Optionally use the context manager to ensure one of the fused kerenels is run 69 | with torch.backends.cuda.sdp_kernel(enable_math=False): 70 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v).reshape(B, N, C) # require pytorch 2.0 71 | 72 | elif self.attention_mode == 'math': 73 | attn = (q @ k.transpose(-2, -1)) * self.scale 74 | attn = attn.softmax(dim=-1) 75 | attn = self.attn_drop(attn) 76 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 77 | 78 | else: 79 | raise NotImplemented 80 | 81 | x = self.proj(x) 82 | x = self.proj_drop(x) 83 | return x 84 | 85 | 86 | class AudioAttention(nn.Module): 87 | def __init__(self, dim, num_heads=8, context_dim=None, qkv_bias=False, attn_drop=0., proj_drop=0., use_lora=False, attention_mode='math'): 88 | super().__init__() 89 | assert dim % num_heads == 0, 'dim should be divisible by num_heads' 90 | self.num_heads = num_heads 91 | head_dim = dim // num_heads 92 | self.scale = head_dim ** -0.5 93 | self.attention_mode = attention_mode 94 | 95 | context_dim = context_dim if context_dim is not None else dim 96 | 97 | # Separate layers for query and key-value pairs 98 | self.q_proj = nn.Linear(dim, dim, bias=qkv_bias) 99 | self.k_proj = nn.Linear(context_dim, dim, bias=qkv_bias) 100 | self.v_proj = nn.Linear(context_dim, dim, bias=qkv_bias) 101 | 102 | self.attn_drop = nn.Dropout(attn_drop) 103 | self.proj = nn.Linear(dim, dim) 104 | self.proj_drop = nn.Dropout(proj_drop) 105 | 106 | def forward(self, query, context): 107 | B, N, C = query.shape 108 | _, M, _ = context.shape 109 | 110 | # Query projection 111 | q = self.q_proj(query).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 112 | 113 | # Key-Value projection 114 | k = self.k_proj(context).reshape(B, M, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 115 | v = self.v_proj(context).reshape(B, M, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 116 | 117 | if self.attention_mode == 'xformers': 118 | q_xf = q.transpose(1, 2).contiguous() 119 | k_xf = k.transpose(1, 2).contiguous() 120 | v_xf = v.transpose(1, 2).contiguous() 121 | x = xformers.ops.memory_efficient_attention(q_xf, k_xf, v_xf).reshape(B, N, C) 122 | 123 | elif self.attention_mode == 'flash': 124 | with torch.backends.cuda.sdp_kernel(enable_math=False): 125 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v).reshape(B, N, C) 126 | 127 | elif self.attention_mode == 'math': 128 | attn = (q @ k.transpose(-2, -1)) * self.scale 129 | attn = attn.softmax(dim=-1) 130 | attn = self.attn_drop(attn) 131 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 132 | 133 | else: 134 | raise NotImplemented 135 | 136 | x = self.proj(x) 137 | x = self.proj_drop(x) 138 | return x 139 | 140 | 141 | ################################################################################# 142 | # Embedding Layers for Timesteps and Class Labels # 143 | ################################################################################# 144 | 145 | class TimestepEmbedder(nn.Module): 146 | """ 147 | Embeds scalar timesteps into vector representations. 148 | """ 149 | def __init__(self, hidden_size, frequency_embedding_size=256): 150 | super().__init__() 151 | self.mlp = nn.Sequential( 152 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 153 | nn.SiLU(), 154 | nn.Linear(hidden_size, hidden_size, bias=True), 155 | ) 156 | self.frequency_embedding_size = frequency_embedding_size 157 | 158 | @staticmethod 159 | def timestep_embedding(t, dim, max_period=10000): 160 | """ 161 | Create sinusoidal timestep embeddings. 162 | :param t: a 1-D Tensor of N indices, one per batch element. 163 | These may be fractional. 164 | :param dim: the dimension of the output. 165 | :param max_period: controls the minimum frequency of the embeddings. 166 | :return: an (N, D) Tensor of positional embeddings. 167 | """ 168 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 169 | half = dim // 2 170 | freqs = torch.exp( 171 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 172 | ).to(device=t.device) 173 | args = t[:, None].float() * freqs[None] 174 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 175 | if dim % 2: 176 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 177 | return embedding 178 | 179 | def forward(self, t): 180 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) 181 | t_emb = self.mlp(t_freq) 182 | return t_emb 183 | 184 | 185 | class LabelEmbedder(nn.Module): 186 | """ 187 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. 188 | """ 189 | def __init__(self, num_classes, hidden_size, dropout_prob): 190 | super().__init__() 191 | use_cfg_embedding = dropout_prob > 0 192 | self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) 193 | self.num_classes = num_classes 194 | self.dropout_prob = dropout_prob 195 | 196 | def token_drop(self, labels, force_drop_ids=None): 197 | """ 198 | Drops labels to enable classifier-free guidance. 199 | """ 200 | if force_drop_ids is None: 201 | drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob 202 | else: 203 | drop_ids = force_drop_ids == 1 204 | labels = torch.where(drop_ids, self.num_classes, labels) 205 | return labels 206 | 207 | def forward(self, labels, train, force_drop_ids=None): 208 | use_dropout = self.dropout_prob > 0 209 | if (train and use_dropout) or (force_drop_ids is not None): 210 | labels = self.token_drop(labels, force_drop_ids) 211 | embeddings = self.embedding_table(labels) 212 | return embeddings 213 | 214 | 215 | ################################################################################# 216 | # Core VDT Model # 217 | ################################################################################# 218 | 219 | class TransformerBlock(nn.Module): 220 | """ 221 | A VDT tansformer block with adaptive layer norm zero (adaLN-Zero) conditioning. 222 | """ 223 | def __init__(self, hidden_size, num_heads, context_dim=None, num_frames=16, mlp_ratio=4.0, **block_kwargs): 224 | super().__init__() 225 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 226 | self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) 227 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 228 | mlp_hidden_dim = int(hidden_size * mlp_ratio) 229 | approx_gelu = lambda: nn.GELU(approximate="tanh") 230 | self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) 231 | self.adaLN_modulation = nn.Sequential( 232 | nn.SiLU(), 233 | nn.Linear(hidden_size, 6 * hidden_size, bias=True) 234 | ) 235 | 236 | self.num_frames = num_frames 237 | ## Temporal Attention Parameters 238 | self.temporal_norm1 = nn.LayerNorm(hidden_size) 239 | self.temporal_attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True) 240 | self.temporal_fc = nn.Linear(hidden_size, hidden_size) 241 | 242 | if context_dim is not None: 243 | self.cross_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 244 | self.cross_attn = AudioAttention(hidden_size, num_heads, context_dim) 245 | 246 | def forward(self, x, cond, c, audio_semantic): 247 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) 248 | T = self.num_frames 249 | K, N, M = x.shape 250 | B = K // T 251 | x = rearrange(x, '(b t) n m -> (b n) t m',b=B,t=T,n=N,m=M) 252 | res_temporal = self.temporal_attn(self.temporal_norm1(x)) 253 | res_temporal = rearrange(res_temporal, '(b n) t m -> (b t) n m',b=B,t=T,n=N,m=M) 254 | res_temporal = self.temporal_fc(res_temporal) 255 | x = rearrange(x, '(b n) t m -> (b t) n m',b=B,t=T,n=N,m=M) 256 | x = x + res_temporal 257 | 258 | x = x + self.cross_attn(self.cross_norm(x), audio_semantic) 259 | 260 | x_cat = torch.cat([x, cond], dim=1) 261 | N = x_cat.size(1) 262 | 263 | attn = self.attn(modulate(self.norm1(x_cat), shift_msa, scale_msa, self.num_frames)) 264 | attn = rearrange(attn, '(b t) n m-> b (t n) m',b=B,t=T,n=N,m=M) 265 | attn = gate_msa.unsqueeze(1) * attn 266 | attn = rearrange(attn, 'b (t n) m-> (b t) n m',b=B,t=T,n=N,m=M) 267 | x_cat = x_cat + attn 268 | 269 | mlp = self.mlp(modulate(self.norm2(x_cat), shift_mlp, scale_mlp, self.num_frames)) 270 | mlp = rearrange(mlp, '(b t) n m-> b (t n) m',b=B,t=T,n=N,m=M) 271 | mlp = gate_mlp.unsqueeze(1) * mlp 272 | mlp = rearrange(mlp, 'b (t n) m-> (b t) n m',b=B,t=T,n=N,m=M) 273 | x_cat = x_cat + mlp 274 | 275 | x = x_cat[:, :x_cat.size(1)//2, ...] 276 | cond = x_cat[:, x_cat.size(1)//2:, ...] 277 | return x, cond 278 | 279 | 280 | class FinalLayer(nn.Module): 281 | """ 282 | The final layer of VDT. 283 | """ 284 | def __init__(self, hidden_size, patch_size, out_channels, num_frames): 285 | super().__init__() 286 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 287 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) 288 | self.adaLN_modulation = nn.Sequential( 289 | nn.SiLU(), 290 | nn.Linear(hidden_size, 2 * hidden_size, bias=True) 291 | ) 292 | self.num_frames = num_frames 293 | 294 | def forward(self, x, c): 295 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) 296 | x = modulate(self.norm_final(x), shift, scale, self.num_frames) 297 | x = self.linear(x) 298 | return x 299 | 300 | 301 | class VDT(nn.Module): 302 | """ 303 | Diffusion model with a Transformer backbone. 304 | """ 305 | def __init__( 306 | self, 307 | input_size=32, 308 | patch_size=2, 309 | in_channels=4, 310 | hidden_size=1024, 311 | context_dim=768, 312 | depth=24, 313 | num_heads=16, 314 | mlp_ratio=4.0, 315 | num_frames=16, 316 | class_dropout_prob=0.1, 317 | num_classes=1000, 318 | learn_sigma=True, 319 | extras=1, 320 | attention_mode='math', 321 | temp_comp_rate=1, 322 | gradient_checkpointing=False, 323 | ): 324 | super().__init__() 325 | self.learn_sigma = learn_sigma 326 | self.in_channels = in_channels 327 | self.out_channels = in_channels * 2 if learn_sigma else in_channels 328 | self.patch_size = patch_size 329 | self.num_heads = num_heads 330 | self.extras = extras 331 | self.num_frames = num_frames // temp_comp_rate 332 | self.gradient_checkpointing = gradient_checkpointing 333 | 334 | self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) 335 | self.t_embedder = TimestepEmbedder(hidden_size) 336 | 337 | if self.extras == 2: 338 | self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) 339 | 340 | num_patches = self.x_embedder.num_patches 341 | # Will use fixed sin-cos embedding: 342 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) 343 | self.temp_embed = nn.Parameter(torch.zeros(1, self.num_frames, hidden_size), requires_grad=False) 344 | self.hidden_size = hidden_size 345 | 346 | self.blocks = nn.ModuleList([ 347 | TransformerBlock(hidden_size, num_heads, context_dim=context_dim, num_frames=num_frames, mlp_ratio=mlp_ratio, attention_mode=attention_mode) for _ in range(depth) 348 | ]) 349 | 350 | self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels, num_frames) 351 | self.initialize_weights() 352 | 353 | def initialize_weights(self): 354 | # Initialize transformer layers: 355 | def _basic_init(module): 356 | if isinstance(module, nn.Linear): 357 | torch.nn.init.xavier_uniform_(module.weight) 358 | if module.bias is not None: 359 | nn.init.constant_(module.bias, 0) 360 | self.apply(_basic_init) 361 | 362 | # Initialize (and freeze) pos_embed by sin-cos embedding: 363 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5)) 364 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 365 | 366 | temp_embed = get_1d_sincos_temp_embed(self.temp_embed.shape[-1], self.temp_embed.shape[-2]) 367 | self.temp_embed.data.copy_(torch.from_numpy(temp_embed).float().unsqueeze(0)) 368 | 369 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): 370 | w = self.x_embedder.proj.weight.data 371 | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 372 | nn.init.constant_(self.x_embedder.proj.bias, 0) 373 | 374 | if self.extras == 2: 375 | # Initialize label embedding table: 376 | nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) 377 | 378 | # Initialize timestep embedding MLP: 379 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) 380 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) 381 | 382 | # Zero-out adaLN modulation layers in VDT blocks: 383 | for block in self.blocks: 384 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 385 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 386 | 387 | # Zero-out output layers: 388 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) 389 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) 390 | nn.init.constant_(self.final_layer.linear.weight, 0) 391 | nn.init.constant_(self.final_layer.linear.bias, 0) 392 | 393 | def unpatchify(self, x): 394 | """ 395 | x: (N, T, patch_size**2 * C) 396 | imgs: (N, H, W, C) 397 | """ 398 | c = self.out_channels 399 | p = self.x_embedder.patch_size[0] 400 | h = w = int(x.shape[1] ** 0.5) 401 | assert h * w == x.shape[1] 402 | 403 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) 404 | x = torch.einsum('nhwpqc->nchpwq', x) 405 | imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) 406 | return imgs 407 | 408 | # @torch.cuda.amp.autocast() 409 | # @torch.compile 410 | def forward(self, 411 | x, 412 | t, 413 | y=None, 414 | cond=None, 415 | audio_embed=None, 416 | ): 417 | """ 418 | Forward pass of VDT. 419 | x: (N, F, C, H, W) tensor of video inputs 420 | t: (N,) tensor of diffusion timesteps 421 | y: (N,) tensor of class labels 422 | cond: (N, C, H, W) 423 | audio_embed: (N, F, D) 424 | """ 425 | batches, frames, channels, high, width = x.shape 426 | x = rearrange(x, 'b f c h w -> (b f) c h w') 427 | cond = rearrange(cond, 'b f c h w -> (b f) c h w') 428 | x = self.x_embedder(x) + self.pos_embed 429 | cond = self.x_embedder(cond) + self.pos_embed 430 | 431 | # Temporal embed 432 | x = rearrange(x, '(b t) n m -> (b n) t m',b=batches,t=frames) 433 | ## Resizing time embeddings in case they don't match 434 | x = x + self.temp_embed 435 | x = rearrange(x, '(b n) t m -> (b t) n m',b=batches,t=frames) 436 | 437 | audio_semantic = rearrange(audio_embed, 'b f n k -> (b f) n k') 438 | 439 | t = self.t_embedder(t) # [5, 384] 440 | if y is not None: 441 | c = t + y 442 | else: 443 | c = t 444 | 445 | for i, block in enumerate(self.blocks): 446 | if self.gradient_checkpointing: 447 | x, cond = checkpoint(block, x, cond, c, audio_semantic) 448 | else: 449 | x, cond = block(x, cond, c, audio_semantic) 450 | 451 | x = self.final_layer(x, c) 452 | x = self.unpatchify(x) 453 | x = rearrange(x, '(b f) c h w -> b f c h w', b=batches) 454 | return x 455 | 456 | def forward_with_cfg(self, x, t, y=None, cfg_scale=7.0, text_embedding=None): 457 | """ 458 | Forward pass of VDT, but also batches the unconditional forward pass for classifier-free guidance. 459 | """ 460 | # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb 461 | half = x[: len(x) // 2] 462 | combined = torch.cat([half, half], dim=0) 463 | model_out = self.forward(combined, t, y=y, text_embedding=text_embedding) 464 | # For exact reproducibility reasons, we apply classifier-free guidance on only 465 | # three channels by default. The standard approach to cfg applies it to all channels. 466 | # This can be done by uncommenting the following line and commenting-out the line following that. 467 | # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] 468 | # eps, rest = model_out[:, :3], model_out[:, 3:] 469 | eps, rest = model_out[:, :, :4, ...], model_out[:, :, 4:, ...] 470 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) 471 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) 472 | eps = torch.cat([half_eps, half_eps], dim=0) 473 | return torch.cat([eps, rest], dim=2) 474 | 475 | 476 | ################################################################################# 477 | # Sine/Cosine Positional Embedding Functions # 478 | ################################################################################# 479 | # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py 480 | 481 | def get_1d_sincos_temp_embed(embed_dim, length): 482 | pos = torch.arange(0, length).unsqueeze(1) 483 | return get_1d_sincos_pos_embed_from_grid(embed_dim, pos) 484 | 485 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): 486 | """ 487 | grid_size: int of the grid height and width 488 | return: 489 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 490 | """ 491 | grid_h = np.arange(grid_size, dtype=np.float32) 492 | grid_w = np.arange(grid_size, dtype=np.float32) 493 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 494 | grid = np.stack(grid, axis=0) 495 | 496 | grid = grid.reshape([2, 1, grid_size, grid_size]) 497 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 498 | if cls_token and extra_tokens > 0: 499 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) 500 | return pos_embed 501 | 502 | 503 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 504 | assert embed_dim % 2 == 0 505 | 506 | # use half of dimensions to encode grid_h 507 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) 508 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) 509 | 510 | emb = np.concatenate([emb_h, emb_w], axis=1) 511 | return emb 512 | 513 | 514 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 515 | """ 516 | embed_dim: output dimension for each position 517 | pos: a list of positions to be encoded: size (M,) 518 | out: (M, D) 519 | """ 520 | assert embed_dim % 2 == 0 521 | omega = np.arange(embed_dim // 2, dtype=np.float64) 522 | omega /= embed_dim / 2. 523 | omega = 1. / 10000**omega 524 | 525 | pos = pos.reshape(-1) 526 | out = np.einsum('m,d->md', pos, omega) 527 | 528 | emb_sin = np.sin(out) 529 | emb_cos = np.cos(out) 530 | 531 | emb = np.concatenate([emb_sin, emb_cos], axis=1) 532 | return emb 533 | 534 | 535 | ################################################################################# 536 | # VDT Configs # 537 | ################################################################################# 538 | 539 | def VDT_XL_2(**kwargs): 540 | return VDT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs) 541 | 542 | def VDT_XL_4(**kwargs): 543 | return VDT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs) 544 | 545 | def VDT_XL_8(**kwargs): 546 | return VDT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs) 547 | 548 | def VDT_L_2(**kwargs): 549 | return VDT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs) 550 | 551 | def VDT_L_4(**kwargs): 552 | return VDT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs) 553 | 554 | def VDT_L_8(**kwargs): 555 | return VDT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs) 556 | 557 | def VDT_B_2(**kwargs): 558 | return VDT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs) 559 | 560 | def VDT_B_4(**kwargs): 561 | return VDT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs) 562 | 563 | def VDT_B_8(**kwargs): 564 | return VDT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs) 565 | 566 | def VDT_S_2(**kwargs): 567 | return VDT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs) 568 | 569 | def VDT_S_4(**kwargs): 570 | return VDT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs) 571 | 572 | def VDT_S_8(**kwargs): 573 | return VDT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs) 574 | 575 | 576 | VDTcat_models = { 577 | 'VDTcat-XL/2': VDT_XL_2, 'VDTcat-XL/4': VDT_XL_4, 'VDTcat-XL/8': VDT_XL_8, 578 | 'VDTcat-L/2': VDT_L_2, 'VDTcat-L/4': VDT_L_4, 'VDTcat-L/8': VDT_L_8, 579 | 'VDTcat-B/2': VDT_B_2, 'VDTcat-B/4': VDT_B_4, 'VDTcat-B/8': VDT_B_8, 580 | 'VDTcat-S/2': VDT_S_2, 'VDTcat-S/4': VDT_S_4, 'VDTcat-S/8': VDT_S_8, 581 | } 582 | 583 | if __name__ == '__main__': 584 | 585 | import torch 586 | 587 | device = "cuda" if torch.cuda.is_available() else "cpu" 588 | 589 | img = torch.randn(3, 16, 4, 32, 32).to(device) 590 | t = torch.tensor([1, 2, 3]).to(device) 591 | y = torch.tensor([1, 2, 3]).to(device) 592 | network = VDT_XL_2().to(device) 593 | from thop import profile 594 | flops, params = profile(network, inputs=(img, t)) 595 | print('FLOPs = ' + str(flops/1000**3) + 'G') 596 | print('Params = ' + str(params/1000**2) + 'M') 597 | # y_embeder = LabelEmbedder(num_classes=101, hidden_size=768, dropout_prob=0.5).to(device) 598 | # lora.mark_only_lora_as_trainable(network) 599 | # out = y_embeder(y, True) 600 | # out = network(img, t, y) 601 | # print(out.shape) 602 | --------------------------------------------------------------------------------