├── .gitignore ├── LICENSE ├── README.md ├── assets ├── demo1_audio.wav ├── demo1_video.mp4 ├── demo2_audio.wav ├── demo2_video.mp4 ├── demo3_audio.wav └── demo3_video.mp4 ├── cog.yaml ├── configs ├── audio.yaml ├── scheduler_config.json ├── syncnet │ ├── syncnet_16_latent.yaml │ ├── syncnet_16_pixel.yaml │ ├── syncnet_16_pixel_attn.yaml │ └── syncnet_25_pixel.yaml └── unet │ ├── stage1.yaml │ ├── stage2.yaml │ └── stage2_efficient.yaml ├── data_processing_pipeline.sh ├── docs ├── changelog_v1.5.md ├── framework.png └── syncnet_arch.md ├── eval ├── detectors │ ├── README.md │ ├── __init__.py │ └── s3fd │ │ ├── __init__.py │ │ ├── box_utils.py │ │ └── nets.py ├── draw_syncnet_lines.py ├── eval_fvd.py ├── eval_sync_conf.py ├── eval_sync_conf.sh ├── eval_syncnet_acc.py ├── eval_syncnet_acc.sh ├── fvd.py ├── hyper_iqa.py ├── inference_videos.py ├── syncnet │ ├── __init__.py │ ├── syncnet.py │ └── syncnet_eval.py └── syncnet_detect.py ├── gradio_app.py ├── inference.sh ├── latentsync ├── data │ ├── syncnet_dataset.py │ └── unet_dataset.py ├── models │ ├── attention.py │ ├── motion_module.py │ ├── resnet.py │ ├── stable_syncnet.py │ ├── unet.py │ ├── unet_blocks.py │ ├── utils.py │ └── wav2lip_syncnet.py ├── pipelines │ └── lipsync_pipeline.py ├── trepa │ ├── loss.py │ ├── third_party │ │ ├── VideoMAEv2 │ │ │ ├── __init__.py │ │ │ ├── utils.py │ │ │ ├── videomaev2_finetune.py │ │ │ └── videomaev2_pretrain.py │ │ └── __init__.py │ └── utils │ │ ├── __init__.py │ │ ├── data_utils.py │ │ └── metric_utils.py ├── utils │ ├── affine_transform.py │ ├── audio.py │ ├── av_reader.py │ ├── face_detector.py │ ├── image_processor.py │ ├── mask.png │ ├── mask2.png │ ├── mask3.png │ ├── mask4.png │ └── util.py └── whisper │ ├── audio2feature.py │ └── whisper │ ├── __init__.py │ ├── __main__.py │ ├── assets │ ├── gpt2 │ │ ├── merges.txt │ │ ├── special_tokens_map.json │ │ ├── tokenizer_config.json │ │ └── vocab.json │ ├── mel_filters.npz │ └── multilingual │ │ ├── added_tokens.json │ │ ├── merges.txt │ │ ├── special_tokens_map.json │ │ ├── tokenizer_config.json │ │ └── vocab.json │ ├── audio.py │ ├── decoding.py │ ├── model.py │ ├── normalizers │ ├── __init__.py │ ├── basic.py │ ├── english.json │ └── english.py │ ├── tokenizer.py │ ├── transcribe.py │ └── utils.py ├── predict.py ├── preprocess ├── affine_transform.py ├── data_processing_pipeline.py ├── detect_shot.py ├── filter_high_resolution.py ├── filter_visual_quality.py ├── remove_broken_videos.py ├── remove_incorrect_affined.py ├── resample_fps_hz.py ├── segment_videos.py └── sync_av.py ├── requirements.txt ├── scripts ├── inference.py ├── train_syncnet.py └── train_unet.py ├── setup_env.sh ├── tools ├── count_total_videos_time.py ├── download_web_videos.py ├── move_files_recur.py ├── occupy_gpu.py ├── plot_videos_time_distribution.py ├── remove_outdated_files.py └── write_fileslist.py ├── train_syncnet.sh └── train_unet.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # PyCharm files 2 | .idea/ 3 | 4 | # macOS dir files 5 | .DS_Store 6 | 7 | # VS Code configuration dir 8 | .vscode/ 9 | 10 | # Jupyter Notebook cache files 11 | .ipynb_checkpoints/ 12 | *.ipynb 13 | 14 | # Python cache files 15 | __pycache__/ 16 | 17 | # folders 18 | wandb/ 19 | *debug* 20 | /debug 21 | /output 22 | /validation 23 | /test 24 | /models/ 25 | /venv/ 26 | /detect_results/ 27 | /temp 28 | 29 | # checkpoint files 30 | *.safetensors 31 | *.ckpt 32 | *.pt 33 | 34 | # data files 35 | *.mp4 36 | *.avi 37 | *.wav 38 | *.png 39 | *.jpg 40 | *.jpeg 41 | *.csv 42 | *.pdf 43 | 44 | # log files 45 | *.log 46 | 47 | !/latentsync/utils/mask*.png 48 | /checkpoints/ 49 | !/assets/* 50 | !/docs/* 51 | !/internal/*.png 52 | .cog/ 53 | -------------------------------------------------------------------------------- /assets/demo1_audio.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/LatentSync/3c3a1a2b62c9d854e3e6ab5525373973380a74a5/assets/demo1_audio.wav -------------------------------------------------------------------------------- /assets/demo1_video.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/LatentSync/3c3a1a2b62c9d854e3e6ab5525373973380a74a5/assets/demo1_video.mp4 -------------------------------------------------------------------------------- /assets/demo2_audio.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/LatentSync/3c3a1a2b62c9d854e3e6ab5525373973380a74a5/assets/demo2_audio.wav -------------------------------------------------------------------------------- /assets/demo2_video.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/LatentSync/3c3a1a2b62c9d854e3e6ab5525373973380a74a5/assets/demo2_video.mp4 -------------------------------------------------------------------------------- /assets/demo3_audio.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/LatentSync/3c3a1a2b62c9d854e3e6ab5525373973380a74a5/assets/demo3_audio.wav -------------------------------------------------------------------------------- /assets/demo3_video.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/LatentSync/3c3a1a2b62c9d854e3e6ab5525373973380a74a5/assets/demo3_video.mp4 -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://cog.run/yaml 3 | 4 | build: 5 | gpu: true 6 | cuda: "12.1" 7 | system_packages: 8 | - "ffmpeg" 9 | - "libgl1" 10 | python_version: "3.10.13" 11 | python_requirements: requirements.txt 12 | 13 | run: 14 | - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.10.2/pget_linux_x86_64" && chmod +x /usr/local/bin/pget 15 | 16 | # predict.py defines how predictions are run on your model 17 | predict: "predict.py:Predictor" 18 | -------------------------------------------------------------------------------- /configs/audio.yaml: -------------------------------------------------------------------------------- 1 | audio: 2 | num_mels: 80 # Number of mel-spectrogram channels and local conditioning dimensionality 3 | rescale: true # Whether to rescale audio prior to preprocessing 4 | rescaling_max: 0.9 # Rescaling value 5 | use_lws: 6 | false # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction 7 | # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder 8 | # Does not work if n_ffit is not multiple of hop_size!! 9 | n_fft: 800 # Extra window size is filled with 0 paddings to match this parameter 10 | hop_size: 200 # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate) 11 | win_size: 800 # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate) 12 | sample_rate: 16000 # 16000Hz (corresponding to librispeech) (sox --i ) 13 | frame_shift_ms: null 14 | signal_normalization: true 15 | allow_clipping_in_normalization: true 16 | symmetric_mels: true 17 | max_abs_value: 4.0 18 | preemphasize: true # whether to apply filter 19 | preemphasis: 0.97 # filter coefficient. 20 | min_level_db: -100 21 | ref_level_db: 20 22 | fmin: 55 23 | fmax: 7600 24 | -------------------------------------------------------------------------------- /configs/scheduler_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "DDIMScheduler", 3 | "beta_end": 0.012, 4 | "beta_schedule": "scaled_linear", 5 | "beta_start": 0.00085, 6 | "clip_sample": false, 7 | "num_train_timesteps": 1000, 8 | "set_alpha_to_one": false, 9 | "steps_offset": 1, 10 | "trained_betas": null, 11 | "skip_prk_steps": true 12 | } 13 | -------------------------------------------------------------------------------- /configs/syncnet/syncnet_16_latent.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | audio_encoder: # input (1, 80, 52) 3 | in_channels: 1 4 | block_out_channels: [32, 64, 128, 256, 512, 1024] 5 | downsample_factors: [[2, 1], 2, 2, 2, 2, [2, 3]] 6 | attn_blocks: [0, 0, 0, 0, 0, 0] 7 | dropout: 0.0 8 | visual_encoder: # input (64, 32, 32) 9 | in_channels: 64 10 | block_out_channels: [64, 128, 256, 256, 512, 1024] 11 | downsample_factors: [2, 2, 2, 1, 2, 2] 12 | attn_blocks: [0, 0, 0, 0, 0, 0] 13 | dropout: 0.0 14 | 15 | ckpt: 16 | resume_ckpt_path: "" 17 | inference_ckpt_path: "" 18 | save_ckpt_steps: 2500 19 | 20 | data: 21 | train_output_dir: debug/syncnet 22 | num_val_samples: 1200 23 | batch_size: 120 # 40 24 | gradient_accumulation_steps: 1 25 | num_workers: 12 # 12 26 | latent_space: true 27 | num_frames: 16 28 | resolution: 256 29 | train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/data_v10_core.txt 30 | train_data_dir: "" 31 | val_fileslist: "" 32 | val_data_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/high_visual_quality/val 33 | audio_mel_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel 34 | lower_half: false 35 | audio_sample_rate: 16000 36 | video_fps: 25 37 | 38 | optimizer: 39 | lr: 1e-5 40 | max_grad_norm: 1.0 41 | 42 | run: 43 | max_train_steps: 10000000 44 | validation_steps: 2500 45 | mixed_precision_training: true 46 | seed: 42 47 | -------------------------------------------------------------------------------- /configs/syncnet/syncnet_16_pixel.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | audio_encoder: # input (1, 80, 52) 3 | in_channels: 1 4 | block_out_channels: [32, 64, 128, 256, 512, 1024, 2048] 5 | downsample_factors: [[2, 1], 2, 2, 1, 2, 2, [2, 3]] 6 | attn_blocks: [0, 0, 0, 0, 0, 0, 0] 7 | dropout: 0.0 8 | visual_encoder: # input (48, 128, 256) 9 | in_channels: 48 10 | block_out_channels: [64, 128, 256, 256, 512, 1024, 2048, 2048] 11 | downsample_factors: [[1, 2], 2, 2, 2, 2, 2, 2, 2] 12 | attn_blocks: [0, 0, 0, 0, 0, 0, 0, 0] 13 | dropout: 0.0 14 | 15 | ckpt: 16 | resume_ckpt_path: "" 17 | inference_ckpt_path: "" 18 | save_ckpt_steps: 2500 19 | 20 | data: 21 | train_output_dir: debug/syncnet 22 | num_val_samples: 2048 23 | batch_size: 256 # 256 24 | gradient_accumulation_steps: 1 25 | num_workers: 12 # 12 26 | latent_space: false 27 | num_frames: 16 28 | resolution: 256 29 | train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/data_v10_core.txt 30 | train_data_dir: "" 31 | val_fileslist: "" 32 | val_data_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/high_visual_quality/val 33 | audio_mel_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel 34 | lower_half: true 35 | audio_sample_rate: 16000 36 | video_fps: 25 37 | 38 | optimizer: 39 | lr: 1e-5 40 | max_grad_norm: 1.0 41 | 42 | run: 43 | max_train_steps: 10000000 44 | validation_steps: 2500 45 | mixed_precision_training: true 46 | seed: 42 47 | -------------------------------------------------------------------------------- /configs/syncnet/syncnet_16_pixel_attn.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | audio_encoder: # input (1, 80, 52) 3 | in_channels: 1 4 | block_out_channels: [32, 64, 128, 256, 512, 1024, 2048] 5 | downsample_factors: [[2, 1], 2, 2, 1, 2, 2, [2, 3]] 6 | attn_blocks: [0, 0, 0, 1, 1, 0, 0] 7 | dropout: 0.0 8 | visual_encoder: # input (48, 128, 256) 9 | in_channels: 48 10 | block_out_channels: [64, 128, 256, 256, 512, 1024, 2048, 2048] 11 | downsample_factors: [[1, 2], 2, 2, 2, 2, 2, 2, 2] 12 | attn_blocks: [0, 0, 0, 0, 1, 1, 0, 0] 13 | dropout: 0.0 14 | 15 | ckpt: 16 | resume_ckpt_path: "" 17 | inference_ckpt_path: checkpoints/stable_syncnet.pt 18 | save_ckpt_steps: 2500 19 | 20 | data: 21 | train_output_dir: debug/syncnet 22 | num_val_samples: 2048 23 | batch_size: 256 # 256 24 | gradient_accumulation_steps: 1 25 | num_workers: 12 # 12 26 | latent_space: false 27 | num_frames: 16 28 | resolution: 256 29 | train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/data_v10_core.txt 30 | train_data_dir: "" 31 | val_fileslist: "" 32 | val_data_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/high_visual_quality/val 33 | audio_mel_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel 34 | lower_half: true 35 | audio_sample_rate: 16000 36 | video_fps: 25 37 | 38 | optimizer: 39 | lr: 1e-5 40 | max_grad_norm: 1.0 41 | 42 | run: 43 | max_train_steps: 10000000 44 | validation_steps: 2500 45 | mixed_precision_training: true 46 | seed: 42 47 | -------------------------------------------------------------------------------- /configs/syncnet/syncnet_25_pixel.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | audio_encoder: # input (1, 80, 80) 3 | in_channels: 1 4 | block_out_channels: [64, 128, 256, 256, 512, 1024] 5 | downsample_factors: [2, 2, 2, 2, 2, 2] 6 | dropout: 0.0 7 | visual_encoder: # input (75, 128, 256) 8 | in_channels: 75 9 | block_out_channels: [128, 128, 256, 256, 512, 512, 1024, 1024] 10 | downsample_factors: [[1, 2], 2, 2, 2, 2, 2, 2, 2] 11 | dropout: 0.0 12 | 13 | ckpt: 14 | resume_ckpt_path: "" 15 | inference_ckpt_path: "" 16 | save_ckpt_steps: 2500 17 | 18 | data: 19 | train_output_dir: debug/syncnet 20 | num_val_samples: 2048 21 | batch_size: 64 # 64 22 | gradient_accumulation_steps: 1 23 | num_workers: 12 # 12 24 | latent_space: false 25 | num_frames: 25 26 | resolution: 256 27 | train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/data_v10_core.txt 28 | train_data_dir: "" 29 | val_fileslist: "" 30 | val_data_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/high_visual_quality/val 31 | audio_mel_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel 32 | lower_half: true 33 | audio_sample_rate: 16000 34 | video_fps: 25 35 | 36 | optimizer: 37 | lr: 1e-5 38 | max_grad_norm: 1.0 39 | 40 | run: 41 | max_train_steps: 10000000 42 | validation_steps: 2500 43 | mixed_precision_training: true 44 | seed: 42 45 | -------------------------------------------------------------------------------- /configs/unet/stage1.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | syncnet_config_path: configs/syncnet/syncnet_16_pixel_attn.yaml 3 | train_output_dir: debug/unet 4 | train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/data_v10_core.txt 5 | train_data_dir: "" 6 | audio_embeds_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/embeds 7 | audio_mel_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel 8 | 9 | val_video_path: assets/demo1_video.mp4 10 | val_audio_path: assets/demo1_audio.wav 11 | batch_size: 1 # 24 12 | num_workers: 12 # 12 13 | num_frames: 16 14 | resolution: 256 15 | mask_image_path: latentsync/utils/mask.png 16 | audio_sample_rate: 16000 17 | video_fps: 25 18 | audio_feat_length: [2, 2] 19 | 20 | ckpt: 21 | resume_ckpt_path: checkpoints/latentsync_unet.pt 22 | save_ckpt_steps: 10000 23 | 24 | run: 25 | pixel_space_supervise: false 26 | use_syncnet: false 27 | sync_loss_weight: 0.05 28 | perceptual_loss_weight: 0.1 # 0.1 29 | recon_loss_weight: 1 # 1 30 | guidance_scale: 2.0 # [1.0 - 3.0] 31 | trepa_loss_weight: 10 32 | inference_steps: 20 33 | seed: 1247 34 | use_mixed_noise: true 35 | mixed_noise_alpha: 1 # 1 36 | mixed_precision_training: true 37 | enable_gradient_checkpointing: true 38 | max_train_steps: 10000000 39 | max_train_epochs: -1 40 | 41 | optimizer: 42 | lr: 1e-5 43 | scale_lr: false 44 | max_grad_norm: 1.0 45 | lr_scheduler: constant 46 | lr_warmup_steps: 0 47 | 48 | model: 49 | act_fn: silu 50 | add_audio_layer: true 51 | attention_head_dim: 8 52 | block_out_channels: [320, 640, 1280, 1280] 53 | center_input_sample: false 54 | cross_attention_dim: 384 55 | down_block_types: 56 | [ 57 | "CrossAttnDownBlock3D", 58 | "CrossAttnDownBlock3D", 59 | "CrossAttnDownBlock3D", 60 | "DownBlock3D", 61 | ] 62 | mid_block_type: UNetMidBlock3DCrossAttn 63 | up_block_types: 64 | [ 65 | "UpBlock3D", 66 | "CrossAttnUpBlock3D", 67 | "CrossAttnUpBlock3D", 68 | "CrossAttnUpBlock3D", 69 | ] 70 | downsample_padding: 1 71 | flip_sin_to_cos: true 72 | freq_shift: 0 73 | in_channels: 13 # 49 74 | layers_per_block: 2 75 | mid_block_scale_factor: 1 76 | norm_eps: 1e-5 77 | norm_num_groups: 32 78 | out_channels: 4 # 16 79 | sample_size: 64 80 | resnet_time_scale_shift: default # Choose between [default, scale_shift] 81 | 82 | use_motion_module: false 83 | motion_module_resolutions: [1, 2, 4, 8] 84 | motion_module_mid_block: false 85 | motion_module_decoder_only: false 86 | motion_module_type: Vanilla 87 | motion_module_kwargs: 88 | num_attention_heads: 8 89 | num_transformer_block: 1 90 | attention_block_types: 91 | - Temporal_Self 92 | - Temporal_Self 93 | temporal_position_encoding: true 94 | temporal_position_encoding_max_len: 24 95 | temporal_attention_dim_div: 1 96 | zero_initialize: true 97 | -------------------------------------------------------------------------------- /configs/unet/stage2.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | syncnet_config_path: configs/syncnet/syncnet_16_pixel_attn.yaml 3 | train_output_dir: debug/unet 4 | train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/data_v10_core.txt 5 | train_data_dir: "" 6 | audio_embeds_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/embeds 7 | audio_mel_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel 8 | 9 | val_video_path: assets/demo1_video.mp4 10 | val_audio_path: assets/demo1_audio.wav 11 | batch_size: 1 # 4 12 | num_workers: 12 # 12 13 | num_frames: 16 14 | resolution: 256 15 | mask_image_path: latentsync/utils/mask.png 16 | audio_sample_rate: 16000 17 | video_fps: 25 18 | audio_feat_length: [2, 2] 19 | 20 | ckpt: 21 | resume_ckpt_path: checkpoints/latentsync_unet.pt 22 | save_ckpt_steps: 10000 23 | 24 | run: 25 | pixel_space_supervise: true 26 | use_syncnet: true 27 | sync_loss_weight: 0.05 28 | perceptual_loss_weight: 0.1 # 0.1 29 | recon_loss_weight: 1 # 1 30 | guidance_scale: 2.0 # [1.0 - 3.0] 31 | trepa_loss_weight: 10 32 | inference_steps: 20 33 | trainable_modules: 34 | - motion_modules. 35 | - attentions. 36 | seed: 1247 37 | use_mixed_noise: true 38 | mixed_noise_alpha: 1 # 1 39 | mixed_precision_training: true 40 | enable_gradient_checkpointing: true 41 | max_train_steps: 10000000 42 | max_train_epochs: -1 43 | 44 | optimizer: 45 | lr: 1e-5 46 | scale_lr: false 47 | max_grad_norm: 1.0 48 | lr_scheduler: constant 49 | lr_warmup_steps: 0 50 | 51 | model: 52 | act_fn: silu 53 | add_audio_layer: true 54 | attention_head_dim: 8 55 | block_out_channels: [320, 640, 1280, 1280] 56 | center_input_sample: false 57 | cross_attention_dim: 384 58 | down_block_types: 59 | [ 60 | "CrossAttnDownBlock3D", 61 | "CrossAttnDownBlock3D", 62 | "CrossAttnDownBlock3D", 63 | "DownBlock3D", 64 | ] 65 | mid_block_type: UNetMidBlock3DCrossAttn 66 | up_block_types: 67 | [ 68 | "UpBlock3D", 69 | "CrossAttnUpBlock3D", 70 | "CrossAttnUpBlock3D", 71 | "CrossAttnUpBlock3D", 72 | ] 73 | downsample_padding: 1 74 | flip_sin_to_cos: true 75 | freq_shift: 0 76 | in_channels: 13 # 49 77 | layers_per_block: 2 78 | mid_block_scale_factor: 1 79 | norm_eps: 1e-5 80 | norm_num_groups: 32 81 | out_channels: 4 # 16 82 | sample_size: 64 83 | resnet_time_scale_shift: default # Choose between [default, scale_shift] 84 | 85 | use_motion_module: true 86 | motion_module_resolutions: [1, 2, 4, 8] 87 | motion_module_mid_block: false 88 | motion_module_decoder_only: false 89 | motion_module_type: Vanilla 90 | motion_module_kwargs: 91 | num_attention_heads: 8 92 | num_transformer_block: 1 93 | attention_block_types: 94 | - Temporal_Self 95 | - Temporal_Self 96 | temporal_position_encoding: true 97 | temporal_position_encoding_max_len: 24 98 | temporal_attention_dim_div: 1 99 | zero_initialize: true 100 | -------------------------------------------------------------------------------- /configs/unet/stage2_efficient.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | syncnet_config_path: configs/syncnet/syncnet_16_pixel_attn.yaml 3 | train_output_dir: debug/unet 4 | train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/data_v10_core.txt 5 | train_data_dir: "" 6 | audio_embeds_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/embeds 7 | audio_mel_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel 8 | 9 | val_video_path: assets/demo1_video.mp4 10 | val_audio_path: assets/demo1_audio.wav 11 | batch_size: 1 # 4 12 | num_workers: 12 # 12 13 | num_frames: 16 14 | resolution: 256 15 | mask_image_path: latentsync/utils/mask.png 16 | audio_sample_rate: 16000 17 | video_fps: 25 18 | audio_feat_length: [2, 2] 19 | 20 | ckpt: 21 | resume_ckpt_path: checkpoints/latentsync_unet.pt 22 | save_ckpt_steps: 10000 23 | 24 | run: 25 | pixel_space_supervise: true 26 | use_syncnet: true 27 | sync_loss_weight: 0.05 28 | perceptual_loss_weight: 0.1 # 0.1 29 | recon_loss_weight: 1 # 1 30 | guidance_scale: 2.0 # [1.0 - 3.0] 31 | trepa_loss_weight: 0 32 | inference_steps: 20 33 | trainable_modules: 34 | - motion_modules. 35 | - attn2. 36 | seed: 1247 37 | use_mixed_noise: true 38 | mixed_noise_alpha: 1 # 1 39 | mixed_precision_training: true 40 | enable_gradient_checkpointing: true 41 | max_train_steps: 10000000 42 | max_train_epochs: -1 43 | 44 | optimizer: 45 | lr: 1e-5 46 | scale_lr: false 47 | max_grad_norm: 1.0 48 | lr_scheduler: constant 49 | lr_warmup_steps: 0 50 | 51 | model: 52 | act_fn: silu 53 | add_audio_layer: true 54 | attention_head_dim: 8 55 | block_out_channels: [320, 640, 1280, 1280] 56 | center_input_sample: false 57 | cross_attention_dim: 384 58 | down_block_types: 59 | [ 60 | "CrossAttnDownBlock3D", 61 | "CrossAttnDownBlock3D", 62 | "CrossAttnDownBlock3D", 63 | "DownBlock3D", 64 | ] 65 | mid_block_type: UNetMidBlock3DCrossAttn 66 | up_block_types: 67 | [ 68 | "UpBlock3D", 69 | "CrossAttnUpBlock3D", 70 | "CrossAttnUpBlock3D", 71 | "CrossAttnUpBlock3D", 72 | ] 73 | downsample_padding: 1 74 | flip_sin_to_cos: true 75 | freq_shift: 0 76 | in_channels: 13 # 49 77 | layers_per_block: 2 78 | mid_block_scale_factor: 1 79 | norm_eps: 1e-5 80 | norm_num_groups: 32 81 | out_channels: 4 # 16 82 | sample_size: 64 83 | resnet_time_scale_shift: default # Choose between [default, scale_shift] 84 | 85 | use_motion_module: true 86 | motion_module_resolutions: [1, 2, 4, 8] 87 | motion_module_mid_block: false 88 | motion_module_decoder_only: true 89 | motion_module_type: Vanilla 90 | motion_module_kwargs: 91 | num_attention_heads: 8 92 | num_transformer_block: 1 93 | attention_block_types: 94 | - Temporal_Self 95 | - Temporal_Self 96 | temporal_position_encoding: true 97 | temporal_position_encoding_max_len: 24 98 | temporal_attention_dim_div: 1 99 | zero_initialize: true 100 | -------------------------------------------------------------------------------- /data_processing_pipeline.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m preprocess.data_processing_pipeline \ 4 | --total_num_workers 96 \ 5 | --per_gpu_num_workers 12 \ 6 | --resolution 256 \ 7 | --sync_conf_threshold 3 \ 8 | --temp_dir temp \ 9 | --input_dir /mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/raw 10 | -------------------------------------------------------------------------------- /docs/changelog_v1.5.md: -------------------------------------------------------------------------------- 1 | # LatentSync 1.5 2 | 3 | ## What's new in LatentSync 1.5? 4 | 5 | 1. Add temporal layer: our previous claim that the [temporal layer](https://arxiv.org/abs/2307.04725) severely impairs lip-sync accuracy was incorrect; the issue was actually caused by a bug in the code implementation. We have corrected our [paper](https://arxiv.org/abs/2412.09262) and updated the code. After incorporating the temporal layer, LatentSync 1.5 demonstrates significantly improved temporal consistency compared to version 1.0. 6 | 7 | 2. Improves performance on Chinese videos: many issues reported poor performance on Chinese videos, so we added Chinese data to the training of the new model version. 8 | 9 | 3. Reduce the VRAM requirement of the stage2 training to **20 GB** through the following optimizations: 10 | 11 | 1. Implement gradient checkpointing in U-Net, VAE, SyncNet and VideoMAE 12 | 2. Replace xFormers with PyTorch's native implementation of FlashAttention-2. 13 | 3. Clear the CUDA cache after loading checkpoints. 14 | 4. The stage2 training only requires training the temporal layer and audio cross-attention layer, which significantly reduces VRAM requirement compared to the previous full-parameter fine-tuning. 15 | 16 | Now you can train LatentSync on a single **RTX 3090**! Start the stage2 training with `configs/unet/stage2_efficient.yaml`. 17 | 18 | 4. Other code optimizations: 19 | 20 | 1. Remove the dependency on xFormers and Triton. 21 | 2. Upgrade the diffusers version to `0.32.2`. 22 | 23 | ## LatentSync 1.5 Demo 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 34 | 37 | 38 | 39 | 42 | 45 | 46 | 47 | 50 | 53 | 54 | 55 | 58 | 61 | 62 |
Original videoLip-synced video
32 | 33 | 35 | 36 |
40 | 41 | 43 | 44 |
48 | 49 | 51 | 52 |
56 | 57 | 59 | 60 |
-------------------------------------------------------------------------------- /docs/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/LatentSync/3c3a1a2b62c9d854e3e6ab5525373973380a74a5/docs/framework.png -------------------------------------------------------------------------------- /docs/syncnet_arch.md: -------------------------------------------------------------------------------- 1 | # Customize the architecture of SyncNet 2 | 3 | The config file of SyncNet defines the architectures of audio and visual encoders. Let's first look at an example of an audio encoder: 4 | 5 | ```yaml 6 | audio_encoder: # input (1, 80, 52) 7 | in_channels: 1 8 | block_out_channels: [32, 64, 128, 256, 512, 1024, 2048] 9 | downsample_factors: [[2, 1], 2, 2, 1, 2, 2, [2, 3]] 10 | attn_blocks: [0, 0, 0, 1, 1, 0, 0] 11 | dropout: 0.0 12 | ``` 13 | 14 | The above model arch accept a `1 x 80 x 52` image (mel spectrogram) and output a `2048 x 1 x 1` feature map. If the resolution of input image changes, you need to redefine the `downsample_factors` to make the output looks like `D x 1 x 1`, so that it can be used to compute cosine similarity. Also reset the `block_out_channels`, in most cases, deeper networks require larger numbers of channels to store more features. We recommend reading the paper [EfficientNet](https://arxiv.org/abs/1905.11946), which discusses how to set the depth and width of CNN networks balancely. The `attn_blocks` defines whether a certain layer has a self-attention layer, where 1 indicates presence and 0 indicates absence. 15 | 16 | Now we look at an example of a visual encoder: 17 | 18 | ```yaml 19 | visual_encoder: # input (48, 128, 256) 20 | in_channels: 48 # (16 x 3) 21 | block_out_channels: [64, 128, 256, 256, 512, 1024, 2048, 2048] 22 | downsample_factors: [[1, 2], 2, 2, 2, 2, 2, 2, 2] 23 | attn_blocks: [0, 0, 0, 0, 1, 1, 0, 0] 24 | dropout: 0.0 25 | ``` 26 | 27 | It is important to note that `in_channels`: it equals `num_frames * image_channels`. For pixel-space SyncNet, `image_channels` is 3, while for latent-space SyncNet, `image_channels` equals the `latent_channels` of the VAE you are using, typically 4 (SD 1.5, SDXL) or 16 (FLUX, SD3). In the example above, the visual encoder has an input frame length of 16 and is a pixel-space SyncNet, so `in_channels` is `16 x 3 = 48`. -------------------------------------------------------------------------------- /eval/detectors/README.md: -------------------------------------------------------------------------------- 1 | # Face detector 2 | 3 | This face detector is adapted from `https://github.com/cs-giung/face-detection-pytorch`. 4 | -------------------------------------------------------------------------------- /eval/detectors/__init__.py: -------------------------------------------------------------------------------- 1 | from .s3fd import S3FD -------------------------------------------------------------------------------- /eval/detectors/s3fd/__init__.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import cv2 4 | import torch 5 | from torchvision import transforms 6 | from .nets import S3FDNet 7 | from .box_utils import nms_ 8 | from latentsync.utils.util import check_model_and_download 9 | 10 | PATH_WEIGHT = "checkpoints/auxiliary/sfd_face.pth" 11 | img_mean = np.array([104.0, 117.0, 123.0])[:, np.newaxis, np.newaxis].astype("float32") 12 | 13 | 14 | class S3FD: 15 | 16 | def __init__(self, device="cuda"): 17 | 18 | tstamp = time.time() 19 | self.device = device 20 | 21 | print("[S3FD] loading with", self.device) 22 | self.net = S3FDNet(device=self.device).to(self.device) 23 | check_model_and_download(PATH_WEIGHT) 24 | state_dict = torch.load(PATH_WEIGHT, map_location=self.device, weights_only=True) 25 | self.net.load_state_dict(state_dict) 26 | self.net.eval() 27 | print("[S3FD] finished loading (%.4f sec)" % (time.time() - tstamp)) 28 | 29 | def detect_faces(self, image, conf_th=0.8, scales=[1]): 30 | 31 | w, h = image.shape[1], image.shape[0] 32 | 33 | bboxes = np.empty(shape=(0, 5)) 34 | 35 | with torch.no_grad(): 36 | for s in scales: 37 | scaled_img = cv2.resize(image, dsize=(0, 0), fx=s, fy=s, interpolation=cv2.INTER_LINEAR) 38 | 39 | scaled_img = np.swapaxes(scaled_img, 1, 2) 40 | scaled_img = np.swapaxes(scaled_img, 1, 0) 41 | scaled_img = scaled_img[[2, 1, 0], :, :] 42 | scaled_img = scaled_img.astype("float32") 43 | scaled_img -= img_mean 44 | scaled_img = scaled_img[[2, 1, 0], :, :] 45 | x = torch.from_numpy(scaled_img).unsqueeze(0).to(self.device) 46 | y = self.net(x) 47 | 48 | detections = y.data 49 | scale = torch.Tensor([w, h, w, h]) 50 | 51 | for i in range(detections.size(1)): 52 | j = 0 53 | while detections[0, i, j, 0] > conf_th: 54 | score = detections[0, i, j, 0] 55 | pt = (detections[0, i, j, 1:] * scale).cpu().numpy() 56 | bbox = (pt[0], pt[1], pt[2], pt[3], score) 57 | bboxes = np.vstack((bboxes, bbox)) 58 | j += 1 59 | 60 | keep = nms_(bboxes, 0.1) 61 | bboxes = bboxes[keep] 62 | 63 | return bboxes 64 | -------------------------------------------------------------------------------- /eval/detectors/s3fd/nets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.init as init 5 | from .box_utils import Detect, PriorBox 6 | 7 | 8 | class L2Norm(nn.Module): 9 | 10 | def __init__(self, n_channels, scale): 11 | super(L2Norm, self).__init__() 12 | self.n_channels = n_channels 13 | self.gamma = scale or None 14 | self.eps = 1e-10 15 | self.weight = nn.Parameter(torch.Tensor(self.n_channels)) 16 | self.reset_parameters() 17 | 18 | def reset_parameters(self): 19 | init.constant_(self.weight, self.gamma) 20 | 21 | def forward(self, x): 22 | norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps 23 | x = torch.div(x, norm) 24 | out = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3).expand_as(x) * x 25 | return out 26 | 27 | 28 | class S3FDNet(nn.Module): 29 | 30 | def __init__(self, device='cuda'): 31 | super(S3FDNet, self).__init__() 32 | self.device = device 33 | 34 | self.vgg = nn.ModuleList([ 35 | nn.Conv2d(3, 64, 3, 1, padding=1), 36 | nn.ReLU(inplace=True), 37 | nn.Conv2d(64, 64, 3, 1, padding=1), 38 | nn.ReLU(inplace=True), 39 | nn.MaxPool2d(2, 2), 40 | 41 | nn.Conv2d(64, 128, 3, 1, padding=1), 42 | nn.ReLU(inplace=True), 43 | nn.Conv2d(128, 128, 3, 1, padding=1), 44 | nn.ReLU(inplace=True), 45 | nn.MaxPool2d(2, 2), 46 | 47 | nn.Conv2d(128, 256, 3, 1, padding=1), 48 | nn.ReLU(inplace=True), 49 | nn.Conv2d(256, 256, 3, 1, padding=1), 50 | nn.ReLU(inplace=True), 51 | nn.Conv2d(256, 256, 3, 1, padding=1), 52 | nn.ReLU(inplace=True), 53 | nn.MaxPool2d(2, 2, ceil_mode=True), 54 | 55 | nn.Conv2d(256, 512, 3, 1, padding=1), 56 | nn.ReLU(inplace=True), 57 | nn.Conv2d(512, 512, 3, 1, padding=1), 58 | nn.ReLU(inplace=True), 59 | nn.Conv2d(512, 512, 3, 1, padding=1), 60 | nn.ReLU(inplace=True), 61 | nn.MaxPool2d(2, 2), 62 | 63 | nn.Conv2d(512, 512, 3, 1, padding=1), 64 | nn.ReLU(inplace=True), 65 | nn.Conv2d(512, 512, 3, 1, padding=1), 66 | nn.ReLU(inplace=True), 67 | nn.Conv2d(512, 512, 3, 1, padding=1), 68 | nn.ReLU(inplace=True), 69 | nn.MaxPool2d(2, 2), 70 | 71 | nn.Conv2d(512, 1024, 3, 1, padding=6, dilation=6), 72 | nn.ReLU(inplace=True), 73 | nn.Conv2d(1024, 1024, 1, 1), 74 | nn.ReLU(inplace=True), 75 | ]) 76 | 77 | self.L2Norm3_3 = L2Norm(256, 10) 78 | self.L2Norm4_3 = L2Norm(512, 8) 79 | self.L2Norm5_3 = L2Norm(512, 5) 80 | 81 | self.extras = nn.ModuleList([ 82 | nn.Conv2d(1024, 256, 1, 1), 83 | nn.Conv2d(256, 512, 3, 2, padding=1), 84 | nn.Conv2d(512, 128, 1, 1), 85 | nn.Conv2d(128, 256, 3, 2, padding=1), 86 | ]) 87 | 88 | self.loc = nn.ModuleList([ 89 | nn.Conv2d(256, 4, 3, 1, padding=1), 90 | nn.Conv2d(512, 4, 3, 1, padding=1), 91 | nn.Conv2d(512, 4, 3, 1, padding=1), 92 | nn.Conv2d(1024, 4, 3, 1, padding=1), 93 | nn.Conv2d(512, 4, 3, 1, padding=1), 94 | nn.Conv2d(256, 4, 3, 1, padding=1), 95 | ]) 96 | 97 | self.conf = nn.ModuleList([ 98 | nn.Conv2d(256, 4, 3, 1, padding=1), 99 | nn.Conv2d(512, 2, 3, 1, padding=1), 100 | nn.Conv2d(512, 2, 3, 1, padding=1), 101 | nn.Conv2d(1024, 2, 3, 1, padding=1), 102 | nn.Conv2d(512, 2, 3, 1, padding=1), 103 | nn.Conv2d(256, 2, 3, 1, padding=1), 104 | ]) 105 | 106 | self.softmax = nn.Softmax(dim=-1) 107 | self.detect = Detect() 108 | 109 | def forward(self, x): 110 | size = x.size()[2:] 111 | sources = list() 112 | loc = list() 113 | conf = list() 114 | 115 | for k in range(16): 116 | x = self.vgg[k](x) 117 | s = self.L2Norm3_3(x) 118 | sources.append(s) 119 | 120 | for k in range(16, 23): 121 | x = self.vgg[k](x) 122 | s = self.L2Norm4_3(x) 123 | sources.append(s) 124 | 125 | for k in range(23, 30): 126 | x = self.vgg[k](x) 127 | s = self.L2Norm5_3(x) 128 | sources.append(s) 129 | 130 | for k in range(30, len(self.vgg)): 131 | x = self.vgg[k](x) 132 | sources.append(x) 133 | 134 | # apply extra layers and cache source layer outputs 135 | for k, v in enumerate(self.extras): 136 | x = F.relu(v(x), inplace=True) 137 | if k % 2 == 1: 138 | sources.append(x) 139 | 140 | # apply multibox head to source layers 141 | loc_x = self.loc[0](sources[0]) 142 | conf_x = self.conf[0](sources[0]) 143 | 144 | max_conf, _ = torch.max(conf_x[:, 0:3, :, :], dim=1, keepdim=True) 145 | conf_x = torch.cat((max_conf, conf_x[:, 3:, :, :]), dim=1) 146 | 147 | loc.append(loc_x.permute(0, 2, 3, 1).contiguous()) 148 | conf.append(conf_x.permute(0, 2, 3, 1).contiguous()) 149 | 150 | for i in range(1, len(sources)): 151 | x = sources[i] 152 | conf.append(self.conf[i](x).permute(0, 2, 3, 1).contiguous()) 153 | loc.append(self.loc[i](x).permute(0, 2, 3, 1).contiguous()) 154 | 155 | features_maps = [] 156 | for i in range(len(loc)): 157 | feat = [] 158 | feat += [loc[i].size(1), loc[i].size(2)] 159 | features_maps += [feat] 160 | 161 | loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1) 162 | conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1) 163 | 164 | with torch.no_grad(): 165 | self.priorbox = PriorBox(size, features_maps) 166 | self.priors = self.priorbox.forward() 167 | 168 | output = self.detect.forward( 169 | loc.view(loc.size(0), -1, 4), 170 | self.softmax(conf.view(conf.size(0), -1, 2)), 171 | self.priors.type(type(x.data)).to(self.device) 172 | ) 173 | 174 | return output 175 | -------------------------------------------------------------------------------- /eval/draw_syncnet_lines.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import matplotlib.pyplot as plt 17 | 18 | 19 | class Chart: 20 | def __init__(self): 21 | self.loss_list = [] 22 | 23 | def add_ckpt(self, ckpt_path, line_name): 24 | ckpt = torch.load(ckpt_path, map_location="cpu") 25 | train_step_list = ckpt["train_step_list"] 26 | train_loss_list = ckpt["train_loss_list"] 27 | val_step_list = ckpt["val_step_list"] 28 | val_loss_list = ckpt["val_loss_list"] 29 | self.loss_list.append((line_name, train_step_list, train_loss_list, val_step_list, val_loss_list)) 30 | 31 | def draw(self, save_path, plot_val=True): 32 | # Global settings 33 | plt.rcParams["font.size"] = 14 34 | plt.rcParams["font.family"] = "serif" 35 | plt.rcParams["font.sans-serif"] = ["Arial", "DejaVu Sans", "Lucida Grande"] 36 | plt.rcParams["font.serif"] = ["Times New Roman", "DejaVu Serif"] 37 | 38 | # Creating the plot 39 | plt.figure(figsize=(7.766, 4.8)) # Golden ratio 40 | for loss in self.loss_list: 41 | if plot_val: 42 | (line,) = plt.plot(loss[1], loss[2], label=loss[0], linewidth=0.5, alpha=0.5) 43 | line_color = line.get_color() 44 | plt.plot(loss[3], loss[4], linewidth=1.5, color=line_color) 45 | else: 46 | plt.plot(loss[1], loss[2], label=loss[0], linewidth=1) 47 | plt.xlabel("Step") 48 | plt.ylabel("Loss") 49 | legend = plt.legend() 50 | # legend = plt.legend(loc='upper right', bbox_to_anchor=(1, 0.82)) 51 | 52 | # Adjust the linewidth of legend 53 | for line in legend.get_lines(): 54 | line.set_linewidth(2) 55 | 56 | plt.savefig(save_path, transparent=True) 57 | plt.close() 58 | 59 | 60 | if __name__ == "__main__": 61 | chart = Chart() 62 | chart.add_ckpt("output/syncnet/train-2024_10_28-23:16:40/checkpoints/checkpoint-20000.pt", "Wav2Lip SyncNet") 63 | chart.add_ckpt("output/syncnet/train-2024_10_29-20:13:43/checkpoints/checkpoint-20000.pt", "StableSyncNet") 64 | chart.draw("ablation.pdf", plot_val=True) 65 | -------------------------------------------------------------------------------- /eval/eval_fvd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import mediapipe as mp 16 | import cv2 17 | from decord import VideoReader 18 | import os 19 | import numpy as np 20 | import torch 21 | import tqdm 22 | from eval.fvd import compute_our_fvd 23 | 24 | 25 | class FVD: 26 | def __init__(self, resolution=(224, 224)): 27 | self.face_detector = mp.solutions.face_detection.FaceDetection(model_selection=0, min_detection_confidence=0.5) 28 | self.resolution = resolution 29 | 30 | def detect_face(self, image): 31 | height, width = image.shape[:2] 32 | # Process the image and detect faces. 33 | results = self.face_detector.process(image) 34 | 35 | if not results.detections: # Face not detected 36 | raise RuntimeError("Face not detected") 37 | 38 | detection = results.detections[0] # Only use the first face in the image 39 | bounding_box = detection.location_data.relative_bounding_box 40 | xmin = int(bounding_box.xmin * width) 41 | ymin = int(bounding_box.ymin * height) 42 | face_width = int(bounding_box.width * width) 43 | face_height = int(bounding_box.height * height) 44 | 45 | # Crop the image to the bounding box. 46 | xmin = max(0, xmin) 47 | ymin = max(0, ymin) 48 | xmax = min(width, xmin + face_width) 49 | ymax = min(height, ymin + face_height) 50 | image = image[ymin:ymax, xmin:xmax] 51 | 52 | return image 53 | 54 | def detect_video(self, video_path): 55 | vr = VideoReader(video_path) 56 | video_frames = vr[20:36].asnumpy() 57 | vr.seek(0) # avoid memory leak 58 | faces = [] 59 | for frame in video_frames: 60 | face = self.detect_face(frame) 61 | face = cv2.resize(face, (self.resolution[1], self.resolution[0]), interpolation=cv2.INTER_AREA) 62 | faces.append(face) 63 | 64 | if len(faces) != 16: 65 | return RuntimeError("Insufficient consecutive frames of faces (less than 16).") 66 | faces = np.stack(faces, axis=0) # (f, h, w, c) 67 | faces = torch.from_numpy(faces) 68 | return faces 69 | 70 | def detect_videos(self, videos_dir: str): 71 | videos_list = [] 72 | 73 | if videos_dir.endswith(".mp4"): 74 | video_faces = self.detect_video(videos_dir) 75 | videos_list.append(video_faces) 76 | else: 77 | for file in tqdm.tqdm(os.listdir(videos_dir)): 78 | if file.endswith(".mp4"): 79 | video_path = os.path.join(videos_dir, file) 80 | video_faces = self.detect_video(video_path) 81 | videos_list.append(video_faces) 82 | 83 | videos_list = torch.stack(videos_list) / 255.0 84 | return videos_list 85 | 86 | 87 | def eval_fvd(real_videos_dir: str, fake_videos_dir: str): 88 | fvd = FVD() 89 | real_videos = fvd.detect_videos(real_videos_dir) 90 | fake_videos = fvd.detect_videos(fake_videos_dir) 91 | fvd_value = compute_our_fvd(real_videos, fake_videos, device="cpu") 92 | print(f"FVD: {fvd_value:.3f}") 93 | 94 | 95 | if __name__ == "__main__": 96 | real_videos_dir = "dir1" 97 | fake_videos_dir = "dir2" 98 | eval_fvd(real_videos_dir, fake_videos_dir) 99 | -------------------------------------------------------------------------------- /eval/eval_sync_conf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import os 17 | import tqdm 18 | from statistics import fmean 19 | from eval.syncnet import SyncNetEval 20 | from eval.syncnet_detect import SyncNetDetector 21 | from latentsync.utils.util import red_text 22 | import torch 23 | 24 | 25 | def syncnet_eval(syncnet, syncnet_detector, video_path, temp_dir, detect_results_dir="detect_results"): 26 | syncnet_detector(video_path=video_path, min_track=50) 27 | crop_videos = os.listdir(os.path.join(detect_results_dir, "crop")) 28 | if crop_videos == []: 29 | raise Exception(red_text(f"Face not detected in {video_path}")) 30 | av_offset_list = [] 31 | conf_list = [] 32 | for video in crop_videos: 33 | av_offset, _, conf = syncnet.evaluate( 34 | video_path=os.path.join(detect_results_dir, "crop", video), temp_dir=temp_dir 35 | ) 36 | av_offset_list.append(av_offset) 37 | conf_list.append(conf) 38 | av_offset = int(fmean(av_offset_list)) 39 | conf = fmean(conf_list) 40 | print(f"Input video: {video_path}\nSyncNet confidence: {conf:.2f}\nAV offset: {av_offset}") 41 | return av_offset, conf 42 | 43 | 44 | def main(): 45 | parser = argparse.ArgumentParser(description="SyncNet") 46 | parser.add_argument("--initial_model", type=str, default="checkpoints/auxiliary/syncnet_v2.model", help="") 47 | parser.add_argument("--video_path", type=str, default=None, help="") 48 | parser.add_argument("--videos_dir", type=str, default="/root/processed") 49 | parser.add_argument("--temp_dir", type=str, default="temp", help="") 50 | 51 | args = parser.parse_args() 52 | 53 | device = "cuda" if torch.cuda.is_available() else "cpu" 54 | 55 | syncnet = SyncNetEval(device=device) 56 | syncnet.loadParameters(args.initial_model) 57 | 58 | syncnet_detector = SyncNetDetector(device=device, detect_results_dir="detect_results") 59 | 60 | if args.video_path is not None: 61 | syncnet_eval(syncnet, syncnet_detector, args.video_path, args.temp_dir) 62 | else: 63 | sync_conf_list = [] 64 | video_names = sorted([f for f in os.listdir(args.videos_dir) if f.endswith(".mp4")]) 65 | for video_name in tqdm.tqdm(video_names): 66 | try: 67 | _, conf = syncnet_eval( 68 | syncnet, syncnet_detector, os.path.join(args.videos_dir, video_name), args.temp_dir 69 | ) 70 | sync_conf_list.append(conf) 71 | except Exception as e: 72 | print(e) 73 | print(f"The average sync confidence is {fmean(sync_conf_list):.02f}") 74 | 75 | 76 | if __name__ == "__main__": 77 | main() 78 | -------------------------------------------------------------------------------- /eval/eval_sync_conf.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python -m eval.eval_sync_conf --video_path "video_out.mp4" 3 | -------------------------------------------------------------------------------- /eval/eval_syncnet_acc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | from tqdm.auto import tqdm 17 | import torch 18 | import torch.nn as nn 19 | from einops import rearrange 20 | from latentsync.models.stable_syncnet import StableSyncNet 21 | from latentsync.data.syncnet_dataset import SyncNetDataset 22 | from diffusers import AutoencoderKL 23 | from omegaconf import OmegaConf 24 | from accelerate.utils import set_seed 25 | 26 | 27 | def main(config): 28 | set_seed(config.run.seed) 29 | 30 | device = "cuda" if torch.cuda.is_available() else "cpu" 31 | 32 | if config.data.latent_space: 33 | vae = AutoencoderKL.from_pretrained( 34 | "runwayml/stable-diffusion-inpainting", subfolder="vae", revision="fp16", torch_dtype=torch.float16 35 | ) 36 | vae.requires_grad_(False) 37 | vae.to(device) 38 | 39 | # Dataset and Dataloader setup 40 | dataset = SyncNetDataset(config.data.val_data_dir, config.data.val_fileslist, config) 41 | 42 | test_dataloader = torch.utils.data.DataLoader( 43 | dataset, 44 | batch_size=config.data.batch_size, 45 | shuffle=False, 46 | num_workers=config.data.num_workers, 47 | drop_last=False, 48 | worker_init_fn=dataset.worker_init_fn, 49 | ) 50 | 51 | # Model 52 | syncnet = StableSyncNet(OmegaConf.to_container(config.model)).to(device) 53 | 54 | print(f"Load checkpoint from: {config.ckpt.inference_ckpt_path}") 55 | checkpoint = torch.load(config.ckpt.inference_ckpt_path, map_location=device, weights_only=True) 56 | 57 | syncnet.load_state_dict(checkpoint["state_dict"]) 58 | syncnet.to(dtype=torch.float16) 59 | syncnet.requires_grad_(False) 60 | syncnet.eval() 61 | 62 | global_step = 0 63 | num_val_batches = config.data.num_val_samples // config.data.batch_size 64 | progress_bar = tqdm(range(0, num_val_batches), initial=0, desc="Testing accuracy") 65 | 66 | num_correct_preds = 0 67 | num_total_preds = 0 68 | 69 | while True: 70 | for step, batch in enumerate(test_dataloader): 71 | ### >>>> Test >>>> ### 72 | 73 | frames = batch["frames"].to(device, dtype=torch.float16) 74 | audio_samples = batch["audio_samples"].to(device, dtype=torch.float16) 75 | y = batch["y"].to(device, dtype=torch.float16).squeeze(1) 76 | 77 | if config.data.latent_space: 78 | frames = rearrange(frames, "b f c h w -> (b f) c h w") 79 | 80 | with torch.no_grad(): 81 | frames = vae.encode(frames).latent_dist.sample() * 0.18215 82 | 83 | frames = rearrange(frames, "(b f) c h w -> b (f c) h w", f=config.data.num_frames) 84 | else: 85 | frames = rearrange(frames, "b f c h w -> b (f c) h w") 86 | 87 | if config.data.lower_half: 88 | height = frames.shape[2] 89 | frames = frames[:, :, height // 2 :, :] 90 | 91 | with torch.no_grad(): 92 | vision_embeds, audio_embeds = syncnet(frames, audio_samples) 93 | 94 | sims = nn.functional.cosine_similarity(vision_embeds, audio_embeds) 95 | 96 | preds = (sims > 0.5).to(dtype=torch.float16) 97 | num_correct_preds += (preds == y).sum().item() 98 | num_total_preds += len(sims) 99 | 100 | progress_bar.update(1) 101 | global_step += 1 102 | 103 | if global_step >= num_val_batches: 104 | progress_bar.close() 105 | print(f"SyncNet Accuracy: {num_correct_preds / num_total_preds*100:.2f}%") 106 | return 107 | 108 | 109 | if __name__ == "__main__": 110 | parser = argparse.ArgumentParser(description="Code to test the accuracy of SyncNet") 111 | 112 | parser.add_argument("--config_path", type=str, default="configs/syncnet/syncnet_16_latent.yaml") 113 | args = parser.parse_args() 114 | 115 | # Load a configuration file 116 | config = OmegaConf.load(args.config_path) 117 | 118 | main(config) 119 | -------------------------------------------------------------------------------- /eval/eval_syncnet_acc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m eval.eval_syncnet_acc --config_path "configs/syncnet/syncnet_16_pixel_attn.yaml" 4 | -------------------------------------------------------------------------------- /eval/fvd.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/universome/fvd-comparison/blob/master/our_fvd.py 2 | 3 | from typing import Tuple 4 | import scipy 5 | import numpy as np 6 | import torch 7 | from latentsync.utils.util import check_model_and_download 8 | 9 | 10 | def compute_fvd(feats_fake: np.ndarray, feats_real: np.ndarray) -> float: 11 | mu_gen, sigma_gen = compute_stats(feats_fake) 12 | mu_real, sigma_real = compute_stats(feats_real) 13 | 14 | m = np.square(mu_gen - mu_real).sum() 15 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member 16 | fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2)) 17 | 18 | return float(fid) 19 | 20 | 21 | def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 22 | mu = feats.mean(axis=0) # [d] 23 | sigma = np.cov(feats, rowvar=False) # [d, d] 24 | 25 | return mu, sigma 26 | 27 | 28 | @torch.no_grad() 29 | def compute_our_fvd(videos_fake: np.ndarray, videos_real: np.ndarray, device: str = "cuda") -> float: 30 | i3d_path = "checkpoints/auxiliary/i3d_torchscript.pt" 31 | check_model_and_download(i3d_path) 32 | i3d_kwargs = dict( 33 | rescale=False, resize=False, return_features=True 34 | ) # Return raw features before the softmax layer. 35 | 36 | with open(i3d_path, "rb") as f: 37 | i3d_model = torch.jit.load(f).eval().to(device) 38 | 39 | videos_fake = videos_fake.permute(0, 4, 1, 2, 3).to(device) 40 | videos_real = videos_real.permute(0, 4, 1, 2, 3).to(device) 41 | 42 | feats_fake = i3d_model(videos_fake, **i3d_kwargs).cpu().numpy() 43 | feats_real = i3d_model(videos_real, **i3d_kwargs).cpu().numpy() 44 | 45 | return compute_fvd(feats_fake, feats_real) 46 | 47 | 48 | def main(): 49 | # input shape: (b, f, h, w, c) 50 | videos_fake = torch.rand(10, 16, 224, 224, 3) 51 | videos_real = torch.rand(10, 16, 224, 224, 3) 52 | 53 | our_fvd_result = compute_our_fvd(videos_fake, videos_real) 54 | print(f"[FVD scores] Ours: {our_fvd_result}") 55 | 56 | 57 | if __name__ == "__main__": 58 | main() 59 | -------------------------------------------------------------------------------- /eval/inference_videos.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import subprocess 17 | from tqdm import tqdm 18 | import random 19 | 20 | 21 | def inference_video_from_fileslist( 22 | video_fileslist: str, 23 | audio_fileslist: str, 24 | output_dir: str, 25 | unet_config_path: str, 26 | ckpt_path: str, 27 | seed: int = 42, 28 | ): 29 | with open(video_fileslist, "r", encoding="utf-8") as file: 30 | video_paths = [line.strip() for line in file.readlines()] 31 | 32 | with open(audio_fileslist, "r", encoding="utf-8") as file: 33 | audio_paths = [line.strip() for line in file.readlines()] 34 | 35 | random.seed(seed) 36 | 37 | output_dir = f"{output_dir}__{seed}" 38 | os.makedirs(output_dir, exist_ok=True) 39 | 40 | random.shuffle(video_paths) 41 | random.shuffle(audio_paths) 42 | 43 | min_length = min(len(video_paths), len(audio_paths)) 44 | 45 | video_paths = video_paths[:min_length] 46 | audio_paths = audio_paths[:min_length] 47 | 48 | random.shuffle(video_paths) 49 | random.shuffle(audio_paths) 50 | 51 | for index, video_path in tqdm(enumerate(video_paths), total=len(video_paths)): 52 | audio_path = audio_paths[index] 53 | video_name = os.path.basename(video_path)[:-4] 54 | audio_name = os.path.basename(audio_path)[:-4] 55 | video_out_path = os.path.join(output_dir, f"{video_name}__{audio_name}.mp4") 56 | inference_command = f"python -m scripts.inference --guidance_scale 1.5 --unet_config_path {unet_config_path} --video_path {video_path} --audio_path {audio_path} --video_out_path {video_out_path} --inference_ckpt_path {ckpt_path}" 57 | subprocess.run(inference_command, shell=True) 58 | 59 | 60 | if __name__ == "__main__": 61 | video_fileslist = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/video_fileslist.txt" 62 | audio_fileslist = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/audio_fileslist.txt" 63 | output_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/inference_videos_results" 64 | 65 | unet_config_path = "configs/unet/stage2.yaml" 66 | ckpt_path = "checkpoints/latentsync_unet.pt" 67 | 68 | seed = 42 69 | 70 | inference_video_from_fileslist(video_fileslist, audio_fileslist, output_dir, unet_config_path, ckpt_path, seed) 71 | -------------------------------------------------------------------------------- /eval/syncnet/__init__.py: -------------------------------------------------------------------------------- 1 | from .syncnet_eval import SyncNetEval 2 | -------------------------------------------------------------------------------- /eval/syncnet/syncnet.py: -------------------------------------------------------------------------------- 1 | # https://github.com/joonson/syncnet_python/blob/master/SyncNetModel.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | def save(model, filename): 8 | with open(filename, "wb") as f: 9 | torch.save(model, f) 10 | print("%s saved." % filename) 11 | 12 | 13 | def load(filename): 14 | net = torch.load(filename) 15 | return net 16 | 17 | 18 | class S(nn.Module): 19 | def __init__(self, num_layers_in_fc_layers=1024): 20 | super(S, self).__init__() 21 | 22 | self.__nFeatures__ = 24 23 | self.__nChs__ = 32 24 | self.__midChs__ = 32 25 | 26 | self.netcnnaud = nn.Sequential( 27 | nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 28 | nn.BatchNorm2d(64), 29 | nn.ReLU(inplace=True), 30 | nn.MaxPool2d(kernel_size=(1, 1), stride=(1, 1)), 31 | nn.Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 32 | nn.BatchNorm2d(192), 33 | nn.ReLU(inplace=True), 34 | nn.MaxPool2d(kernel_size=(3, 3), stride=(1, 2)), 35 | nn.Conv2d(192, 384, kernel_size=(3, 3), padding=(1, 1)), 36 | nn.BatchNorm2d(384), 37 | nn.ReLU(inplace=True), 38 | nn.Conv2d(384, 256, kernel_size=(3, 3), padding=(1, 1)), 39 | nn.BatchNorm2d(256), 40 | nn.ReLU(inplace=True), 41 | nn.Conv2d(256, 256, kernel_size=(3, 3), padding=(1, 1)), 42 | nn.BatchNorm2d(256), 43 | nn.ReLU(inplace=True), 44 | nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2)), 45 | nn.Conv2d(256, 512, kernel_size=(5, 4), padding=(0, 0)), 46 | nn.BatchNorm2d(512), 47 | nn.ReLU(), 48 | ) 49 | 50 | self.netfcaud = nn.Sequential( 51 | nn.Linear(512, 512), 52 | nn.BatchNorm1d(512), 53 | nn.ReLU(), 54 | nn.Linear(512, num_layers_in_fc_layers), 55 | ) 56 | 57 | self.netfclip = nn.Sequential( 58 | nn.Linear(512, 512), 59 | nn.BatchNorm1d(512), 60 | nn.ReLU(), 61 | nn.Linear(512, num_layers_in_fc_layers), 62 | ) 63 | 64 | self.netcnnlip = nn.Sequential( 65 | nn.Conv3d(3, 96, kernel_size=(5, 7, 7), stride=(1, 2, 2), padding=0), 66 | nn.BatchNorm3d(96), 67 | nn.ReLU(inplace=True), 68 | nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2)), 69 | nn.Conv3d(96, 256, kernel_size=(1, 5, 5), stride=(1, 2, 2), padding=(0, 1, 1)), 70 | nn.BatchNorm3d(256), 71 | nn.ReLU(inplace=True), 72 | nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)), 73 | nn.Conv3d(256, 256, kernel_size=(1, 3, 3), padding=(0, 1, 1)), 74 | nn.BatchNorm3d(256), 75 | nn.ReLU(inplace=True), 76 | nn.Conv3d(256, 256, kernel_size=(1, 3, 3), padding=(0, 1, 1)), 77 | nn.BatchNorm3d(256), 78 | nn.ReLU(inplace=True), 79 | nn.Conv3d(256, 256, kernel_size=(1, 3, 3), padding=(0, 1, 1)), 80 | nn.BatchNorm3d(256), 81 | nn.ReLU(inplace=True), 82 | nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2)), 83 | nn.Conv3d(256, 512, kernel_size=(1, 6, 6), padding=0), 84 | nn.BatchNorm3d(512), 85 | nn.ReLU(inplace=True), 86 | ) 87 | 88 | def forward_aud(self, x): 89 | 90 | mid = self.netcnnaud(x) 91 | # N x ch x 24 x M 92 | mid = mid.view((mid.size()[0], -1)) 93 | # N x (ch x 24) 94 | out = self.netfcaud(mid) 95 | 96 | return out 97 | 98 | def forward_lip(self, x): 99 | 100 | mid = self.netcnnlip(x) 101 | mid = mid.view((mid.size()[0], -1)) 102 | # N x (ch x 24) 103 | out = self.netfclip(mid) 104 | 105 | return out 106 | 107 | def forward_lipfeat(self, x): 108 | 109 | mid = self.netcnnlip(x) 110 | out = mid.view((mid.size()[0], -1)) 111 | # N x (ch x 24) 112 | 113 | return out 114 | -------------------------------------------------------------------------------- /eval/syncnet/syncnet_eval.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/joonson/syncnet_python/blob/master/SyncNetInstance.py 2 | 3 | import torch 4 | import numpy 5 | import time, pdb, argparse, subprocess, os, math, glob 6 | import cv2 7 | import python_speech_features 8 | 9 | from scipy import signal 10 | from scipy.io import wavfile 11 | from .syncnet import S 12 | from shutil import rmtree 13 | from latentsync.utils.util import check_model_and_download 14 | 15 | 16 | # ==================== Get OFFSET ==================== 17 | 18 | # Video 25 FPS, Audio 16000HZ 19 | 20 | 21 | def calc_pdist(feat1, feat2, vshift=10): 22 | win_size = vshift * 2 + 1 23 | 24 | feat2p = torch.nn.functional.pad(feat2, (0, 0, vshift, vshift)) 25 | 26 | dists = [] 27 | 28 | for i in range(0, len(feat1)): 29 | 30 | dists.append( 31 | torch.nn.functional.pairwise_distance(feat1[[i], :].repeat(win_size, 1), feat2p[i : i + win_size, :]) 32 | ) 33 | 34 | return dists 35 | 36 | 37 | # ==================== MAIN DEF ==================== 38 | 39 | 40 | class SyncNetEval(torch.nn.Module): 41 | def __init__(self, dropout=0, num_layers_in_fc_layers=1024, device="cpu"): 42 | super().__init__() 43 | 44 | self.__S__ = S(num_layers_in_fc_layers=num_layers_in_fc_layers).to(device) 45 | self.device = device 46 | 47 | def evaluate(self, video_path, temp_dir="temp", batch_size=20, vshift=15): 48 | 49 | self.__S__.eval() 50 | 51 | # ========== ========== 52 | # Convert files 53 | # ========== ========== 54 | 55 | if os.path.exists(temp_dir): 56 | rmtree(temp_dir) 57 | 58 | os.makedirs(temp_dir) 59 | 60 | # temp_video_path = os.path.join(temp_dir, "temp.mp4") 61 | # command = f"ffmpeg -loglevel error -nostdin -y -i {video_path} -vf scale='224:224' {temp_video_path}" 62 | # subprocess.call(command, shell=True) 63 | 64 | command = f"ffmpeg -loglevel error -nostdin -y -i {video_path} -f image2 {os.path.join(temp_dir, '%06d.jpg')}" 65 | subprocess.call(command, shell=True, stdout=None) 66 | 67 | command = f"ffmpeg -loglevel error -nostdin -y -i {video_path} -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 {os.path.join(temp_dir, 'audio.wav')}" 68 | subprocess.call(command, shell=True, stdout=None) 69 | 70 | # ========== ========== 71 | # Load video 72 | # ========== ========== 73 | 74 | images = [] 75 | 76 | flist = glob.glob(os.path.join(temp_dir, "*.jpg")) 77 | flist.sort() 78 | 79 | for fname in flist: 80 | img_input = cv2.imread(fname) 81 | img_input = cv2.resize(img_input, (224, 224)) # HARD CODED, CHANGE BEFORE RELEASE 82 | images.append(img_input) 83 | 84 | im = numpy.stack(images, axis=3) 85 | im = numpy.expand_dims(im, axis=0) 86 | im = numpy.transpose(im, (0, 3, 4, 1, 2)) 87 | 88 | imtv = torch.autograd.Variable(torch.from_numpy(im.astype(float)).float()) 89 | 90 | # ========== ========== 91 | # Load audio 92 | # ========== ========== 93 | 94 | sample_rate, audio = wavfile.read(os.path.join(temp_dir, "audio.wav")) 95 | mfcc = zip(*python_speech_features.mfcc(audio, sample_rate)) 96 | mfcc = numpy.stack([numpy.array(i) for i in mfcc]) 97 | 98 | cc = numpy.expand_dims(numpy.expand_dims(mfcc, axis=0), axis=0) 99 | cct = torch.autograd.Variable(torch.from_numpy(cc.astype(float)).float()) 100 | 101 | # ========== ========== 102 | # Check audio and video input length 103 | # ========== ========== 104 | 105 | # if (float(len(audio)) / 16000) != (float(len(images)) / 25): 106 | # print( 107 | # "WARNING: Audio (%.4fs) and video (%.4fs) lengths are different." 108 | # % (float(len(audio)) / 16000, float(len(images)) / 25) 109 | # ) 110 | 111 | min_length = min(len(images), math.floor(len(audio) / 640)) 112 | 113 | # ========== ========== 114 | # Generate video and audio feats 115 | # ========== ========== 116 | 117 | lastframe = min_length - 5 118 | im_feat = [] 119 | cc_feat = [] 120 | 121 | tS = time.time() 122 | for i in range(0, lastframe, batch_size): 123 | 124 | im_batch = [imtv[:, :, vframe : vframe + 5, :, :] for vframe in range(i, min(lastframe, i + batch_size))] 125 | im_in = torch.cat(im_batch, 0) 126 | im_out = self.__S__.forward_lip(im_in.to(self.device)) 127 | im_feat.append(im_out.data.cpu()) 128 | 129 | cc_batch = [ 130 | cct[:, :, :, vframe * 4 : vframe * 4 + 20] for vframe in range(i, min(lastframe, i + batch_size)) 131 | ] 132 | cc_in = torch.cat(cc_batch, 0) 133 | cc_out = self.__S__.forward_aud(cc_in.to(self.device)) 134 | cc_feat.append(cc_out.data.cpu()) 135 | 136 | im_feat = torch.cat(im_feat, 0) 137 | cc_feat = torch.cat(cc_feat, 0) 138 | 139 | # ========== ========== 140 | # Compute offset 141 | # ========== ========== 142 | 143 | dists = calc_pdist(im_feat, cc_feat, vshift=vshift) 144 | mean_dists = torch.mean(torch.stack(dists, 1), 1) 145 | 146 | min_dist, minidx = torch.min(mean_dists, 0) 147 | 148 | av_offset = vshift - minidx 149 | conf = torch.median(mean_dists) - min_dist 150 | 151 | fdist = numpy.stack([dist[minidx].numpy() for dist in dists]) 152 | # fdist = numpy.pad(fdist, (3,3), 'constant', constant_values=15) 153 | fconf = torch.median(mean_dists).numpy() - fdist 154 | framewise_conf = signal.medfilt(fconf, kernel_size=9) 155 | 156 | # numpy.set_printoptions(formatter={"float": "{: 0.3f}".format}) 157 | rmtree(temp_dir) 158 | return av_offset.item(), min_dist.item(), conf.item() 159 | 160 | def extract_feature(self, opt, videofile): 161 | 162 | self.__S__.eval() 163 | 164 | # ========== ========== 165 | # Load video 166 | # ========== ========== 167 | cap = cv2.VideoCapture(videofile) 168 | 169 | frame_num = 1 170 | images = [] 171 | while frame_num: 172 | frame_num += 1 173 | ret, image = cap.read() 174 | if ret == 0: 175 | break 176 | 177 | images.append(image) 178 | 179 | im = numpy.stack(images, axis=3) 180 | im = numpy.expand_dims(im, axis=0) 181 | im = numpy.transpose(im, (0, 3, 4, 1, 2)) 182 | 183 | imtv = torch.autograd.Variable(torch.from_numpy(im.astype(float)).float()) 184 | 185 | # ========== ========== 186 | # Generate video feats 187 | # ========== ========== 188 | 189 | lastframe = len(images) - 4 190 | im_feat = [] 191 | 192 | tS = time.time() 193 | for i in range(0, lastframe, opt.batch_size): 194 | 195 | im_batch = [ 196 | imtv[:, :, vframe : vframe + 5, :, :] for vframe in range(i, min(lastframe, i + opt.batch_size)) 197 | ] 198 | im_in = torch.cat(im_batch, 0) 199 | im_out = self.__S__.forward_lipfeat(im_in.to(self.device)) 200 | im_feat.append(im_out.data.cpu()) 201 | 202 | im_feat = torch.cat(im_feat, 0) 203 | 204 | # ========== ========== 205 | # Compute offset 206 | # ========== ========== 207 | 208 | print("Compute time %.3f sec." % (time.time() - tS)) 209 | 210 | return im_feat 211 | 212 | def loadParameters(self, path): 213 | check_model_and_download(path) 214 | loaded_state = torch.load(path, map_location=lambda storage, loc: storage, weights_only=True) 215 | 216 | self_state = self.__S__.state_dict() 217 | 218 | for name, param in loaded_state.items(): 219 | 220 | self_state[name].copy_(param) 221 | -------------------------------------------------------------------------------- /gradio_app.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | from pathlib import Path 3 | from scripts.inference import main 4 | from omegaconf import OmegaConf 5 | import argparse 6 | from datetime import datetime 7 | 8 | CONFIG_PATH = Path("configs/unet/stage2.yaml") 9 | CHECKPOINT_PATH = Path("checkpoints/latentsync_unet.pt") 10 | 11 | 12 | def process_video( 13 | video_path, 14 | audio_path, 15 | guidance_scale, 16 | inference_steps, 17 | seed, 18 | ): 19 | # Create the temp directory if it doesn't exist 20 | output_dir = Path("./temp") 21 | output_dir.mkdir(parents=True, exist_ok=True) 22 | 23 | # Convert paths to absolute Path objects and normalize them 24 | video_file_path = Path(video_path) 25 | video_path = video_file_path.absolute().as_posix() 26 | audio_path = Path(audio_path).absolute().as_posix() 27 | 28 | current_time = datetime.now().strftime("%Y%m%d_%H%M%S") 29 | # Set the output path for the processed video 30 | output_path = str(output_dir / f"{video_file_path.stem}_{current_time}.mp4") # Change the filename as needed 31 | 32 | config = OmegaConf.load(CONFIG_PATH) 33 | 34 | config["run"].update( 35 | { 36 | "guidance_scale": guidance_scale, 37 | "inference_steps": inference_steps, 38 | } 39 | ) 40 | 41 | # Parse the arguments 42 | args = create_args(video_path, audio_path, output_path, inference_steps, guidance_scale, seed) 43 | 44 | try: 45 | result = main( 46 | config=config, 47 | args=args, 48 | ) 49 | print("Processing completed successfully.") 50 | return output_path # Ensure the output path is returned 51 | except Exception as e: 52 | print(f"Error during processing: {str(e)}") 53 | raise gr.Error(f"Error during processing: {str(e)}") 54 | 55 | 56 | def create_args( 57 | video_path: str, audio_path: str, output_path: str, inference_steps: int, guidance_scale: float, seed: int 58 | ) -> argparse.Namespace: 59 | parser = argparse.ArgumentParser() 60 | parser.add_argument("--inference_ckpt_path", type=str, required=True) 61 | parser.add_argument("--video_path", type=str, required=True) 62 | parser.add_argument("--audio_path", type=str, required=True) 63 | parser.add_argument("--video_out_path", type=str, required=True) 64 | parser.add_argument("--inference_steps", type=int, default=20) 65 | parser.add_argument("--guidance_scale", type=float, default=1.0) 66 | parser.add_argument("--seed", type=int, default=1247) 67 | 68 | return parser.parse_args( 69 | [ 70 | "--inference_ckpt_path", 71 | CHECKPOINT_PATH.absolute().as_posix(), 72 | "--video_path", 73 | video_path, 74 | "--audio_path", 75 | audio_path, 76 | "--video_out_path", 77 | output_path, 78 | "--inference_steps", 79 | str(inference_steps), 80 | "--guidance_scale", 81 | str(guidance_scale), 82 | "--seed", 83 | str(seed), 84 | ] 85 | ) 86 | 87 | 88 | # Create Gradio interface 89 | with gr.Blocks(title="LatentSync demo") as demo: 90 | gr.Markdown( 91 | """ 92 |

LatentSync

93 | 94 |
95 | 96 | 97 | 98 | 99 | 100 | 101 |
102 | """ 103 | ) 104 | 105 | with gr.Row(): 106 | with gr.Column(): 107 | video_input = gr.Video(label="Input Video") 108 | audio_input = gr.Audio(label="Input Audio", type="filepath") 109 | 110 | with gr.Row(): 111 | guidance_scale = gr.Slider( 112 | minimum=1.0, 113 | maximum=3.0, 114 | value=2.0, 115 | step=0.5, 116 | label="Guidance Scale", 117 | ) 118 | inference_steps = gr.Slider(minimum=10, maximum=50, value=20, step=1, label="Inference Steps") 119 | 120 | with gr.Row(): 121 | seed = gr.Number(value=1247, label="Random Seed", precision=0) 122 | 123 | process_btn = gr.Button("Process Video") 124 | 125 | with gr.Column(): 126 | video_output = gr.Video(label="Output Video") 127 | 128 | gr.Examples( 129 | examples=[ 130 | ["assets/demo1_video.mp4", "assets/demo1_audio.wav"], 131 | ["assets/demo2_video.mp4", "assets/demo2_audio.wav"], 132 | ["assets/demo3_video.mp4", "assets/demo3_audio.wav"], 133 | ], 134 | inputs=[video_input, audio_input], 135 | ) 136 | 137 | process_btn.click( 138 | fn=process_video, 139 | inputs=[ 140 | video_input, 141 | audio_input, 142 | guidance_scale, 143 | inference_steps, 144 | seed, 145 | ], 146 | outputs=video_output, 147 | ) 148 | 149 | if __name__ == "__main__": 150 | demo.launch(inbrowser=True, share=True) 151 | -------------------------------------------------------------------------------- /inference.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m scripts.inference \ 4 | --unet_config_path "configs/unet/stage2.yaml" \ 5 | --inference_ckpt_path "checkpoints/latentsync_unet.pt" \ 6 | --inference_steps 20 \ 7 | --guidance_scale 2.0 \ 8 | --video_path "assets/demo1_video.mp4" \ 9 | --audio_path "assets/demo1_audio.wav" \ 10 | --video_out_path "video_out.mp4" 11 | -------------------------------------------------------------------------------- /latentsync/data/syncnet_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import numpy as np 17 | from torch.utils.data import Dataset 18 | import torch 19 | import random 20 | from ..utils.util import gather_video_paths_recursively 21 | from ..utils.image_processor import ImageProcessor 22 | from ..utils.audio import melspectrogram 23 | import math 24 | from pathlib import Path 25 | 26 | from decord import AudioReader, VideoReader, cpu 27 | 28 | 29 | class SyncNetDataset(Dataset): 30 | def __init__(self, data_dir: str, fileslist: str, config): 31 | if fileslist != "": 32 | with open(fileslist) as file: 33 | self.video_paths = [line.rstrip() for line in file] 34 | elif data_dir != "": 35 | self.video_paths = gather_video_paths_recursively(data_dir) 36 | else: 37 | raise ValueError("data_dir and fileslist cannot be both empty") 38 | 39 | self.resolution = config.data.resolution 40 | self.num_frames = config.data.num_frames 41 | 42 | self.mel_window_length = math.ceil(self.num_frames / 5 * 16) 43 | 44 | self.audio_sample_rate = config.data.audio_sample_rate 45 | self.video_fps = config.data.video_fps 46 | self.image_processor = ImageProcessor(resolution=config.data.resolution) 47 | self.audio_mel_cache_dir = config.data.audio_mel_cache_dir 48 | Path(self.audio_mel_cache_dir).mkdir(parents=True, exist_ok=True) 49 | 50 | def __len__(self): 51 | return len(self.video_paths) 52 | 53 | def read_audio(self, video_path: str): 54 | ar = AudioReader(video_path, ctx=cpu(self.worker_id), sample_rate=self.audio_sample_rate) 55 | original_mel = melspectrogram(ar[:].asnumpy().squeeze(0)) 56 | return torch.from_numpy(original_mel) 57 | 58 | def crop_audio_window(self, original_mel, start_index): 59 | start_idx = int(80.0 * (start_index / float(self.video_fps))) 60 | end_idx = start_idx + self.mel_window_length 61 | return original_mel[:, start_idx:end_idx].unsqueeze(0) 62 | 63 | def get_frames(self, video_reader: VideoReader): 64 | total_num_frames = len(video_reader) 65 | 66 | start_idx = random.randint(0, total_num_frames - self.num_frames) 67 | frames_index = np.arange(start_idx, start_idx + self.num_frames, dtype=int) 68 | 69 | while True: 70 | wrong_start_idx = random.randint(0, total_num_frames - self.num_frames) 71 | if wrong_start_idx == start_idx: 72 | continue 73 | wrong_frames_index = np.arange(wrong_start_idx, wrong_start_idx + self.num_frames, dtype=int) 74 | break 75 | 76 | frames = video_reader.get_batch(frames_index).asnumpy() 77 | wrong_frames = video_reader.get_batch(wrong_frames_index).asnumpy() 78 | 79 | return frames, wrong_frames, start_idx 80 | 81 | def worker_init_fn(self, worker_id): 82 | self.worker_id = worker_id 83 | 84 | def __getitem__(self, idx): 85 | while True: 86 | try: 87 | idx = random.randint(0, len(self) - 1) 88 | 89 | # Get video file path 90 | video_path = self.video_paths[idx] 91 | 92 | vr = VideoReader(video_path, ctx=cpu(self.worker_id)) 93 | 94 | if len(vr) < 2 * self.num_frames: 95 | continue 96 | 97 | frames, wrong_frames, start_idx = self.get_frames(vr) 98 | 99 | mel_cache_path = os.path.join( 100 | self.audio_mel_cache_dir, os.path.basename(video_path).replace(".mp4", "_mel.pt") 101 | ) 102 | 103 | if os.path.isfile(mel_cache_path): 104 | try: 105 | original_mel = torch.load(mel_cache_path, weights_only=True) 106 | except Exception as e: 107 | print(f"{type(e).__name__} - {e} - {mel_cache_path}") 108 | os.remove(mel_cache_path) 109 | original_mel = self.read_audio(video_path) 110 | torch.save(original_mel, mel_cache_path) 111 | else: 112 | original_mel = self.read_audio(video_path) 113 | torch.save(original_mel, mel_cache_path) 114 | 115 | mel = self.crop_audio_window(original_mel, start_idx) 116 | 117 | if mel.shape[-1] != self.mel_window_length: 118 | continue 119 | 120 | if random.choice([True, False]): 121 | y = torch.ones(1).float() 122 | chosen_frames = frames 123 | else: 124 | y = torch.zeros(1).float() 125 | chosen_frames = wrong_frames 126 | 127 | chosen_frames = self.image_processor.process_images(chosen_frames) 128 | 129 | vr.seek(0) # avoid memory leak 130 | break 131 | 132 | except Exception as e: # Handle the exception of face not detcted 133 | print(f"{type(e).__name__} - {e} - {video_path}") 134 | if "vr" in locals(): 135 | vr.seek(0) # avoid memory leak 136 | 137 | sample = dict(frames=chosen_frames, audio_samples=mel, y=y) 138 | 139 | return sample 140 | -------------------------------------------------------------------------------- /latentsync/data/unet_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import math 17 | import numpy as np 18 | from torch.utils.data import Dataset 19 | import torch 20 | import random 21 | import cv2 22 | from ..utils.image_processor import ImageProcessor, load_fixed_mask 23 | from ..utils.audio import melspectrogram 24 | from decord import AudioReader, VideoReader, cpu 25 | import torch.nn.functional as F 26 | from pathlib import Path 27 | 28 | 29 | class UNetDataset(Dataset): 30 | def __init__(self, train_data_dir: str, config): 31 | if config.data.train_fileslist != "": 32 | with open(config.data.train_fileslist) as file: 33 | self.video_paths = [line.rstrip() for line in file] 34 | elif train_data_dir != "": 35 | self.video_paths = [] 36 | for file in os.listdir(train_data_dir): 37 | if file.endswith(".mp4"): 38 | self.video_paths.append(os.path.join(train_data_dir, file)) 39 | else: 40 | raise ValueError("data_dir and fileslist cannot be both empty") 41 | 42 | self.resolution = config.data.resolution 43 | self.num_frames = config.data.num_frames 44 | 45 | self.mel_window_length = math.ceil(self.num_frames / 5 * 16) 46 | 47 | self.audio_sample_rate = config.data.audio_sample_rate 48 | self.video_fps = config.data.video_fps 49 | self.image_processor = ImageProcessor( 50 | self.resolution, mask_image=load_fixed_mask(self.resolution, config.data.mask_image_path) 51 | ) 52 | self.load_audio_data = config.model.add_audio_layer and config.run.use_syncnet 53 | self.audio_mel_cache_dir = config.data.audio_mel_cache_dir 54 | Path(self.audio_mel_cache_dir).mkdir(parents=True, exist_ok=True) 55 | 56 | def __len__(self): 57 | return len(self.video_paths) 58 | 59 | def read_audio(self, video_path: str): 60 | ar = AudioReader(video_path, ctx=cpu(self.worker_id), sample_rate=self.audio_sample_rate) 61 | original_mel = melspectrogram(ar[:].asnumpy().squeeze(0)) 62 | return torch.from_numpy(original_mel) 63 | 64 | def crop_audio_window(self, original_mel, start_index): 65 | start_idx = int(80.0 * (start_index / float(self.video_fps))) 66 | end_idx = start_idx + self.mel_window_length 67 | return original_mel[:, start_idx:end_idx].unsqueeze(0) 68 | 69 | def get_frames(self, video_reader: VideoReader): 70 | total_num_frames = len(video_reader) 71 | 72 | start_idx = random.randint(0, total_num_frames - self.num_frames) 73 | gt_frames_index = np.arange(start_idx, start_idx + self.num_frames, dtype=int) 74 | 75 | while True: 76 | ref_start_idx = random.randint(0, total_num_frames - self.num_frames) 77 | if ref_start_idx > start_idx - self.num_frames and ref_start_idx < start_idx + self.num_frames: 78 | continue 79 | ref_frames_index = np.arange(ref_start_idx, ref_start_idx + self.num_frames, dtype=int) 80 | break 81 | 82 | gt_frames = video_reader.get_batch(gt_frames_index).asnumpy() 83 | ref_frames = video_reader.get_batch(ref_frames_index).asnumpy() 84 | 85 | return gt_frames, ref_frames, start_idx 86 | 87 | def worker_init_fn(self, worker_id): 88 | self.worker_id = worker_id 89 | 90 | def __getitem__(self, idx): 91 | while True: 92 | try: 93 | idx = random.randint(0, len(self) - 1) 94 | 95 | # Get video file path 96 | video_path = self.video_paths[idx] 97 | 98 | vr = VideoReader(video_path, ctx=cpu(self.worker_id)) 99 | 100 | if len(vr) < 3 * self.num_frames: 101 | continue 102 | 103 | gt_frames, ref_frames, start_idx = self.get_frames(vr) 104 | 105 | if self.load_audio_data: 106 | mel_cache_path = os.path.join( 107 | self.audio_mel_cache_dir, os.path.basename(video_path).replace(".mp4", "_mel.pt") 108 | ) 109 | 110 | if os.path.isfile(mel_cache_path): 111 | try: 112 | original_mel = torch.load(mel_cache_path, weights_only=True) 113 | except Exception as e: 114 | print(f"{type(e).__name__} - {e} - {mel_cache_path}") 115 | os.remove(mel_cache_path) 116 | original_mel = self.read_audio(video_path) 117 | torch.save(original_mel, mel_cache_path) 118 | else: 119 | original_mel = self.read_audio(video_path) 120 | torch.save(original_mel, mel_cache_path) 121 | 122 | mel = self.crop_audio_window(original_mel, start_idx) 123 | 124 | if mel.shape[-1] != self.mel_window_length: 125 | continue 126 | else: 127 | mel = [] 128 | 129 | gt_pixel_values, masked_pixel_values, masks = self.image_processor.prepare_masks_and_masked_images( 130 | gt_frames, affine_transform=False 131 | ) # (f, c, h, w) 132 | ref_pixel_values = self.image_processor.process_images(ref_frames) 133 | 134 | vr.seek(0) # avoid memory leak 135 | break 136 | 137 | except Exception as e: # Handle the exception of face not detcted 138 | print(f"{type(e).__name__} - {e} - {video_path}") 139 | if "vr" in locals(): 140 | vr.seek(0) # avoid memory leak 141 | 142 | sample = dict( 143 | gt_pixel_values=gt_pixel_values, 144 | masked_pixel_values=masked_pixel_values, 145 | ref_pixel_values=ref_pixel_values, 146 | mel=mel, 147 | masks=masks, 148 | video_path=video_path, 149 | start_idx=start_idx, 150 | ) 151 | 152 | return sample 153 | -------------------------------------------------------------------------------- /latentsync/models/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | def zero_module(module): 16 | # Zero out the parameters of a module and return it. 17 | for p in module.parameters(): 18 | p.detach().zero_() 19 | return module 20 | -------------------------------------------------------------------------------- /latentsync/models/wav2lip_syncnet.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/primepake/wav2lip_288x288/blob/master/models/syncnetv2.py 2 | # The code here is for ablation study. 3 | 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | 8 | class Wav2LipSyncNet(nn.Module): 9 | def __init__(self, act_fn="leaky"): 10 | super().__init__() 11 | 12 | # input image sequences: (15, 128, 256) 13 | self.visual_encoder = nn.Sequential( 14 | Conv2d(15, 32, kernel_size=(7, 7), stride=1, padding=3, act_fn=act_fn), # (128, 256) 15 | Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=1, act_fn=act_fn), # (126, 127) 16 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn), 17 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn), 18 | Conv2d(64, 128, kernel_size=3, stride=2, padding=1, act_fn=act_fn), # (63, 64) 19 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn), 20 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn), 21 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn), 22 | Conv2d(128, 256, kernel_size=3, stride=3, padding=1, act_fn=act_fn), # (21, 22) 23 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn), 24 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn), 25 | Conv2d(256, 512, kernel_size=3, stride=2, padding=1, act_fn=act_fn), # (11, 11) 26 | Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn), 27 | Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn), 28 | Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, act_fn=act_fn), # (6, 6) 29 | Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn), 30 | Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn), 31 | Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1, act_fn="relu"), # (3, 3) 32 | Conv2d(1024, 1024, kernel_size=3, stride=1, padding=0, act_fn="relu"), # (1, 1) 33 | Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0, act_fn="relu"), 34 | ) 35 | 36 | # input audio sequences: (1, 80, 16) 37 | self.audio_encoder = nn.Sequential( 38 | Conv2d(1, 32, kernel_size=3, stride=1, padding=1, act_fn=act_fn), 39 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn), 40 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn), 41 | Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1, act_fn=act_fn), # (27, 16) 42 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn), 43 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn), 44 | Conv2d(64, 128, kernel_size=3, stride=3, padding=1, act_fn=act_fn), # (9, 6) 45 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn), 46 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn), 47 | Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1, act_fn=act_fn), # (3, 3) 48 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn), 49 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn), 50 | Conv2d(256, 512, kernel_size=3, stride=1, padding=1, act_fn=act_fn), 51 | Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn), 52 | Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn), 53 | Conv2d(512, 1024, kernel_size=3, stride=1, padding=0, act_fn="relu"), # (1, 1) 54 | Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0, act_fn="relu"), 55 | ) 56 | 57 | def forward(self, image_sequences, audio_sequences): 58 | vision_embeds = self.visual_encoder(image_sequences) # (b, c, 1, 1) 59 | audio_embeds = self.audio_encoder(audio_sequences) # (b, c, 1, 1) 60 | 61 | vision_embeds = vision_embeds.reshape(vision_embeds.shape[0], -1) # (b, c) 62 | audio_embeds = audio_embeds.reshape(audio_embeds.shape[0], -1) # (b, c) 63 | 64 | # Make them unit vectors 65 | vision_embeds = F.normalize(vision_embeds, p=2, dim=1) 66 | audio_embeds = F.normalize(audio_embeds, p=2, dim=1) 67 | 68 | return vision_embeds, audio_embeds 69 | 70 | 71 | class Conv2d(nn.Module): 72 | def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, act_fn="relu", *args, **kwargs): 73 | super().__init__(*args, **kwargs) 74 | self.conv_block = nn.Sequential(nn.Conv2d(cin, cout, kernel_size, stride, padding), nn.BatchNorm2d(cout)) 75 | if act_fn == "relu": 76 | self.act_fn = nn.ReLU() 77 | elif act_fn == "tanh": 78 | self.act_fn = nn.Tanh() 79 | elif act_fn == "silu": 80 | self.act_fn = nn.SiLU() 81 | elif act_fn == "leaky": 82 | self.act_fn = nn.LeakyReLU(0.2, inplace=True) 83 | 84 | self.residual = residual 85 | 86 | def forward(self, x): 87 | out = self.conv_block(x) 88 | if self.residual: 89 | out += x 90 | return self.act_fn(out) 91 | -------------------------------------------------------------------------------- /latentsync/trepa/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn.functional as F 17 | from einops import rearrange 18 | from .third_party.VideoMAEv2.utils import load_videomae_model 19 | from ..utils.util import check_model_and_download 20 | 21 | 22 | class TREPALoss: 23 | def __init__( 24 | self, 25 | device="cuda", 26 | ckpt_path="checkpoints/auxiliary/vit_g_hybrid_pt_1200e_ssv2_ft.pth", 27 | with_cp=False, 28 | ): 29 | check_model_and_download(ckpt_path) 30 | self.model = load_videomae_model(device, ckpt_path, with_cp).eval().to(dtype=torch.float16) 31 | self.model.requires_grad_(False) 32 | 33 | def __call__(self, videos_fake, videos_real): 34 | batch_size = videos_fake.shape[0] 35 | num_frames = videos_fake.shape[2] 36 | videos_fake = rearrange(videos_fake.clone(), "b c f h w -> (b f) c h w") 37 | videos_real = rearrange(videos_real.clone(), "b c f h w -> (b f) c h w") 38 | 39 | videos_fake = F.interpolate(videos_fake, size=(224, 224), mode="bicubic") 40 | videos_real = F.interpolate(videos_real, size=(224, 224), mode="bicubic") 41 | 42 | videos_fake = rearrange(videos_fake, "(b f) c h w -> b c f h w", f=num_frames) 43 | videos_real = rearrange(videos_real, "(b f) c h w -> b c f h w", f=num_frames) 44 | 45 | # Because input pixel range is [-1, 1], and model expects pixel range to be [0, 1] 46 | videos_fake = (videos_fake / 2 + 0.5).clamp(0, 1) 47 | videos_real = (videos_real / 2 + 0.5).clamp(0, 1) 48 | 49 | feats_fake = self.model.forward_features(videos_fake) 50 | feats_real = self.model.forward_features(videos_real) 51 | 52 | feats_fake = F.normalize(feats_fake, p=2, dim=1) 53 | feats_real = F.normalize(feats_real, p=2, dim=1) 54 | 55 | return F.mse_loss(feats_fake, feats_real) 56 | 57 | 58 | if __name__ == "__main__": 59 | torch.manual_seed(42) 60 | 61 | # input shape: (b, c, f, h, w) 62 | videos_fake = torch.randn(2, 3, 16, 256, 256, requires_grad=True).to(device="cuda", dtype=torch.float16) 63 | videos_real = torch.randn(2, 3, 16, 256, 256, requires_grad=True).to(device="cuda", dtype=torch.float16) 64 | 65 | trepa_loss = TREPALoss(device="cuda", with_cp=True) 66 | loss = trepa_loss(videos_fake, videos_real) 67 | print(loss) 68 | -------------------------------------------------------------------------------- /latentsync/trepa/third_party/VideoMAEv2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/LatentSync/3c3a1a2b62c9d854e3e6ab5525373973380a74a5/latentsync/trepa/third_party/VideoMAEv2/__init__.py -------------------------------------------------------------------------------- /latentsync/trepa/third_party/VideoMAEv2/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import requests 4 | from tqdm import tqdm 5 | from torchvision import transforms 6 | from .videomaev2_finetune import vit_giant_patch14_224 7 | 8 | 9 | def to_normalized_float_tensor(vid): 10 | return vid.permute(3, 0, 1, 2).to(torch.float32) / 255 11 | 12 | 13 | # NOTE: for those functions, which generally expect mini-batches, we keep them 14 | # as non-minibatch so that they are applied as if they were 4d (thus image). 15 | # this way, we only apply the transformation in the spatial domain 16 | def resize(vid, size, interpolation="bilinear"): 17 | # NOTE: using bilinear interpolation because we don't work on minibatches 18 | # at this level 19 | scale = None 20 | if isinstance(size, int): 21 | scale = float(size) / min(vid.shape[-2:]) 22 | size = None 23 | return torch.nn.functional.interpolate(vid, size=size, scale_factor=scale, mode=interpolation, align_corners=False) 24 | 25 | 26 | class ToFloatTensorInZeroOne(object): 27 | def __call__(self, vid): 28 | return to_normalized_float_tensor(vid) 29 | 30 | 31 | class Resize(object): 32 | def __init__(self, size): 33 | self.size = size 34 | 35 | def __call__(self, vid): 36 | return resize(vid, self.size) 37 | 38 | 39 | def preprocess_videomae(videos): 40 | transform = transforms.Compose([ToFloatTensorInZeroOne(), Resize((224, 224))]) 41 | return torch.stack([transform(f) for f in torch.from_numpy(videos)]) 42 | 43 | 44 | def load_videomae_model(device, ckpt_path=None, with_cp=False): 45 | if ckpt_path is None: 46 | current_dir = os.path.dirname(os.path.abspath(__file__)) 47 | ckpt_path = os.path.join(current_dir, "vit_g_hybrid_pt_1200e_ssv2_ft.pth") 48 | 49 | if not os.path.exists(ckpt_path): 50 | # download the ckpt to the path 51 | ckpt_url = "https://pjlab-gvm-data.oss-cn-shanghai.aliyuncs.com/internvideo/videomaev2/vit_g_hybrid_pt_1200e_ssv2_ft.pth" 52 | response = requests.get(ckpt_url, stream=True, allow_redirects=True) 53 | total_size = int(response.headers.get("content-length", 0)) 54 | block_size = 1024 55 | 56 | with tqdm(total=total_size, unit="B", unit_scale=True) as progress_bar: 57 | with open(ckpt_path, "wb") as fw: 58 | for data in response.iter_content(block_size): 59 | progress_bar.update(len(data)) 60 | fw.write(data) 61 | 62 | model = vit_giant_patch14_224( 63 | img_size=224, 64 | pretrained=False, 65 | num_classes=174, 66 | all_frames=16, 67 | tubelet_size=2, 68 | drop_path_rate=0.3, 69 | use_mean_pooling=True, 70 | with_cp=with_cp, 71 | ) 72 | 73 | ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True) 74 | for model_key in ["model", "module"]: 75 | if model_key in ckpt: 76 | ckpt = ckpt[model_key] 77 | break 78 | model.load_state_dict(ckpt) 79 | 80 | del ckpt 81 | torch.cuda.empty_cache() 82 | return model.to(device) 83 | -------------------------------------------------------------------------------- /latentsync/trepa/third_party/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/LatentSync/3c3a1a2b62c9d854e3e6ab5525373973380a74a5/latentsync/trepa/third_party/__init__.py -------------------------------------------------------------------------------- /latentsync/trepa/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/LatentSync/3c3a1a2b62c9d854e3e6ab5525373973380a74a5/latentsync/trepa/utils/__init__.py -------------------------------------------------------------------------------- /latentsync/trepa/utils/metric_utils.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/universome/stylegan-v/blob/master/src/metrics/metric_utils.py 2 | import os 3 | import random 4 | import torch 5 | import pickle 6 | import numpy as np 7 | 8 | from typing import List, Tuple 9 | 10 | def seed_everything(seed): 11 | random.seed(seed) 12 | os.environ['PYTHONHASHSEED'] = str(seed) 13 | np.random.seed(seed) 14 | torch.manual_seed(seed) 15 | torch.cuda.manual_seed(seed) 16 | 17 | 18 | class FeatureStats: 19 | ''' 20 | Class to store statistics of features, including all features and mean/covariance. 21 | 22 | Args: 23 | capture_all: Whether to store all the features. 24 | capture_mean_cov: Whether to store mean and covariance. 25 | max_items: Maximum number of items to store. 26 | ''' 27 | def __init__(self, capture_all: bool = False, capture_mean_cov: bool = False, max_items: int = None): 28 | ''' 29 | ''' 30 | self.capture_all = capture_all 31 | self.capture_mean_cov = capture_mean_cov 32 | self.max_items = max_items 33 | self.num_items = 0 34 | self.num_features = None 35 | self.all_features = None 36 | self.raw_mean = None 37 | self.raw_cov = None 38 | 39 | def set_num_features(self, num_features: int): 40 | ''' 41 | Set the number of features diminsions. 42 | 43 | Args: 44 | num_features: Number of features diminsions. 45 | ''' 46 | if self.num_features is not None: 47 | assert num_features == self.num_features 48 | else: 49 | self.num_features = num_features 50 | self.all_features = [] 51 | self.raw_mean = np.zeros([num_features], dtype=np.float64) 52 | self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64) 53 | 54 | def is_full(self) -> bool: 55 | ''' 56 | Check if the maximum number of samples is reached. 57 | 58 | Returns: 59 | True if the storage is full, False otherwise. 60 | ''' 61 | return (self.max_items is not None) and (self.num_items >= self.max_items) 62 | 63 | def append(self, x: np.ndarray): 64 | ''' 65 | Add the newly computed features to the list. Update the mean and covariance. 66 | 67 | Args: 68 | x: New features to record. 69 | ''' 70 | x = np.asarray(x, dtype=np.float32) 71 | assert x.ndim == 2 72 | if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items): 73 | if self.num_items >= self.max_items: 74 | return 75 | x = x[:self.max_items - self.num_items] 76 | 77 | self.set_num_features(x.shape[1]) 78 | self.num_items += x.shape[0] 79 | if self.capture_all: 80 | self.all_features.append(x) 81 | if self.capture_mean_cov: 82 | x64 = x.astype(np.float64) 83 | self.raw_mean += x64.sum(axis=0) 84 | self.raw_cov += x64.T @ x64 85 | 86 | def append_torch(self, x: torch.Tensor, rank: int, num_gpus: int): 87 | ''' 88 | Add the newly computed PyTorch features to the list. Update the mean and covariance. 89 | 90 | Args: 91 | x: New features to record. 92 | rank: Rank of the current GPU. 93 | num_gpus: Total number of GPUs. 94 | ''' 95 | assert isinstance(x, torch.Tensor) and x.ndim == 2 96 | assert 0 <= rank < num_gpus 97 | if num_gpus > 1: 98 | ys = [] 99 | for src in range(num_gpus): 100 | y = x.clone() 101 | torch.distributed.broadcast(y, src=src) 102 | ys.append(y) 103 | x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples 104 | self.append(x.cpu().numpy()) 105 | 106 | def get_all(self) -> np.ndarray: 107 | ''' 108 | Get all the stored features as NumPy Array. 109 | 110 | Returns: 111 | Concatenation of the stored features. 112 | ''' 113 | assert self.capture_all 114 | return np.concatenate(self.all_features, axis=0) 115 | 116 | def get_all_torch(self) -> torch.Tensor: 117 | ''' 118 | Get all the stored features as PyTorch Tensor. 119 | 120 | Returns: 121 | Concatenation of the stored features. 122 | ''' 123 | return torch.from_numpy(self.get_all()) 124 | 125 | def get_mean_cov(self) -> Tuple[np.ndarray, np.ndarray]: 126 | ''' 127 | Get the mean and covariance of the stored features. 128 | 129 | Returns: 130 | Mean and covariance of the stored features. 131 | ''' 132 | assert self.capture_mean_cov 133 | mean = self.raw_mean / self.num_items 134 | cov = self.raw_cov / self.num_items 135 | cov = cov - np.outer(mean, mean) 136 | return mean, cov 137 | 138 | def save(self, pkl_file: str): 139 | ''' 140 | Save the features and statistics to a pickle file. 141 | 142 | Args: 143 | pkl_file: Path to the pickle file. 144 | ''' 145 | with open(pkl_file, 'wb') as f: 146 | pickle.dump(self.__dict__, f) 147 | 148 | @staticmethod 149 | def load(pkl_file: str) -> 'FeatureStats': 150 | ''' 151 | Load the features and statistics from a pickle file. 152 | 153 | Args: 154 | pkl_file: Path to the pickle file. 155 | ''' 156 | with open(pkl_file, 'rb') as f: 157 | s = pickle.load(f) 158 | obj = FeatureStats(capture_all=s['capture_all'], max_items=s['max_items']) 159 | obj.__dict__.update(s) 160 | print('Loaded %d features from %s' % (obj.num_items, pkl_file)) 161 | return obj 162 | -------------------------------------------------------------------------------- /latentsync/utils/affine_transform.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/guanjz20/StyleSync/blob/main/utils.py 2 | 3 | import numpy as np 4 | import cv2 5 | import torch 6 | from einops import rearrange 7 | import kornia 8 | 9 | 10 | class AlignRestore(object): 11 | def __init__(self, align_points=3, resolution=256, device="cpu", dtype=torch.float16): 12 | if align_points == 3: 13 | self.upscale_factor = 1 14 | ratio = resolution / 256 * 2.8 15 | self.crop_ratio = (ratio, ratio) 16 | self.face_template = np.array([[19 - 2, 30 - 10], [56 + 2, 30 - 10], [37.5, 45 - 5]]) 17 | self.face_template = self.face_template * ratio 18 | self.face_size = (int(75 * self.crop_ratio[0]), int(100 * self.crop_ratio[1])) 19 | self.p_bias = None 20 | self.device = device 21 | self.dtype = dtype 22 | self.fill_value = torch.tensor([127, 127, 127], device=device, dtype=dtype) 23 | self.mask = torch.ones((1, 1, self.face_size[1], self.face_size[0]), device=device, dtype=dtype) 24 | 25 | def align_warp_face(self, img, landmarks3, smooth=True): 26 | affine_matrix, self.p_bias = self.transformation_from_points( 27 | landmarks3, self.face_template, smooth, self.p_bias 28 | ) 29 | 30 | img = rearrange(torch.from_numpy(img).to(device=self.device, dtype=self.dtype), "h w c -> c h w").unsqueeze(0) 31 | affine_matrix = torch.from_numpy(affine_matrix).to(device=self.device, dtype=self.dtype).unsqueeze(0) 32 | 33 | cropped_face = kornia.geometry.transform.warp_affine( 34 | img, 35 | affine_matrix, 36 | (self.face_size[1], self.face_size[0]), 37 | mode="bilinear", 38 | padding_mode="fill", 39 | fill_value=self.fill_value, 40 | ) 41 | cropped_face = rearrange(cropped_face.squeeze(0), "c h w -> h w c").cpu().numpy().astype(np.uint8) 42 | return cropped_face, affine_matrix 43 | 44 | def restore_img(self, input_img, face, affine_matrix): 45 | h, w, _ = input_img.shape 46 | 47 | if isinstance(affine_matrix, np.ndarray): 48 | affine_matrix = torch.from_numpy(affine_matrix).to(device=self.device, dtype=self.dtype).unsqueeze(0) 49 | 50 | inv_affine_matrix = kornia.geometry.transform.invert_affine_transform(affine_matrix) 51 | face = face.to(dtype=self.dtype).unsqueeze(0) 52 | 53 | inv_face = kornia.geometry.transform.warp_affine( 54 | face, inv_affine_matrix, (h, w), mode="bilinear", padding_mode="fill", fill_value=self.fill_value 55 | ).squeeze(0) 56 | inv_face = (inv_face / 2 + 0.5).clamp(0, 1) * 255 57 | 58 | input_img = rearrange(torch.from_numpy(input_img).to(device=self.device, dtype=self.dtype), "h w c -> c h w") 59 | inv_mask = kornia.geometry.transform.warp_affine( 60 | self.mask, inv_affine_matrix, (h, w), padding_mode="zeros" 61 | ) # (1, 1, h_up, w_up) 62 | 63 | inv_mask_erosion = kornia.morphology.erosion( 64 | inv_mask, 65 | torch.ones( 66 | (int(2 * self.upscale_factor), int(2 * self.upscale_factor)), device=self.device, dtype=self.dtype 67 | ), 68 | ) 69 | 70 | inv_mask_erosion_t = inv_mask_erosion.squeeze(0).expand_as(inv_face) 71 | pasted_face = inv_mask_erosion_t * inv_face 72 | total_face_area = torch.sum(inv_mask_erosion.float()) 73 | w_edge = int(total_face_area**0.5) // 20 74 | erosion_radius = w_edge * 2 75 | 76 | # This step will consume a large amount of GPU memory. 77 | # inv_mask_center = kornia.morphology.erosion( 78 | # inv_mask_erosion, torch.ones((erosion_radius, erosion_radius), device=self.device, dtype=self.dtype) 79 | # ) 80 | 81 | # Run on CPU to avoid consuming a large amount of GPU memory. 82 | inv_mask_erosion = inv_mask_erosion.squeeze().cpu().numpy().astype(np.float32) 83 | inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8)) 84 | inv_mask_center = torch.from_numpy(inv_mask_center).to(device=self.device, dtype=self.dtype)[None, None, ...] 85 | 86 | blur_size = w_edge * 2 + 1 87 | sigma = 0.3 * ((blur_size - 1) * 0.5 - 1) + 0.8 88 | inv_soft_mask = kornia.filters.gaussian_blur2d( 89 | inv_mask_center, (blur_size, blur_size), (sigma, sigma) 90 | ).squeeze(0) 91 | inv_soft_mask_3d = inv_soft_mask.expand_as(inv_face) 92 | img_back = inv_soft_mask_3d * pasted_face + (1 - inv_soft_mask_3d) * input_img 93 | 94 | img_back = rearrange(img_back, "c h w -> h w c").contiguous().to(dtype=torch.uint8) 95 | img_back = img_back.cpu().numpy() 96 | return img_back 97 | 98 | def transformation_from_points(self, points1: torch.Tensor, points0: torch.Tensor, smooth=True, p_bias=None): 99 | if isinstance(points0, np.ndarray): 100 | points2 = torch.tensor(points0, device=self.device, dtype=torch.float32) 101 | else: 102 | points2 = points0.clone() 103 | 104 | if isinstance(points1, np.ndarray): 105 | points1_tensor = torch.tensor(points1, device=self.device, dtype=torch.float32) 106 | else: 107 | points1_tensor = points1.clone() 108 | 109 | c1 = torch.mean(points1_tensor, dim=0) 110 | c2 = torch.mean(points2, dim=0) 111 | 112 | points1_centered = points1_tensor - c1 113 | points2_centered = points2 - c2 114 | 115 | s1 = torch.std(points1_centered) 116 | s2 = torch.std(points2_centered) 117 | 118 | points1_normalized = points1_centered / s1 119 | points2_normalized = points2_centered / s2 120 | 121 | covariance = torch.matmul(points1_normalized.T, points2_normalized) 122 | U, S, V = torch.svd(covariance) 123 | 124 | R = torch.matmul(V, U.T) 125 | 126 | det = torch.det(R) 127 | if det < 0: 128 | V[:, -1] = -V[:, -1] 129 | R = torch.matmul(V, U.T) 130 | 131 | sR = (s2 / s1) * R 132 | T = c2.reshape(2, 1) - (s2 / s1) * torch.matmul(R, c1.reshape(2, 1)) 133 | 134 | M = torch.cat((sR, T), dim=1) 135 | 136 | if smooth: 137 | bias = points2_normalized[2] - points1_normalized[2] 138 | if p_bias is None: 139 | p_bias = bias 140 | else: 141 | bias = p_bias * 0.2 + bias * 0.8 142 | p_bias = bias 143 | M[:, 2] = M[:, 2] + bias 144 | 145 | return M.cpu().numpy(), p_bias 146 | -------------------------------------------------------------------------------- /latentsync/utils/audio.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/Rudrabha/Wav2Lip/blob/master/audio.py 2 | 3 | import librosa 4 | import librosa.filters 5 | import numpy as np 6 | from scipy import signal 7 | from scipy.io import wavfile 8 | from omegaconf import OmegaConf 9 | import torch 10 | 11 | audio_config_path = "configs/audio.yaml" 12 | 13 | config = OmegaConf.load(audio_config_path) 14 | 15 | 16 | def load_wav(path, sr): 17 | return librosa.core.load(path, sr=sr)[0] 18 | 19 | 20 | def save_wav(wav, path, sr): 21 | wav *= 32767 / max(0.01, np.max(np.abs(wav))) 22 | # proposed by @dsmiller 23 | wavfile.write(path, sr, wav.astype(np.int16)) 24 | 25 | 26 | def save_wavenet_wav(wav, path, sr): 27 | librosa.output.write_wav(path, wav, sr=sr) 28 | 29 | 30 | def preemphasis(wav, k, preemphasize=True): 31 | if preemphasize: 32 | return signal.lfilter([1, -k], [1], wav) 33 | return wav 34 | 35 | 36 | def inv_preemphasis(wav, k, inv_preemphasize=True): 37 | if inv_preemphasize: 38 | return signal.lfilter([1], [1, -k], wav) 39 | return wav 40 | 41 | 42 | def get_hop_size(): 43 | hop_size = config.audio.hop_size 44 | if hop_size is None: 45 | assert config.audio.frame_shift_ms is not None 46 | hop_size = int(config.audio.frame_shift_ms / 1000 * config.audio.sample_rate) 47 | return hop_size 48 | 49 | 50 | def linearspectrogram(wav): 51 | D = _stft(preemphasis(wav, config.audio.preemphasis, config.audio.preemphasize)) 52 | S = _amp_to_db(np.abs(D)) - config.audio.ref_level_db 53 | 54 | if config.audio.signal_normalization: 55 | return _normalize(S) 56 | return S 57 | 58 | 59 | def melspectrogram(wav): 60 | D = _stft(preemphasis(wav, config.audio.preemphasis, config.audio.preemphasize)) 61 | S = _amp_to_db(_linear_to_mel(np.abs(D))) - config.audio.ref_level_db 62 | 63 | if config.audio.signal_normalization: 64 | return _normalize(S) 65 | return S 66 | 67 | 68 | def _lws_processor(): 69 | import lws 70 | 71 | return lws.lws(config.audio.n_fft, get_hop_size(), fftsize=config.audio.win_size, mode="speech") 72 | 73 | 74 | def _stft(y): 75 | if config.audio.use_lws: 76 | return _lws_processor(config.audio).stft(y).T 77 | else: 78 | return librosa.stft(y=y, n_fft=config.audio.n_fft, hop_length=get_hop_size(), win_length=config.audio.win_size) 79 | 80 | 81 | ########################################################## 82 | # Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!) 83 | def num_frames(length, fsize, fshift): 84 | """Compute number of time frames of spectrogram""" 85 | pad = fsize - fshift 86 | if length % fshift == 0: 87 | M = (length + pad * 2 - fsize) // fshift + 1 88 | else: 89 | M = (length + pad * 2 - fsize) // fshift + 2 90 | return M 91 | 92 | 93 | def pad_lr(x, fsize, fshift): 94 | """Compute left and right padding""" 95 | M = num_frames(len(x), fsize, fshift) 96 | pad = fsize - fshift 97 | T = len(x) + 2 * pad 98 | r = (M - 1) * fshift + fsize - T 99 | return pad, pad + r 100 | 101 | 102 | ########################################################## 103 | # Librosa correct padding 104 | def librosa_pad_lr(x, fsize, fshift): 105 | return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0] 106 | 107 | 108 | # Conversions 109 | _mel_basis = None 110 | 111 | 112 | def _linear_to_mel(spectogram): 113 | global _mel_basis 114 | if _mel_basis is None: 115 | _mel_basis = _build_mel_basis() 116 | return np.dot(_mel_basis, spectogram) 117 | 118 | 119 | def _build_mel_basis(): 120 | assert config.audio.fmax <= config.audio.sample_rate // 2 121 | return librosa.filters.mel( 122 | sr=config.audio.sample_rate, 123 | n_fft=config.audio.n_fft, 124 | n_mels=config.audio.num_mels, 125 | fmin=config.audio.fmin, 126 | fmax=config.audio.fmax, 127 | ) 128 | 129 | 130 | def _amp_to_db(x): 131 | min_level = np.exp(config.audio.min_level_db / 20 * np.log(10)) 132 | return 20 * np.log10(np.maximum(min_level, x)) 133 | 134 | 135 | def _db_to_amp(x): 136 | return np.power(10.0, (x) * 0.05) 137 | 138 | 139 | def _normalize(S): 140 | if config.audio.allow_clipping_in_normalization: 141 | if config.audio.symmetric_mels: 142 | return np.clip( 143 | (2 * config.audio.max_abs_value) * ((S - config.audio.min_level_db) / (-config.audio.min_level_db)) 144 | - config.audio.max_abs_value, 145 | -config.audio.max_abs_value, 146 | config.audio.max_abs_value, 147 | ) 148 | else: 149 | return np.clip( 150 | config.audio.max_abs_value * ((S - config.audio.min_level_db) / (-config.audio.min_level_db)), 151 | 0, 152 | config.audio.max_abs_value, 153 | ) 154 | 155 | assert S.max() <= 0 and S.min() - config.audio.min_level_db >= 0 156 | if config.audio.symmetric_mels: 157 | return (2 * config.audio.max_abs_value) * ( 158 | (S - config.audio.min_level_db) / (-config.audio.min_level_db) 159 | ) - config.audio.max_abs_value 160 | else: 161 | return config.audio.max_abs_value * ((S - config.audio.min_level_db) / (-config.audio.min_level_db)) 162 | 163 | 164 | def _denormalize(D): 165 | if config.audio.allow_clipping_in_normalization: 166 | if config.audio.symmetric_mels: 167 | return ( 168 | (np.clip(D, -config.audio.max_abs_value, config.audio.max_abs_value) + config.audio.max_abs_value) 169 | * -config.audio.min_level_db 170 | / (2 * config.audio.max_abs_value) 171 | ) + config.audio.min_level_db 172 | else: 173 | return ( 174 | np.clip(D, 0, config.audio.max_abs_value) * -config.audio.min_level_db / config.audio.max_abs_value 175 | ) + config.audio.min_level_db 176 | 177 | if config.audio.symmetric_mels: 178 | return ( 179 | (D + config.audio.max_abs_value) * -config.audio.min_level_db / (2 * config.audio.max_abs_value) 180 | ) + config.audio.min_level_db 181 | else: 182 | return (D * -config.audio.min_level_db / config.audio.max_abs_value) + config.audio.min_level_db 183 | 184 | 185 | def get_melspec_overlap(audio_samples, melspec_length=52): 186 | mel_spec_overlap = melspectrogram(audio_samples.numpy()) 187 | mel_spec_overlap = torch.from_numpy(mel_spec_overlap) 188 | i = 0 189 | mel_spec_overlap_list = [] 190 | while i + melspec_length < mel_spec_overlap.shape[1] - 3: 191 | mel_spec_overlap_list.append(mel_spec_overlap[:, i : i + melspec_length].unsqueeze(0)) 192 | i += 3 193 | mel_spec_overlap = torch.stack(mel_spec_overlap_list) 194 | return mel_spec_overlap 195 | -------------------------------------------------------------------------------- /latentsync/utils/av_reader.py: -------------------------------------------------------------------------------- 1 | # We modified the original AVReader class of decord to solve the problem of memory leak. 2 | # For more details, refer to: https://github.com/dmlc/decord/issues/208 3 | 4 | import numpy as np 5 | from decord.video_reader import VideoReader 6 | from decord.audio_reader import AudioReader 7 | 8 | from decord.ndarray import cpu 9 | from decord import ndarray as _nd 10 | from decord.bridge import bridge_out 11 | 12 | 13 | class AVReader(object): 14 | """Individual audio video reader with convenient indexing function. 15 | 16 | Parameters 17 | ---------- 18 | uri: str 19 | Path of file. 20 | ctx: decord.Context 21 | The context to decode the file, can be decord.cpu() or decord.gpu(). 22 | sample_rate: int, default is -1 23 | Desired output sample rate of the audio, unchanged if `-1` is specified. 24 | mono: bool, default is True 25 | Desired output channel layout of the audio. `True` is mono layout. `False` is unchanged. 26 | width : int, default is -1 27 | Desired output width of the video, unchanged if `-1` is specified. 28 | height : int, default is -1 29 | Desired output height of the video, unchanged if `-1` is specified. 30 | num_threads : int, default is 0 31 | Number of decoding thread, auto if `0` is specified. 32 | fault_tol : int, default is -1 33 | The threshold of corupted and recovered frames. This is to prevent silent fault 34 | tolerance when for example 50% frames of a video cannot be decoded and duplicate 35 | frames are returned. You may find the fault tolerant feature sweet in many cases, 36 | but not for training models. Say `N = # recovered frames` 37 | If `fault_tol` < 0, nothing will happen. 38 | If 0 < `fault_tol` < 1.0, if N > `fault_tol * len(video)`, raise `DECORDLimitReachedError`. 39 | If 1 < `fault_tol`, if N > `fault_tol`, raise `DECORDLimitReachedError`. 40 | """ 41 | 42 | def __init__( 43 | self, uri, ctx=cpu(0), sample_rate=44100, mono=True, width=-1, height=-1, num_threads=0, fault_tol=-1 44 | ): 45 | self.__audio_reader = AudioReader(uri, ctx, sample_rate, mono) 46 | self.__audio_reader.add_padding() 47 | if hasattr(uri, "read"): 48 | uri.seek(0) 49 | self.__video_reader = VideoReader(uri, ctx, width, height, num_threads, fault_tol) 50 | self.__video_reader.seek(0) 51 | 52 | def __len__(self): 53 | """Get length of the video. Note that sometimes FFMPEG reports inaccurate number of frames, 54 | we always follow what FFMPEG reports. 55 | Returns 56 | ------- 57 | int 58 | The number of frames in the video file. 59 | """ 60 | return len(self.__video_reader) 61 | 62 | def __getitem__(self, idx): 63 | """Get audio samples and video frame at `idx`. 64 | 65 | Parameters 66 | ---------- 67 | idx : int or slice 68 | The frame index, can be negative which means it will index backwards, 69 | or slice of frame indices. 70 | 71 | Returns 72 | ------- 73 | (ndarray/list of ndarray, ndarray) 74 | First element is samples of shape CxS or a list of length N containing samples of shape CxS, 75 | where N is the number of frames, C is the number of channels, 76 | S is the number of samples of the corresponding frame. 77 | 78 | Second element is Frame of shape HxWx3 or batch of image frames with shape NxHxWx3, 79 | where N is the length of the slice. 80 | """ 81 | assert self.__video_reader is not None and self.__audio_reader is not None 82 | if isinstance(idx, slice): 83 | return self.get_batch(range(*idx.indices(len(self.__video_reader)))) 84 | if idx < 0: 85 | idx += len(self.__video_reader) 86 | if idx >= len(self.__video_reader) or idx < 0: 87 | raise IndexError("Index: {} out of bound: {}".format(idx, len(self.__video_reader))) 88 | audio_start_idx, audio_end_idx = self.__video_reader.get_frame_timestamp(idx) 89 | audio_start_idx = self.__audio_reader._time_to_sample(audio_start_idx) 90 | audio_end_idx = self.__audio_reader._time_to_sample(audio_end_idx) 91 | results = (self.__audio_reader[audio_start_idx:audio_end_idx], self.__video_reader[idx]) 92 | self.__video_reader.seek(0) 93 | return results 94 | 95 | def get_batch(self, indices): 96 | """Get entire batch of audio samples and video frames. 97 | 98 | Parameters 99 | ---------- 100 | indices : list of integers 101 | A list of frame indices. If negative indices detected, the indices will be indexed from backward 102 | Returns 103 | ------- 104 | (list of ndarray, ndarray) 105 | First element is a list of length N containing samples of shape CxS, 106 | where N is the number of frames, C is the number of channels, 107 | S is the number of samples of the corresponding frame. 108 | 109 | Second element is Frame of shape HxWx3 or batch of image frames with shape NxHxWx3, 110 | where N is the length of the slice. 111 | 112 | """ 113 | assert self.__video_reader is not None and self.__audio_reader is not None 114 | indices = self._validate_indices(indices) 115 | audio_arr = [] 116 | prev_video_idx = None 117 | prev_audio_end_idx = None 118 | for idx in list(indices): 119 | frame_start_time, frame_end_time = self.__video_reader.get_frame_timestamp(idx) 120 | # timestamp and sample conversion could have some error that could cause non-continuous audio 121 | # we detect if retrieving continuous frame and make the audio continuous 122 | if prev_video_idx and idx == prev_video_idx + 1: 123 | audio_start_idx = prev_audio_end_idx 124 | else: 125 | audio_start_idx = self.__audio_reader._time_to_sample(frame_start_time) 126 | audio_end_idx = self.__audio_reader._time_to_sample(frame_end_time) 127 | audio_arr.append(self.__audio_reader[audio_start_idx:audio_end_idx]) 128 | prev_video_idx = idx 129 | prev_audio_end_idx = audio_end_idx 130 | results = (audio_arr, self.__video_reader.get_batch(indices)) 131 | self.__video_reader.seek(0) 132 | return results 133 | 134 | def _get_slice(self, sl): 135 | audio_arr = np.empty(shape=(self.__audio_reader.shape()[0], 0), dtype="float32") 136 | for idx in list(sl): 137 | audio_start_idx, audio_end_idx = self.__video_reader.get_frame_timestamp(idx) 138 | audio_start_idx = self.__audio_reader._time_to_sample(audio_start_idx) 139 | audio_end_idx = self.__audio_reader._time_to_sample(audio_end_idx) 140 | audio_arr = np.concatenate( 141 | (audio_arr, self.__audio_reader[audio_start_idx:audio_end_idx].asnumpy()), axis=1 142 | ) 143 | results = (bridge_out(_nd.array(audio_arr)), self.__video_reader.get_batch(sl)) 144 | self.__video_reader.seek(0) 145 | return results 146 | 147 | def _validate_indices(self, indices): 148 | """Validate int64 integers and convert negative integers to positive by backward search""" 149 | assert self.__video_reader is not None and self.__audio_reader is not None 150 | indices = np.array(indices, dtype=np.int64) 151 | # process negative indices 152 | indices[indices < 0] += len(self.__video_reader) 153 | if not (indices >= 0).all(): 154 | raise IndexError("Invalid negative indices: {}".format(indices[indices < 0] + len(self.__video_reader))) 155 | if not (indices < len(self.__video_reader)).all(): 156 | raise IndexError("Out of bound indices: {}".format(indices[indices >= len(self.__video_reader)])) 157 | return indices 158 | -------------------------------------------------------------------------------- /latentsync/utils/face_detector.py: -------------------------------------------------------------------------------- 1 | from insightface.app import FaceAnalysis 2 | import numpy as np 3 | import torch 4 | 5 | INSIGHTFACE_DETECT_SIZE = 512 6 | 7 | 8 | class FaceDetector: 9 | def __init__(self, device="cuda"): 10 | self.app = FaceAnalysis( 11 | allowed_modules=["detection", "landmark_2d_106"], 12 | root="checkpoints/auxiliary", 13 | providers=["CUDAExecutionProvider"], 14 | ) 15 | self.app.prepare(ctx_id=cuda_to_int(device), det_size=(INSIGHTFACE_DETECT_SIZE, INSIGHTFACE_DETECT_SIZE)) 16 | 17 | def __call__(self, frame, threshold=0.5): 18 | f_h, f_w, _ = frame.shape 19 | 20 | faces = self.app.get(frame) 21 | 22 | get_face_store = None 23 | max_size = 0 24 | 25 | if len(faces) == 0: 26 | return None, None 27 | else: 28 | for face in faces: 29 | bbox = face.bbox.astype(np.int_).tolist() 30 | w, h = bbox[2] - bbox[0], bbox[3] - bbox[1] 31 | if w < 50 or h < 80: 32 | continue 33 | if w / h > 1.5 or w / h < 0.2: 34 | continue 35 | if face.det_score < threshold: 36 | continue 37 | size_now = w * h 38 | 39 | if size_now > max_size: 40 | max_size = size_now 41 | get_face_store = face 42 | 43 | if get_face_store is None: 44 | return None, None 45 | else: 46 | face = get_face_store 47 | lmk = np.round(face.landmark_2d_106).astype(np.int_) 48 | 49 | halk_face_coord = np.mean([lmk[74], lmk[73]], axis=0) # lmk[73] 50 | 51 | sub_lmk = lmk[LMK_ADAPT_ORIGIN_ORDER] 52 | halk_face_dist = np.max(sub_lmk[:, 1]) - halk_face_coord[1] 53 | upper_bond = halk_face_coord[1] - halk_face_dist # *0.94 54 | 55 | x1, y1, x2, y2 = (np.min(sub_lmk[:, 0]), int(upper_bond), np.max(sub_lmk[:, 0]), np.max(sub_lmk[:, 1])) 56 | 57 | if y2 - y1 <= 0 or x2 - x1 <= 0 or x1 < 0: 58 | x1, y1, x2, y2 = face.bbox.astype(np.int_).tolist() 59 | 60 | y2 += int((x2 - x1) * 0.1) 61 | x1 -= int((x2 - x1) * 0.05) 62 | x2 += int((x2 - x1) * 0.05) 63 | 64 | x1 = max(0, x1) 65 | y1 = max(0, y1) 66 | x2 = min(f_w, x2) 67 | y2 = min(f_h, y2) 68 | 69 | return (x1, y1, x2, y2), lmk 70 | 71 | 72 | def cuda_to_int(cuda_str: str) -> int: 73 | """ 74 | Convert the string with format "cuda:X" to integer X. 75 | """ 76 | if cuda_str == "cuda": 77 | return 0 78 | device = torch.device(cuda_str) 79 | if device.type != "cuda": 80 | raise ValueError(f"Device type must be 'cuda', got: {device.type}") 81 | return device.index 82 | 83 | 84 | LMK_ADAPT_ORIGIN_ORDER = [ 85 | 1, 86 | 10, 87 | 12, 88 | 14, 89 | 16, 90 | 3, 91 | 5, 92 | 7, 93 | 0, 94 | 23, 95 | 21, 96 | 19, 97 | 32, 98 | 30, 99 | 28, 100 | 26, 101 | 17, 102 | 43, 103 | 48, 104 | 49, 105 | 51, 106 | 50, 107 | 102, 108 | 103, 109 | 104, 110 | 105, 111 | 101, 112 | 73, 113 | 74, 114 | 86, 115 | ] 116 | -------------------------------------------------------------------------------- /latentsync/utils/image_processor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from latentsync.utils.util import read_video, write_video 16 | from torchvision import transforms 17 | import cv2 18 | from einops import rearrange 19 | import torch 20 | import numpy as np 21 | from typing import Union 22 | from .affine_transform import AlignRestore 23 | from .face_detector import FaceDetector 24 | 25 | 26 | def load_fixed_mask(resolution: int, mask_image_path="latentsync/utils/mask.png") -> torch.Tensor: 27 | mask_image = cv2.imread(mask_image_path) 28 | mask_image = cv2.cvtColor(mask_image, cv2.COLOR_BGR2RGB) 29 | mask_image = cv2.resize(mask_image, (resolution, resolution), interpolation=cv2.INTER_LANCZOS4) / 255.0 30 | mask_image = rearrange(torch.from_numpy(mask_image), "h w c -> c h w") 31 | return mask_image 32 | 33 | 34 | class ImageProcessor: 35 | def __init__(self, resolution: int = 512, device: str = "cpu", mask_image=None): 36 | self.resolution = resolution 37 | self.resize = transforms.Resize( 38 | (resolution, resolution), interpolation=transforms.InterpolationMode.BICUBIC, antialias=True 39 | ) 40 | self.normalize = transforms.Normalize([0.5], [0.5], inplace=True) 41 | 42 | self.restorer = AlignRestore(resolution=resolution, device=device) 43 | 44 | if mask_image is None: 45 | self.mask_image = load_fixed_mask(resolution) 46 | else: 47 | self.mask_image = mask_image 48 | 49 | if device == "cpu": 50 | self.face_detector = None 51 | else: 52 | self.face_detector = FaceDetector(device=device) 53 | 54 | def affine_transform(self, image: torch.Tensor) -> np.ndarray: 55 | if self.face_detector is None: 56 | raise NotImplementedError("Using the CPU for face detection is not supported") 57 | bbox, landmark_2d_106 = self.face_detector(image) 58 | if bbox is None: 59 | raise RuntimeError("Face not detected") 60 | 61 | pt_left_eye = np.mean(landmark_2d_106[[43, 48, 49, 51, 50]], axis=0) # left eyebrow center 62 | pt_right_eye = np.mean(landmark_2d_106[101:106], axis=0) # right eyebrow center 63 | pt_nose = np.mean(landmark_2d_106[[74, 77, 83, 86]], axis=0) # nose center 64 | 65 | landmarks3 = np.round([pt_left_eye, pt_right_eye, pt_nose]) 66 | 67 | face, affine_matrix = self.restorer.align_warp_face(image.copy(), landmarks3=landmarks3, smooth=True) 68 | box = [0, 0, face.shape[1], face.shape[0]] # x1, y1, x2, y2 69 | face = cv2.resize(face, (self.resolution, self.resolution), interpolation=cv2.INTER_LANCZOS4) 70 | face = rearrange(torch.from_numpy(face), "h w c -> c h w") 71 | return face, box, affine_matrix 72 | 73 | def preprocess_fixed_mask_image(self, image: torch.Tensor, affine_transform=False): 74 | if affine_transform: 75 | image, _, _ = self.affine_transform(image) 76 | else: 77 | image = self.resize(image) 78 | pixel_values = self.normalize(image / 255.0) 79 | masked_pixel_values = pixel_values * self.mask_image 80 | return pixel_values, masked_pixel_values, self.mask_image[0:1] 81 | 82 | def prepare_masks_and_masked_images(self, images: Union[torch.Tensor, np.ndarray], affine_transform=False): 83 | if isinstance(images, np.ndarray): 84 | images = torch.from_numpy(images) 85 | if images.shape[3] == 3: 86 | images = rearrange(images, "f h w c -> f c h w") 87 | 88 | results = [self.preprocess_fixed_mask_image(image, affine_transform=affine_transform) for image in images] 89 | 90 | pixel_values_list, masked_pixel_values_list, masks_list = list(zip(*results)) 91 | return torch.stack(pixel_values_list), torch.stack(masked_pixel_values_list), torch.stack(masks_list) 92 | 93 | def process_images(self, images: Union[torch.Tensor, np.ndarray]): 94 | if isinstance(images, np.ndarray): 95 | images = torch.from_numpy(images) 96 | if images.shape[3] == 3: 97 | images = rearrange(images, "f h w c -> f c h w") 98 | images = self.resize(images) 99 | pixel_values = self.normalize(images / 255.0) 100 | return pixel_values 101 | 102 | 103 | class VideoProcessor: 104 | def __init__(self, resolution: int = 512, device: str = "cpu"): 105 | self.image_processor = ImageProcessor(resolution, device) 106 | 107 | def affine_transform_video(self, video_path): 108 | video_frames = read_video(video_path, change_fps=False) 109 | results = [] 110 | for frame in video_frames: 111 | frame, _, _ = self.image_processor.affine_transform(frame) 112 | results.append(frame) 113 | results = torch.stack(results) 114 | 115 | results = rearrange(results, "f c h w -> f h w c").numpy() 116 | return results 117 | 118 | 119 | if __name__ == "__main__": 120 | video_processor = VideoProcessor(256, "cuda") 121 | video_frames = video_processor.affine_transform_video("assets/demo2_video.mp4") 122 | write_video("output.mp4", video_frames, fps=25) 123 | -------------------------------------------------------------------------------- /latentsync/utils/mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/LatentSync/3c3a1a2b62c9d854e3e6ab5525373973380a74a5/latentsync/utils/mask.png -------------------------------------------------------------------------------- /latentsync/utils/mask2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/LatentSync/3c3a1a2b62c9d854e3e6ab5525373973380a74a5/latentsync/utils/mask2.png -------------------------------------------------------------------------------- /latentsync/utils/mask3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/LatentSync/3c3a1a2b62c9d854e3e6ab5525373973380a74a5/latentsync/utils/mask3.png -------------------------------------------------------------------------------- /latentsync/utils/mask4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/LatentSync/3c3a1a2b62c9d854e3e6ab5525373973380a74a5/latentsync/utils/mask4.png -------------------------------------------------------------------------------- /latentsync/whisper/audio2feature.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/TMElyralab/MuseTalk/blob/main/musetalk/whisper/audio2feature.py 2 | 3 | from .whisper import load_model 4 | import numpy as np 5 | import torch 6 | import os 7 | from pathlib import Path 8 | 9 | 10 | class Audio2Feature: 11 | def __init__( 12 | self, 13 | model_path="checkpoints/whisper/tiny.pt", 14 | device=None, 15 | audio_embeds_cache_dir=None, 16 | num_frames=16, 17 | audio_feat_length=[2, 2], 18 | ): 19 | self.model = load_model(model_path, device) 20 | self.audio_embeds_cache_dir = audio_embeds_cache_dir 21 | if audio_embeds_cache_dir is not None and audio_embeds_cache_dir != "": 22 | Path(audio_embeds_cache_dir).mkdir(parents=True, exist_ok=True) 23 | self.num_frames = num_frames 24 | self.embedding_dim = self.model.dims.n_audio_state 25 | self.audio_feat_length = audio_feat_length 26 | 27 | def get_sliced_feature(self, feature_array, vid_idx, fps=25): 28 | """ 29 | Get sliced features based on a given index 30 | :param feature_array: 31 | :param start_idx: the start index of the feature 32 | :param audio_feat_length: 33 | :return: 34 | """ 35 | length = len(feature_array) 36 | selected_feature = [] 37 | selected_idx = [] 38 | 39 | center_idx = int(vid_idx * 50 / fps) 40 | left_idx = center_idx - self.audio_feat_length[0] * 2 41 | right_idx = center_idx + (self.audio_feat_length[1] + 1) * 2 42 | 43 | for idx in range(left_idx, right_idx): 44 | idx = max(0, idx) 45 | idx = min(length - 1, idx) 46 | x = feature_array[idx] 47 | selected_feature.append(x) 48 | selected_idx.append(idx) 49 | 50 | selected_feature = torch.cat(selected_feature, dim=0) 51 | selected_feature = selected_feature.reshape(-1, self.embedding_dim) # 50*384 52 | return selected_feature, selected_idx 53 | 54 | def get_sliced_feature_sparse(self, feature_array, vid_idx, fps=25): 55 | """ 56 | Get sliced features based on a given index 57 | :param feature_array: 58 | :param start_idx: the start index of the feature 59 | :param audio_feat_length: 60 | :return: 61 | """ 62 | length = len(feature_array) 63 | selected_feature = [] 64 | selected_idx = [] 65 | 66 | for dt in range(-self.audio_feat_length[0], self.audio_feat_length[1] + 1): 67 | left_idx = int((vid_idx + dt) * 50 / fps) 68 | if left_idx < 1 or left_idx > length - 1: 69 | left_idx = max(0, left_idx) 70 | left_idx = min(length - 1, left_idx) 71 | 72 | x = feature_array[left_idx] 73 | x = x[np.newaxis, :, :] 74 | x = np.repeat(x, 2, axis=0) 75 | selected_feature.append(x) 76 | selected_idx.append(left_idx) 77 | selected_idx.append(left_idx) 78 | else: 79 | x = feature_array[left_idx - 1 : left_idx + 1] 80 | selected_feature.append(x) 81 | selected_idx.append(left_idx - 1) 82 | selected_idx.append(left_idx) 83 | selected_feature = np.concatenate(selected_feature, axis=0) 84 | selected_feature = selected_feature.reshape(-1, self.embedding_dim) # 50*384 85 | selected_feature = torch.from_numpy(selected_feature) 86 | return selected_feature, selected_idx 87 | 88 | def feature2chunks(self, feature_array, fps): 89 | whisper_chunks = [] 90 | whisper_idx_multiplier = 50.0 / fps 91 | i = 0 92 | print(f"video in {fps} FPS, audio idx in 50FPS") 93 | 94 | while True: 95 | start_idx = int(i * whisper_idx_multiplier) 96 | selected_feature, selected_idx = self.get_sliced_feature(feature_array=feature_array, vid_idx=i, fps=fps) 97 | # print(f"i:{i},selected_idx {selected_idx}") 98 | whisper_chunks.append(selected_feature) 99 | i += 1 100 | if start_idx > len(feature_array): 101 | break 102 | 103 | return whisper_chunks 104 | 105 | def _audio2feat(self, audio_path: str): 106 | # get the sample rate of the audio 107 | result = self.model.transcribe(audio_path) 108 | embed_list = [] 109 | for emb in result["segments"]: 110 | encoder_embeddings = emb["encoder_embeddings"] 111 | encoder_embeddings = encoder_embeddings.transpose(0, 2, 1, 3) 112 | encoder_embeddings = encoder_embeddings.squeeze(0) 113 | start_idx = int(emb["start"]) 114 | end_idx = int(emb["end"]) 115 | emb_end_idx = int((end_idx - start_idx) / 2) 116 | embed_list.append(encoder_embeddings[:emb_end_idx]) 117 | concatenated_array = torch.from_numpy(np.concatenate(embed_list, axis=0)) 118 | return concatenated_array 119 | 120 | def audio2feat(self, audio_path): 121 | if self.audio_embeds_cache_dir == "" or self.audio_embeds_cache_dir is None: 122 | return self._audio2feat(audio_path) 123 | 124 | audio_embeds_cache_path = os.path.join( 125 | self.audio_embeds_cache_dir, os.path.basename(audio_path).replace(".mp4", "_embeds.pt") 126 | ) 127 | 128 | if os.path.isfile(audio_embeds_cache_path): 129 | try: 130 | audio_feat = torch.load(audio_embeds_cache_path, weights_only=True) 131 | except Exception as e: 132 | print(f"{type(e).__name__} - {e} - {audio_embeds_cache_path}") 133 | os.remove(audio_embeds_cache_path) 134 | audio_feat = self._audio2feat(audio_path) 135 | torch.save(audio_feat, audio_embeds_cache_path) 136 | else: 137 | audio_feat = self._audio2feat(audio_path) 138 | torch.save(audio_feat, audio_embeds_cache_path) 139 | 140 | return audio_feat 141 | 142 | def crop_overlap_audio_window(self, audio_feat, start_index): 143 | selected_feature_list = [] 144 | for i in range(start_index, start_index + self.num_frames): 145 | selected_feature, selected_idx = self.get_sliced_feature(feature_array=audio_feat, vid_idx=i, fps=25) 146 | selected_feature_list.append(selected_feature) 147 | mel_overlap = torch.stack(selected_feature_list) 148 | return mel_overlap 149 | 150 | 151 | if __name__ == "__main__": 152 | audio_encoder = Audio2Feature(model_path="checkpoints/whisper/tiny.pt") 153 | audio_path = "assets/demo1_audio.wav" 154 | array = audio_encoder.audio2feat(audio_path) 155 | print(array.shape) 156 | fps = 25 157 | whisper_idx_multiplier = 50.0 / fps 158 | 159 | i = 0 160 | print(f"video in {fps} FPS, audio idx in 50FPS") 161 | while True: 162 | start_idx = int(i * whisper_idx_multiplier) 163 | selected_feature, selected_idx = audio_encoder.get_sliced_feature(feature_array=array, vid_idx=i, fps=fps) 164 | print(f"video idx {i},\t audio idx {selected_idx},\t shape {selected_feature.shape}") 165 | i += 1 166 | if start_idx > len(array): 167 | break 168 | -------------------------------------------------------------------------------- /latentsync/whisper/whisper/__init__.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import io 3 | import os 4 | import urllib 5 | import warnings 6 | from typing import List, Optional, Union 7 | 8 | import torch 9 | from tqdm import tqdm 10 | 11 | from .audio import load_audio, log_mel_spectrogram, pad_or_trim 12 | from .decoding import DecodingOptions, DecodingResult, decode, detect_language 13 | from .model import Whisper, ModelDimensions 14 | from .transcribe import transcribe 15 | 16 | 17 | _MODELS = { 18 | "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", 19 | "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt", 20 | "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt", 21 | "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt", 22 | "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt", 23 | "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt", 24 | "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt", 25 | "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt", 26 | "large": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large.pt", 27 | "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt", 28 | "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", 29 | "large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt", 30 | } 31 | 32 | 33 | def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: 34 | os.makedirs(root, exist_ok=True) 35 | 36 | expected_sha256 = url.split("/")[-2] 37 | download_target = os.path.join(root, os.path.basename(url)) 38 | 39 | if os.path.exists(download_target) and not os.path.isfile(download_target): 40 | raise RuntimeError(f"{download_target} exists and is not a regular file") 41 | 42 | if os.path.isfile(download_target): 43 | model_bytes = open(download_target, "rb").read() 44 | if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: 45 | return model_bytes if in_memory else download_target 46 | else: 47 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 48 | 49 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 50 | with tqdm( 51 | total=int(source.info().get("Content-Length")), ncols=80, unit="iB", unit_scale=True, unit_divisor=1024 52 | ) as loop: 53 | while True: 54 | buffer = source.read(8192) 55 | if not buffer: 56 | break 57 | 58 | output.write(buffer) 59 | loop.update(len(buffer)) 60 | 61 | model_bytes = open(download_target, "rb").read() 62 | if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: 63 | raise RuntimeError( 64 | "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." 65 | ) 66 | 67 | return model_bytes if in_memory else download_target 68 | 69 | 70 | def available_models() -> List[str]: 71 | """Returns the names of available models""" 72 | return list(_MODELS.keys()) 73 | 74 | 75 | def load_model( 76 | name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False 77 | ) -> Whisper: 78 | """ 79 | Load a Whisper ASR model 80 | 81 | Parameters 82 | ---------- 83 | name : str 84 | one of the official model names listed by `whisper.available_models()`, or 85 | path to a model checkpoint containing the model dimensions and the model state_dict. 86 | device : Union[str, torch.device] 87 | the PyTorch device to put the model into 88 | download_root: str 89 | path to download the model files; by default, it uses "~/.cache/whisper" 90 | in_memory: bool 91 | whether to preload the model weights into host memory 92 | 93 | Returns 94 | ------- 95 | model : Whisper 96 | The Whisper ASR model instance 97 | """ 98 | 99 | if device is None: 100 | device = "cuda" if torch.cuda.is_available() else "cpu" 101 | if download_root is None: 102 | download_root = os.getenv("XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache", "whisper")) 103 | 104 | if name in _MODELS: 105 | checkpoint_file = _download(_MODELS[name], download_root, in_memory) 106 | elif os.path.isfile(name): 107 | checkpoint_file = open(name, "rb").read() if in_memory else name 108 | else: 109 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 110 | 111 | with io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") as fp: 112 | checkpoint = torch.load(fp, map_location=device, weights_only=True) 113 | del checkpoint_file 114 | 115 | dims = ModelDimensions(**checkpoint["dims"]) 116 | model = Whisper(dims) 117 | model.load_state_dict(checkpoint["model_state_dict"]) 118 | 119 | del checkpoint 120 | torch.cuda.empty_cache() 121 | 122 | return model.to(device) 123 | -------------------------------------------------------------------------------- /latentsync/whisper/whisper/__main__.py: -------------------------------------------------------------------------------- 1 | from .transcribe import cli 2 | 3 | 4 | cli() 5 | -------------------------------------------------------------------------------- /latentsync/whisper/whisper/assets/gpt2/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"} -------------------------------------------------------------------------------- /latentsync/whisper/whisper/assets/gpt2/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | {"unk_token": "<|endoftext|>", "bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "gpt2", "tokenizer_class": "GPT2Tokenizer"} -------------------------------------------------------------------------------- /latentsync/whisper/whisper/assets/mel_filters.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/LatentSync/3c3a1a2b62c9d854e3e6ab5525373973380a74a5/latentsync/whisper/whisper/assets/mel_filters.npz -------------------------------------------------------------------------------- /latentsync/whisper/whisper/assets/multilingual/added_tokens.json: -------------------------------------------------------------------------------- 1 | {"<|endoftext|>": 50257} 2 | -------------------------------------------------------------------------------- /latentsync/whisper/whisper/assets/multilingual/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"} -------------------------------------------------------------------------------- /latentsync/whisper/whisper/assets/multilingual/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | {"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "multilingual", "errors": "replace", "tokenizer_class": "GPT2Tokenizer"} -------------------------------------------------------------------------------- /latentsync/whisper/whisper/audio.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import lru_cache 3 | from typing import Union 4 | 5 | import ffmpeg 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from .utils import exact_div 11 | 12 | # hard-coded audio hyperparameters 13 | SAMPLE_RATE = 16000 14 | N_FFT = 400 15 | N_MELS = 80 16 | HOP_LENGTH = 160 17 | CHUNK_LENGTH = 30 18 | N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk 19 | N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input 20 | 21 | 22 | def load_audio(file: str, sr: int = SAMPLE_RATE): 23 | """ 24 | Open an audio file and read as mono waveform, resampling as necessary 25 | 26 | Parameters 27 | ---------- 28 | file: str 29 | The audio file to open 30 | 31 | sr: int 32 | The sample rate to resample the audio if necessary 33 | 34 | Returns 35 | ------- 36 | A NumPy array containing the audio waveform, in float32 dtype. 37 | """ 38 | try: 39 | # This launches a subprocess to decode audio while down-mixing and resampling as necessary. 40 | # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed. 41 | out, _ = ( 42 | ffmpeg.input(file, threads=0) 43 | .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr) 44 | .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True) 45 | ) 46 | except ffmpeg.Error as e: 47 | raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e 48 | 49 | return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 50 | 51 | 52 | def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): 53 | """ 54 | Pad or trim the audio array to N_SAMPLES, as expected by the encoder. 55 | """ 56 | if torch.is_tensor(array): 57 | if array.shape[axis] > length: 58 | array = array.index_select(dim=axis, index=torch.arange(length)) 59 | 60 | if array.shape[axis] < length: 61 | pad_widths = [(0, 0)] * array.ndim 62 | pad_widths[axis] = (0, length - array.shape[axis]) 63 | array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes]) 64 | else: 65 | if array.shape[axis] > length: 66 | array = array.take(indices=range(length), axis=axis) 67 | 68 | if array.shape[axis] < length: 69 | pad_widths = [(0, 0)] * array.ndim 70 | pad_widths[axis] = (0, length - array.shape[axis]) 71 | array = np.pad(array, pad_widths) 72 | 73 | return array 74 | 75 | 76 | @lru_cache(maxsize=None) 77 | def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor: 78 | """ 79 | load the mel filterbank matrix for projecting STFT into a Mel spectrogram. 80 | Allows decoupling librosa dependency; saved using: 81 | 82 | np.savez_compressed( 83 | "mel_filters.npz", 84 | mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), 85 | ) 86 | """ 87 | assert n_mels == 80, f"Unsupported n_mels: {n_mels}" 88 | with np.load(os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")) as f: 89 | return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) 90 | 91 | 92 | def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS): 93 | """ 94 | Compute the log-Mel spectrogram of 95 | 96 | Parameters 97 | ---------- 98 | audio: Union[str, np.ndarray, torch.Tensor], shape = (*) 99 | The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz 100 | 101 | n_mels: int 102 | The number of Mel-frequency filters, only 80 is supported 103 | 104 | Returns 105 | ------- 106 | torch.Tensor, shape = (80, n_frames) 107 | A Tensor that contains the Mel spectrogram 108 | """ 109 | if not torch.is_tensor(audio): 110 | if isinstance(audio, str): 111 | audio = load_audio(audio) 112 | audio = torch.from_numpy(audio) 113 | 114 | window = torch.hann_window(N_FFT).to(audio.device) 115 | stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) 116 | 117 | magnitudes = stft[:, :-1].abs() ** 2 118 | 119 | filters = mel_filters(audio.device, n_mels) 120 | mel_spec = filters @ magnitudes 121 | 122 | log_spec = torch.clamp(mel_spec, min=1e-10).log10() 123 | log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) 124 | log_spec = (log_spec + 4.0) / 4.0 125 | return log_spec 126 | -------------------------------------------------------------------------------- /latentsync/whisper/whisper/normalizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .basic import BasicTextNormalizer 2 | from .english import EnglishTextNormalizer 3 | -------------------------------------------------------------------------------- /latentsync/whisper/whisper/normalizers/basic.py: -------------------------------------------------------------------------------- 1 | import re 2 | import unicodedata 3 | 4 | import regex 5 | 6 | # non-ASCII letters that are not separated by "NFKD" normalization 7 | ADDITIONAL_DIACRITICS = { 8 | "œ": "oe", 9 | "Œ": "OE", 10 | "ø": "o", 11 | "Ø": "O", 12 | "æ": "ae", 13 | "Æ": "AE", 14 | "ß": "ss", 15 | "ẞ": "SS", 16 | "đ": "d", 17 | "Đ": "D", 18 | "ð": "d", 19 | "Ð": "D", 20 | "þ": "th", 21 | "Þ": "th", 22 | "ł": "l", 23 | "Ł": "L", 24 | } 25 | 26 | 27 | def remove_symbols_and_diacritics(s: str, keep=""): 28 | """ 29 | Replace any other markers, symbols, and punctuations with a space, 30 | and drop any diacritics (category 'Mn' and some manual mappings) 31 | """ 32 | return "".join( 33 | c 34 | if c in keep 35 | else ADDITIONAL_DIACRITICS[c] 36 | if c in ADDITIONAL_DIACRITICS 37 | else "" 38 | if unicodedata.category(c) == "Mn" 39 | else " " 40 | if unicodedata.category(c)[0] in "MSP" 41 | else c 42 | for c in unicodedata.normalize("NFKD", s) 43 | ) 44 | 45 | 46 | def remove_symbols(s: str): 47 | """ 48 | Replace any other markers, symbols, punctuations with a space, keeping diacritics 49 | """ 50 | return "".join( 51 | " " if unicodedata.category(c)[0] in "MSP" else c for c in unicodedata.normalize("NFKC", s) 52 | ) 53 | 54 | 55 | class BasicTextNormalizer: 56 | def __init__(self, remove_diacritics: bool = False, split_letters: bool = False): 57 | self.clean = remove_symbols_and_diacritics if remove_diacritics else remove_symbols 58 | self.split_letters = split_letters 59 | 60 | def __call__(self, s: str): 61 | s = s.lower() 62 | s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets 63 | s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis 64 | s = self.clean(s).lower() 65 | 66 | if self.split_letters: 67 | s = " ".join(regex.findall(r"\X", s, regex.U)) 68 | 69 | s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space 70 | 71 | return s 72 | -------------------------------------------------------------------------------- /latentsync/whisper/whisper/utils.py: -------------------------------------------------------------------------------- 1 | import zlib 2 | from typing import Iterator, TextIO 3 | 4 | 5 | def exact_div(x, y): 6 | assert x % y == 0 7 | return x // y 8 | 9 | 10 | def str2bool(string): 11 | str2val = {"True": True, "False": False} 12 | if string in str2val: 13 | return str2val[string] 14 | else: 15 | raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") 16 | 17 | 18 | def optional_int(string): 19 | return None if string == "None" else int(string) 20 | 21 | 22 | def optional_float(string): 23 | return None if string == "None" else float(string) 24 | 25 | 26 | def compression_ratio(text) -> float: 27 | return len(text) / len(zlib.compress(text.encode("utf-8"))) 28 | 29 | 30 | def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'): 31 | assert seconds >= 0, "non-negative timestamp expected" 32 | milliseconds = round(seconds * 1000.0) 33 | 34 | hours = milliseconds // 3_600_000 35 | milliseconds -= hours * 3_600_000 36 | 37 | minutes = milliseconds // 60_000 38 | milliseconds -= minutes * 60_000 39 | 40 | seconds = milliseconds // 1_000 41 | milliseconds -= seconds * 1_000 42 | 43 | hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" 44 | return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" 45 | 46 | 47 | def write_txt(transcript: Iterator[dict], file: TextIO): 48 | for segment in transcript: 49 | print(segment['text'].strip(), file=file, flush=True) 50 | 51 | 52 | def write_vtt(transcript: Iterator[dict], file: TextIO): 53 | print("WEBVTT\n", file=file) 54 | for segment in transcript: 55 | print( 56 | f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n" 57 | f"{segment['text'].strip().replace('-->', '->')}\n", 58 | file=file, 59 | flush=True, 60 | ) 61 | 62 | 63 | def write_srt(transcript: Iterator[dict], file: TextIO): 64 | """ 65 | Write a transcript to a file in SRT format. 66 | 67 | Example usage: 68 | from pathlib import Path 69 | from whisper.utils import write_srt 70 | 71 | result = transcribe(model, audio_path, temperature=temperature, **args) 72 | 73 | # save SRT 74 | audio_basename = Path(audio_path).stem 75 | with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt: 76 | write_srt(result["segments"], file=srt) 77 | """ 78 | for i, segment in enumerate(transcript, start=1): 79 | # write srt lines 80 | print( 81 | f"{i}\n" 82 | f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> " 83 | f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n" 84 | f"{segment['text'].strip().replace('-->', '->')}\n", 85 | file=file, 86 | flush=True, 87 | ) 88 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | # Prediction interface for Cog ⚙️ 2 | # https://cog.run/python 3 | 4 | from cog import BasePredictor, Input, Path 5 | import os 6 | import time 7 | import subprocess 8 | 9 | MODEL_CACHE = "checkpoints" 10 | MODEL_URL = "https://weights.replicate.delivery/default/chunyu-li/LatentSync/model.tar" 11 | 12 | 13 | def download_weights(url, dest): 14 | start = time.time() 15 | print("downloading url: ", url) 16 | print("downloading to: ", dest) 17 | subprocess.check_call(["pget", "-xf", url, dest], close_fds=False) 18 | print("downloading took: ", time.time() - start) 19 | 20 | 21 | class Predictor(BasePredictor): 22 | def setup(self) -> None: 23 | """Load the model into memory to make running multiple predictions efficient""" 24 | # Download the model weights 25 | if not os.path.exists(MODEL_CACHE): 26 | download_weights(MODEL_URL, MODEL_CACHE) 27 | 28 | # Soft links for the auxiliary models 29 | os.system("mkdir -p ~/.cache/torch/hub/checkpoints") 30 | os.system( 31 | "ln -s $(pwd)/checkpoints/auxiliary/2DFAN4-cd938726ad.zip ~/.cache/torch/hub/checkpoints/2DFAN4-cd938726ad.zip" 32 | ) 33 | os.system( 34 | "ln -s $(pwd)/checkpoints/auxiliary/s3fd-619a316812.pth ~/.cache/torch/hub/checkpoints/s3fd-619a316812.pth" 35 | ) 36 | os.system( 37 | "ln -s $(pwd)/checkpoints/auxiliary/vgg16-397923af.pth ~/.cache/torch/hub/checkpoints/vgg16-397923af.pth" 38 | ) 39 | 40 | def predict( 41 | self, 42 | video: Path = Input(description="Input video", default=None), 43 | audio: Path = Input(description="Input audio to ", default=None), 44 | guidance_scale: float = Input(description="Guidance scale", ge=1, le=3, default=2.0), 45 | inference_steps: int = Input(description="Inference steps", ge=20, le=50, default=20), 46 | seed: int = Input(description="Set to 0 for Random seed", default=0), 47 | ) -> Path: 48 | """Run a single prediction on the model""" 49 | if seed <= 0: 50 | seed = int.from_bytes(os.urandom(2), "big") 51 | print(f"Using seed: {seed}") 52 | 53 | video_path = str(video) 54 | audio_path = str(audio) 55 | config_path = "configs/unet/stage2.yaml" 56 | ckpt_path = "checkpoints/latentsync_unet.pt" 57 | output_path = "/tmp/video_out.mp4" 58 | 59 | # Run the following command: 60 | os.system( 61 | f"python -m scripts.inference --unet_config_path {config_path} --inference_ckpt_path {ckpt_path} --guidance_scale {str(guidance_scale)} --video_path {video_path} --audio_path {audio_path} --video_out_path {output_path} --seed {seed} --inference_steps {inference_steps}" 62 | ) 63 | return Path(output_path) 64 | -------------------------------------------------------------------------------- /preprocess/affine_transform.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from latentsync.utils.util import write_video 16 | from latentsync.utils.image_processor import VideoProcessor 17 | import torch 18 | import os 19 | import subprocess 20 | from multiprocessing import Process 21 | import shutil 22 | 23 | paths = [] 24 | 25 | 26 | def gather_video_paths(input_dir, output_dir): 27 | for video in sorted(os.listdir(input_dir)): 28 | if video.endswith(".mp4"): 29 | video_input = os.path.join(input_dir, video) 30 | video_output = os.path.join(output_dir, video) 31 | if os.path.isfile(video_output): 32 | continue 33 | paths.append((video_input, video_output)) 34 | elif os.path.isdir(os.path.join(input_dir, video)): 35 | gather_video_paths(os.path.join(input_dir, video), os.path.join(output_dir, video)) 36 | 37 | 38 | def combine_video_audio(video_frames, video_input_path, video_output_path, process_temp_dir): 39 | video_name = os.path.basename(video_input_path)[:-4] 40 | audio_temp = os.path.join(process_temp_dir, f"{video_name}_temp.wav") 41 | video_temp = os.path.join(process_temp_dir, f"{video_name}_temp.mp4") 42 | 43 | write_video(video_temp, video_frames, fps=25) 44 | 45 | command = f"ffmpeg -y -loglevel error -i {video_input_path} -q:a 0 -map a {audio_temp}" 46 | subprocess.run(command, shell=True) 47 | 48 | os.makedirs(os.path.dirname(video_output_path), exist_ok=True) 49 | command = f"ffmpeg -y -loglevel error -i {video_temp} -i {audio_temp} -c:v libx264 -c:a aac -map 0:v -map 1:a -q:v 0 -q:a 0 {video_output_path}" 50 | subprocess.run(command, shell=True) 51 | 52 | os.remove(audio_temp) 53 | os.remove(video_temp) 54 | 55 | 56 | def func(paths, process_temp_dir, device_id, resolution): 57 | os.makedirs(process_temp_dir, exist_ok=True) 58 | video_processor = VideoProcessor(resolution, f"cuda:{device_id}") 59 | 60 | for video_input, video_output in paths: 61 | if os.path.isfile(video_output): 62 | continue 63 | try: 64 | video_frames = video_processor.affine_transform_video(video_input) 65 | except Exception as e: # Handle the exception of face not detcted 66 | print(f"Exception: {e} - {video_input}") 67 | continue 68 | 69 | os.makedirs(os.path.dirname(video_output), exist_ok=True) 70 | combine_video_audio(video_frames, video_input, video_output, process_temp_dir) 71 | print(f"Saved: {video_output}") 72 | 73 | 74 | def split(a, n): 75 | k, m = divmod(len(a), n) 76 | return (a[i * k + min(i, m) : (i + 1) * k + min(i + 1, m)] for i in range(n)) 77 | 78 | 79 | def affine_transform_multi_gpus(input_dir, output_dir, temp_dir, resolution, num_workers): 80 | print(f"Recursively gathering video paths of {input_dir} ...") 81 | gather_video_paths(input_dir, output_dir) 82 | num_devices = torch.cuda.device_count() 83 | if num_devices == 0: 84 | raise RuntimeError("No GPUs found") 85 | 86 | if os.path.exists(temp_dir): 87 | shutil.rmtree(temp_dir) 88 | os.makedirs(temp_dir, exist_ok=True) 89 | 90 | split_paths = list(split(paths, num_workers * num_devices)) 91 | 92 | processes = [] 93 | 94 | for i in range(num_devices): 95 | for j in range(num_workers): 96 | process_index = i * num_workers + j 97 | process = Process( 98 | target=func, args=(split_paths[process_index], os.path.join(temp_dir, f"process_{i}"), i, resolution) 99 | ) 100 | process.start() 101 | processes.append(process) 102 | 103 | for process in processes: 104 | process.join() 105 | 106 | 107 | if __name__ == "__main__": 108 | input_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/segmented" 109 | output_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/affine_transformed" 110 | temp_dir = "temp" 111 | resolution = 256 112 | num_workers = 10 # How many processes per device 113 | 114 | affine_transform_multi_gpus(input_dir, output_dir, temp_dir, resolution, num_workers) 115 | -------------------------------------------------------------------------------- /preprocess/data_processing_pipeline.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import os 17 | from preprocess.affine_transform import affine_transform_multi_gpus 18 | from preprocess.remove_broken_videos import remove_broken_videos_multiprocessing 19 | from preprocess.detect_shot import detect_shot_multiprocessing 20 | from preprocess.filter_high_resolution import filter_high_resolution_multiprocessing 21 | from preprocess.resample_fps_hz import resample_fps_hz_multiprocessing 22 | from preprocess.segment_videos import segment_videos_multiprocessing 23 | from preprocess.sync_av import sync_av_multi_gpus 24 | from preprocess.filter_visual_quality import filter_visual_quality_multi_gpus 25 | from preprocess.remove_incorrect_affined import remove_incorrect_affined_multiprocessing 26 | from latentsync.utils.util import check_model_and_download 27 | 28 | 29 | def data_processing_pipeline( 30 | total_num_workers, per_gpu_num_workers, resolution, sync_conf_threshold, temp_dir, input_dir 31 | ): 32 | print("Checking models are downloaded...") 33 | check_model_and_download("checkpoints/auxiliary/syncnet_v2.model") 34 | check_model_and_download("checkpoints/auxiliary/sfd_face.pth") 35 | check_model_and_download("checkpoints/auxiliary/koniq_pretrained.pkl") 36 | 37 | print("Removing broken videos...") 38 | remove_broken_videos_multiprocessing(input_dir, total_num_workers) 39 | 40 | print("Resampling FPS hz...") 41 | resampled_dir = os.path.join(os.path.dirname(input_dir), "resampled") 42 | resample_fps_hz_multiprocessing(input_dir, resampled_dir, total_num_workers) 43 | 44 | print("Detecting shot...") 45 | shot_dir = os.path.join(os.path.dirname(input_dir), "shot") 46 | detect_shot_multiprocessing(resampled_dir, shot_dir, total_num_workers) 47 | 48 | print("Segmenting videos...") 49 | segmented_dir = os.path.join(os.path.dirname(input_dir), "segmented") 50 | segment_videos_multiprocessing(shot_dir, segmented_dir, total_num_workers) 51 | 52 | # If there are too many videos, you can first use this step to filter and reduce the quantity 53 | # print("Filtering high resolution...") 54 | # high_resolution_dir = os.path.join(os.path.dirname(input_dir), "high_resolution") 55 | # filter_high_resolution_multiprocessing(segmented_dir, high_resolution_dir, resolution, total_num_workers) 56 | 57 | print("Affine transforming videos...") 58 | affine_transformed_dir = os.path.join(os.path.dirname(input_dir), "affine_transformed") 59 | affine_transform_multi_gpus(segmented_dir, affine_transformed_dir, temp_dir, resolution, per_gpu_num_workers // 2) 60 | 61 | # print("Removing incorrect affined videos...") 62 | # remove_incorrect_affined_multiprocessing(affine_transformed_dir, total_num_workers) 63 | 64 | print("Syncing audio and video...") 65 | av_synced_dir = os.path.join(os.path.dirname(input_dir), f"av_synced_{sync_conf_threshold}") 66 | sync_av_multi_gpus(affine_transformed_dir, av_synced_dir, temp_dir, per_gpu_num_workers, sync_conf_threshold) 67 | 68 | print("Filtering visual quality...") 69 | high_visual_quality_dir = os.path.join(os.path.dirname(input_dir), "high_visual_quality") 70 | filter_visual_quality_multi_gpus(av_synced_dir, high_visual_quality_dir, per_gpu_num_workers) 71 | 72 | 73 | if __name__ == "__main__": 74 | parser = argparse.ArgumentParser() 75 | parser.add_argument("--total_num_workers", type=int, default=100) 76 | parser.add_argument("--per_gpu_num_workers", type=int, default=20) 77 | parser.add_argument("--resolution", type=int, default=256) 78 | parser.add_argument("--sync_conf_threshold", type=int, default=3) 79 | parser.add_argument("--temp_dir", type=str, default="temp") 80 | parser.add_argument("--input_dir", type=str, required=True) 81 | args = parser.parse_args() 82 | 83 | data_processing_pipeline( 84 | args.total_num_workers, 85 | args.per_gpu_num_workers, 86 | args.resolution, 87 | args.sync_conf_threshold, 88 | args.temp_dir, 89 | args.input_dir, 90 | ) 91 | -------------------------------------------------------------------------------- /preprocess/detect_shot.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import subprocess 17 | import tqdm 18 | from multiprocessing import Pool 19 | 20 | paths = [] 21 | 22 | 23 | def gather_paths(input_dir, output_dir): 24 | for video in sorted(os.listdir(input_dir)): 25 | if video.endswith(".mp4"): 26 | video_input = os.path.join(input_dir, video) 27 | video_output = os.path.join(output_dir, video) 28 | if os.path.isfile(video_output): 29 | continue 30 | paths.append([video_input, output_dir]) 31 | elif os.path.isdir(os.path.join(input_dir, video)): 32 | gather_paths(os.path.join(input_dir, video), os.path.join(output_dir, video)) 33 | 34 | 35 | def detect_shot(video_input, output_dir): 36 | os.makedirs(output_dir, exist_ok=True) 37 | video = os.path.basename(video_input)[:-4] 38 | command = f"scenedetect --quiet -i {video_input} detect-adaptive --threshold 2 split-video --filename '{video}_shot_$SCENE_NUMBER' --output {output_dir}" 39 | # command = f"scenedetect --quiet -i {video_input} detect-adaptive --threshold 2 split-video --high-quality --filename '{video}_shot_$SCENE_NUMBER' --output {output_dir}" 40 | subprocess.run(command, shell=True) 41 | 42 | 43 | def multi_run_wrapper(args): 44 | return detect_shot(*args) 45 | 46 | 47 | def detect_shot_multiprocessing(input_dir, output_dir, num_workers): 48 | print(f"Recursively gathering video paths of {input_dir} ...") 49 | gather_paths(input_dir, output_dir) 50 | 51 | print(f"Detecting shot of {input_dir} ...") 52 | with Pool(num_workers) as pool: 53 | for _ in tqdm.tqdm(pool.imap_unordered(multi_run_wrapper, paths), total=len(paths)): 54 | pass 55 | 56 | 57 | if __name__ == "__main__": 58 | input_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/high_resolution" 59 | output_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/shot" 60 | num_workers = 50 61 | 62 | detect_shot_multiprocessing(input_dir, output_dir, num_workers) 63 | -------------------------------------------------------------------------------- /preprocess/filter_high_resolution.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import mediapipe as mp 16 | from latentsync.utils.util import read_video 17 | import os 18 | import tqdm 19 | import shutil 20 | from multiprocessing import Pool 21 | 22 | paths = [] 23 | 24 | 25 | def gather_video_paths(input_dir, output_dir, resolution): 26 | for video in sorted(os.listdir(input_dir)): 27 | if video.endswith(".mp4"): 28 | video_input = os.path.join(input_dir, video) 29 | video_output = os.path.join(output_dir, video) 30 | if os.path.isfile(video_output): 31 | continue 32 | paths.append([video_input, video_output, resolution]) 33 | elif os.path.isdir(os.path.join(input_dir, video)): 34 | gather_video_paths(os.path.join(input_dir, video), os.path.join(output_dir, video), resolution) 35 | 36 | 37 | class FaceDetector: 38 | def __init__(self, resolution=256): 39 | self.face_detection = mp.solutions.face_detection.FaceDetection( 40 | model_selection=0, min_detection_confidence=0.5 41 | ) 42 | self.resolution = resolution 43 | 44 | def detect_face(self, image): 45 | height, width = image.shape[:2] 46 | # Process the image and detect faces. 47 | results = self.face_detection.process(image) 48 | 49 | if not results.detections: # Face not detected 50 | raise Exception("Face not detected") 51 | 52 | if len(results.detections) != 1: 53 | return False 54 | detection = results.detections[0] # Only use the first face in the image 55 | 56 | bounding_box = detection.location_data.relative_bounding_box 57 | face_width = int(bounding_box.width * width) 58 | face_height = int(bounding_box.height * height) 59 | if face_width < self.resolution or face_height < self.resolution: 60 | return False 61 | return True 62 | 63 | def detect_video(self, video_path): 64 | video_frames = read_video(video_path, change_fps=False) 65 | if len(video_frames) == 0: 66 | return False 67 | for frame in video_frames: 68 | if not self.detect_face(frame): 69 | return False 70 | return True 71 | 72 | def close(self): 73 | self.face_detection.close() 74 | 75 | 76 | def filter_video(video_input, video_out, resolution): 77 | if os.path.isfile(video_out): 78 | return 79 | face_detector = FaceDetector(resolution) 80 | try: 81 | save = face_detector.detect_video(video_input) 82 | except Exception as e: 83 | # print(f"Exception: {e} Input video: {video_input}") 84 | face_detector.close() 85 | return 86 | if save: 87 | os.makedirs(os.path.dirname(video_out), exist_ok=True) 88 | shutil.copy(video_input, video_out) 89 | face_detector.close() 90 | 91 | 92 | def multi_run_wrapper(args): 93 | return filter_video(*args) 94 | 95 | 96 | def filter_high_resolution_multiprocessing(input_dir, output_dir, resolution, num_workers): 97 | print(f"Recursively gathering video paths of {input_dir} ...") 98 | gather_video_paths(input_dir, output_dir, resolution) 99 | 100 | print(f"Filtering high resolution videos in {input_dir} ...") 101 | with Pool(num_workers) as pool: 102 | for _ in tqdm.tqdm(pool.imap_unordered(multi_run_wrapper, paths), total=len(paths)): 103 | pass 104 | 105 | 106 | if __name__ == "__main__": 107 | input_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/raw" 108 | output_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li//VoxCeleb2/high_resolution" 109 | resolution = 256 110 | num_workers = 50 111 | 112 | filter_high_resolution_multiprocessing(input_dir, output_dir, resolution, num_workers) 113 | -------------------------------------------------------------------------------- /preprocess/filter_visual_quality.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import tqdm 17 | import torch 18 | import torchvision 19 | import shutil 20 | from multiprocessing import Process 21 | import numpy as np 22 | from decord import VideoReader 23 | from einops import rearrange 24 | from eval.hyper_iqa import HyperNet, TargetNet 25 | 26 | 27 | paths = [] 28 | 29 | 30 | def gather_paths(input_dir, output_dir): 31 | # os.makedirs(output_dir, exist_ok=True) 32 | 33 | for video in tqdm.tqdm(sorted(os.listdir(input_dir))): 34 | if video.endswith(".mp4"): 35 | video_input = os.path.join(input_dir, video) 36 | video_output = os.path.join(output_dir, video) 37 | if os.path.isfile(video_output): 38 | continue 39 | paths.append((video_input, video_output)) 40 | elif os.path.isdir(os.path.join(input_dir, video)): 41 | gather_paths(os.path.join(input_dir, video), os.path.join(output_dir, video)) 42 | 43 | 44 | def read_video(video_path: str): 45 | vr = VideoReader(video_path) 46 | first_frame = vr[0].asnumpy() 47 | middle_frame = vr[len(vr) // 2].asnumpy() 48 | last_frame = vr[-1].asnumpy() 49 | vr.seek(0) 50 | video_frames = np.stack([first_frame, middle_frame, last_frame], axis=0) 51 | video_frames = torch.from_numpy(rearrange(video_frames, "b h w c -> b c h w")) 52 | video_frames = video_frames / 255.0 53 | return video_frames 54 | 55 | 56 | def func(paths, device_id): 57 | device = f"cuda:{device_id}" 58 | 59 | model_hyper = HyperNet(16, 112, 224, 112, 56, 28, 14, 7).to(device) 60 | model_hyper.train(False) 61 | 62 | # load the pre-trained model on the koniq-10k dataset 63 | model_hyper.load_state_dict( 64 | (torch.load("checkpoints/auxiliary/koniq_pretrained.pkl", map_location=device, weights_only=True)) 65 | ) 66 | 67 | transforms = torchvision.transforms.Compose( 68 | [ 69 | torchvision.transforms.CenterCrop(size=224), 70 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 71 | ] 72 | ) 73 | 74 | for video_input, video_output in paths: 75 | try: 76 | video_frames = read_video(video_input) 77 | video_frames = transforms(video_frames) 78 | video_frames = video_frames.clone().detach().to(device) 79 | paras = model_hyper(video_frames) # 'paras' contains the network weights conveyed to target network 80 | 81 | # Building target network 82 | model_target = TargetNet(paras).to(device) 83 | for param in model_target.parameters(): 84 | param.requires_grad = False 85 | 86 | # Quality prediction 87 | pred = model_target(paras["target_in_vec"]) # 'paras['target_in_vec']' is the input to target net 88 | 89 | # quality score ranges from 0-100, a higher score indicates a better quality 90 | quality_score = pred.mean().item() 91 | print(f"Input video: {video_input}\nVisual quality score: {quality_score:.2f}") 92 | 93 | if quality_score >= 40: 94 | os.makedirs(os.path.dirname(video_output), exist_ok=True) 95 | shutil.copy(video_input, video_output) 96 | except Exception as e: 97 | print(e) 98 | 99 | 100 | def split(a, n): 101 | k, m = divmod(len(a), n) 102 | return (a[i * k + min(i, m) : (i + 1) * k + min(i + 1, m)] for i in range(n)) 103 | 104 | 105 | def filter_visual_quality_multi_gpus(input_dir, output_dir, num_workers): 106 | gather_paths(input_dir, output_dir) 107 | num_devices = torch.cuda.device_count() 108 | if num_devices == 0: 109 | raise RuntimeError("No GPUs found") 110 | split_paths = list(split(paths, num_workers * num_devices)) 111 | processes = [] 112 | 113 | for i in range(num_devices): 114 | for j in range(num_workers): 115 | process_index = i * num_workers + j 116 | process = Process(target=func, args=(split_paths[process_index], i)) 117 | process.start() 118 | processes.append(process) 119 | 120 | for process in processes: 121 | process.join() 122 | 123 | 124 | if __name__ == "__main__": 125 | input_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/av_synced" 126 | output_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/high_visual_quality" 127 | num_workers = 20 # How many processes per device 128 | 129 | filter_visual_quality_multi_gpus(input_dir, output_dir, num_workers) 130 | -------------------------------------------------------------------------------- /preprocess/remove_broken_videos.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | from multiprocessing import Pool 17 | import tqdm 18 | 19 | from latentsync.utils.av_reader import AVReader 20 | from latentsync.utils.util import gather_video_paths_recursively 21 | 22 | 23 | def remove_broken_video(video_path): 24 | try: 25 | AVReader(video_path) 26 | except Exception: 27 | os.remove(video_path) 28 | 29 | 30 | def remove_broken_videos_multiprocessing(input_dir, num_workers): 31 | video_paths = gather_video_paths_recursively(input_dir) 32 | 33 | print("Removing broken videos...") 34 | with Pool(num_workers) as pool: 35 | for _ in tqdm.tqdm(pool.imap_unordered(remove_broken_video, video_paths), total=len(video_paths)): 36 | pass 37 | 38 | 39 | if __name__ == "__main__": 40 | input_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/raw" 41 | num_workers = 50 42 | 43 | remove_broken_videos_multiprocessing(input_dir, num_workers) 44 | -------------------------------------------------------------------------------- /preprocess/remove_incorrect_affined.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import mediapipe as mp 16 | from latentsync.utils.util import read_video, gather_video_paths_recursively 17 | import os 18 | import tqdm 19 | from multiprocessing import Pool 20 | 21 | 22 | class FaceDetector: 23 | def __init__(self): 24 | self.face_detection = mp.solutions.face_detection.FaceDetection( 25 | model_selection=0, min_detection_confidence=0.5 26 | ) 27 | 28 | def detect_face(self, image): 29 | # Process the image and detect faces. 30 | results = self.face_detection.process(image) 31 | 32 | if not results.detections: # Face not detected 33 | return False 34 | 35 | if len(results.detections) != 1: 36 | return False 37 | return True 38 | 39 | def detect_video(self, video_path): 40 | try: 41 | video_frames = read_video(video_path, change_fps=False) 42 | except Exception as e: 43 | print(f"Exception: {e} - {video_path}") 44 | return False 45 | if len(video_frames) == 0: 46 | return False 47 | for frame in video_frames: 48 | if not self.detect_face(frame): 49 | return False 50 | return True 51 | 52 | def close(self): 53 | self.face_detection.close() 54 | 55 | 56 | def remove_incorrect_affined(video_path): 57 | if not os.path.isfile(video_path): 58 | return 59 | face_detector = FaceDetector() 60 | has_face = face_detector.detect_video(video_path) 61 | if not has_face: 62 | os.remove(video_path) 63 | print(f"Removed: {video_path}") 64 | face_detector.close() 65 | 66 | 67 | def remove_incorrect_affined_multiprocessing(input_dir, num_workers): 68 | video_paths = gather_video_paths_recursively(input_dir) 69 | print(f"Total videos: {len(video_paths)}") 70 | 71 | print(f"Removing incorrect affined videos in {input_dir} ...") 72 | with Pool(num_workers) as pool: 73 | for _ in tqdm.tqdm(pool.imap_unordered(remove_incorrect_affined, video_paths), total=len(video_paths)): 74 | pass 75 | 76 | 77 | if __name__ == "__main__": 78 | input_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/affine_transformed" 79 | num_workers = 50 80 | 81 | remove_incorrect_affined_multiprocessing(input_dir, num_workers) 82 | -------------------------------------------------------------------------------- /preprocess/resample_fps_hz.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import subprocess 17 | import tqdm 18 | from multiprocessing import Pool 19 | import cv2 20 | 21 | paths = [] 22 | 23 | 24 | def gather_paths(input_dir, output_dir): 25 | for video in sorted(os.listdir(input_dir)): 26 | if video.endswith(".mp4"): 27 | video_input = os.path.join(input_dir, video) 28 | video_output = os.path.join(output_dir, video) 29 | if os.path.isfile(video_output): 30 | continue 31 | paths.append([video_input, video_output]) 32 | elif os.path.isdir(os.path.join(input_dir, video)): 33 | gather_paths(os.path.join(input_dir, video), os.path.join(output_dir, video)) 34 | 35 | 36 | def get_video_fps(video_path: str): 37 | cam = cv2.VideoCapture(video_path) 38 | fps = cam.get(cv2.CAP_PROP_FPS) 39 | return fps 40 | 41 | 42 | def resample_fps_hz(video_input, video_output): 43 | os.makedirs(os.path.dirname(video_output), exist_ok=True) 44 | if get_video_fps(video_input) == 25: 45 | command = f"ffmpeg -loglevel error -y -i {video_input} -c:v copy -ar 16000 -q:a 0 {video_output}" 46 | else: 47 | command = f"ffmpeg -loglevel error -y -i {video_input} -r 25 -ar 16000 -q:a 0 {video_output}" 48 | subprocess.run(command, shell=True) 49 | 50 | 51 | def multi_run_wrapper(args): 52 | return resample_fps_hz(*args) 53 | 54 | 55 | def resample_fps_hz_multiprocessing(input_dir, output_dir, num_workers): 56 | print(f"Recursively gathering video paths of {input_dir} ...") 57 | gather_paths(input_dir, output_dir) 58 | 59 | print(f"Resampling FPS and Hz of {input_dir} ...") 60 | with Pool(num_workers) as pool: 61 | for _ in tqdm.tqdm(pool.imap_unordered(multi_run_wrapper, paths), total=len(paths)): 62 | pass 63 | 64 | 65 | if __name__ == "__main__": 66 | input_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/raw" 67 | output_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/resampled" 68 | num_workers = 20 69 | 70 | resample_fps_hz_multiprocessing(input_dir, output_dir, num_workers) 71 | -------------------------------------------------------------------------------- /preprocess/segment_videos.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import subprocess 17 | import tqdm 18 | from multiprocessing import Pool 19 | 20 | paths = [] 21 | 22 | 23 | def gather_paths(input_dir, output_dir): 24 | for video in sorted(os.listdir(input_dir)): 25 | if video.endswith(".mp4"): 26 | video_basename = video[:-4] 27 | video_input = os.path.join(input_dir, video) 28 | video_output = os.path.join(output_dir, f"{video_basename}_%03d.mp4") 29 | if os.path.isfile(video_output): 30 | continue 31 | paths.append([video_input, video_output]) 32 | elif os.path.isdir(os.path.join(input_dir, video)): 33 | gather_paths(os.path.join(input_dir, video), os.path.join(output_dir, video)) 34 | 35 | 36 | def segment_video(video_input, video_output): 37 | os.makedirs(os.path.dirname(video_output), exist_ok=True) 38 | command = f"ffmpeg -loglevel error -y -i {video_input} -map 0 -c:v copy -segment_time 5 -f segment -reset_timestamps 1 -q:a 0 {video_output}" 39 | # command = f'ffmpeg -loglevel error -y -i {video_input} -map 0 -segment_time 5 -f segment -reset_timestamps 1 -force_key_frames "expr:gte(t,n_forced*5)" -crf 18 -q:a 0 {video_output}' 40 | subprocess.run(command, shell=True) 41 | 42 | 43 | def multi_run_wrapper(args): 44 | return segment_video(*args) 45 | 46 | 47 | def segment_videos_multiprocessing(input_dir, output_dir, num_workers): 48 | print(f"Recursively gathering video paths of {input_dir} ...") 49 | gather_paths(input_dir, output_dir) 50 | 51 | print(f"Segmenting videos of {input_dir} ...") 52 | with Pool(num_workers) as pool: 53 | for _ in tqdm.tqdm(pool.imap_unordered(multi_run_wrapper, paths), total=len(paths)): 54 | pass 55 | 56 | 57 | if __name__ == "__main__": 58 | input_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/shot" 59 | output_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/segmented" 60 | num_workers = 50 61 | 62 | segment_videos_multiprocessing(input_dir, output_dir, num_workers) 63 | -------------------------------------------------------------------------------- /preprocess/sync_av.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import tqdm 17 | from eval.syncnet import SyncNetEval 18 | from eval.syncnet_detect import SyncNetDetector 19 | from eval.eval_sync_conf import syncnet_eval 20 | import torch 21 | import subprocess 22 | import shutil 23 | from multiprocessing import Process 24 | 25 | paths = [] 26 | 27 | 28 | def gather_paths(input_dir, output_dir): 29 | # os.makedirs(output_dir, exist_ok=True) 30 | 31 | for video in tqdm.tqdm(sorted(os.listdir(input_dir))): 32 | if video.endswith(".mp4"): 33 | video_input = os.path.join(input_dir, video) 34 | video_output = os.path.join(output_dir, video) 35 | if os.path.isfile(video_output): 36 | continue 37 | paths.append((video_input, video_output)) 38 | elif os.path.isdir(os.path.join(input_dir, video)): 39 | gather_paths(os.path.join(input_dir, video), os.path.join(output_dir, video)) 40 | 41 | 42 | def adjust_offset(video_input: str, video_output: str, av_offset: int, fps: int = 25): 43 | command = f"ffmpeg -loglevel error -y -i {video_input} -itsoffset {av_offset/fps} -i {video_input} -map 0:v -map 1:a -c copy -q:v 0 -q:a 0 {video_output}" 44 | subprocess.run(command, shell=True) 45 | 46 | 47 | def func(sync_conf_threshold, paths, device_id, process_temp_dir): 48 | os.makedirs(process_temp_dir, exist_ok=True) 49 | device = f"cuda:{device_id}" 50 | 51 | syncnet = SyncNetEval(device=device) 52 | syncnet.loadParameters("checkpoints/auxiliary/syncnet_v2.model") 53 | 54 | detect_results_dir = os.path.join(process_temp_dir, "detect_results") 55 | syncnet_eval_results_dir = os.path.join(process_temp_dir, "syncnet_eval_results") 56 | 57 | syncnet_detector = SyncNetDetector(device=device, detect_results_dir=detect_results_dir) 58 | 59 | for video_input, video_output in paths: 60 | try: 61 | av_offset, conf = syncnet_eval( 62 | syncnet, syncnet_detector, video_input, syncnet_eval_results_dir, detect_results_dir 63 | ) 64 | 65 | if conf >= sync_conf_threshold and abs(av_offset) <= 6: 66 | os.makedirs(os.path.dirname(video_output), exist_ok=True) 67 | if av_offset == 0: 68 | shutil.copy(video_input, video_output) 69 | else: 70 | adjust_offset(video_input, video_output, av_offset) 71 | except Exception as e: 72 | print(e) 73 | 74 | 75 | def split(a, n): 76 | k, m = divmod(len(a), n) 77 | return (a[i * k + min(i, m) : (i + 1) * k + min(i + 1, m)] for i in range(n)) 78 | 79 | 80 | def sync_av_multi_gpus(input_dir, output_dir, temp_dir, num_workers, sync_conf_threshold): 81 | gather_paths(input_dir, output_dir) 82 | num_devices = torch.cuda.device_count() 83 | if num_devices == 0: 84 | raise RuntimeError("No GPUs found") 85 | split_paths = list(split(paths, num_workers * num_devices)) 86 | processes = [] 87 | 88 | for i in range(num_devices): 89 | for j in range(num_workers): 90 | process_index = i * num_workers + j 91 | process = Process( 92 | target=func, 93 | args=( 94 | sync_conf_threshold, 95 | split_paths[process_index], 96 | i, 97 | os.path.join(temp_dir, f"process_{process_index}"), 98 | ), 99 | ) 100 | process.start() 101 | processes.append(process) 102 | 103 | for process in processes: 104 | process.join() 105 | 106 | 107 | if __name__ == "__main__": 108 | input_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/affine_transformed" 109 | output_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/av_synced" 110 | temp_dir = "temp" 111 | num_workers = 20 # How many processes per device 112 | sync_conf_threshold = 3 113 | 114 | sync_av_multi_gpus(input_dir, output_dir, temp_dir, num_workers, sync_conf_threshold) 115 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.5.1 2 | torchvision==0.20.1 3 | --extra-index-url https://download.pytorch.org/whl/cu121 4 | diffusers==0.32.2 5 | transformers==4.48.0 6 | decord==0.6.0 7 | accelerate==0.26.1 8 | einops==0.7.0 9 | omegaconf==2.3.0 10 | opencv-python==4.9.0.80 11 | mediapipe==0.10.11 12 | python_speech_features==0.6 13 | librosa==0.10.1 14 | scenedetect==0.6.1 15 | ffmpeg-python==0.2.0 16 | imageio==2.31.1 17 | imageio-ffmpeg==0.5.1 18 | lpips==0.1.4 19 | face-alignment==1.4.1 20 | gradio==5.24.0 21 | huggingface-hub==0.30.2 22 | numpy==1.26.4 23 | kornia==0.8.0 24 | insightface==0.7.3 25 | onnxruntime-gpu==1.21.0 -------------------------------------------------------------------------------- /scripts/inference.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import os 17 | from omegaconf import OmegaConf 18 | import torch 19 | from diffusers import AutoencoderKL, DDIMScheduler 20 | from latentsync.models.unet import UNet3DConditionModel 21 | from latentsync.pipelines.lipsync_pipeline import LipsyncPipeline 22 | from accelerate.utils import set_seed 23 | from latentsync.whisper.audio2feature import Audio2Feature 24 | 25 | 26 | def main(config, args): 27 | if not os.path.exists(args.video_path): 28 | raise RuntimeError(f"Video path '{args.video_path}' not found") 29 | if not os.path.exists(args.audio_path): 30 | raise RuntimeError(f"Audio path '{args.audio_path}' not found") 31 | 32 | # Check if the GPU supports float16 33 | is_fp16_supported = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] > 7 34 | dtype = torch.float16 if is_fp16_supported else torch.float32 35 | 36 | print(f"Input video path: {args.video_path}") 37 | print(f"Input audio path: {args.audio_path}") 38 | print(f"Loaded checkpoint path: {args.inference_ckpt_path}") 39 | 40 | scheduler = DDIMScheduler.from_pretrained("configs") 41 | 42 | if config.model.cross_attention_dim == 768: 43 | whisper_model_path = "checkpoints/whisper/small.pt" 44 | elif config.model.cross_attention_dim == 384: 45 | whisper_model_path = "checkpoints/whisper/tiny.pt" 46 | else: 47 | raise NotImplementedError("cross_attention_dim must be 768 or 384") 48 | 49 | audio_encoder = Audio2Feature( 50 | model_path=whisper_model_path, 51 | device="cuda", 52 | num_frames=config.data.num_frames, 53 | audio_feat_length=config.data.audio_feat_length, 54 | ) 55 | 56 | vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=dtype) 57 | vae.config.scaling_factor = 0.18215 58 | vae.config.shift_factor = 0 59 | 60 | denoising_unet, _ = UNet3DConditionModel.from_pretrained( 61 | OmegaConf.to_container(config.model), 62 | args.inference_ckpt_path, 63 | device="cpu", 64 | ) 65 | 66 | denoising_unet = denoising_unet.to(dtype=dtype) 67 | 68 | pipeline = LipsyncPipeline( 69 | vae=vae, 70 | audio_encoder=audio_encoder, 71 | denoising_unet=denoising_unet, 72 | scheduler=scheduler, 73 | ).to("cuda") 74 | 75 | if args.seed != -1: 76 | set_seed(args.seed) 77 | else: 78 | torch.seed() 79 | 80 | print(f"Initial seed: {torch.initial_seed()}") 81 | 82 | pipeline( 83 | video_path=args.video_path, 84 | audio_path=args.audio_path, 85 | video_out_path=args.video_out_path, 86 | video_mask_path=args.video_out_path.replace(".mp4", "_mask.mp4"), 87 | num_frames=config.data.num_frames, 88 | num_inference_steps=args.inference_steps, 89 | guidance_scale=args.guidance_scale, 90 | weight_dtype=dtype, 91 | width=config.data.resolution, 92 | height=config.data.resolution, 93 | mask_image_path=config.data.mask_image_path, 94 | ) 95 | 96 | 97 | if __name__ == "__main__": 98 | parser = argparse.ArgumentParser() 99 | parser.add_argument("--unet_config_path", type=str, default="configs/unet.yaml") 100 | parser.add_argument("--inference_ckpt_path", type=str, required=True) 101 | parser.add_argument("--video_path", type=str, required=True) 102 | parser.add_argument("--audio_path", type=str, required=True) 103 | parser.add_argument("--video_out_path", type=str, required=True) 104 | parser.add_argument("--inference_steps", type=int, default=20) 105 | parser.add_argument("--guidance_scale", type=float, default=1.0) 106 | parser.add_argument("--seed", type=int, default=1247) 107 | args = parser.parse_args() 108 | 109 | config = OmegaConf.load(args.unet_config_path) 110 | 111 | main(config, args) 112 | -------------------------------------------------------------------------------- /setup_env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Create a new conda environment 4 | conda create -y -n latentsync python=3.10.13 5 | conda activate latentsync 6 | 7 | # Install ffmpeg 8 | conda install -y -c conda-forge ffmpeg 9 | 10 | # Python dependencies 11 | pip install -r requirements.txt 12 | 13 | # OpenCV dependencies 14 | sudo apt -y install libgl1 15 | 16 | # Download the checkpoints required for inference from HuggingFace 17 | huggingface-cli download ByteDance/LatentSync-1.5 whisper/tiny.pt --local-dir checkpoints 18 | huggingface-cli download ByteDance/LatentSync-1.5 latentsync_unet.pt --local-dir checkpoints -------------------------------------------------------------------------------- /tools/count_total_videos_time.py: -------------------------------------------------------------------------------- 1 | from latentsync.utils.util import count_video_time 2 | from tqdm import tqdm 3 | 4 | 5 | def count_total_videos_time(fileslist_path: str): 6 | with open(fileslist_path, "r") as f: 7 | filepaths = f.readlines() 8 | 9 | # Remove trailing newline characters 10 | filepaths = [filepath.strip() for filepath in filepaths] 11 | 12 | total_videos_time = 0 13 | for filepath in tqdm(filepaths): 14 | total_videos_time += count_video_time(filepath) 15 | 16 | print(f"Fileslist path: {fileslist_path}") 17 | print(f"Total videos time: {round(total_videos_time/3600)} hours") 18 | 19 | 20 | if __name__ == "__main__": 21 | fileslist_path = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/data_v10_core.txt" 22 | count_total_videos_time(fileslist_path) 23 | -------------------------------------------------------------------------------- /tools/download_web_videos.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import subprocess 17 | from concurrent.futures import ThreadPoolExecutor 18 | from tqdm import tqdm 19 | 20 | """ 21 | To use this python script, first install yt-dlp by: 22 | 23 | pip install -U yt-dlp 24 | """ 25 | 26 | 27 | def download_video(video_url, video_path): 28 | download_video_command = f"yt-dlp -f bestvideo+bestaudio --skip-unavailable-fragments --merge-output-format mp4 '{video_url}' --output '{video_path}' --external-downloader aria2c --external-downloader-args '-x 16 -k 1M'" 29 | try: 30 | subprocess.run(download_video_command, shell=True) # ignore_security_alert_wait_for_fix RCE 31 | except KeyboardInterrupt: 32 | print("Stopped") 33 | exit() 34 | except: 35 | print(f"Error downloading video {video_url}") 36 | 37 | 38 | def download_videos(num_workers, video_urls, video_paths): 39 | with ThreadPoolExecutor(max_workers=num_workers) as executor: 40 | executor.map(download_video, video_urls, video_paths) 41 | 42 | 43 | def extract_vid(video_url): 44 | if "clip" in video_url: 45 | print(f"Cannot download youtube clip video: {video_url}") 46 | return None 47 | elif "watch?v=" in video_url: # ignore_security_alert_wait_for_fix RCE 48 | return video_url.split("watch?v=")[1][:11] 49 | elif "shorts/" in video_url: 50 | return video_url.split("shorts/")[1][:11] 51 | elif "youtu.be/" in video_url: 52 | return video_url.split("youtu.be/")[1][:11] 53 | elif "&v=" in video_url: 54 | return video_url.split("&v=")[1][:11] 55 | elif "bilibili.com/video/" in video_url: 56 | return video_url.split("bilibili.com/video/")[1][:12] 57 | elif "douyin.com/video/" in video_url: 58 | return video_url.split("douyin.com/video/")[1][:19] 59 | elif "douyin.com/user/self?modal_id=" in video_url: 60 | return video_url.split("douyin.com/user/self?modal_id=")[1][:19] 61 | else: 62 | print(f"Invalid video url: {video_url}") 63 | return None 64 | 65 | 66 | def main(urls_txt_path, output_dir, num_workers): 67 | os.makedirs(output_dir, exist_ok=True) 68 | 69 | with open(urls_txt_path, "r") as file: 70 | # Read lines into a list and strip newline characters 71 | all_video_urls = [line.strip() for line in file] 72 | 73 | video_paths = [] 74 | video_urls = [] 75 | 76 | print("Extracting vid...") 77 | for video_url in tqdm(all_video_urls): 78 | vid = extract_vid(video_url) 79 | if vid is None: 80 | continue 81 | video_path = os.path.join(output_dir, f"vid_{vid}.mp4") 82 | if os.path.isfile(video_path): 83 | continue 84 | os.makedirs(os.path.dirname(video_path), exist_ok=True) 85 | video_paths.append(video_path) 86 | video_urls.append(video_url) 87 | 88 | if len(video_paths) == 0: 89 | print("All videos have been downloaded") 90 | exit() 91 | else: 92 | print(f"Downloading {len(video_paths)} videos") 93 | 94 | download_videos(num_workers, video_urls, video_paths) 95 | 96 | 97 | if __name__ == "__main__": 98 | urls_txt_path = "video_urls.txt" 99 | output_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/youtube/raw" 100 | num_workers = 50 101 | 102 | maximum_duration = 60 * 30 # set video maximum duration as 30 minutes 103 | 104 | main(urls_txt_path, output_dir, num_workers) 105 | -------------------------------------------------------------------------------- /tools/move_files_recur.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import shutil 17 | from tqdm import tqdm 18 | 19 | paths = [] 20 | 21 | 22 | def gather_paths(input_dir, output_dir): 23 | os.makedirs(output_dir, exist_ok=True) 24 | 25 | for video in sorted(os.listdir(input_dir)): 26 | if video.endswith(".mp4"): 27 | video_input = os.path.join(input_dir, video) 28 | video_output = os.path.join(output_dir, video) 29 | if os.path.isfile(video_output): 30 | continue 31 | paths.append([video_input, output_dir]) 32 | elif os.path.isdir(os.path.join(input_dir, video)): 33 | gather_paths(os.path.join(input_dir, video), os.path.join(output_dir, video)) 34 | 35 | 36 | def main(input_dir, output_dir): 37 | print(f"Recursively gathering video paths of {input_dir} ...") 38 | gather_paths(input_dir, output_dir) 39 | 40 | for video_input, output_dir in tqdm(paths): 41 | shutil.move(video_input, output_dir) 42 | 43 | 44 | if __name__ == "__main__": 45 | # from input_dir to output_dir 46 | input_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/willdata2" 47 | output_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/willdata" 48 | 49 | main(input_dir, output_dir) 50 | -------------------------------------------------------------------------------- /tools/occupy_gpu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import os 17 | import torch.multiprocessing as mp 18 | import time 19 | 20 | 21 | def check_mem(cuda_device): 22 | devices_info = ( 23 | os.popen('"/usr/bin/nvidia-smi" --query-gpu=memory.total,memory.used --format=csv,nounits,noheader') 24 | .read() 25 | .strip() 26 | .split("\n") 27 | ) 28 | total, used = devices_info[int(cuda_device)].split(",") 29 | return total, used 30 | 31 | 32 | def loop(cuda_device): 33 | cuda_i = torch.device(f"cuda:{cuda_device}") 34 | total, used = check_mem(cuda_device) 35 | total = int(total) 36 | used = int(used) 37 | max_mem = int(total * 0.9) 38 | block_mem = max_mem - used 39 | while True: 40 | x = torch.rand(50, 512, 512, dtype=torch.float, device=cuda_i) 41 | y = torch.rand(50, 512, 512, dtype=torch.float, device=cuda_i) 42 | time.sleep(0.001) 43 | x = torch.matmul(x, y) 44 | 45 | 46 | def main(): 47 | if torch.cuda.is_available(): 48 | num_processes = torch.cuda.device_count() 49 | processes = list() 50 | for i in range(num_processes): 51 | p = mp.Process(target=loop, args=(i,)) 52 | p.start() 53 | processes.append(p) 54 | for p in processes: 55 | p.join() 56 | 57 | 58 | if __name__ == "__main__": 59 | torch.multiprocessing.set_start_method("spawn") 60 | main() 61 | -------------------------------------------------------------------------------- /tools/plot_videos_time_distribution.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import matplotlib.pyplot as plt 16 | from latentsync.utils.util import count_video_time, gather_video_paths_recursively 17 | from tqdm import tqdm 18 | 19 | 20 | def plot_histogram(data, fig_path): 21 | # Create histogram 22 | plt.hist(data, bins=30, edgecolor="black") 23 | 24 | # Add titles and labels 25 | plt.title("Histogram of Data Distribution") 26 | plt.xlabel("Video time") 27 | plt.ylabel("Frequency") 28 | 29 | # Save plot as an image file 30 | plt.savefig(fig_path) # Save as PNG file. You can also use 'histogram.jpg', 'histogram.pdf', etc. 31 | 32 | 33 | def main(input_dir, fig_path): 34 | video_paths = gather_video_paths_recursively(input_dir) 35 | video_times = [] 36 | for video_path in tqdm(video_paths): 37 | video_times.append(count_video_time(video_path)) 38 | plot_histogram(video_times, fig_path) 39 | 40 | 41 | if __name__ == "__main__": 42 | input_dir = "validation" 43 | fig_path = "histogram.png" 44 | 45 | main(input_dir, fig_path) 46 | -------------------------------------------------------------------------------- /tools/remove_outdated_files.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import subprocess 17 | 18 | 19 | def remove_outdated_files(input_dir, begin_date, end_date): 20 | # Remove files from a specific time period 21 | for subdir in os.listdir(input_dir): 22 | if subdir >= begin_date and subdir <= end_date: 23 | subdir_path = os.path.join(input_dir, subdir) 24 | command = f"rm -rf {subdir_path}" 25 | subprocess.run(command, shell=True) 26 | print(f"Deleted: {subdir_path}") 27 | 28 | 29 | if __name__ == "__main__": 30 | input_dir = "/mnt/bn/video-datasets/output/unet" 31 | begin_date = "train-2024_05_29-12:22:35" 32 | end_date = "train-2024_09_26-00:10:46" 33 | 34 | remove_outdated_files(input_dir, begin_date, end_date) 35 | -------------------------------------------------------------------------------- /tools/write_fileslist.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from tqdm import tqdm 16 | from latentsync.utils.util import gather_video_paths_recursively 17 | 18 | 19 | class FileslistWriter: 20 | def __init__(self, fileslist_path: str): 21 | self.fileslist_path = fileslist_path 22 | with open(fileslist_path, "w") as _: 23 | pass 24 | 25 | def append_dataset(self, dataset_dir: str): 26 | print(f"Dataset dir: {dataset_dir}") 27 | video_paths = gather_video_paths_recursively(dataset_dir) 28 | with open(self.fileslist_path, "a") as f: 29 | for video_path in tqdm(video_paths): 30 | f.write(f"{video_path}\n") 31 | 32 | 33 | if __name__ == "__main__": 34 | fileslist_path = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/data_v10_core.txt" 35 | 36 | writer = FileslistWriter(fileslist_path) 37 | writer.append_dataset("/mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/high_visual_quality/train") 38 | writer.append_dataset("/mnt/bn/maliva-gen-ai-v2/chunyu.li/HDTF/high_visual_quality/train") 39 | -------------------------------------------------------------------------------- /train_syncnet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | torchrun --nnodes=1 --nproc_per_node=1 --master_port=25678 -m scripts.train_syncnet \ 4 | --config_path "configs/syncnet/syncnet_16_pixel_attn.yaml" 5 | -------------------------------------------------------------------------------- /train_unet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | torchrun --nnodes=1 --nproc_per_node=1 --master_port=25679 -m scripts.train_unet \ 4 | --unet_config_path "configs/unet/stage1.yaml" 5 | --------------------------------------------------------------------------------