├── audio ├── infer.sh ├── module │ ├── __init__.py │ ├── commons.py │ ├── utils.py │ └── transforms.py ├── sample │ ├── src_p241_004.wav │ └── tar_p239_022.wav ├── requirements.txt ├── alias_free_torch │ ├── __init__.py │ ├── act.py │ ├── resample.py │ └── filter.py ├── model │ ├── base.py │ ├── styleencoder.py │ ├── diffusion_module.py │ ├── diffusion_f0.py │ ├── diffhiervc.py │ └── diffusion_mel.py ├── ckpt │ ├── config.json │ └── config_bigvgan.json ├── configs │ └── config_16k.json ├── augmentation │ ├── peq.py │ └── aug.py ├── utils │ ├── data_loader.py │ └── utils.py ├── vocoder │ ├── activations.py │ ├── hifigan.py │ └── bigvgan.py ├── train.py └── server.py ├── video ├── modules │ ├── __init__.py │ ├── processors │ │ ├── __init__.py │ │ └── frame │ │ │ ├── __init__.py │ │ │ ├── core.py │ │ │ └── face_enhancer.py │ ├── metadata.py │ ├── typing.py │ ├── gettext.py │ ├── cluster_analysis.py │ ├── capturer.py │ ├── globals.py │ ├── predicter.py │ ├── video_capture.py │ ├── ui.json │ ├── utilities.py │ ├── face_analyser.py │ └── core.py ├── image.jpg ├── image │ ├── Elon Musk.jpg │ └── Benedict Cumberbatch.jpg ├── models │ └── instructions.txt ├── requirements.txt ├── client.py ├── wrapper.py └── server.py ├── .gitignore ├── media └── Gui_deepfake.PNG └── LICENSE /audio/infer.sh: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /audio/module/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /video/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /video/modules/processors/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /video/modules/processors/frame/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .venv 3 | gfpgan 4 | .DS_Store -------------------------------------------------------------------------------- /video/modules/metadata.py: -------------------------------------------------------------------------------- 1 | name = 'Deep-Live-Cam' 2 | version = '1.8' 3 | edition = 'GitHub Edition' 4 | -------------------------------------------------------------------------------- /video/image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ali-Shariati-Najafabadi/Real-Time-Deepfake-Pipeline/HEAD/video/image.jpg -------------------------------------------------------------------------------- /media/Gui_deepfake.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ali-Shariati-Najafabadi/Real-Time-Deepfake-Pipeline/HEAD/media/Gui_deepfake.PNG -------------------------------------------------------------------------------- /video/image/Elon Musk.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ali-Shariati-Najafabadi/Real-Time-Deepfake-Pipeline/HEAD/video/image/Elon Musk.jpg -------------------------------------------------------------------------------- /audio/sample/src_p241_004.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ali-Shariati-Najafabadi/Real-Time-Deepfake-Pipeline/HEAD/audio/sample/src_p241_004.wav -------------------------------------------------------------------------------- /audio/sample/tar_p239_022.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ali-Shariati-Najafabadi/Real-Time-Deepfake-Pipeline/HEAD/audio/sample/tar_p239_022.wav -------------------------------------------------------------------------------- /video/image/Benedict Cumberbatch.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ali-Shariati-Najafabadi/Real-Time-Deepfake-Pipeline/HEAD/video/image/Benedict Cumberbatch.jpg -------------------------------------------------------------------------------- /video/modules/typing.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from insightface.app.common import Face 4 | import numpy 5 | 6 | Face = Face 7 | Frame = numpy.ndarray[Any, Any] 8 | -------------------------------------------------------------------------------- /audio/requirements.txt: -------------------------------------------------------------------------------- 1 | amfm_decompy==1.0.11 2 | einops==0.7.0 3 | numpy==1.21.4 4 | scipy==1.6.3 5 | torch==1.11.0+cu113 6 | torchaudio==0.11.0+cu113 7 | tqdm==4.62.3 8 | transformers==4.35.0 9 | -------------------------------------------------------------------------------- /audio/alias_free_torch/__init__.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | from .filter import * 5 | from .resample import * 6 | from .act import * -------------------------------------------------------------------------------- /video/models/instructions.txt: -------------------------------------------------------------------------------- 1 | The models are too large, so they can't be pushed into the GitHub repo. Just download the models below and move them into this directory. 2 | 3 | 1. https://huggingface.co/hacksider/deep-live-cam/resolve/main/inswapper_128_fp16.onnx?download=true 4 | 2. https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/GFPGANv1.4.pth, https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth -------------------------------------------------------------------------------- /video/requirements.txt: -------------------------------------------------------------------------------- 1 | customtkinter==5.2.2 2 | cv2_enumerate_cameras==1.1.15 3 | gfpgan==1.3.8 4 | insightface==0.7.3 5 | numpy>=1.23.5,<2 6 | onnx==1.16.0 7 | onnxruntime-gpu==1.16.3; sys_platform != 'darwin' 8 | onnxruntime-silicon==1.16.3; sys_platform == 'darwin' and platform_machine == 'arm64' 9 | pillow==9.5.0 10 | protobuf==4.23.2 11 | tensorflow; 12 | torch==2.0.1+cu118; sys_platform != 'darwin' 13 | torch==2.0.1; sys_platform == 'darwin' 14 | tqdm==4.66.4 15 | pyzmq==26.2.1 16 | msgpack==1.1.0 17 | msgpack-numpy==0.4.8 -------------------------------------------------------------------------------- /audio/model/base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class BaseModule(torch.nn.Module): 6 | def __init__(self): 7 | super(BaseModule, self).__init__() 8 | 9 | @property 10 | def nparams(self): 11 | num_params = 0 12 | for name, param in self.named_parameters(): 13 | if param.requires_grad: 14 | num_params += np.prod(param.detach().cpu().numpy().shape) 15 | return num_params 16 | 17 | 18 | def relocate_input(self, x: list): 19 | device = next(self.parameters()).device 20 | for i in range(len(x)): 21 | if isinstance(x[i], torch.Tensor) and x[i].device != device: 22 | x[i] = x[i].to(device) 23 | return x 24 | -------------------------------------------------------------------------------- /audio/alias_free_torch/act.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch.nn as nn 5 | from .resample import UpSample1d, DownSample1d 6 | 7 | 8 | class Activation1d(nn.Module): 9 | def __init__(self, 10 | activation, 11 | up_ratio: int = 2, 12 | down_ratio: int = 2, 13 | up_kernel_size: int = 12, 14 | down_kernel_size: int = 12): 15 | super().__init__() 16 | self.up_ratio = up_ratio 17 | self.down_ratio = down_ratio 18 | self.act = activation 19 | self.upsample = UpSample1d(up_ratio, up_kernel_size) 20 | self.downsample = DownSample1d(down_ratio, down_kernel_size) 21 | 22 | # x: [B,C,T] 23 | def forward(self, x): 24 | x = self.upsample(x) 25 | x = self.act(x) 26 | x = self.downsample(x) 27 | 28 | return x -------------------------------------------------------------------------------- /video/modules/gettext.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | class LanguageManager: 5 | def __init__(self, default_language="en"): 6 | self.current_language = default_language 7 | self.translations = {} 8 | self.load_language(default_language) 9 | 10 | def load_language(self, language_code) -> bool: 11 | """load language file""" 12 | if language_code == "en": 13 | return True 14 | try: 15 | file_path = Path(__file__).parent.parent / f"locales/{language_code}.json" 16 | with open(file_path, "r", encoding="utf-8") as file: 17 | self.translations = json.load(file) 18 | self.current_language = language_code 19 | return True 20 | except FileNotFoundError: 21 | print(f"Language file not found: {language_code}") 22 | return False 23 | 24 | def _(self, key, default=None) -> str: 25 | """get translate text""" 26 | return self.translations.get(key, default if default else key) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 ali shariati 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /video/modules/cluster_analysis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.cluster import KMeans 3 | from sklearn.metrics import silhouette_score 4 | from typing import Any 5 | 6 | 7 | def find_cluster_centroids(embeddings, max_k=10) -> Any: 8 | inertia = [] 9 | cluster_centroids = [] 10 | K = range(1, max_k+1) 11 | 12 | for k in K: 13 | kmeans = KMeans(n_clusters=k, random_state=0) 14 | kmeans.fit(embeddings) 15 | inertia.append(kmeans.inertia_) 16 | cluster_centroids.append({"k": k, "centroids": kmeans.cluster_centers_}) 17 | 18 | diffs = [inertia[i] - inertia[i+1] for i in range(len(inertia)-1)] 19 | optimal_centroids = cluster_centroids[diffs.index(max(diffs)) + 1]['centroids'] 20 | 21 | return optimal_centroids 22 | 23 | def find_closest_centroid(centroids: list, normed_face_embedding) -> list: 24 | try: 25 | centroids = np.array(centroids) 26 | normed_face_embedding = np.array(normed_face_embedding) 27 | similarities = np.dot(centroids, normed_face_embedding) 28 | closest_centroid_index = np.argmax(similarities) 29 | 30 | return closest_centroid_index, centroids[closest_centroid_index] 31 | except ValueError: 32 | return None -------------------------------------------------------------------------------- /video/modules/capturer.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | import cv2 3 | import modules.globals # Import the globals to check the color correction toggle 4 | 5 | 6 | def get_video_frame(video_path: str, frame_number: int = 0) -> Any: 7 | capture = cv2.VideoCapture(video_path) 8 | 9 | # Set MJPEG format to ensure correct color space handling 10 | capture.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*'MJPG')) 11 | 12 | # Only force RGB conversion if color correction is enabled 13 | if modules.globals.color_correction: 14 | capture.set(cv2.CAP_PROP_CONVERT_RGB, 1) 15 | 16 | frame_total = capture.get(cv2.CAP_PROP_FRAME_COUNT) 17 | capture.set(cv2.CAP_PROP_POS_FRAMES, min(frame_total, frame_number - 1)) 18 | has_frame, frame = capture.read() 19 | 20 | if has_frame and modules.globals.color_correction: 21 | # Convert the frame color if necessary 22 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 23 | 24 | capture.release() 25 | return frame if has_frame else None 26 | 27 | 28 | def get_video_frame_total(video_path: str) -> int: 29 | capture = cv2.VideoCapture(video_path) 30 | video_frame_total = int(capture.get(cv2.CAP_PROP_FRAME_COUNT)) 31 | capture.release() 32 | return video_frame_total 33 | -------------------------------------------------------------------------------- /video/modules/globals.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Dict, Any 3 | 4 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 5 | WORKFLOW_DIR = os.path.join(ROOT_DIR, "workflow") 6 | 7 | file_types = [ 8 | ("Image", ("*.png", "*.jpg", "*.jpeg", "*.gif", "*.bmp")), 9 | ("Video", ("*.mp4", "*.mkv")), 10 | ] 11 | 12 | souce_target_map = [] 13 | simple_map = {} 14 | 15 | source_path = None 16 | target_path = None 17 | output_path = None 18 | frame_processors: List[str] = [] 19 | keep_fps = True 20 | keep_audio = True 21 | keep_frames = False 22 | many_faces = False 23 | map_faces = False 24 | color_correction = False # New global variable for color correction toggle 25 | nsfw_filter = False 26 | video_encoder = None 27 | video_quality = None 28 | live_mirror = False 29 | live_resizable = True 30 | max_memory = None 31 | execution_providers = [ 32 | "CUDAExecutionProvider", 33 | "CoreMLExecutionProvider", 34 | "CPUExecutionProvider", 35 | ] 36 | execution_threads = None 37 | headless = None 38 | log_level = "error" 39 | fp_ui: Dict[str, bool] = {"face_enhancer": False} 40 | camera_input_combobox = None 41 | webcam_preview_running = False 42 | show_fps = False 43 | mouth_mask = False 44 | show_mouth_mask_box = False 45 | mask_feather_ratio = 8 46 | mask_down_size = 0.50 47 | mask_size = 1 48 | -------------------------------------------------------------------------------- /video/modules/predicter.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import opennsfw2 3 | from PIL import Image 4 | import cv2 # Add OpenCV import 5 | import modules.globals # Import globals to access the color correction toggle 6 | 7 | from modules.typing import Frame 8 | 9 | MAX_PROBABILITY = 0.85 10 | 11 | # Preload the model once for efficiency 12 | model = None 13 | 14 | def predict_frame(target_frame: Frame) -> bool: 15 | # Convert the frame to RGB before processing if color correction is enabled 16 | if modules.globals.color_correction: 17 | target_frame = cv2.cvtColor(target_frame, cv2.COLOR_BGR2RGB) 18 | 19 | image = Image.fromarray(target_frame) 20 | image = opennsfw2.preprocess_image(image, opennsfw2.Preprocessing.YAHOO) 21 | global model 22 | if model is None: 23 | model = opennsfw2.make_open_nsfw_model() 24 | 25 | views = numpy.expand_dims(image, axis=0) 26 | _, probability = model.predict(views)[0] 27 | return probability > MAX_PROBABILITY 28 | 29 | 30 | def predict_image(target_path: str) -> bool: 31 | return opennsfw2.predict_image(target_path) > MAX_PROBABILITY 32 | 33 | 34 | def predict_video(target_path: str) -> bool: 35 | _, probabilities = opennsfw2.predict_video_frames(video_path=target_path, frame_interval=100) 36 | return any(probability > MAX_PROBABILITY for probability in probabilities) 37 | -------------------------------------------------------------------------------- /audio/ckpt/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "log_interval": 1000, 4 | "eval_interval": 10000, 5 | "save_interval": 10000, 6 | "seed": 1234, 7 | "epochs": 1000, 8 | "optimizer": "adamw", 9 | "lr_decay_on": true, 10 | "learning_rate": 5e-5, 11 | "betas": [0.8, 0.99], 12 | "eps": 1e-9, 13 | "batch_size": 32, 14 | "fp16_run": false, 15 | "lr_decay": 0.999875, 16 | "segment_size": 35840, 17 | "init_lr_ratio": 1, 18 | "warmup_epochs": 0, 19 | "c_mel": 1, 20 | "aug": true, 21 | "lambda_commit": 0.02 22 | }, 23 | "data": { 24 | "sampling_rate": 16000, 25 | "filter_length": 1280, 26 | "hop_length": 320, 27 | "win_length": 1280, 28 | "n_mel_channels": 80, 29 | "mel_fmin": 0, 30 | "mel_fmax": 8000 31 | }, 32 | "model": { 33 | "inter_channels": 192, 34 | "hidden_channels": 192, 35 | "filter_channels": 768, 36 | "n_heads": 2, 37 | "n_layers": 6, 38 | "kernel_size": 3, 39 | "p_dropout": 0.1, 40 | "resblock": "1", 41 | "resblock_kernel_sizes": [3,7,11], 42 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 43 | "upsample_rates": [5,4,4,2,2], 44 | "upsample_initial_channel": 512, 45 | "upsample_kernel_sizes": [11,8,8,4,4], 46 | "mixup_ratio": 0.6, 47 | "n_layers_q": 3, 48 | "use_spectral_norm": false, 49 | "hidden_size": 128 50 | }, 51 | "diffusion" : { 52 | "dec_dim" : 64, 53 | "spk_dim" : 128, 54 | "beta_min" : 0.05, 55 | "beta_max" : 20.0 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /audio/ckpt/config_bigvgan.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "log_interval": 1000, 4 | "eval_interval": 10000, 5 | "save_interval": 10000, 6 | "seed": 1234, 7 | "epochs": 1000, 8 | "optimizer": "adamw", 9 | "lr_decay_on": true, 10 | "learning_rate": 5e-5, 11 | "betas": [0.8, 0.99], 12 | "eps": 1e-9, 13 | "batch_size": 32, 14 | "fp16_run": false, 15 | "lr_decay": 0.999875, 16 | "segment_size": 35840, 17 | "init_lr_ratio": 1, 18 | "warmup_epochs": 0, 19 | "c_mel": 1, 20 | "aug": true, 21 | "lambda_commit": 0.02 22 | }, 23 | "data": { 24 | "sampling_rate": 16000, 25 | "filter_length": 1280, 26 | "hop_length": 320, 27 | "win_length": 1280, 28 | "n_mel_channels": 80, 29 | "mel_fmin": 0, 30 | "mel_fmax": 8000 31 | }, 32 | "model": { 33 | "inter_channels": 192, 34 | "hidden_channels": 192, 35 | "filter_channels": 768, 36 | "n_heads": 2, 37 | "n_layers": 8, 38 | "kernel_size": 3, 39 | "p_dropout": 0.1, 40 | "resblock": "1", 41 | "resblock_kernel_sizes": [3,7,11], 42 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 43 | "upsample_rates": [5,4,2,2,2,2], 44 | "upsample_initial_channel": 1024, 45 | "upsample_kernel_sizes": [11,8,4,4,4,4], 46 | "hidden_size": 128 47 | }, 48 | "diffusion" : { 49 | "dec_dim" : 64, 50 | "spk_dim" : 128, 51 | "beta_min" : 0.05, 52 | "beta_max" : 20.0 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /audio/configs/config_16k.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "log_interval": 1000, 4 | "eval_interval": 10000, 5 | "save_interval": 10000, 6 | "seed": 1234, 7 | "epochs": 1000, 8 | "optimizer": "adamw", 9 | "lr_decay_on": true, 10 | "learning_rate": 5e-5, 11 | "betas": [0.8, 0.99], 12 | "eps": 1e-9, 13 | "batch_size": 32, 14 | "fp16_run": false, 15 | "lr_decay": 0.999875, 16 | "segment_size": 35840, 17 | "init_lr_ratio": 1, 18 | "warmup_epochs": 0, 19 | "c_mel": 1, 20 | "aug": true, 21 | "lambda_commit": 0.02 22 | }, 23 | "data": { 24 | "train_filelist_path": "fp_16k/train_wav.txt", 25 | "test_filelist_path": "fp_16k/test_wav.txt", 26 | "sampling_rate": 16000, 27 | "filter_length": 1280, 28 | "hop_length": 320, 29 | "win_length": 1280, 30 | "n_mel_channels": 80, 31 | "mel_fmin": 0, 32 | "mel_fmax": 8000 33 | }, 34 | "model": { 35 | "inter_channels": 192, 36 | "hidden_channels": 192, 37 | "filter_channels": 768, 38 | "n_heads": 2, 39 | "n_layers": 6, 40 | "kernel_size": 3, 41 | "p_dropout": 0.1, 42 | "resblock": "1", 43 | "resblock_kernel_sizes": [3,7,11], 44 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 45 | "upsample_rates": [5,4,4,2,2], 46 | "upsample_initial_channel": 512, 47 | "upsample_kernel_sizes": [11,8,8,4,4], 48 | "mixup_ratio": 0.6, 49 | "n_layers_q": 3, 50 | "use_spectral_norm": false, 51 | "hidden_size": 128 52 | }, 53 | "diffusion" : { 54 | "dec_dim" : 64, 55 | "spk_dim" : 128, 56 | "beta_min" : 0.05, 57 | "beta_max" : 20.0 58 | } 59 | } -------------------------------------------------------------------------------- /audio/alias_free_torch/resample.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | from .filter import LowPassFilter1d 7 | from .filter import kaiser_sinc_filter1d 8 | 9 | 10 | class UpSample1d(nn.Module): 11 | def __init__(self, ratio=2, kernel_size=None): 12 | super().__init__() 13 | self.ratio = ratio 14 | self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size 15 | self.stride = ratio 16 | self.pad = self.kernel_size // ratio - 1 17 | self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 18 | self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 19 | filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, 20 | half_width=0.6 / ratio, 21 | kernel_size=self.kernel_size) 22 | self.register_buffer("filter", filter) 23 | 24 | # x: [B, C, T] 25 | def forward(self, x): 26 | _, C, _ = x.shape 27 | 28 | x = F.pad(x, (self.pad, self.pad), mode='replicate') 29 | x = self.ratio * F.conv_transpose1d( 30 | x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) 31 | x = x[..., self.pad_left:-self.pad_right] 32 | 33 | return x 34 | 35 | 36 | class DownSample1d(nn.Module): 37 | def __init__(self, ratio=2, kernel_size=None): 38 | super().__init__() 39 | self.ratio = ratio 40 | self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size 41 | self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio, 42 | half_width=0.6 / ratio, 43 | stride=ratio, 44 | kernel_size=self.kernel_size) 45 | 46 | def forward(self, x): 47 | xx = self.lowpass(x) 48 | 49 | return xx -------------------------------------------------------------------------------- /video/client.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import zmq 3 | import msgpack 4 | import msgpack_numpy as m 5 | import numpy as np 6 | import time 7 | 8 | m.patch() 9 | 10 | ZMQ_SERVER_ADDRESS = "tcp://localhost:5558" 11 | ZMQ_CLIENT_ADDRESS = "tcp://localhost:5559" 12 | 13 | 14 | def main(): 15 | context = zmq.Context() 16 | 17 | # Socket to send frames to the server 18 | sender = context.socket(zmq.PUSH) 19 | sender.connect(ZMQ_SERVER_ADDRESS) 20 | 21 | # Socket to receive processed frames from the server 22 | receiver = context.socket(zmq.PULL) 23 | receiver.connect(ZMQ_CLIENT_ADDRESS) 24 | 25 | cap = cv2.VideoCapture(0) 26 | 27 | if not cap.isOpened(): 28 | print("Error: Could not open webcam.") 29 | return 30 | 31 | print("Client started, sending frames to the server...") 32 | 33 | try: 34 | while True: 35 | ret, frame = cap.read() 36 | if not ret: 37 | print("Error: Could not read frame.") 38 | break 39 | 40 | # Compress and send the frame 41 | _, encoded_frame = cv2.imencode( 42 | ".jpg", frame, [cv2.IMWRITE_JPEG_QUALITY, 80] 43 | ) 44 | sender.send(msgpack.packb(encoded_frame.tobytes())) 45 | print("Sent frame to server") 46 | 47 | # Receive processed frame from the server 48 | start_time = time.time() 49 | data = receiver.recv() 50 | print("Received frame from server") 51 | elapsed_time = time.time() - start_time 52 | print(f"Receive time: {elapsed_time:.4f} seconds") 53 | processed_frame_data = np.frombuffer(msgpack.unpackb(data), dtype=np.uint8) 54 | processed_frame = cv2.imdecode( 55 | processed_frame_data, cv2.IMWRITE_JPEG_QUALITY 56 | ) 57 | 58 | # Show the processed frame 59 | cv2.imshow("Real-Time-Deepfake-Pipeline", processed_frame) 60 | 61 | if cv2.waitKey(1) & 0xFF == ord("q"): 62 | break 63 | 64 | except KeyboardInterrupt: 65 | print("Client interrupted.") 66 | 67 | finally: 68 | cap.release() 69 | cv2.destroyAllWindows() 70 | 71 | 72 | if __name__ == "__main__": 73 | main() 74 | -------------------------------------------------------------------------------- /audio/model/styleencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from module.attentions import * 4 | 5 | class Mish(nn.Module): 6 | def __init__(self): 7 | super(Mish, self).__init__() 8 | def forward(self, x): 9 | return x * torch.tanh(torch.nn.functional.softplus(x)) 10 | 11 | 12 | class Conv1dGLU(nn.Module): 13 | def __init__(self, in_channels, out_channels, kernel_size, dropout): 14 | super(Conv1dGLU, self).__init__() 15 | self.out_channels = out_channels 16 | self.conv1 = nn.Conv1d(in_channels, 2 * out_channels, kernel_size=kernel_size, padding=2) 17 | self.dropout = nn.Dropout(dropout) 18 | 19 | def forward(self, x): 20 | residual = x 21 | x = self.conv1(x) 22 | x1, x2 = torch.split(x, split_size_or_sections=self.out_channels, dim=1) 23 | x = x1 * torch.sigmoid(x2) 24 | x = residual + self.dropout(x) 25 | 26 | return x 27 | 28 | 29 | class StyleEncoder(torch.nn.Module): 30 | def __init__(self, in_dim, hidden_dim, out_dim): 31 | super().__init__() 32 | 33 | self.in_dim = in_dim 34 | self.hidden_dim = hidden_dim 35 | self.out_dim = out_dim 36 | self.kernel_size = 5 37 | self.n_head = 2 38 | self.dropout = 0.1 39 | 40 | self.spectral = nn.Sequential( 41 | nn.Conv1d(self.in_dim, self.hidden_dim, 1), 42 | Mish(), 43 | nn.Dropout(self.dropout), 44 | nn.Conv1d(self.hidden_dim, self.hidden_dim, 1), 45 | Mish(), 46 | nn.Dropout(self.dropout) 47 | ) 48 | 49 | self.temporal = nn.Sequential( 50 | Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout), 51 | Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout), 52 | ) 53 | 54 | self.slf_attn = MultiHeadAttention(self.hidden_dim, self.hidden_dim, self.n_head, p_dropout=self.dropout, proximal_bias=False, proximal_init=True) 55 | self.atten_drop = nn.Dropout(self.dropout) 56 | self.fc = nn.Conv1d(self.hidden_dim, self.out_dim, 1) 57 | 58 | def forward(self, x, mask=None): 59 | x = self.spectral(x)*mask 60 | x = self.temporal(x)*mask 61 | 62 | attn_mask = mask.unsqueeze(2) * mask.unsqueeze(-1) 63 | y = self.slf_attn(x,x, attn_mask=attn_mask) 64 | x = x + self.atten_drop(y) 65 | x = self.fc(x) 66 | 67 | return self.temporal_avg_pool(x, mask=mask) 68 | 69 | def temporal_avg_pool(self, x, mask=None): 70 | if mask is None: 71 | out = torch.mean(x, dim=2) 72 | else: 73 | x = x.sum(dim=2) 74 | out = torch.div(x, mask.sum(dim=2)) 75 | 76 | return out -------------------------------------------------------------------------------- /video/modules/processors/frame/core.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import importlib 3 | from concurrent.futures import ThreadPoolExecutor 4 | from types import ModuleType 5 | from typing import Any, List, Callable 6 | from tqdm import tqdm 7 | 8 | import modules 9 | import modules.globals 10 | 11 | FRAME_PROCESSORS_MODULES: List[ModuleType] = [] 12 | FRAME_PROCESSORS_INTERFACE = [ 13 | 'pre_check', 14 | 'pre_start', 15 | 'process_frame', 16 | 'process_image', 17 | 'process_video' 18 | ] 19 | 20 | 21 | def load_frame_processor_module(frame_processor: str) -> Any: 22 | try: 23 | frame_processor_module = importlib.import_module(f'modules.processors.frame.{frame_processor}') 24 | for method_name in FRAME_PROCESSORS_INTERFACE: 25 | if not hasattr(frame_processor_module, method_name): 26 | sys.exit() 27 | except ImportError: 28 | print(f"Frame processor {frame_processor} not found") 29 | sys.exit() 30 | return frame_processor_module 31 | 32 | 33 | def get_frame_processors_modules(frame_processors: List[str]) -> List[ModuleType]: 34 | global FRAME_PROCESSORS_MODULES 35 | 36 | if not FRAME_PROCESSORS_MODULES: 37 | for frame_processor in frame_processors: 38 | frame_processor_module = load_frame_processor_module(frame_processor) 39 | FRAME_PROCESSORS_MODULES.append(frame_processor_module) 40 | set_frame_processors_modules_from_ui(frame_processors) 41 | return FRAME_PROCESSORS_MODULES 42 | 43 | def set_frame_processors_modules_from_ui(frame_processors: List[str]) -> None: 44 | global FRAME_PROCESSORS_MODULES 45 | for frame_processor, state in modules.globals.fp_ui.items(): 46 | if state == True and frame_processor not in frame_processors: 47 | frame_processor_module = load_frame_processor_module(frame_processor) 48 | FRAME_PROCESSORS_MODULES.append(frame_processor_module) 49 | modules.globals.frame_processors.append(frame_processor) 50 | if state == False: 51 | try: 52 | frame_processor_module = load_frame_processor_module(frame_processor) 53 | FRAME_PROCESSORS_MODULES.remove(frame_processor_module) 54 | modules.globals.frame_processors.remove(frame_processor) 55 | except: 56 | pass 57 | 58 | def multi_process_frame(source_path: str, temp_frame_paths: List[str], process_frames: Callable[[str, List[str], Any], None], progress: Any = None) -> None: 59 | with ThreadPoolExecutor(max_workers=modules.globals.execution_threads) as executor: 60 | futures = [] 61 | for path in temp_frame_paths: 62 | future = executor.submit(process_frames, source_path, [path], progress) 63 | futures.append(future) 64 | for future in futures: 65 | future.result() 66 | 67 | 68 | def process_video(source_path: str, frame_paths: list[str], process_frames: Callable[[str, List[str], Any], None]) -> None: 69 | progress_bar_format = '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]' 70 | total = len(frame_paths) 71 | with tqdm(total=total, desc='Processing', unit='frame', dynamic_ncols=True, bar_format=progress_bar_format) as progress: 72 | progress.set_postfix({'execution_providers': modules.globals.execution_providers, 'execution_threads': modules.globals.execution_threads, 'max_memory': modules.globals.max_memory}) 73 | multi_process_frame(source_path, frame_paths, process_frames, progress) 74 | -------------------------------------------------------------------------------- /audio/augmentation/peq.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | class ParametricEqualizer(nn.Module): 6 | """Fast-parametric equalizer for approximation of Biquad IIR filter. 7 | """ 8 | def __init__(self, sr: int, windows: int): 9 | """Initializer. 10 | Args: 11 | sr: sample rate. 12 | windows: size of the fft window. 13 | """ 14 | super().__init__() 15 | self.sr = sr 16 | self.windows = windows 17 | 18 | def biquad(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: 19 | """Construct frequency level biquad filter. 20 | Args: 21 | a: [torch.float32; [..., 3]], recursive filter, iir. 22 | b: [torch.float32; [..., 3]], finite impulse filter. 23 | Returns: 24 | [torch.float32; [..., windows // 2 + 1]], biquad filter. 25 | """ 26 | iir = torch.fft.rfft(a, self.windows, dim=-1) 27 | fir = torch.fft.rfft(b, self.windows, dim=-1) 28 | return fir / iir 29 | 30 | def low_shelving(self, cutoff: float, q: torch.Tensor) -> torch.Tensor: 31 | """Frequency level low-shelving filter. 32 | Args: 33 | cutoff: cutoff frequency. 34 | q: [torch.float32; [B]], quality factor. 35 | Returns: 36 | [torch.float32; [B, windows // 2 + 1]], frequency filter. 37 | """ 38 | bsize, = q.shape 39 | # ref: torchaudio.functional.lowpass_biquad 40 | w0 = 2 * np.pi * cutoff / self.sr 41 | cos_w0 = np.cos(w0) 42 | # [B] 43 | alpha = np.sin(w0) / 2 / q 44 | cos_w0 = torch.tensor( 45 | [np.cos(w0)] * bsize, dtype=torch.float32, device=q.device) 46 | # [B, windows // 2 + 1] 47 | return self.biquad( 48 | a=torch.stack([1 + alpha, -2 * cos_w0, 1 - alpha], dim=-1), 49 | b=torch.stack([(1 - cos_w0) / 2, 1 - cos_w0, (1 - cos_w0) / 2], dim=-1)) 50 | 51 | def high_shelving(self, cutoff: float, q: torch.Tensor) -> torch.Tensor: 52 | """Frequency level high-shelving filter. 53 | Args: 54 | cutoff: cutoff frequency. 55 | q: [torch.float32; [B]], quality factor. 56 | Returns: 57 | [torch.float32; [B, windows // 2 + 1]], frequency filter. 58 | """ 59 | bsize, = q.shape 60 | w0 = 2 * np.pi * cutoff / self.sr 61 | 62 | alpha = np.sin(w0) / 2 / q 63 | cos_w0 = torch.tensor( 64 | [np.cos(w0)] * bsize, dtype=torch.float32, device=q.device) 65 | 66 | return self.biquad( 67 | a=torch.stack([1 + alpha, -2 * cos_w0, 1 - alpha], dim=-1), 68 | b=torch.stack([(1 + cos_w0) / 2, -1 - cos_w0, (1 + cos_w0) / 2], dim=-1)) 69 | 70 | def peaking_equalizer(self, 71 | center: torch.Tensor, 72 | gain: torch.Tensor, 73 | q: torch.Tensor) -> torch.Tensor: 74 | """Frequency level peaking equalizer. 75 | Args: 76 | center: [torch.float32; [...]], center frequency. 77 | gain: [torch.float32; [...]], boost or attenuation in decibel. 78 | q: [torch.float32; [...]], quality factor. 79 | Returns: 80 | [torch.float32; [..., windows // 2 + 1]], frequency filter. 81 | """ 82 | w0 = 2 * np.pi * center / self.sr 83 | alpha = torch.sin(w0) / 2 / q 84 | cos_w0 = torch.cos(w0) 85 | A = (gain / 40. * np.log(10)).exp() 86 | return self.biquad( 87 | a=torch.stack([1 + alpha / A, -2 * cos_w0, 1 - alpha / A], dim=-1), 88 | b=torch.stack([1 + alpha * A, -2 * cos_w0, 1 - alpha * A], dim=-1)) -------------------------------------------------------------------------------- /video/modules/video_capture.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from typing import Optional, Tuple, Callable 4 | import platform 5 | import threading 6 | 7 | # Only import Windows-specific library if on Windows 8 | if platform.system() == "Windows": 9 | from pygrabber.dshow_graph import FilterGraph 10 | 11 | 12 | class VideoCapturer: 13 | def __init__(self, device_index: int): 14 | self.device_index = device_index 15 | self.frame_callback = None 16 | self._current_frame = None 17 | self._frame_ready = threading.Event() 18 | self.is_running = False 19 | self.cap = None 20 | 21 | # Initialize Windows-specific components if on Windows 22 | if platform.system() == "Windows": 23 | self.graph = FilterGraph() 24 | # Verify device exists 25 | devices = self.graph.get_input_devices() 26 | if self.device_index >= len(devices): 27 | raise ValueError( 28 | f"Invalid device index {device_index}. Available devices: {len(devices)}" 29 | ) 30 | 31 | def start(self, width: int = 960, height: int = 540, fps: int = 60) -> bool: 32 | """Initialize and start video capture""" 33 | try: 34 | if platform.system() == "Windows": 35 | # Windows-specific capture methods 36 | capture_methods = [ 37 | (self.device_index, cv2.CAP_DSHOW), # Try DirectShow first 38 | (self.device_index, cv2.CAP_ANY), # Then try default backend 39 | (-1, cv2.CAP_ANY), # Try -1 as fallback 40 | (0, cv2.CAP_ANY), # Finally try 0 without specific backend 41 | ] 42 | 43 | for dev_id, backend in capture_methods: 44 | try: 45 | self.cap = cv2.VideoCapture(dev_id, backend) 46 | if self.cap.isOpened(): 47 | break 48 | self.cap.release() 49 | except Exception: 50 | continue 51 | else: 52 | # Unix-like systems (Linux/Mac) capture method 53 | self.cap = cv2.VideoCapture(self.device_index) 54 | 55 | if not self.cap or not self.cap.isOpened(): 56 | raise RuntimeError("Failed to open camera") 57 | 58 | # Configure format 59 | self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, width) 60 | self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height) 61 | self.cap.set(cv2.CAP_PROP_FPS, fps) 62 | 63 | self.is_running = True 64 | return True 65 | 66 | except Exception as e: 67 | print(f"Failed to start capture: {str(e)}") 68 | if self.cap: 69 | self.cap.release() 70 | return False 71 | 72 | def read(self) -> Tuple[bool, Optional[np.ndarray]]: 73 | """Read a frame from the camera""" 74 | if not self.is_running or self.cap is None: 75 | return False, None 76 | 77 | ret, frame = self.cap.read() 78 | if ret: 79 | self._current_frame = frame 80 | if self.frame_callback: 81 | self.frame_callback(frame) 82 | return True, frame 83 | return False, None 84 | 85 | def release(self) -> None: 86 | """Stop capture and release resources""" 87 | if self.is_running and self.cap is not None: 88 | self.cap.release() 89 | self.is_running = False 90 | self.cap = None 91 | 92 | def set_frame_callback(self, callback: Callable[[np.ndarray], None]) -> None: 93 | """Set callback for frame processing""" 94 | self.frame_callback = callback 95 | -------------------------------------------------------------------------------- /audio/alias_free_torch/filter.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import math 8 | 9 | if 'sinc' in dir(torch): 10 | sinc = torch.sinc 11 | else: 12 | # This code is adopted from adefossez's julius.core.sinc under the MIT License 13 | # https://adefossez.github.io/julius/julius/core.html 14 | # LICENSE is in incl_licenses directory. 15 | def sinc(x: torch.Tensor): 16 | """ 17 | Implementation of sinc, i.e. sin(pi * x) / (pi * x) 18 | __Warning__: Different to julius.sinc, the input is multiplied by `pi`! 19 | """ 20 | return torch.where(x == 0, 21 | torch.tensor(1., device=x.device, dtype=x.dtype), 22 | torch.sin(math.pi * x) / math.pi / x) 23 | 24 | 25 | # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License 26 | # https://adefossez.github.io/julius/julius/lowpass.html 27 | # LICENSE is in incl_licenses directory. 28 | def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size] 29 | even = (kernel_size % 2 == 0) 30 | half_size = kernel_size // 2 31 | 32 | #For kaiser window 33 | delta_f = 4 * half_width 34 | A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 35 | if A > 50.: 36 | beta = 0.1102 * (A - 8.7) 37 | elif A >= 21.: 38 | beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.) 39 | else: 40 | beta = 0. 41 | window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) 42 | 43 | # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio 44 | if even: 45 | time = (torch.arange(-half_size, half_size) + 0.5) 46 | else: 47 | time = torch.arange(kernel_size) - half_size 48 | if cutoff == 0: 49 | filter_ = torch.zeros_like(time) 50 | else: 51 | filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) 52 | # Normalize filter to have sum = 1, otherwise we will have a small leakage 53 | # of the constant component in the input signal. 54 | filter_ /= filter_.sum() 55 | filter = filter_.view(1, 1, kernel_size) 56 | 57 | return filter 58 | 59 | 60 | class LowPassFilter1d(nn.Module): 61 | def __init__(self, 62 | cutoff=0.5, 63 | half_width=0.6, 64 | stride: int = 1, 65 | padding: bool = True, 66 | padding_mode: str = 'replicate', 67 | kernel_size: int = 12): 68 | # kernel_size should be even number for stylegan3 setup, 69 | # in this implementation, odd number is also possible. 70 | super().__init__() 71 | if cutoff < -0.: 72 | raise ValueError("Minimum cutoff must be larger than zero.") 73 | if cutoff > 0.5: 74 | raise ValueError("A cutoff above 0.5 does not make sense.") 75 | self.kernel_size = kernel_size 76 | self.even = (kernel_size % 2 == 0) 77 | self.pad_left = kernel_size // 2 - int(self.even) 78 | self.pad_right = kernel_size // 2 79 | self.stride = stride 80 | self.padding = padding 81 | self.padding_mode = padding_mode 82 | filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) 83 | self.register_buffer("filter", filter) 84 | 85 | #input [B, C, T] 86 | def forward(self, x): 87 | _, C, _ = x.shape 88 | 89 | if self.padding: 90 | x = F.pad(x, (self.pad_left, self.pad_right), 91 | mode=self.padding_mode) 92 | out = F.conv1d(x, self.filter.expand(C, -1, -1), 93 | stride=self.stride, groups=C) 94 | 95 | return out -------------------------------------------------------------------------------- /video/modules/processors/frame/face_enhancer.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List 2 | import cv2 3 | import threading 4 | import gfpgan 5 | import os 6 | 7 | import modules.globals 8 | import modules.processors.frame.core 9 | from modules.core import update_status 10 | from modules.face_analyser import get_one_face 11 | from modules.typing import Frame, Face 12 | import platform 13 | import torch 14 | from modules.utilities import ( 15 | conditional_download, 16 | is_image, 17 | is_video, 18 | ) 19 | 20 | GFPGAN_PATH = "" 21 | FACE_ENHANCER = None 22 | FACE_ENHANCER_UPSCALE = 0.4 23 | THREAD_SEMAPHORE = threading.Semaphore() 24 | THREAD_LOCK = threading.Lock() 25 | NAME = "DLC.FACE-ENHANCER" 26 | 27 | abs_dir = os.path.dirname(os.path.abspath(__file__)) 28 | models_dir = os.path.join( 29 | os.path.dirname(os.path.dirname(os.path.dirname(abs_dir))), "models" 30 | ) 31 | 32 | 33 | def pre_check() -> bool: 34 | download_directory_path = models_dir 35 | conditional_download( 36 | download_directory_path, 37 | [ 38 | "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/GFPGANv1.4.pth" 39 | ], 40 | ) 41 | return True 42 | 43 | 44 | def pre_start() -> bool: 45 | if not is_image(modules.globals.target_path) and not is_video( 46 | modules.globals.target_path 47 | ): 48 | update_status("Select an image or video for target path.", NAME) 49 | return False 50 | return True 51 | 52 | 53 | def get_face_enhancer() -> Any: 54 | global FACE_ENHANCER, FACE_ENHANCER_UPSCALE, GFPGAN_PATH 55 | 56 | with THREAD_LOCK: 57 | if FACE_ENHANCER is None: 58 | if not os.path.isabs(GFPGAN_PATH): 59 | model_path = os.path.join(models_dir, os.path.basename(GFPGAN_PATH)) 60 | else: 61 | model_path = GFPGAN_PATH 62 | 63 | match platform.system(): 64 | case "Darwin": # Mac OS 65 | if torch.backends.mps.is_available(): 66 | mps_device = torch.device("mps") 67 | FACE_ENHANCER = gfpgan.GFPGANer(model_path=model_path, upscale=FACE_ENHANCER_UPSCALE, device=mps_device) # type: ignore[attr-defined] 68 | else: 69 | FACE_ENHANCER = gfpgan.GFPGANer(model_path=model_path, upscale=FACE_ENHANCER_UPSCALE) # type: ignore[attr-defined] 70 | case _: # Other OS 71 | FACE_ENHANCER = gfpgan.GFPGANer(model_path=model_path, upscale=FACE_ENHANCER_UPSCALE) # type: ignore[attr-defined] 72 | 73 | return FACE_ENHANCER 74 | 75 | 76 | def enhance_face(temp_frame: Frame) -> Frame: 77 | with THREAD_SEMAPHORE: 78 | _, _, temp_frame = get_face_enhancer().enhance(temp_frame, paste_back=True) 79 | return temp_frame 80 | 81 | 82 | def process_frame(source_face: Face, temp_frame: Frame) -> Frame: 83 | target_face = get_one_face(temp_frame) 84 | if target_face: 85 | temp_frame = enhance_face(temp_frame) 86 | return temp_frame 87 | 88 | 89 | def process_frames( 90 | source_path: str, temp_frame_paths: List[str], progress: Any = None 91 | ) -> None: 92 | for temp_frame_path in temp_frame_paths: 93 | temp_frame = cv2.imread(temp_frame_path) 94 | result = process_frame(None, temp_frame) 95 | cv2.imwrite(temp_frame_path, result) 96 | if progress: 97 | progress.update(1) 98 | 99 | 100 | def process_image(source_path: str, target_path: str, output_path: str) -> None: 101 | target_frame = cv2.imread(target_path) 102 | result = process_frame(None, target_frame) 103 | cv2.imwrite(output_path, result) 104 | 105 | 106 | def process_video(source_path: str, temp_frame_paths: List[str]) -> None: 107 | modules.processors.frame.core.process_video(None, temp_frame_paths, process_frames) 108 | 109 | 110 | def process_frame_v2(temp_frame: Frame) -> Frame: 111 | target_face = get_one_face(temp_frame) 112 | if target_face: 113 | temp_frame = enhance_face(temp_frame) 114 | return temp_frame 115 | -------------------------------------------------------------------------------- /audio/utils/data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchaudio 4 | from torchaudio.transforms import MelSpectrogram 5 | from module.utils import parse_filelist 6 | from torch.nn import functional as F 7 | np.random.seed(1234) 8 | 9 | class AudioDataset(torch.utils.data.Dataset): 10 | """ 11 | Provides dataset management for given filelist. 12 | """ 13 | def __init__(self, config, training=True): 14 | super(AudioDataset, self).__init__() 15 | self.config = config 16 | self.hop_length = config.data.hop_length 17 | self.training = training 18 | self.mel_length = config.train.segment_size // config.data.hop_length 19 | self.segment_length = config.train.segment_size 20 | self.sample_rate = config.data.sampling_rate 21 | 22 | self.filelist_path = config.data.train_filelist_path \ 23 | if self.training else config.data.test_filelist_path 24 | self.audio_paths = parse_filelist(self.filelist_path) \ 25 | if self.training else parse_filelist(self.filelist_path)[:101] 26 | 27 | self.f0_norm_paths = parse_filelist(self.filelist_path.replace('_wav', '_f0_norm')) 28 | self.f0_paths = parse_filelist(self.filelist_path.replace('_wav', '_f0')) 29 | 30 | 31 | def load_audio_to_torch(self, audio_path): 32 | audio, sample_rate = torchaudio.load(audio_path) 33 | 34 | if not self.training: 35 | p = (audio.shape[-1] // 1280 + 1) * 1280 - audio.shape[-1] 36 | audio = F.pad(audio, (0, p), mode='constant').data 37 | return audio.squeeze(), sample_rate 38 | 39 | def __getitem__(self, index): 40 | audio_path = self.audio_paths[index] 41 | f0_norm_path = self.f0_norm_paths[index] 42 | f0_path = self.f0_paths[index] 43 | 44 | audio, sample_rate = self.load_audio_to_torch(audio_path) 45 | f0_norm = torch.load(f0_norm_path) 46 | f0 = torch.load(f0_path) 47 | 48 | assert sample_rate == self.sample_rate, \ 49 | f"""Got path to audio of sampling rate {sample_rate}, \ 50 | but required {self.sample_rate} according config.""" 51 | 52 | if not self.training: 53 | return audio, f0_norm, f0 54 | 55 | if audio.shape[-1] > self.segment_length: 56 | max_f0_start = f0.shape[-1] - self.segment_length//80 57 | 58 | f0_start = np.random.randint(0, max_f0_start) 59 | f0_norm_seg = f0_norm[:, f0_start:f0_start + self.segment_length // 80] 60 | f0_seg = f0[:, f0_start:f0_start + self.segment_length // 80] 61 | 62 | audio_start = f0_start*80 63 | segment = audio[audio_start:audio_start + self.segment_length] 64 | 65 | if segment.shape[-1] < self.segment_length: 66 | segment = F.pad(segment, (0, self.segment_length - segment.shape[-1]), 'constant') 67 | length = torch.LongTensor([self.mel_length]) 68 | 69 | else: 70 | segment = F.pad(audio, (0, self.segment_length - audio.shape[-1]), 'constant') 71 | length = torch.LongTensor([audio.shape[-1] // self.hop_length]) 72 | 73 | f0_norm_seg = F.pad(f0_norm, (0, self.segment_length // 80 - f0_norm.shape[-1]), 'constant') 74 | 75 | f0_seg = F.pad(f0, (0, self.segment_length // 80 - f0.shape[-1]), 'constant') 76 | 77 | return segment, f0_norm_seg, f0_seg, length 78 | 79 | def __len__(self): 80 | return len(self.audio_paths) 81 | 82 | def sample_test_batch(self, size): 83 | idx = np.random.choice(range(len(self)), size=size, replace=False) 84 | test_batch = [] 85 | for index in idx: 86 | test_batch.append(self.__getitem__(index)) 87 | return test_batch 88 | 89 | class MelSpectrogramFixed(torch.nn.Module): 90 | """In order to remove padding of torchaudio package + add log10 scale.""" 91 | 92 | def __init__(self, **kwargs): 93 | super(MelSpectrogramFixed, self).__init__() 94 | self.torchaudio_backend = MelSpectrogram(**kwargs) 95 | 96 | def forward(self, x): 97 | outputs = torch.log(self.torchaudio_backend(x) + 0.001) 98 | 99 | return outputs[..., :-1] -------------------------------------------------------------------------------- /video/wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | import numpy as np 4 | from insightface.app import FaceAnalysis 5 | import insightface 6 | from PIL import Image 7 | from modules import face_analyser 8 | from modules.processors.frame import face_enhancer, face_swapper 9 | import time 10 | import os 11 | 12 | torch.cuda.empty_cache() 13 | 14 | 15 | class Wrapper: 16 | def __init__( 17 | self, 18 | source_image="./image.jpg", 19 | gfpgan_path="models/GFPGANv1.3.pth", 20 | inswapper_path="models/inswapper_128_fp16.onnx", 21 | upscale=0.4, 22 | disable_face_enhancement=False, 23 | ): 24 | self.face_analyzer = FaceAnalysis( 25 | name="buffalo_l", 26 | providers=[ 27 | "CUDAExecutionProvider", 28 | "CoreMLExecutionProvider", 29 | "CPUExecutionProvider", 30 | ], 31 | provider_options=[{"device_id": 0}, {"device_id": 0}, {"device_id": 0}], 32 | ) 33 | self.face_analyzer.prepare(ctx_id=0, det_size=(640, 640), det_thresh=0.5) 34 | 35 | self.face_swapper = insightface.model_zoo.get_model( 36 | inswapper_path, 37 | providers=[ 38 | "CUDAExecutionProvider", 39 | "CoreMLExecutionProvider", 40 | "CPUExecutionProvider", 41 | ], 42 | provider_options=[{"device_id": 0}, {"device_id": 0}, {"device_id": 0}], 43 | ) 44 | face_swapper.INSWAPPER_PATH = inswapper_path 45 | 46 | self.disable_face_enhancement = disable_face_enhancement 47 | if not self.disable_face_enhancement: 48 | self.face_enhancer = self.load_model(gfpgan_path) 49 | face_enhancer.FACE_ENHANCER_UPSCALE = upscale 50 | face_enhancer.GFPGAN_PATH = gfpgan_path 51 | 52 | self.source_face = face_analyser.get_one_face(cv2.imread(source_image)) 53 | 54 | def load_model(self, path): 55 | if torch.cuda.is_available(): 56 | device = "cuda" 57 | torch.cuda.set_device(0) 58 | elif torch.backends.mps.is_available(): 59 | device = "mps" 60 | else: 61 | device = "cpu" 62 | 63 | print(f"Loading model on: {device.upper()}") 64 | model = torch.load(path, map_location=device) 65 | return model 66 | 67 | def update_config(self, new_image_path, new_upscale, new_disable_face): 68 | print("Updating configuration...") 69 | new_image = cv2.imread(new_image_path) 70 | if new_image is not None: 71 | self.source_face = face_analyser.get_one_face(new_image) 72 | print("Source face updated.") 73 | else: 74 | print("Error: Unable to load new image.") 75 | 76 | face_enhancer.FACE_ENHANCER_UPSCALE = new_upscale 77 | print(f"Upscale factor updated to {new_upscale}") 78 | self.disable_face_enhancement = new_disable_face 79 | print(f"Disable face enhancement value updated to {new_disable_face}") 80 | 81 | def generate(self, frame): 82 | start_time = time.time() 83 | target_face = face_analyser.get_one_face(frame) 84 | elapsed_time = time.time() - start_time 85 | print(f"1. Face detection: {elapsed_time:.4f} seconds") 86 | 87 | start_time = time.time() 88 | if self.source_face and target_face: 89 | tmp_frame = face_swapper.swap_face(self.source_face, target_face, frame) 90 | else: 91 | tmp_frame = frame 92 | elapsed_time = time.time() - start_time 93 | print(f"2. Face swapper: {elapsed_time:.4f} seconds") 94 | 95 | if not self.disable_face_enhancement: 96 | start_time = time.time() 97 | processed_frame = face_enhancer.process_frame(None, tmp_frame) 98 | elapsed_time = time.time() - start_time 99 | print(f"3. Face enhancer: {elapsed_time:.4f} seconds") 100 | else: 101 | processed_frame = tmp_frame 102 | 103 | if isinstance(processed_frame, Image.Image): 104 | processed_frame = np.array(processed_frame) 105 | 106 | if processed_frame is not None and isinstance(processed_frame, np.ndarray): 107 | return processed_frame 108 | else: 109 | print("Error: Processed frame is invalid.") 110 | return frame 111 | -------------------------------------------------------------------------------- /audio/vocoder/activations.py: -------------------------------------------------------------------------------- 1 | # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch 5 | from torch import nn, sin, pow 6 | from torch.nn import Parameter 7 | 8 | 9 | class Snake(nn.Module): 10 | ''' 11 | Implementation of a sine-based periodic activation function 12 | Shape: 13 | - Input: (B, C, T) 14 | - Output: (B, C, T), same shape as the input 15 | Parameters: 16 | - alpha - trainable parameter 17 | References: 18 | - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: 19 | https://arxiv.org/abs/2006.08195 20 | Examples: 21 | >>> a1 = snake(256) 22 | >>> x = torch.randn(256) 23 | >>> x = a1(x) 24 | ''' 25 | def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): 26 | ''' 27 | Initialization. 28 | INPUT: 29 | - in_features: shape of the input 30 | - alpha: trainable parameter 31 | alpha is initialized to 1 by default, higher values = higher-frequency. 32 | alpha will be trained along with the rest of your model. 33 | ''' 34 | super(Snake, self).__init__() 35 | self.in_features = in_features 36 | 37 | # initialize alpha 38 | self.alpha_logscale = alpha_logscale 39 | if self.alpha_logscale: # log scale alphas initialized to zeros 40 | self.alpha = Parameter(torch.zeros(in_features) * alpha) 41 | else: # linear scale alphas initialized to ones 42 | self.alpha = Parameter(torch.ones(in_features) * alpha) 43 | 44 | self.alpha.requires_grad = alpha_trainable 45 | 46 | self.no_div_by_zero = 0.000000001 47 | 48 | def forward(self, x): 49 | ''' 50 | Forward pass of the function. 51 | Applies the function to the input elementwise. 52 | Snake ∶= x + 1/a * sin^2 (xa) 53 | ''' 54 | alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] 55 | if self.alpha_logscale: 56 | alpha = torch.exp(alpha) 57 | x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) 58 | 59 | return x 60 | 61 | 62 | class SnakeBeta(nn.Module): 63 | ''' 64 | A modified Snake function which uses separate parameters for the magnitude of the periodic components 65 | Shape: 66 | - Input: (B, C, T) 67 | - Output: (B, C, T), same shape as the input 68 | Parameters: 69 | - alpha - trainable parameter that controls frequency 70 | - beta - trainable parameter that controls magnitude 71 | References: 72 | - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: 73 | https://arxiv.org/abs/2006.08195 74 | Examples: 75 | >>> a1 = snakebeta(256) 76 | >>> x = torch.randn(256) 77 | >>> x = a1(x) 78 | ''' 79 | def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): 80 | ''' 81 | Initialization. 82 | INPUT: 83 | - in_features: shape of the input 84 | - alpha - trainable parameter that controls frequency 85 | - beta - trainable parameter that controls magnitude 86 | alpha is initialized to 1 by default, higher values = higher-frequency. 87 | beta is initialized to 1 by default, higher values = higher-magnitude. 88 | alpha will be trained along with the rest of your model. 89 | ''' 90 | super(SnakeBeta, self).__init__() 91 | self.in_features = in_features 92 | 93 | # initialize alpha 94 | self.alpha_logscale = alpha_logscale 95 | if self.alpha_logscale: # log scale alphas initialized to zeros 96 | self.alpha = Parameter(torch.zeros(in_features) * alpha) 97 | self.beta = Parameter(torch.zeros(in_features) * alpha) 98 | else: # linear scale alphas initialized to ones 99 | self.alpha = Parameter(torch.ones(in_features) * alpha) 100 | self.beta = Parameter(torch.ones(in_features) * alpha) 101 | 102 | self.alpha.requires_grad = alpha_trainable 103 | self.beta.requires_grad = alpha_trainable 104 | 105 | self.no_div_by_zero = 0.000000001 106 | 107 | def forward(self, x): 108 | ''' 109 | Forward pass of the function. 110 | Applies the function to the input elementwise. 111 | SnakeBeta ∶= x + 1/b * sin^2 (xa) 112 | ''' 113 | alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] 114 | beta = self.beta.unsqueeze(0).unsqueeze(-1) 115 | if self.alpha_logscale: 116 | alpha = torch.exp(alpha) 117 | beta = torch.exp(beta) 118 | x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) 119 | 120 | return x 121 | -------------------------------------------------------------------------------- /video/server.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import zmq 3 | import msgpack 4 | import msgpack_numpy as m 5 | import numpy as np 6 | from wrapper import Wrapper 7 | import time 8 | import argparse 9 | import os 10 | import json 11 | 12 | m.patch() 13 | 14 | ZMQ_RECEIVE_ADDRESS = "tcp://0.0.0.0:5558" 15 | ZMQ_SEND_ADDRESS = "tcp://0.0.0.0:5559" 16 | ZMQ_UPDATE_ADDRESS = "tcp://0.0.0.0:5560" 17 | 18 | 19 | def process_frame(frame, wrapper): 20 | frame_resized = cv2.resize(frame, (1280, 720), interpolation=cv2.INTER_AREA) 21 | processed_frame = wrapper.generate(frame_resized) 22 | return processed_frame 23 | 24 | 25 | def main(): 26 | parser = argparse.ArgumentParser(description="Real-Time Deepfake Pipeline Server") 27 | parser.add_argument( 28 | "--source_image", 29 | type=str, 30 | default="./image.jpg", 31 | help="Path to the source image.", 32 | ) 33 | parser.add_argument( 34 | "--gfpgan_path", 35 | type=str, 36 | default="models/GFPGANv1.3.pth", 37 | help="Path to the GFPGAN model file.", 38 | ) 39 | parser.add_argument( 40 | "--inswapper_path", 41 | type=str, 42 | default="models/inswapper_128_fp16.onnx", 43 | help="Path to the inswapper model file.", 44 | ) 45 | parser.add_argument( 46 | "--upscale", 47 | type=float, 48 | default=0.4, 49 | help="Upscale factor for GFPGAN face enhancement.", 50 | ) 51 | parser.add_argument( 52 | "--disable_face_enhancement", 53 | action="store_true", 54 | help="Disable face enhancement (GFPGAN) if set.", 55 | ) 56 | args = parser.parse_args() 57 | 58 | print("Initializing wrapper...") 59 | wrapper = Wrapper( 60 | source_image=args.source_image, 61 | gfpgan_path=args.gfpgan_path, 62 | inswapper_path=args.inswapper_path, 63 | upscale=args.upscale, 64 | disable_face_enhancement=args.disable_face_enhancement, 65 | ) 66 | 67 | context = zmq.Context() 68 | receiver = context.socket(zmq.PULL) 69 | receiver.bind(ZMQ_RECEIVE_ADDRESS) 70 | sender = context.socket(zmq.PUSH) 71 | sender.bind(ZMQ_SEND_ADDRESS) 72 | update_socket = context.socket(zmq.REP) 73 | update_socket.bind(ZMQ_UPDATE_ADDRESS) 74 | 75 | poller = zmq.Poller() 76 | poller.register(receiver, zmq.POLLIN) 77 | poller.register(update_socket, zmq.POLLIN) 78 | 79 | print("Server is running and waiting for frames and update commands...") 80 | 81 | try: 82 | while True: 83 | socks = dict(poller.poll(timeout=50)) 84 | 85 | if update_socket in socks and socks[update_socket] == zmq.POLLIN: 86 | message = update_socket.recv() 87 | try: 88 | update_data = json.loads(message.decode("utf-8")) 89 | new_image_path = update_data.get("source_image") 90 | new_upscale = update_data.get("upscale") 91 | new_disable_face = update_data.get("disable_face_enhancement") 92 | if new_image_path and new_upscale and new_disable_face is not None: 93 | print("Received update command:") 94 | print(f" - New source image: {new_image_path}") 95 | print(f" - New upscale factor: {new_upscale}") 96 | print( 97 | f" - New disable face enhancement value: {new_disable_face}" 98 | ) 99 | wrapper.update_config( 100 | new_image_path, float(new_upscale), new_disable_face 101 | ) 102 | update_socket.send_string("Update successful") 103 | else: 104 | update_socket.send_string("Invalid update command") 105 | except Exception as e: 106 | print(f"Error processing update command: {e}") 107 | update_socket.send_string("Error in update command") 108 | 109 | if receiver in socks and socks[receiver] == zmq.POLLIN: 110 | start_time = time.time() 111 | data = receiver.recv() 112 | compressed_frame = np.frombuffer(msgpack.unpackb(data), dtype=np.uint8) 113 | frame = cv2.imdecode(compressed_frame, cv2.IMREAD_REDUCED_COLOR_2) 114 | 115 | processed_frame = process_frame(frame, wrapper) 116 | processed_frame = cv2.resize( 117 | processed_frame, 118 | (frame.shape[1], frame.shape[0]), 119 | interpolation=cv2.INTER_CUBIC, 120 | ) 121 | 122 | _, encoded_frame = cv2.imencode( 123 | ".jpg", processed_frame, [cv2.IMWRITE_JPEG_QUALITY, 80] 124 | ) 125 | sender.send(msgpack.packb(encoded_frame.tobytes())) 126 | elapsed_time = time.time() - start_time 127 | print(f"Processing time: {elapsed_time:.4f} seconds") 128 | 129 | except KeyboardInterrupt: 130 | print("Server interrupted.") 131 | 132 | 133 | if __name__ == "__main__": 134 | main() 135 | -------------------------------------------------------------------------------- /video/modules/ui.json: -------------------------------------------------------------------------------- 1 | { 2 | "CTk": { 3 | "fg_color": ["gray95", "gray10"] 4 | }, 5 | "CTkToplevel": { 6 | "fg_color": ["gray95", "gray10"] 7 | }, 8 | "CTkFrame": { 9 | "corner_radius": 0, 10 | "border_width": 0, 11 | "fg_color": ["gray90", "gray13"], 12 | "top_fg_color": ["gray85", "gray16"], 13 | "border_color": ["gray65", "gray28"] 14 | }, 15 | "CTkButton": { 16 | "corner_radius": 0, 17 | "border_width": 0, 18 | "fg_color": ["#2aa666", "#1f538d"], 19 | "hover_color": ["#3cb666", "#14375e"], 20 | "border_color": ["#3e4a40", "#949A9F"], 21 | "text_color": ["#f3faf6", "#f3faf6"], 22 | "text_color_disabled": ["gray74", "gray60"] 23 | }, 24 | "CTkLabel": { 25 | "corner_radius": 0, 26 | "fg_color": "transparent", 27 | "text_color": ["gray14", "gray84"] 28 | }, 29 | "CTkEntry": { 30 | "corner_radius": 0, 31 | "border_width": 2, 32 | "fg_color": ["#F9F9FA", "#343638"], 33 | "border_color": ["#979DA2", "#565B5E"], 34 | "text_color": ["gray14", "gray84"], 35 | "placeholder_text_color": ["gray52", "gray62"] 36 | }, 37 | "CTkCheckbox": { 38 | "corner_radius": 0, 39 | "border_width": 3, 40 | "fg_color": ["#2aa666", "#1f538d"], 41 | "border_color": ["#3e4a40", "#949A9F"], 42 | "hover_color": ["#3cb666", "#14375e"], 43 | "checkmark_color": ["#f3faf6", "gray90"], 44 | "text_color": ["gray14", "gray84"], 45 | "text_color_disabled": ["gray60", "gray45"] 46 | }, 47 | "CTkSwitch": { 48 | "corner_radius": 1000, 49 | "border_width": 3, 50 | "button_length": 0, 51 | "fg_color": ["#939BA2", "#4A4D50"], 52 | "progress_color": ["#2aa666", "#1f538d"], 53 | "button_color": ["gray36", "#D5D9DE"], 54 | "button_hover_color": ["gray20", "gray100"], 55 | "text_color": ["gray14", "gray84"], 56 | "text_color_disabled": ["gray60", "gray45"] 57 | }, 58 | "CTkRadiobutton": { 59 | "corner_radius": 1000, 60 | "border_width_checked": 6, 61 | "border_width_unchecked": 3, 62 | "fg_color": ["#2aa666", "#1f538d"], 63 | "border_color": ["#3e4a40", "#949A9F"], 64 | "hover_color": ["#3cb666", "#14375e"], 65 | "text_color": ["gray14", "gray84"], 66 | "text_color_disabled": ["gray60", "gray45"] 67 | }, 68 | "CTkProgressBar": { 69 | "corner_radius": 1000, 70 | "border_width": 0, 71 | "fg_color": ["#939BA2", "#4A4D50"], 72 | "progress_color": ["#2aa666", "#1f538d"], 73 | "border_color": ["gray", "gray"] 74 | }, 75 | "CTkSlider": { 76 | "corner_radius": 1000, 77 | "button_corner_radius": 1000, 78 | "border_width": 6, 79 | "button_length": 0, 80 | "fg_color": ["#939BA2", "#4A4D50"], 81 | "progress_color": ["gray40", "#AAB0B5"], 82 | "button_color": ["#2aa666", "#1f538d"], 83 | "button_hover_color": ["#3cb666", "#14375e"] 84 | }, 85 | "CTkOptionMenu": { 86 | "corner_radius": 0, 87 | "fg_color": ["#2aa666", "#1f538d"], 88 | "button_color": ["#3cb666", "#14375e"], 89 | "button_hover_color": ["#234567", "#1e2c40"], 90 | "text_color": ["#f3faf6", "#f3faf6"], 91 | "text_color_disabled": ["gray74", "gray60"] 92 | }, 93 | "CTkComboBox": { 94 | "corner_radius": 0, 95 | "border_width": 2, 96 | "fg_color": ["#F9F9FA", "#343638"], 97 | "border_color": ["#979DA2", "#565B5E"], 98 | "button_color": ["#979DA2", "#565B5E"], 99 | "button_hover_color": ["#6E7174", "#7A848D"], 100 | "text_color": ["gray14", "gray84"], 101 | "text_color_disabled": ["gray50", "gray45"] 102 | }, 103 | "CTkScrollbar": { 104 | "corner_radius": 1000, 105 | "border_spacing": 4, 106 | "fg_color": "transparent", 107 | "button_color": ["gray55", "gray41"], 108 | "button_hover_color": ["gray40", "gray53"] 109 | }, 110 | "CTkSegmentedButton": { 111 | "corner_radius": 0, 112 | "border_width": 2, 113 | "fg_color": ["#979DA2", "gray29"], 114 | "selected_color": ["#2aa666", "#1f538d"], 115 | "selected_hover_color": ["#3cb666", "#14375e"], 116 | "unselected_color": ["#979DA2", "gray29"], 117 | "unselected_hover_color": ["gray70", "gray41"], 118 | "text_color": ["#f3faf6", "#f3faf6"], 119 | "text_color_disabled": ["gray74", "gray60"] 120 | }, 121 | "CTkTextbox": { 122 | "corner_radius": 0, 123 | "border_width": 0, 124 | "fg_color": ["gray100", "gray20"], 125 | "border_color": ["#979DA2", "#565B5E"], 126 | "text_color": ["gray14", "gray84"], 127 | "scrollbar_button_color": ["gray55", "gray41"], 128 | "scrollbar_button_hover_color": ["gray40", "gray53"] 129 | }, 130 | "CTkScrollableFrame": { 131 | "label_fg_color": ["gray80", "gray21"] 132 | }, 133 | "DropdownMenu": { 134 | "fg_color": ["gray90", "gray20"], 135 | "hover_color": ["gray75", "gray28"], 136 | "text_color": ["gray14", "gray84"] 137 | }, 138 | "CTkFont": { 139 | "macOS": { 140 | "family": "Avenir", 141 | "size": 18, 142 | "weight": "normal" 143 | }, 144 | "Windows": { 145 | "family": "Corbel", 146 | "size": 18, 147 | "weight": "normal" 148 | }, 149 | "Linux": { 150 | "family": "Montserrat", 151 | "size": 18, 152 | "weight": "normal" 153 | } 154 | }, 155 | "URL": { 156 | "text_color": ["gray74", "gray60"] 157 | } 158 | } 159 | -------------------------------------------------------------------------------- /audio/module/commons.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | 8 | def init_weights(m, mean=0.0, std=0.01): 9 | classname = m.__class__.__name__ 10 | if classname.find("Conv") != -1: 11 | m.weight.data.normal_(mean, std) 12 | 13 | 14 | def get_padding(kernel_size, dilation=1): 15 | return int((kernel_size*dilation - dilation)/2) 16 | 17 | 18 | def convert_pad_shape(pad_shape): 19 | l = pad_shape[::-1] 20 | pad_shape = [item for sublist in l for item in sublist] 21 | return pad_shape 22 | 23 | 24 | def intersperse(lst, item): 25 | result = [item] * (len(lst) * 2 + 1) 26 | result[1::2] = lst 27 | return result 28 | 29 | 30 | def kl_divergence(m_p, logs_p, m_q, logs_q): 31 | """KL(P||Q)""" 32 | kl = (logs_q - logs_p) - 0.5 33 | kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q) 34 | return kl 35 | 36 | 37 | def rand_gumbel(shape): 38 | """Sample from the Gumbel distribution, protect from overflows.""" 39 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 40 | return -torch.log(-torch.log(uniform_samples)) 41 | 42 | 43 | def rand_gumbel_like(x): 44 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) 45 | return g 46 | 47 | 48 | def slice_segments(x, ids_str, segment_size=4): 49 | ret = torch.zeros_like(x[:, :, :segment_size]) 50 | for i in range(x.size(0)): 51 | idx_str = ids_str[i] 52 | idx_end = idx_str + segment_size 53 | ret[i] = x[i, :, idx_str:idx_end] 54 | return ret 55 | 56 | def slice_segments_audio(x, ids_str, segment_size=4): 57 | ret = torch.zeros_like(x[:, :segment_size]) 58 | for i in range(x.size(0)): 59 | idx_str = ids_str[i] 60 | idx_end = idx_str + segment_size 61 | ret[i] = x[i, idx_str:idx_end] 62 | return ret 63 | 64 | def rand_slice_segments(x, x_lengths=None, segment_size=4): 65 | b, d, t = x.size() 66 | if x_lengths is None: 67 | x_lengths = t 68 | ids_str_max = x_lengths - segment_size + 1 69 | ids_str = ((torch.rand([b]).to(device=x.device) * ids_str_max).clip(0)).to(dtype=torch.long) 70 | ret = slice_segments(x, ids_str, segment_size) 71 | return ret, ids_str 72 | 73 | 74 | def get_timing_signal_1d( 75 | length, channels, min_timescale=1.0, max_timescale=1.0e4): 76 | position = torch.arange(length, dtype=torch.float) 77 | num_timescales = channels // 2 78 | log_timescale_increment = ( 79 | math.log(float(max_timescale) / float(min_timescale)) / 80 | (num_timescales - 1)) 81 | inv_timescales = min_timescale * torch.exp( 82 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment) 83 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) 84 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) 85 | signal = F.pad(signal, [0, 0, 0, channels % 2]) 86 | signal = signal.view(1, channels, length) 87 | return signal 88 | 89 | 90 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): 91 | b, channels, length = x.size() 92 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 93 | return x + signal.to(dtype=x.dtype, device=x.device) 94 | 95 | 96 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): 97 | b, channels, length = x.size() 98 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 99 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) 100 | 101 | 102 | def subsequent_mask(length): 103 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) 104 | return mask 105 | 106 | 107 | @torch.jit.script 108 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 109 | n_channels_int = n_channels[0] 110 | in_act = input_a + input_b 111 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 112 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 113 | acts = t_act * s_act 114 | return acts 115 | 116 | 117 | def convert_pad_shape(pad_shape): 118 | l = pad_shape[::-1] 119 | pad_shape = [item for sublist in l for item in sublist] 120 | return pad_shape 121 | 122 | 123 | def shift_1d(x): 124 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] 125 | return x 126 | 127 | 128 | def sequence_mask(length, max_length=None): 129 | if max_length is None: 130 | max_length = length.max() 131 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 132 | return x.unsqueeze(0) < length.unsqueeze(1) 133 | 134 | 135 | def generate_path(duration, mask): 136 | """ 137 | duration: [b, 1, t_x] 138 | mask: [b, 1, t_y, t_x] 139 | """ 140 | device = duration.device 141 | 142 | b, _, t_y, t_x = mask.shape 143 | cum_duration = torch.cumsum(duration, -1) 144 | 145 | cum_duration_flat = cum_duration.view(b * t_x) 146 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 147 | path = path.view(b, t_x, t_y) 148 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 149 | path = path.unsqueeze(1).transpose(2,3) * mask 150 | return path 151 | 152 | 153 | def clip_grad_value_(parameters, clip_value, norm_type=2): 154 | if isinstance(parameters, torch.Tensor): 155 | parameters = [parameters] 156 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 157 | norm_type = float(norm_type) 158 | if clip_value is not None: 159 | clip_value = float(clip_value) 160 | 161 | total_norm = 0 162 | for p in parameters: 163 | param_norm = p.grad.data.norm(norm_type) 164 | total_norm += param_norm.item() ** norm_type 165 | if clip_value is not None: 166 | p.grad.data.clamp_(min=-clip_value, max=clip_value) 167 | total_norm = total_norm ** (1. / norm_type) 168 | return total_norm 169 | -------------------------------------------------------------------------------- /audio/model/diffusion_module.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from einops import rearrange 4 | 5 | from model.base import BaseModule 6 | 7 | 8 | class Mish(BaseModule): 9 | def forward(self, x): 10 | return x * torch.tanh(torch.nn.functional.softplus(x)) 11 | 12 | 13 | class Upsample(BaseModule): 14 | def __init__(self, dim): 15 | super(Upsample, self).__init__() 16 | self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1) 17 | 18 | def forward(self, x): 19 | return self.conv(x) 20 | 21 | 22 | class Downsample(BaseModule): 23 | def __init__(self, dim): 24 | super(Downsample, self).__init__() 25 | self.conv = torch.nn.Conv2d(dim, dim, 3, 2, 1) 26 | 27 | def forward(self, x): 28 | return self.conv(x) 29 | 30 | 31 | class Rezero(BaseModule): 32 | def __init__(self, fn): 33 | super(Rezero, self).__init__() 34 | self.fn = fn 35 | self.g = torch.nn.Parameter(torch.zeros(1)) 36 | 37 | def forward(self, x): 38 | return self.fn(x) * self.g 39 | 40 | 41 | class Block(BaseModule): 42 | def __init__(self, dim, dim_out, groups=8): 43 | super(Block, self).__init__() 44 | self.block = torch.nn.Sequential(torch.nn.Conv2d(dim, dim_out, 3, 45 | padding=1), torch.nn.GroupNorm( 46 | groups, dim_out), Mish()) 47 | 48 | def forward(self, x, mask): 49 | output = self.block(x * mask) 50 | return output * mask 51 | 52 | 53 | class ResnetBlock(BaseModule): 54 | def __init__(self, dim, dim_out, time_emb_dim, groups=8): 55 | super(ResnetBlock, self).__init__() 56 | self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, 57 | dim_out)) 58 | 59 | self.block1 = Block(dim, dim_out, groups=groups) 60 | self.block2 = Block(dim_out, dim_out, groups=groups) 61 | if dim != dim_out: 62 | self.res_conv = torch.nn.Conv2d(dim, dim_out, 1) 63 | else: 64 | self.res_conv = torch.nn.Identity() 65 | 66 | def forward(self, x, mask, time_emb): 67 | h = self.block1(x, mask) 68 | h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1) 69 | h = self.block2(h, mask) 70 | output = h + self.res_conv(x * mask) 71 | return output 72 | 73 | 74 | class LinearAttention(BaseModule): 75 | def __init__(self, dim, heads=4, dim_head=32): 76 | super(LinearAttention, self).__init__() 77 | self.heads = heads 78 | hidden_dim = dim_head * heads 79 | self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) 80 | self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1) 81 | 82 | def forward(self, x): 83 | b, c, h, w = x.shape 84 | qkv = self.to_qkv(x) 85 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', 86 | heads = self.heads, qkv=3) 87 | k = k.softmax(dim=-1) 88 | context = torch.einsum('bhdn,bhen->bhde', k, v) 89 | out = torch.einsum('bhde,bhdn->bhen', context, q) 90 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', 91 | heads=self.heads, h=h, w=w) 92 | return self.to_out(out) 93 | 94 | 95 | class Residual(BaseModule): 96 | def __init__(self, fn): 97 | super(Residual, self).__init__() 98 | self.fn = fn 99 | 100 | def forward(self, x, *args, **kwargs): 101 | output = self.fn(x, *args, **kwargs) + x 102 | return output 103 | 104 | 105 | class SinusoidalPosEmb(BaseModule): 106 | def __init__(self, dim): 107 | super(SinusoidalPosEmb, self).__init__() 108 | self.dim = dim 109 | 110 | def forward(self, x): 111 | device = x.device 112 | half_dim = self.dim // 2 113 | emb = math.log(10000) / (half_dim - 1) 114 | emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) 115 | emb = 1000.0 * x.unsqueeze(1) * emb.unsqueeze(0) 116 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 117 | return emb 118 | 119 | 120 | class RefBlock(BaseModule): 121 | def __init__(self, out_dim, time_emb_dim): 122 | super(RefBlock, self).__init__() 123 | base_dim = out_dim // 4 124 | self.mlp1 = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, 125 | base_dim)) 126 | self.mlp2 = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, 127 | 2 * base_dim)) 128 | self.block11 = torch.nn.Sequential(torch.nn.Conv2d(1, 2 * base_dim, 129 | 3, 1, 1), torch.nn.InstanceNorm2d(2 * base_dim, affine=True), 130 | torch.nn.GLU(dim=1)) 131 | self.block12 = torch.nn.Sequential(torch.nn.Conv2d(base_dim, 2 * base_dim, 132 | 3, 1, 1), torch.nn.InstanceNorm2d(2 * base_dim, affine=True), 133 | torch.nn.GLU(dim=1)) 134 | self.block21 = torch.nn.Sequential(torch.nn.Conv2d(base_dim, 4 * base_dim, 135 | 3, 1, 1), torch.nn.InstanceNorm2d(4 * base_dim, affine=True), 136 | torch.nn.GLU(dim=1)) 137 | self.block22 = torch.nn.Sequential(torch.nn.Conv2d(2 * base_dim, 4 * base_dim, 138 | 3, 1, 1), torch.nn.InstanceNorm2d(4 * base_dim, affine=True), 139 | torch.nn.GLU(dim=1)) 140 | self.block31 = torch.nn.Sequential(torch.nn.Conv2d(2 * base_dim, 8 * base_dim, 141 | 3, 1, 1), torch.nn.InstanceNorm2d(8 * base_dim, affine=True), 142 | torch.nn.GLU(dim=1)) 143 | self.block32 = torch.nn.Sequential(torch.nn.Conv2d(4 * base_dim, 8 * base_dim, 144 | 3, 1, 1), torch.nn.InstanceNorm2d(8 * base_dim, affine=True), 145 | torch.nn.GLU(dim=1)) 146 | self.final_conv = torch.nn.Conv2d(4 * base_dim, out_dim, 1) 147 | 148 | def forward(self, x, mask, time_emb): 149 | y = self.block11(x * mask) 150 | y = self.block12(y * mask) 151 | y += self.mlp1(time_emb).unsqueeze(-1).unsqueeze(-1) 152 | y = self.block21(y * mask) 153 | y = self.block22(y * mask) 154 | y += self.mlp2(time_emb).unsqueeze(-1).unsqueeze(-1) 155 | y = self.block31(y * mask) 156 | y = self.block32(y * mask) 157 | y = self.final_conv(y * mask) 158 | return (y * mask).sum((2, 3)) / (mask.sum((2, 3)) * x.shape[2]) 159 | -------------------------------------------------------------------------------- /video/modules/utilities.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import mimetypes 3 | import os 4 | import platform 5 | import shutil 6 | import ssl 7 | import subprocess 8 | import urllib 9 | from pathlib import Path 10 | from typing import List, Any 11 | from tqdm import tqdm 12 | 13 | import modules.globals 14 | 15 | TEMP_FILE = "temp.mp4" 16 | TEMP_DIRECTORY = "temp" 17 | 18 | # monkey patch ssl for mac 19 | if platform.system().lower() == "darwin": 20 | ssl._create_default_https_context = ssl._create_unverified_context 21 | 22 | 23 | def run_ffmpeg(args: List[str]) -> bool: 24 | commands = [ 25 | "ffmpeg", 26 | "-hide_banner", 27 | "-hwaccel", 28 | "auto", 29 | "-loglevel", 30 | modules.globals.log_level, 31 | ] 32 | commands.extend(args) 33 | try: 34 | subprocess.check_output(commands, stderr=subprocess.STDOUT) 35 | return True 36 | except Exception: 37 | pass 38 | return False 39 | 40 | 41 | def detect_fps(target_path: str) -> float: 42 | command = [ 43 | "ffprobe", 44 | "-v", 45 | "error", 46 | "-select_streams", 47 | "v:0", 48 | "-show_entries", 49 | "stream=r_frame_rate", 50 | "-of", 51 | "default=noprint_wrappers=1:nokey=1", 52 | target_path, 53 | ] 54 | output = subprocess.check_output(command).decode().strip().split("/") 55 | try: 56 | numerator, denominator = map(int, output) 57 | return numerator / denominator 58 | except Exception: 59 | pass 60 | return 30.0 61 | 62 | 63 | def extract_frames(target_path: str) -> None: 64 | temp_directory_path = get_temp_directory_path(target_path) 65 | run_ffmpeg( 66 | [ 67 | "-i", 68 | target_path, 69 | "-pix_fmt", 70 | "rgb24", 71 | os.path.join(temp_directory_path, "%04d.png"), 72 | ] 73 | ) 74 | 75 | 76 | def create_video(target_path: str, fps: float = 30.0) -> None: 77 | temp_output_path = get_temp_output_path(target_path) 78 | temp_directory_path = get_temp_directory_path(target_path) 79 | run_ffmpeg( 80 | [ 81 | "-r", 82 | str(fps), 83 | "-i", 84 | os.path.join(temp_directory_path, "%04d.png"), 85 | "-c:v", 86 | modules.globals.video_encoder, 87 | "-crf", 88 | str(modules.globals.video_quality), 89 | "-pix_fmt", 90 | "yuv420p", 91 | "-vf", 92 | "colorspace=bt709:iall=bt601-6-625:fast=1", 93 | "-y", 94 | temp_output_path, 95 | ] 96 | ) 97 | 98 | 99 | def restore_audio(target_path: str, output_path: str) -> None: 100 | temp_output_path = get_temp_output_path(target_path) 101 | done = run_ffmpeg( 102 | [ 103 | "-i", 104 | temp_output_path, 105 | "-i", 106 | target_path, 107 | "-c:v", 108 | "copy", 109 | "-map", 110 | "0:v:0", 111 | "-map", 112 | "1:a:0", 113 | "-y", 114 | output_path, 115 | ] 116 | ) 117 | if not done: 118 | move_temp(target_path, output_path) 119 | 120 | 121 | def get_temp_frame_paths(target_path: str) -> List[str]: 122 | temp_directory_path = get_temp_directory_path(target_path) 123 | return glob.glob((os.path.join(glob.escape(temp_directory_path), "*.png"))) 124 | 125 | 126 | def get_temp_directory_path(target_path: str) -> str: 127 | target_name, _ = os.path.splitext(os.path.basename(target_path)) 128 | target_directory_path = os.path.dirname(target_path) 129 | return os.path.join(target_directory_path, TEMP_DIRECTORY, target_name) 130 | 131 | 132 | def get_temp_output_path(target_path: str) -> str: 133 | temp_directory_path = get_temp_directory_path(target_path) 134 | return os.path.join(temp_directory_path, TEMP_FILE) 135 | 136 | 137 | def normalize_output_path(source_path: str, target_path: str, output_path: str) -> Any: 138 | if source_path and target_path: 139 | source_name, _ = os.path.splitext(os.path.basename(source_path)) 140 | target_name, target_extension = os.path.splitext(os.path.basename(target_path)) 141 | if os.path.isdir(output_path): 142 | return os.path.join( 143 | output_path, source_name + "-" + target_name + target_extension 144 | ) 145 | return output_path 146 | 147 | 148 | def create_temp(target_path: str) -> None: 149 | temp_directory_path = get_temp_directory_path(target_path) 150 | Path(temp_directory_path).mkdir(parents=True, exist_ok=True) 151 | 152 | 153 | def move_temp(target_path: str, output_path: str) -> None: 154 | temp_output_path = get_temp_output_path(target_path) 155 | if os.path.isfile(temp_output_path): 156 | if os.path.isfile(output_path): 157 | os.remove(output_path) 158 | shutil.move(temp_output_path, output_path) 159 | 160 | 161 | def clean_temp(target_path: str) -> None: 162 | temp_directory_path = get_temp_directory_path(target_path) 163 | parent_directory_path = os.path.dirname(temp_directory_path) 164 | if not modules.globals.keep_frames and os.path.isdir(temp_directory_path): 165 | shutil.rmtree(temp_directory_path) 166 | if os.path.exists(parent_directory_path) and not os.listdir(parent_directory_path): 167 | os.rmdir(parent_directory_path) 168 | 169 | 170 | def has_image_extension(image_path: str) -> bool: 171 | return image_path.lower().endswith(("png", "jpg", "jpeg")) 172 | 173 | 174 | def is_image(image_path: str) -> bool: 175 | if image_path and os.path.isfile(image_path): 176 | mimetype, _ = mimetypes.guess_type(image_path) 177 | return bool(mimetype and mimetype.startswith("image/")) 178 | return False 179 | 180 | 181 | def is_video(video_path: str) -> bool: 182 | if video_path and os.path.isfile(video_path): 183 | mimetype, _ = mimetypes.guess_type(video_path) 184 | return bool(mimetype and mimetype.startswith("video/")) 185 | return False 186 | 187 | 188 | def conditional_download(download_directory_path: str, urls: List[str]) -> None: 189 | if not os.path.exists(download_directory_path): 190 | os.makedirs(download_directory_path) 191 | for url in urls: 192 | download_file_path = os.path.join( 193 | download_directory_path, os.path.basename(url) 194 | ) 195 | if not os.path.exists(download_file_path): 196 | request = urllib.request.urlopen(url) # type: ignore[attr-defined] 197 | total = int(request.headers.get("Content-Length", 0)) 198 | with tqdm( 199 | total=total, 200 | desc="Downloading", 201 | unit="B", 202 | unit_scale=True, 203 | unit_divisor=1024, 204 | ) as progress: 205 | urllib.request.urlretrieve(url, download_file_path, reporthook=lambda count, block_size, total_size: progress.update(block_size)) # type: ignore[attr-defined] 206 | 207 | 208 | def resolve_relative_path(path: str) -> str: 209 | return os.path.abspath(os.path.join(os.path.dirname(__file__), path)) 210 | -------------------------------------------------------------------------------- /video/modules/face_analyser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from typing import Any 4 | import insightface 5 | 6 | import cv2 7 | import numpy as np 8 | import modules.globals 9 | from tqdm import tqdm 10 | from modules.typing import Frame 11 | from modules.cluster_analysis import find_cluster_centroids, find_closest_centroid 12 | from modules.utilities import get_temp_directory_path, create_temp, extract_frames, clean_temp, get_temp_frame_paths 13 | from pathlib import Path 14 | 15 | FACE_ANALYSER = None 16 | 17 | 18 | def get_face_analyser() -> Any: 19 | global FACE_ANALYSER 20 | 21 | if FACE_ANALYSER is None: 22 | FACE_ANALYSER = insightface.app.FaceAnalysis(name='buffalo_l', providers=modules.globals.execution_providers) 23 | FACE_ANALYSER.prepare(ctx_id=0, det_size=(640, 640)) 24 | return FACE_ANALYSER 25 | 26 | 27 | def get_one_face(frame: Frame) -> Any: 28 | face = get_face_analyser().get(frame) 29 | try: 30 | return min(face, key=lambda x: x.bbox[0]) 31 | except ValueError: 32 | return None 33 | 34 | 35 | def get_many_faces(frame: Frame) -> Any: 36 | try: 37 | return get_face_analyser().get(frame) 38 | except IndexError: 39 | return None 40 | 41 | def has_valid_map() -> bool: 42 | for map in modules.globals.souce_target_map: 43 | if "source" in map and "target" in map: 44 | return True 45 | return False 46 | 47 | def default_source_face() -> Any: 48 | for map in modules.globals.souce_target_map: 49 | if "source" in map: 50 | return map['source']['face'] 51 | return None 52 | 53 | def simplify_maps() -> Any: 54 | centroids = [] 55 | faces = [] 56 | for map in modules.globals.souce_target_map: 57 | if "source" in map and "target" in map: 58 | centroids.append(map['target']['face'].normed_embedding) 59 | faces.append(map['source']['face']) 60 | 61 | modules.globals.simple_map = {'source_faces': faces, 'target_embeddings': centroids} 62 | return None 63 | 64 | def add_blank_map() -> Any: 65 | try: 66 | max_id = -1 67 | if len(modules.globals.souce_target_map) > 0: 68 | max_id = max(modules.globals.souce_target_map, key=lambda x: x['id'])['id'] 69 | 70 | modules.globals.souce_target_map.append({ 71 | 'id' : max_id + 1 72 | }) 73 | except ValueError: 74 | return None 75 | 76 | def get_unique_faces_from_target_image() -> Any: 77 | try: 78 | modules.globals.souce_target_map = [] 79 | target_frame = cv2.imread(modules.globals.target_path) 80 | many_faces = get_many_faces(target_frame) 81 | i = 0 82 | 83 | for face in many_faces: 84 | x_min, y_min, x_max, y_max = face['bbox'] 85 | modules.globals.souce_target_map.append({ 86 | 'id' : i, 87 | 'target' : { 88 | 'cv2' : target_frame[int(y_min):int(y_max), int(x_min):int(x_max)], 89 | 'face' : face 90 | } 91 | }) 92 | i = i + 1 93 | except ValueError: 94 | return None 95 | 96 | 97 | def get_unique_faces_from_target_video() -> Any: 98 | try: 99 | modules.globals.souce_target_map = [] 100 | frame_face_embeddings = [] 101 | face_embeddings = [] 102 | 103 | print('Creating temp resources...') 104 | clean_temp(modules.globals.target_path) 105 | create_temp(modules.globals.target_path) 106 | print('Extracting frames...') 107 | extract_frames(modules.globals.target_path) 108 | 109 | temp_frame_paths = get_temp_frame_paths(modules.globals.target_path) 110 | 111 | i = 0 112 | for temp_frame_path in tqdm(temp_frame_paths, desc="Extracting face embeddings from frames"): 113 | temp_frame = cv2.imread(temp_frame_path) 114 | many_faces = get_many_faces(temp_frame) 115 | 116 | for face in many_faces: 117 | face_embeddings.append(face.normed_embedding) 118 | 119 | frame_face_embeddings.append({'frame': i, 'faces': many_faces, 'location': temp_frame_path}) 120 | i += 1 121 | 122 | centroids = find_cluster_centroids(face_embeddings) 123 | 124 | for frame in frame_face_embeddings: 125 | for face in frame['faces']: 126 | closest_centroid_index, _ = find_closest_centroid(centroids, face.normed_embedding) 127 | face['target_centroid'] = closest_centroid_index 128 | 129 | for i in range(len(centroids)): 130 | modules.globals.souce_target_map.append({ 131 | 'id' : i 132 | }) 133 | 134 | temp = [] 135 | for frame in tqdm(frame_face_embeddings, desc=f"Mapping frame embeddings to centroids-{i}"): 136 | temp.append({'frame': frame['frame'], 'faces': [face for face in frame['faces'] if face['target_centroid'] == i], 'location': frame['location']}) 137 | 138 | modules.globals.souce_target_map[i]['target_faces_in_frame'] = temp 139 | 140 | # dump_faces(centroids, frame_face_embeddings) 141 | default_target_face() 142 | except ValueError: 143 | return None 144 | 145 | 146 | def default_target_face(): 147 | for map in modules.globals.souce_target_map: 148 | best_face = None 149 | best_frame = None 150 | for frame in map['target_faces_in_frame']: 151 | if len(frame['faces']) > 0: 152 | best_face = frame['faces'][0] 153 | best_frame = frame 154 | break 155 | 156 | for frame in map['target_faces_in_frame']: 157 | for face in frame['faces']: 158 | if face['det_score'] > best_face['det_score']: 159 | best_face = face 160 | best_frame = frame 161 | 162 | x_min, y_min, x_max, y_max = best_face['bbox'] 163 | 164 | target_frame = cv2.imread(best_frame['location']) 165 | map['target'] = { 166 | 'cv2' : target_frame[int(y_min):int(y_max), int(x_min):int(x_max)], 167 | 'face' : best_face 168 | } 169 | 170 | 171 | def dump_faces(centroids: Any, frame_face_embeddings: list): 172 | temp_directory_path = get_temp_directory_path(modules.globals.target_path) 173 | 174 | for i in range(len(centroids)): 175 | if os.path.exists(temp_directory_path + f"/{i}") and os.path.isdir(temp_directory_path + f"/{i}"): 176 | shutil.rmtree(temp_directory_path + f"/{i}") 177 | Path(temp_directory_path + f"/{i}").mkdir(parents=True, exist_ok=True) 178 | 179 | for frame in tqdm(frame_face_embeddings, desc=f"Copying faces to temp/./{i}"): 180 | temp_frame = cv2.imread(frame['location']) 181 | 182 | j = 0 183 | for face in frame['faces']: 184 | if face['target_centroid'] == i: 185 | x_min, y_min, x_max, y_max = face['bbox'] 186 | 187 | if temp_frame[int(y_min):int(y_max), int(x_min):int(x_max)].size > 0: 188 | cv2.imwrite(temp_directory_path + f"/{i}/{frame['frame']}_{j}.png", temp_frame[int(y_min):int(y_max), int(x_min):int(x_max)]) 189 | j += 1 -------------------------------------------------------------------------------- /audio/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import sys 4 | import argparse 5 | import logging 6 | import json 7 | import subprocess 8 | import numpy as np 9 | import torch 10 | from scipy.io.wavfile import read 11 | MATPLOTLIB_FLAG = False 12 | 13 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) 14 | logger = logging 15 | 16 | from torchaudio.transforms import MelSpectrogram 17 | 18 | class MelSpectrogramFixed(torch.nn.Module): 19 | def __init__(self, **kwargs): 20 | super(MelSpectrogramFixed, self).__init__() 21 | self.torchaudio_backend = MelSpectrogram(**kwargs) 22 | 23 | def forward(self, x): 24 | outputs = torch.log(self.torchaudio_backend(x) + 0.001) 25 | return outputs[..., :-1] 26 | 27 | def load_checkpoint(checkpoint_path, model, optimizer=None): 28 | assert os.path.isfile(checkpoint_path) 29 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') 30 | iteration = checkpoint_dict['iteration'] 31 | learning_rate = checkpoint_dict['learning_rate'] 32 | if optimizer is not None: 33 | optimizer.load_state_dict(checkpoint_dict['optimizer']) 34 | saved_state_dict = checkpoint_dict['model'] 35 | if hasattr(model, 'module'): 36 | state_dict = model.module.state_dict() 37 | else: 38 | state_dict = model.state_dict() 39 | new_state_dict = {} 40 | for k, v in state_dict.items(): 41 | try: 42 | new_state_dict[k] = saved_state_dict[k] 43 | except: 44 | logger.info("%s is not in the checkpoint" % k) 45 | new_state_dict[k] = v 46 | if hasattr(model, 'module'): 47 | model.module.load_state_dict(new_state_dict) 48 | else: 49 | model.load_state_dict(new_state_dict) 50 | logger.info("Loaded checkpoint '{}' (iteration {})".format( 51 | checkpoint_path, iteration)) 52 | return model, optimizer, learning_rate, iteration 53 | 54 | 55 | def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): 56 | logger.info("Saving model and optimizer state at iteration {} to {}".format( 57 | iteration, checkpoint_path)) 58 | if hasattr(model, 'module'): 59 | state_dict = model.module.state_dict() 60 | else: 61 | state_dict = model.state_dict() 62 | torch.save({'model': state_dict, 63 | 'iteration': iteration, 64 | 'optimizer': optimizer.state_dict(), 65 | 'learning_rate': learning_rate}, checkpoint_path) 66 | 67 | 68 | def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050): 69 | for k, v in scalars.items(): 70 | writer.add_scalar(k, v, global_step) 71 | for k, v in histograms.items(): 72 | writer.add_histogram(k, v, global_step) 73 | for k, v in images.items(): 74 | writer.add_image(k, v, global_step, dataformats='HWC') 75 | for k, v in audios.items(): 76 | writer.add_audio(k, v, global_step, audio_sampling_rate) 77 | 78 | 79 | def latest_checkpoint_path(dir_path, regex="G_*.pth"): 80 | f_list = glob.glob(os.path.join(dir_path, regex)) 81 | f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) 82 | x = f_list[-1] 83 | print(x) 84 | return x 85 | 86 | 87 | def load_wav_to_torch(full_path): 88 | sampling_rate, data = read(full_path) 89 | return torch.FloatTensor(data.astype(np.float32)), sampling_rate 90 | 91 | 92 | def load_filepaths_and_text(filename, split="|"): 93 | with open(filename, encoding='utf-8') as f: 94 | filepaths_and_text = [line.strip().split(split) for line in f] 95 | return filepaths_and_text 96 | 97 | def get_hparams(init=True): 98 | parser = argparse.ArgumentParser() 99 | parser.add_argument('-c', '--config', type=str, required=True, 100 | help='JSON file for configuration') 101 | parser.add_argument('-m', '--model', type=str, required=True, 102 | help='Model name') 103 | 104 | args = parser.parse_args() 105 | model_dir = os.path.join("/workspace/raid/ha0/logs_diffhier", args.model) 106 | 107 | if not os.path.exists(model_dir): 108 | os.makedirs(model_dir) 109 | 110 | config_path = args.config 111 | config_save_path = os.path.join(model_dir, "config.json") 112 | if init: 113 | with open(config_path, "r") as f: 114 | data = f.read() 115 | with open(config_save_path, "w") as f: 116 | f.write(data) 117 | else: 118 | with open(config_save_path, "r") as f: 119 | data = f.read() 120 | config = json.loads(data) 121 | 122 | hparams = HParams(**config) 123 | hparams.model_dir = model_dir 124 | return hparams 125 | 126 | def get_hparams_from_dir(model_dir): 127 | config_save_path = os.path.join(model_dir, "config.json") 128 | with open(config_save_path, "r") as f: 129 | data = f.read() 130 | config = json.loads(data) 131 | 132 | hparams = HParams(**config) 133 | hparams.model_dir = model_dir 134 | return hparams 135 | 136 | def get_hparams_from_file(config_path): 137 | with open(config_path, "r") as f: 138 | data = f.read() 139 | config = json.loads(data) 140 | 141 | hparams = HParams(**config) 142 | return hparams 143 | 144 | def check_git_hash(model_dir): 145 | source_dir = os.path.dirname(os.path.realpath(__file__)) 146 | if not os.path.exists(os.path.join(source_dir, ".git")): 147 | logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format( 148 | source_dir 149 | )) 150 | return 151 | 152 | cur_hash = subprocess.getoutput("git rev-parse HEAD") 153 | 154 | path = os.path.join(model_dir, "githash") 155 | if os.path.exists(path): 156 | saved_hash = open(path).read() 157 | if saved_hash != cur_hash: 158 | logger.warn("git hash values are different. {}(saved) != {}(current)".format( 159 | saved_hash[:8], cur_hash[:8])) 160 | else: 161 | open(path, "w").write(cur_hash) 162 | 163 | 164 | def get_logger(model_dir, filename="train.log"): 165 | global logger 166 | logger = logging.getLogger(os.path.basename(model_dir)) 167 | logger.setLevel(logging.DEBUG) 168 | 169 | formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") 170 | if not os.path.exists(model_dir): 171 | os.makedirs(model_dir) 172 | h = logging.FileHandler(os.path.join(model_dir, filename)) 173 | h.setLevel(logging.DEBUG) 174 | h.setFormatter(formatter) 175 | logger.addHandler(h) 176 | return logger 177 | 178 | def parse_filelist(filelist_path): 179 | with open(filelist_path, 'r') as f: 180 | filelist = [line.strip() for line in f.readlines()] 181 | return filelist 182 | 183 | 184 | def parse_filelist_and_spk_id(filelist_path, split="|"): 185 | with open(filelist_path, encoding='utf-8') as f: 186 | filepaths_and_spkid = [line.strip().split(split) for line in f] 187 | return filepaths_and_spkid 188 | 189 | 190 | class HParams(): 191 | def __init__(self, **kwargs): 192 | for k, v in kwargs.items(): 193 | if type(v) == dict: 194 | v = HParams(**v) 195 | self[k] = v 196 | 197 | def keys(self): 198 | return self.__dict__.keys() 199 | 200 | def items(self): 201 | return self.__dict__.items() 202 | 203 | def values(self): 204 | return self.__dict__.values() 205 | 206 | def __len__(self): 207 | return len(self.__dict__) 208 | 209 | def __getitem__(self, key): 210 | return getattr(self, key) 211 | 212 | def __setitem__(self, key, value): 213 | return setattr(self, key, value) 214 | 215 | def __contains__(self, key): 216 | return key in self.__dict__ 217 | 218 | def __repr__(self): 219 | return self.__dict__.__repr__() 220 | -------------------------------------------------------------------------------- /audio/module/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import sys 4 | import argparse 5 | import logging 6 | import json 7 | import subprocess 8 | import numpy as np 9 | from scipy.io.wavfile import read 10 | import torch 11 | 12 | MATPLOTLIB_FLAG = False 13 | 14 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) 15 | logger = logging 16 | 17 | 18 | def load_checkpoint(checkpoint_path, model, optimizer=None): 19 | assert os.path.isfile(checkpoint_path) 20 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') 21 | iteration = checkpoint_dict['iteration'] 22 | learning_rate = checkpoint_dict['learning_rate'] 23 | if optimizer is not None: 24 | optimizer.load_state_dict(checkpoint_dict['optimizer']) 25 | saved_state_dict = checkpoint_dict['model'] 26 | if hasattr(model, 'module'): 27 | state_dict = model.module.state_dict() 28 | else: 29 | state_dict = model.state_dict() 30 | new_state_dict = {} 31 | for k, v in state_dict.items(): 32 | try: 33 | new_state_dict[k] = saved_state_dict[k] 34 | except: 35 | logger.info("%s is not in the checkpoint" % k) 36 | new_state_dict[k] = v 37 | if hasattr(model, 'module'): 38 | model.module.load_state_dict(new_state_dict) 39 | else: 40 | model.load_state_dict(new_state_dict) 41 | logger.info("Loaded checkpoint '{}' (iteration {})".format( 42 | checkpoint_path, iteration)) 43 | return model, optimizer, learning_rate, iteration 44 | 45 | 46 | def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): 47 | logger.info("Saving model and optimizer state at iteration {} to {}".format( 48 | iteration, checkpoint_path)) 49 | if hasattr(model, 'module'): 50 | state_dict = model.module.state_dict() 51 | else: 52 | state_dict = model.state_dict() 53 | torch.save({'model': state_dict, 54 | 'iteration': iteration, 55 | 'optimizer': optimizer.state_dict(), 56 | 'learning_rate': learning_rate}, checkpoint_path) 57 | 58 | 59 | def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050): 60 | for k, v in scalars.items(): 61 | writer.add_scalar(k, v, global_step) 62 | for k, v in histograms.items(): 63 | writer.add_histogram(k, v, global_step) 64 | for k, v in images.items(): 65 | writer.add_image(k, v, global_step, dataformats='HWC') 66 | for k, v in audios.items(): 67 | writer.add_audio(k, v, global_step, audio_sampling_rate) 68 | 69 | 70 | def latest_checkpoint_path(dir_path, regex="G_*.pth"): 71 | f_list = glob.glob(os.path.join(dir_path, regex)) 72 | f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) 73 | x = f_list[-1] 74 | print(x) 75 | return x 76 | 77 | 78 | def load_wav_to_torch(full_path): 79 | sampling_rate, data = read(full_path) 80 | return torch.FloatTensor(data.astype(np.float32)), sampling_rate 81 | 82 | 83 | def load_filepaths_and_text(filename, split="|"): 84 | with open(filename, encoding='utf-8') as f: 85 | filepaths_and_text = [line.strip().split(split) for line in f] 86 | return filepaths_and_text 87 | 88 | 89 | def get_hparams(init=True): 90 | parser = argparse.ArgumentParser() 91 | parser.add_argument('-c', '--config', type=str, required=True, 92 | help='JSON file for configuration') 93 | parser.add_argument('-m', '--model', type=str, required=True, 94 | help='Model name') 95 | 96 | args = parser.parse_args() 97 | model_dir = os.path.join("/workspace/raid/data/ha0/logs_rfhiervc", args.model) 98 | 99 | if not os.path.exists(model_dir): 100 | os.makedirs(model_dir) 101 | 102 | config_path = args.config 103 | config_save_path = os.path.join(model_dir, "config.json") 104 | if init: 105 | with open(config_path, "r") as f: 106 | data = f.read() 107 | with open(config_save_path, "w") as f: 108 | f.write(data) 109 | else: 110 | with open(config_save_path, "r") as f: 111 | data = f.read() 112 | config = json.loads(data) 113 | 114 | hparams = HParams(**config) 115 | hparams.model_dir = model_dir 116 | return hparams 117 | 118 | 119 | def get_hparams_from_dir(model_dir): 120 | config_save_path = os.path.join(model_dir, "config.json") 121 | with open(config_save_path, "r") as f: 122 | data = f.read() 123 | config = json.loads(data) 124 | 125 | hparams = HParams(**config) 126 | hparams.model_dir = model_dir 127 | return hparams 128 | 129 | 130 | def get_hparams_from_file(config_path): 131 | with open(config_path, "r") as f: 132 | data = f.read() 133 | config = json.loads(data) 134 | 135 | hparams = HParams(**config) 136 | return hparams 137 | 138 | 139 | def check_git_hash(model_dir): 140 | source_dir = os.path.dirname(os.path.realpath(__file__)) 141 | if not os.path.exists(os.path.join(source_dir, ".git")): 142 | logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format( 143 | source_dir 144 | )) 145 | return 146 | 147 | cur_hash = subprocess.getoutput("git rev-parse HEAD") 148 | 149 | path = os.path.join(model_dir, "githash") 150 | if os.path.exists(path): 151 | saved_hash = open(path).read() 152 | if saved_hash != cur_hash: 153 | logger.warn("git hash values are different. {}(saved) != {}(current)".format( 154 | saved_hash[:8], cur_hash[:8])) 155 | else: 156 | open(path, "w").write(cur_hash) 157 | 158 | 159 | def get_logger(model_dir, filename="train.log"): 160 | global logger 161 | logger = logging.getLogger(os.path.basename(model_dir)) 162 | logger.setLevel(logging.DEBUG) 163 | 164 | formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") 165 | if not os.path.exists(model_dir): 166 | os.makedirs(model_dir) 167 | h = logging.FileHandler(os.path.join(model_dir, filename)) 168 | h.setLevel(logging.DEBUG) 169 | h.setFormatter(formatter) 170 | logger.addHandler(h) 171 | return logger 172 | 173 | 174 | def parse_filelist(filelist_path): 175 | with open(filelist_path, 'r') as f: 176 | filelist = [line.strip() for line in f.readlines()] 177 | return filelist 178 | 179 | 180 | def parse_filelist_and_spk_id(filelist_path, split="|"): 181 | with open(filelist_path, encoding='utf-8') as f: 182 | filepaths_and_spkid = [line.strip().split(split) for line in f] 183 | return filepaths_and_spkid 184 | 185 | 186 | class HParams(): 187 | def __init__(self, **kwargs): 188 | for k, v in kwargs.items(): 189 | if type(v) == dict: 190 | v = HParams(**v) 191 | self[k] = v 192 | 193 | def keys(self): 194 | return self.__dict__.keys() 195 | 196 | def items(self): 197 | return self.__dict__.items() 198 | 199 | def values(self): 200 | return self.__dict__.values() 201 | 202 | def __len__(self): 203 | return len(self.__dict__) 204 | 205 | def __getitem__(self, key): 206 | return getattr(self, key) 207 | 208 | def __setitem__(self, key, value): 209 | return setattr(self, key, value) 210 | 211 | def __contains__(self, key): 212 | return key in self.__dict__ 213 | 214 | def __repr__(self): 215 | return self.__dict__.__repr__() 216 | 217 | 218 | def sequence_mask(length, max_length=None): 219 | if max_length is None: 220 | max_length = length.max() 221 | x = torch.arange(int(max_length), dtype=length.dtype, device=length.device) 222 | return x.unsqueeze(0) < length.unsqueeze(1) 223 | 224 | def convert_pad_shape(pad_shape): 225 | l = pad_shape[::-1] 226 | pad_shape = [item for sublist in l for item in sublist] 227 | return pad_shape 228 | 229 | def fix_len_compatibility(length, num_downsamplings_in_unet=2): 230 | while True: 231 | if length % (2**num_downsamplings_in_unet) == 0: 232 | return length 233 | length += 1 234 | 235 | -------------------------------------------------------------------------------- /audio/model/diffusion_f0.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from model.base import BaseModule 6 | from model.diffusion_module import * 7 | from math import sqrt 8 | 9 | Linear = nn.Linear 10 | ConvTranspose2d = nn.ConvTranspose2d 11 | 12 | def Conv1d(*args, **kwargs): 13 | layer = nn.Conv1d(*args, **kwargs) 14 | nn.init.kaiming_normal_(layer.weight) 15 | return layer 16 | 17 | class ResidualBlock(nn.Module): 18 | def __init__(self, n_mels, residual_channels, dilation, dim_base): 19 | super().__init__() 20 | self.dilated_conv = Conv1d(residual_channels, 2 * residual_channels, 3, padding=dilation, dilation=dilation) 21 | self.diffusion_projection = Linear(dim_base, residual_channels) 22 | self.conditioner_projection = Conv1d(n_mels, 2 * residual_channels, 1) 23 | self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1) 24 | 25 | def forward(self, x, diffusion_step, conditioner, x_mask): 26 | diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1) 27 | y = x + diffusion_step 28 | 29 | conditioner = self.conditioner_projection(conditioner) 30 | y = self.dilated_conv(y*x_mask) + conditioner 31 | 32 | gate, filter = torch.chunk(y, 2, dim=1) 33 | y = torch.sigmoid(gate) * torch.tanh(filter) 34 | 35 | y = self.output_projection(y*x_mask) 36 | residual, skip = torch.chunk(y, 2, dim=1) 37 | return (x + residual) / sqrt(2.0), skip 38 | 39 | 40 | class GradLogPEstimator(BaseModule): 41 | def __init__(self, dim_base, dim_cond, res_layer=30, res_ch=64, dilation_cycle=10): 42 | super(GradLogPEstimator, self).__init__() 43 | 44 | self.time_pos_emb = SinusoidalPosEmb(dim_base) 45 | self.mlp = torch.nn.Sequential(torch.nn.Linear(dim_base, dim_base * 4), 46 | Mish(), 47 | torch.nn.Linear(dim_base * 4, dim_base), 48 | Mish()) 49 | 50 | cond_total = dim_base + 256 + 128 51 | self.cond_block = torch.nn.Sequential(Conv1d(cond_total, 4 * dim_cond, 1), 52 | Mish(), 53 | Conv1d(4 * dim_cond, dim_cond, 1), 54 | Mish()) 55 | 56 | self.input_projection = torch.nn.Sequential(Conv1d(1, res_ch, 1), Mish()) 57 | self.residual_layers = nn.ModuleList([ 58 | ResidualBlock(dim_cond, res_ch, 2 ** (i % dilation_cycle), dim_base) 59 | for i in range(res_layer) 60 | ]) 61 | self.skip_projection = torch.nn.Sequential(Conv1d(res_ch, res_ch, 1), Mish()) 62 | self.output_projection = Conv1d(res_ch, 1, 1) 63 | nn.init.zeros_(self.output_projection.weight) 64 | 65 | def forward(self, x, x_mask, f0, spk, t): 66 | condition = self.time_pos_emb(t) 67 | t = self.mlp(condition) 68 | x = self.input_projection(x) * x_mask 69 | 70 | condition = torch.cat([f0, condition.unsqueeze(-1).expand(-1, -1, f0.size(2)), spk.expand(-1, -1, f0.size(2))], 1) 71 | condition = self.cond_block(condition)*x_mask 72 | 73 | skip = None 74 | for layer in self.residual_layers: 75 | x, skip_connection = layer(x, t, condition, x_mask) 76 | skip = skip_connection * x_mask if skip is None else (skip_connection + skip) * x_mask 77 | 78 | x = skip / sqrt(len(self.residual_layers)) 79 | x = self.skip_projection(x) * x_mask 80 | x = self.output_projection(x) * x_mask 81 | 82 | return x 83 | 84 | @torch.no_grad() 85 | def infer(self, x, x_mask, f0, spk, t): 86 | condition = self.time_pos_emb(t) 87 | t = self.mlp(condition) 88 | x = self.input_projection(x) * x_mask 89 | 90 | condition = torch.cat([f0, condition.unsqueeze(-1).expand(-1, -1, f0.size(2)), spk.expand(-1, -1, f0.size(2))], 1) 91 | condition = self.cond_block(condition)*x_mask 92 | 93 | skip = None 94 | for layer in self.residual_layers: 95 | x, skip_connection = layer(x, t, condition, x_mask) 96 | skip = skip_connection * x_mask if skip is None else (skip_connection + skip) * x_mask 97 | 98 | x = skip / sqrt(len(self.residual_layers)) 99 | x = self.skip_projection(x) * x_mask 100 | x = self.output_projection(x) * x_mask 101 | 102 | return x 103 | 104 | class Diffusion(BaseModule): 105 | def __init__(self, n_feats, dim, dim_spk, beta_min, beta_max): 106 | super(Diffusion, self).__init__() 107 | self.estimator_f0 = GradLogPEstimator(dim, dim_spk) 108 | 109 | self.n_feats = n_feats 110 | self.dim_unet = dim 111 | self.dim_spk = dim_spk 112 | self.beta_min = beta_min 113 | self.beta_max = beta_max 114 | 115 | def get_beta(self, t): 116 | beta = self.beta_min + (self.beta_max - self.beta_min) * t 117 | return beta 118 | 119 | def get_gamma(self, s, t, p=1.0, use_torch=False): 120 | beta_integral = self.beta_min + 0.5 * (self.beta_max - self.beta_min) * (t + s) 121 | beta_integral *= (t - s) 122 | if use_torch: 123 | gamma = torch.exp(-0.5 * p * beta_integral).unsqueeze(-1).unsqueeze(-1) 124 | else: 125 | gamma = math.exp(-0.5 * p * beta_integral) 126 | return gamma 127 | 128 | def get_mu(self, s, t): 129 | a = self.get_gamma(s, t) 130 | b = 1.0 - self.get_gamma(0, s, p=2.0) 131 | c = 1.0 - self.get_gamma(0, t, p=2.0) 132 | return a * b / c 133 | 134 | def get_nu(self, s, t): 135 | a = self.get_gamma(0, s) 136 | b = 1.0 - self.get_gamma(s, t, p=2.0) 137 | c = 1.0 - self.get_gamma(0, t, p=2.0) 138 | return a * b / c 139 | 140 | def get_sigma(self, s, t): 141 | a = 1.0 - self.get_gamma(0, s, p=2.0) 142 | b = 1.0 - self.get_gamma(s, t, p=2.0) 143 | c = 1.0 - self.get_gamma(0, t, p=2.0) 144 | return math.sqrt(a * b / c) 145 | 146 | def compute_diffused_z_pr(self, x0, mask, z_pr, t, use_torch=False): 147 | x0_weight = self.get_gamma(0, t, use_torch=use_torch) 148 | z_pr_weight = 1.0 - x0_weight 149 | xt_z_pr = x0 * x0_weight + z_pr * z_pr_weight 150 | return xt_z_pr * mask 151 | 152 | def forward_diffusion(self, x0, mask, src_out, t): 153 | xt_src = self.compute_diffused_z_pr(x0, mask, src_out, t, use_torch=True) 154 | variance = 1.0 - self.get_gamma(0, t, p=2.0, use_torch=True) 155 | z = torch.randn(x0.shape, dtype=x0.dtype, device=x0.device, requires_grad=False) 156 | xt_src = xt_src + z * torch.sqrt(variance) 157 | 158 | return xt_src * mask, z * mask 159 | 160 | @torch.no_grad() 161 | def reverse(self, z, mask, y_hat, z_f0, spk, ts): 162 | h = 1.0 / ts 163 | xt = z * mask 164 | 165 | for i in range(ts): 166 | t = 1.0 - i * h 167 | time = t * torch.ones(z.shape[0], dtype=z.dtype, device=z.device) 168 | beta_t = self.get_beta(t) 169 | 170 | kappa = self.get_gamma(0, t - h) * (1.0 - self.get_gamma(t - h, t, p=2.0)) 171 | kappa /= (self.get_gamma(0, t) * beta_t * h) 172 | kappa -= 1.0 173 | omega = self.get_nu(t - h, t) / self.get_gamma(0, t) 174 | omega += self.get_mu(t - h, t) 175 | omega -= (0.5 * beta_t * h + 1.0) 176 | sigma = self.get_sigma(t - h, t) 177 | 178 | dxt = (y_hat - xt) * (0.5 * beta_t * h + omega) 179 | dxt -= (self.estimator_f0.infer(xt, mask, z_f0, spk, time)) * (1.0 + kappa) * (beta_t * h) 180 | dxt += torch.randn_like(z, device=z.device) * sigma 181 | xt = (xt - dxt) * mask 182 | 183 | return xt 184 | 185 | 186 | def compute_loss(self, x0, mask, x0_hat, spk, f0, t): 187 | xt, z = self.forward_diffusion(x0, mask, x0_hat, t) 188 | z_estimation = self.estimator_f0(xt, mask, f0, spk, t) 189 | z_estimation *= torch.sqrt(1.0 - self.get_gamma(0, t, p=2.0, use_torch=True)) 190 | loss = torch.sum((z_estimation + z) ** 2) / (torch.sum(mask)) 191 | 192 | return loss 193 | 194 | def compute_t(self, x0, mask, x0_hat, f0, spk, offset=1e-5): 195 | b = x0.shape[0] 196 | t = torch.rand(b, dtype=x0.dtype, device=x0.device, requires_grad=False) 197 | t = torch.clamp(t, offset, 1.0 - offset) 198 | 199 | return self.compute_loss(x0, mask, x0_hat, spk, f0, t) 200 | -------------------------------------------------------------------------------- /audio/model/diffhiervc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from model.base import BaseModule 6 | from model.diffusion_mel import Diffusion as Mel_Diffusion 7 | from model.diffusion_f0 import Diffusion as F0_Diffusion 8 | from model.styleencoder import StyleEncoder 9 | 10 | import copy 11 | import transformers 12 | import typing as tp 13 | 14 | from module.modules import * 15 | from module.utils import * 16 | 17 | 18 | class Wav2vec2(torch.nn.Module): 19 | def __init__(self, layer=12): 20 | super().__init__() 21 | self.wav2vec2 = transformers.Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-xls-r-300m") 22 | for param in self.wav2vec2.parameters(): 23 | param.requires_grad = False 24 | param.grad = None 25 | self.wav2vec2.eval() 26 | self.feature_layer = layer 27 | 28 | @torch.no_grad() 29 | def forward(self, x): 30 | outputs = self.wav2vec2(x.squeeze(1), output_hidden_states=True) 31 | y = outputs.hidden_states[self.feature_layer] 32 | 33 | return y.permute((0, 2, 1)) 34 | 35 | class Encoder(nn.Module): 36 | def __init__(self, 37 | in_channels, 38 | hidden_channels, 39 | kernel_size, 40 | dilation_rate, 41 | n_layers, 42 | mel_size=80, 43 | gin_channels=0, 44 | p_dropout=0): 45 | super().__init__() 46 | self.in_channels = in_channels 47 | self.hidden_channels = hidden_channels 48 | self.kernel_size = kernel_size 49 | self.dilation_rate = dilation_rate 50 | self.n_layers = n_layers 51 | self.gin_channels = gin_channels 52 | self.p_dropout = p_dropout 53 | 54 | self.pre = nn.Conv1d(in_channels, hidden_channels, 1) 55 | self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, p_dropout=p_dropout) 56 | self.proj = nn.Conv1d(hidden_channels, mel_size, 1) 57 | 58 | def forward(self, x, x_mask, g=None): 59 | x = self.pre(x * x_mask) * x_mask 60 | x = self.enc(x, x_mask, g=g) 61 | x = self.proj(x) * x_mask 62 | 63 | return x 64 | 65 | 66 | class SynthesizerTrn(nn.Module): 67 | def __init__(self, hidden_size): 68 | super().__init__() 69 | self.emb_c = nn.Conv1d(1024, hidden_size, 1) 70 | self.emb_c_f0 = nn.Conv1d(1024, hidden_size, 1) 71 | self.emb_f0 = nn.Conv1d(1, hidden_size, kernel_size=9, stride=4, padding=4) 72 | self.emb_norm_f0 = nn.Conv1d(1, hidden_size, 1) 73 | self.emb_g = StyleEncoder(in_dim=80, hidden_dim=256, out_dim=256) 74 | 75 | self.mel_enc_c = Encoder(hidden_size, hidden_size, 5, 1, 8, 80, gin_channels=256, p_dropout=0) 76 | self.mel_enc_f = Encoder(hidden_size, hidden_size, 5, 1, 8, 80, gin_channels=256, p_dropout=0) 77 | self.f0_enc = Encoder(hidden_size, hidden_size, 5, 1, 8, 128, gin_channels=256, p_dropout=0) 78 | self.proj = nn.Conv1d(hidden_size, 1, 1) 79 | 80 | def forward(self, x_mel, w2v, norm_f0, f0, x_mask, f0_mask): 81 | content = self.emb_c(w2v) 82 | content_f = self.emb_c_f0(w2v) 83 | f0 = self.emb_f0(f0) 84 | norm_f0 = self.emb_norm_f0(norm_f0) 85 | 86 | g = self.emb_g(x_mel, x_mask).unsqueeze(-1) 87 | y_cont = self.mel_enc_c(F.relu(content), x_mask, g=g) 88 | y_f0 = self.mel_enc_f(F.relu(f0), x_mask, g=g) 89 | y_mel = y_cont + y_f0 90 | 91 | content_f = F.interpolate(content_f, norm_f0.shape[-1]) 92 | enc_f0 = self.f0_enc(F.relu(content_f+norm_f0), f0_mask, g=g) 93 | y_f0_hat = self.proj(enc_f0) 94 | 95 | return g, y_mel, enc_f0, y_f0_hat 96 | 97 | def spk_embedding(self, mel, length): 98 | x_mask = torch.unsqueeze(commons.sequence_mask(length, mel.size(-1)), 1).to(mel.dtype) 99 | 100 | return self.emb_g(mel, x_mask).unsqueeze(-1) 101 | 102 | def mel_predictor(self, w2v, x_mask, spk, pred_f0): 103 | content = self.emb_c(w2v) 104 | pred_f0 = self.emb_f0(pred_f0) 105 | 106 | y_cont = self.mel_enc_c(F.relu(content), x_mask, g=spk) 107 | y_f0 = self.mel_enc_f(F.relu(pred_f0), x_mask, g=spk) 108 | y_mel = y_cont + y_f0 109 | 110 | return y_mel 111 | 112 | def f0_predictor(self, w2v, x_f0_norm, y_mel, y_mask, f0_mask): 113 | content_f = self.emb_c_f0(w2v) 114 | norm_f0 = self.emb_norm_f0(x_f0_norm) 115 | g = self.emb_g(y_mel, y_mask).unsqueeze(-1) 116 | content_f = F.interpolate(content_f, norm_f0.shape[-1]) 117 | 118 | enc_f0 = self.f0_enc(F.relu(content_f+norm_f0), f0_mask, g=g) 119 | y_f0_hat = self.proj(enc_f0) 120 | 121 | return g, y_f0_hat, enc_f0 122 | 123 | 124 | class DiffHierVC(BaseModule): 125 | def __init__(self, n_feats, spk_dim, dec_dim, beta_min, beta_max, hps): 126 | super(DiffHierVC, self).__init__() 127 | self.n_feats = n_feats 128 | self.spk_dim = spk_dim 129 | self.dec_dim = dec_dim 130 | self.beta_min = beta_min 131 | self.beta_max = beta_max 132 | 133 | self.encoder = SynthesizerTrn(hps.model.hidden_size) 134 | self.f0_dec = F0_Diffusion(n_feats, 64, spk_dim, beta_min, beta_max) 135 | self.mel_dec = Mel_Diffusion(n_feats, dec_dim, spk_dim, beta_min, beta_max) 136 | 137 | @torch.no_grad() 138 | def forward(self, x, w2v, norm_y_f0, f0_x, x_length, n_timesteps, mode='ml'): 139 | x_mask = sequence_mask(x_length, x.size(2)).unsqueeze(1).to(x.dtype) 140 | f0_mask = sequence_mask(x_length*4, x.size(2)*4).unsqueeze(1).to(x.dtype) 141 | 142 | max_length = int(x_length.max()) 143 | spk, y_mel, h_f0, y_f0_hat = self.encoder(x, w2v, norm_y_f0, f0_x, x_mask, f0_mask) 144 | f0_mean_x = self.f0_dec.compute_diffused_z_pr(f0_x, f0_mask, y_f0_hat, 1.0) 145 | 146 | z_f0 = f0_mean_x * f0_mask 147 | z_f0 += torch.randn_like(z_f0, device=z_f0.device) 148 | o_f0 = self.f0_dec.reverse(z_f0, f0_mask, y_f0_hat*f0_mask, h_f0*f0_mask, spk, n_timesteps) 149 | 150 | z_mel = self.mel_dec.compute_diffused_z_pr(x, x_mask, y_mel, 1.0) 151 | z_mel += torch.randn_like(z_mel, device=z_mel.device) 152 | 153 | o_mel = self.mel_dec.reverse(z_mel, x_mask, y_mel, spk, n_timesteps) 154 | 155 | return y_f0_hat, y_mel, o_f0, o_mel[:, :, :max_length] 156 | 157 | def infer_vc(self, x, x_w2v, x_f0_norm, x_f0, x_length, y, y_length, diffpitch_ts, diffvoice_ts): 158 | x_mask = sequence_mask(x_length, x.size(2)).unsqueeze(1).to(x.dtype) 159 | y_mask = sequence_mask(y_length, y.size(2)).unsqueeze(1).to(y.dtype) 160 | f0_mask = sequence_mask(x_length*4, x.size(2)*4).unsqueeze(1).to(x.dtype) 161 | 162 | spk, y_f0_hat, enc_f0 = self.encoder.f0_predictor(x_w2v, x_f0_norm, y, y_mask, f0_mask) 163 | 164 | # Diff-Pitch 165 | z_f0 = self.f0_dec.compute_diffused_z_pr(x_f0, f0_mask, y_f0_hat, 1.0) 166 | z_f0 += torch.randn_like(z_f0, device=z_f0.device) 167 | pred_f0 = self.f0_dec.reverse(z_f0, f0_mask, y_f0_hat*f0_mask, enc_f0*f0_mask, spk, ts=diffpitch_ts) 168 | f0_zeros_mask = (x_f0 == 0) 169 | pred_f0[f0_zeros_mask.expand_as(pred_f0)] = 0 170 | 171 | # Diff-Voice 172 | y_mel = self.encoder.mel_predictor(x_w2v, x_mask, spk, pred_f0) 173 | z_mel = self.mel_dec.compute_diffused_z_pr(x, x_mask, y_mel, 1.0) 174 | z_mel += torch.randn_like(z_mel, device=z_mel.device) 175 | o_mel = self.mel_dec.reverse(z_mel, x_mask, y_mel, spk, ts=diffvoice_ts) 176 | 177 | return o_mel[:, :, :x_length] 178 | 179 | 180 | def compute_loss(self, x, w2v_x, norm_f0_x, f0_x, x_length): 181 | x_mask = sequence_mask(x_length, x.size(2)).unsqueeze(1).to(x.dtype) 182 | f0_mask = sequence_mask(x_length*4, x.size(2)*4).unsqueeze(1).to(x.dtype) 183 | 184 | spk, y_mel, y_f0, y_f0_hat = self.encoder(x, w2v_x, norm_f0_x, f0_x, x_mask, f0_mask) 185 | 186 | f0_loss = torch.sum(torch.abs(f0_x - y_f0_hat)*f0_mask) / (torch.sum(f0_mask)) 187 | mel_loss = torch.sum(torch.abs(x - y_mel)*x_mask) / (torch.sum(x_mask) * self.n_feats) 188 | 189 | f0_diff_loss = self.f0_dec.compute_t(f0_x, f0_mask, y_f0_hat, y_f0, spk) 190 | mel_diff_loss, mel_recon_loss = self.mel_dec.compute_t(x, x_mask, y_mel, spk) 191 | 192 | return mel_diff_loss, mel_recon_loss, f0_diff_loss, mel_loss, f0_loss 193 | 194 | 195 | -------------------------------------------------------------------------------- /audio/module/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | import numpy as np 4 | 5 | DEFAULT_MIN_BIN_WIDTH = 1e-3 6 | DEFAULT_MIN_BIN_HEIGHT = 1e-3 7 | DEFAULT_MIN_DERIVATIVE = 1e-3 8 | 9 | 10 | def piecewise_rational_quadratic_transform(inputs, 11 | unnormalized_widths, 12 | unnormalized_heights, 13 | unnormalized_derivatives, 14 | inverse=False, 15 | tails=None, 16 | tail_bound=1., 17 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 18 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 19 | min_derivative=DEFAULT_MIN_DERIVATIVE): 20 | 21 | if tails is None: 22 | spline_fn = rational_quadratic_spline 23 | spline_kwargs = {} 24 | else: 25 | spline_fn = unconstrained_rational_quadratic_spline 26 | spline_kwargs = { 27 | 'tails': tails, 28 | 'tail_bound': tail_bound 29 | } 30 | 31 | outputs, logabsdet = spline_fn( 32 | inputs=inputs, 33 | unnormalized_widths=unnormalized_widths, 34 | unnormalized_heights=unnormalized_heights, 35 | unnormalized_derivatives=unnormalized_derivatives, 36 | inverse=inverse, 37 | min_bin_width=min_bin_width, 38 | min_bin_height=min_bin_height, 39 | min_derivative=min_derivative, 40 | **spline_kwargs 41 | ) 42 | return outputs, logabsdet 43 | 44 | 45 | def searchsorted(bin_locations, inputs, eps=1e-6): 46 | bin_locations[..., -1] += eps 47 | return torch.sum( 48 | inputs[..., None] >= bin_locations, 49 | dim=-1 50 | ) - 1 51 | 52 | 53 | def unconstrained_rational_quadratic_spline(inputs, 54 | unnormalized_widths, 55 | unnormalized_heights, 56 | unnormalized_derivatives, 57 | inverse=False, 58 | tails='linear', 59 | tail_bound=1., 60 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 61 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 62 | min_derivative=DEFAULT_MIN_DERIVATIVE): 63 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) 64 | outside_interval_mask = ~inside_interval_mask 65 | 66 | outputs = torch.zeros_like(inputs) 67 | logabsdet = torch.zeros_like(inputs) 68 | 69 | if tails == 'linear': 70 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) 71 | constant = np.log(np.exp(1 - min_derivative) - 1) 72 | unnormalized_derivatives[..., 0] = constant 73 | unnormalized_derivatives[..., -1] = constant 74 | 75 | outputs[outside_interval_mask] = inputs[outside_interval_mask] 76 | logabsdet[outside_interval_mask] = 0 77 | else: 78 | raise RuntimeError('{} tails are not implemented.'.format(tails)) 79 | 80 | outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline( 81 | inputs=inputs[inside_interval_mask], 82 | unnormalized_widths=unnormalized_widths[inside_interval_mask, :], 83 | unnormalized_heights=unnormalized_heights[inside_interval_mask, :], 84 | unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], 85 | inverse=inverse, 86 | left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound, 87 | min_bin_width=min_bin_width, 88 | min_bin_height=min_bin_height, 89 | min_derivative=min_derivative 90 | ) 91 | 92 | return outputs, logabsdet 93 | 94 | def rational_quadratic_spline(inputs, 95 | unnormalized_widths, 96 | unnormalized_heights, 97 | unnormalized_derivatives, 98 | inverse=False, 99 | left=0., right=1., bottom=0., top=1., 100 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 101 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 102 | min_derivative=DEFAULT_MIN_DERIVATIVE): 103 | if torch.min(inputs) < left or torch.max(inputs) > right: 104 | raise ValueError('Input to a transform is not within its domain') 105 | 106 | num_bins = unnormalized_widths.shape[-1] 107 | 108 | if min_bin_width * num_bins > 1.0: 109 | raise ValueError('Minimal bin width too large for the number of bins') 110 | if min_bin_height * num_bins > 1.0: 111 | raise ValueError('Minimal bin height too large for the number of bins') 112 | 113 | widths = F.softmax(unnormalized_widths, dim=-1) 114 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths 115 | cumwidths = torch.cumsum(widths, dim=-1) 116 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0) 117 | cumwidths = (right - left) * cumwidths + left 118 | cumwidths[..., 0] = left 119 | cumwidths[..., -1] = right 120 | widths = cumwidths[..., 1:] - cumwidths[..., :-1] 121 | 122 | derivatives = min_derivative + F.softplus(unnormalized_derivatives) 123 | 124 | heights = F.softmax(unnormalized_heights, dim=-1) 125 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights 126 | cumheights = torch.cumsum(heights, dim=-1) 127 | cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0) 128 | cumheights = (top - bottom) * cumheights + bottom 129 | cumheights[..., 0] = bottom 130 | cumheights[..., -1] = top 131 | heights = cumheights[..., 1:] - cumheights[..., :-1] 132 | 133 | if inverse: 134 | bin_idx = searchsorted(cumheights, inputs)[..., None] 135 | else: 136 | bin_idx = searchsorted(cumwidths, inputs)[..., None] 137 | 138 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] 139 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0] 140 | 141 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] 142 | delta = heights / widths 143 | input_delta = delta.gather(-1, bin_idx)[..., 0] 144 | 145 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] 146 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] 147 | 148 | input_heights = heights.gather(-1, bin_idx)[..., 0] 149 | 150 | if inverse: 151 | a = (((inputs - input_cumheights) * (input_derivatives 152 | + input_derivatives_plus_one 153 | - 2 * input_delta) 154 | + input_heights * (input_delta - input_derivatives))) 155 | b = (input_heights * input_derivatives 156 | - (inputs - input_cumheights) * (input_derivatives 157 | + input_derivatives_plus_one 158 | - 2 * input_delta)) 159 | c = - input_delta * (inputs - input_cumheights) 160 | 161 | discriminant = b.pow(2) - 4 * a * c 162 | assert (discriminant >= 0).all() 163 | 164 | root = (2 * c) / (-b - torch.sqrt(discriminant)) 165 | outputs = root * input_bin_widths + input_cumwidths 166 | 167 | theta_one_minus_theta = root * (1 - root) 168 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) 169 | * theta_one_minus_theta) 170 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2) 171 | + 2 * input_delta * theta_one_minus_theta 172 | + input_derivatives * (1 - root).pow(2)) 173 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 174 | 175 | return outputs, -logabsdet 176 | else: 177 | theta = (inputs - input_cumwidths) / input_bin_widths 178 | theta_one_minus_theta = theta * (1 - theta) 179 | 180 | numerator = input_heights * (input_delta * theta.pow(2) 181 | + input_derivatives * theta_one_minus_theta) 182 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) 183 | * theta_one_minus_theta) 184 | outputs = input_cumheights + numerator / denominator 185 | 186 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2) 187 | + 2 * input_delta * theta_one_minus_theta 188 | + input_derivatives * (1 - theta).pow(2)) 189 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 190 | 191 | return outputs, logabsdet 192 | -------------------------------------------------------------------------------- /audio/augmentation/aug.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchaudio.functional as AF 5 | from .peq import ParametricEqualizer 6 | 7 | class Augment(nn.Module): 8 | def __init__(self, h): 9 | super().__init__() 10 | self.config = h 11 | self.coder = LinearPredictiveCoding( 12 | 32, h.data.win_length, h.data.hop_length) 13 | self.peq = ParametricEqualizer( 14 | h.data.sampling_rate, h.data.win_length) 15 | self.register_buffer( 16 | 'window', 17 | torch.hann_window(h.data.win_length), 18 | persistent=False) 19 | f_min, f_max, peaks = 60, 10000, 8 20 | self.register_buffer( 21 | 'peak_centers', 22 | f_min * (f_max / f_min) ** (torch.arange(peaks) / (peaks - 1)), 23 | persistent=False) 24 | 25 | def forward(self, 26 | wavs: torch.Tensor, 27 | mode: str = 'linear', 28 | ): 29 | """Augment the audio signal, random pitch, formant shift and PEQ. 30 | Args: 31 | wavs: [torch.float32; [B, T]], audio signal. 32 | mode: interpolation mode, `linear` or `nearest`. 33 | """ 34 | auxs = {} 35 | fft = torch.stft( 36 | wavs, 37 | self.config.data.filter_length, 38 | self.config.data.hop_length, 39 | self.config.data.win_length, 40 | self.window, 41 | return_complex=True) 42 | 43 | power, gain = self.sample(wavs) # for fs, ps 44 | 45 | if power is not None: 46 | q_min, q_max = 2, 5 47 | q = q_min * (q_max / q_min) ** power 48 | 49 | if gain is None: 50 | gain = torch.zeros_like(q[:, :-2]) 51 | 52 | bsize = wavs.shape[0] 53 | center = self.peak_centers[None].repeat(bsize, 1) 54 | peaks = torch.prod( 55 | self.peq.peaking_equalizer(center, gain, q[:, :-2]), dim=1) 56 | lowpass = self.peq.low_shelving(60, q[:, -2]) 57 | highpass = self.peq.high_shelving(10000, q[:, -1]) 58 | 59 | filters = peaks * highpass * lowpass 60 | fft = fft * filters[..., None] 61 | auxs.update({'peaks': peaks, 'highpass': highpass, 'lowpass': lowpass}) 62 | 63 | # Formant shifting and Pitch shifting 64 | fs_ratio = 1.4 65 | ps_ratio = 2.0 66 | 67 | code = self.coder.from_stft(fft / fft.abs().mean(dim=1)[:, None].clamp_min(1e-7)) 68 | filter_ = self.coder.envelope(code) 69 | source = fft.transpose(1, 2) / (filter_ + 1e-7) 70 | 71 | bsize = wavs.shape[0] 72 | def sampler(ratio): 73 | shifts = torch.rand(bsize, device=wavs.device) * (ratio - 1.) + 1. 74 | flip = torch.rand(bsize) < 0.5 75 | shifts[flip] = shifts[flip] ** -1 76 | return shifts 77 | 78 | fs_shift = sampler(fs_ratio) 79 | ps_shift = sampler(ps_ratio) 80 | 81 | source = fft.transpose(1, 2) / (filter_ + 1e-7) 82 | 83 | filter_ = self.interp(filter_, fs_shift, mode=mode) 84 | source = self.interp(source, ps_shift, mode=mode) 85 | 86 | fft = (source * filter_).transpose(1, 2) 87 | out = torch.istft( 88 | fft, 89 | self.config.data.filter_length, 90 | self.config.data.hop_length, 91 | self.config.data.win_length, 92 | self.window) 93 | out = out / out.max(dim=-1, keepdim=True).values.clamp_min(1e-7) 94 | 95 | return out 96 | 97 | def sample(self, wavs: torch.Tensor): 98 | bsize, _ = wavs.shape 99 | 100 | # parametric equalizer 101 | peaks = 8 102 | # quality factor 103 | power = torch.rand(bsize, peaks + 2, device=wavs.device) 104 | # gains 105 | g_min, g_max = -12, 12 106 | gain = torch.rand(bsize, peaks, device=wavs.device) * (g_max - g_min) + g_min 107 | 108 | return power, gain 109 | 110 | @staticmethod 111 | def complex_interp(inputs: torch.Tensor, *args, **kwargs): 112 | mag = F.interpolate(inputs.abs(), *args, **kwargs) 113 | angle = F.interpolate(inputs.angle(), *args, **kwargs) 114 | return torch.polar(mag, angle) 115 | 116 | def interp(self, inputs: torch.Tensor, shifts: torch.Tensor, mode: str): 117 | """Interpolate the channel axis with dynamic shifts. 118 | Args: 119 | inputs: [torch.complex64; [B, T, C]], input tensor. 120 | shifts: [torch.float32; [B]], shift factor. 121 | mode: interpolation mode. 122 | Returns: 123 | [torch.complex64; [B, T, C]], interpolated. 124 | """ 125 | INTERPOLATION = { 126 | torch.float32: F.interpolate, 127 | torch.complex64: Augment.complex_interp} 128 | assert inputs.dtype in INTERPOLATION, 'unsupported interpolation' 129 | interp_fn = INTERPOLATION[inputs.dtype] 130 | 131 | _, _, channels = inputs.shape 132 | 133 | interp = [ 134 | interp_fn( 135 | f[None], scale_factor=s.item(), mode=mode)[..., :channels] 136 | for f, s in zip(inputs, shifts)] 137 | 138 | return torch.cat([ 139 | F.pad(f, [0, channels - f.shape[-1]]) 140 | for f in interp], dim=0) 141 | 142 | 143 | class LinearPredictiveCoding(nn.Module): 144 | """LPC: Linear-predictive coding supports. 145 | """ 146 | 147 | def __init__(self, num_code: int, windows: int, strides: int): 148 | """Initializer. 149 | Args: 150 | num_code: the number of the coefficients. 151 | windows: size of the windows. 152 | strides: the number of the frames between adjacent windows. 153 | """ 154 | super().__init__() 155 | self.num_code = num_code 156 | self.windows = windows 157 | self.strides = strides 158 | 159 | def forward(self, inputs: torch.Tensor): 160 | """Compute the linear-predictive coefficients from inputs. 161 | Args: 162 | inputs: [torch.float32; [B, T]], audio signal. 163 | Returns: 164 | [torch.float32; [B, T / strides, num_code]], coefficients. 165 | """ 166 | w = self.windows 167 | frames = F.pad(inputs, [0, w]).unfold(-1, w, self.strides) 168 | corrcoef = LinearPredictiveCoding.autocorr(frames) 169 | 170 | return LinearPredictiveCoding.solve_toeplitz( 171 | corrcoef[..., :self.num_code + 1]) 172 | 173 | def from_stft(self, inputs: torch.Tensor): 174 | """Compute the linear-predictive coefficients from STFT. 175 | Args: 176 | inputs: [torch.complex64; [B, windows // 2 + 1, T / strides]], fourier features. 177 | Returns: 178 | [torch.float32; [B, T / strides, num_code]], linear-predictive coefficient. 179 | """ 180 | corrcoef = torch.fft.irfft(inputs.abs().square(), dim=1) 181 | 182 | return LinearPredictiveCoding.solve_toeplitz( 183 | corrcoef[:, :self.num_code + 1].transpose(1, 2)) 184 | 185 | def envelope(self, lpc: torch.Tensor): 186 | """LPC to spectral envelope. 187 | Args: 188 | lpc: [torch.float32; [..., num_code]], coefficients. 189 | Returns: 190 | [torch.float32; [..., windows // 2 + 1]], filters. 191 | """ 192 | denom = torch.fft.rfft(-F.pad(lpc, [1, 0], value=1.), self.windows, dim=-1).abs() 193 | # for preventing zero-division 194 | denom[(denom.abs() - 1e-7) < 0] = 1. 195 | return denom ** -1 196 | 197 | @staticmethod 198 | def autocorr(wavs: torch.Tensor): 199 | """Compute the autocorrelation. 200 | Args: audio signal. 201 | Returns: auto-correlation. 202 | """ 203 | fft = torch.fft.rfft(wavs, dim=-1) 204 | return torch.fft.irfft(fft.abs().square(), dim=-1) 205 | 206 | @staticmethod 207 | def solve_toeplitz(corrcoef: torch.Tensor): 208 | """Solve the toeplitz matrix. 209 | Args: 210 | corrcoef: [torch.float32; [..., num_code + 1]], auto-correlation. 211 | Returns: 212 | [torch.float32; [..., num_code]], solutions. 213 | """ 214 | 215 | solutions = F.pad( 216 | (-corrcoef[..., 1] / corrcoef[..., 0].clamp_min(1e-7))[..., None], 217 | [1, 0], value=1.) 218 | 219 | extra = corrcoef[..., 0] + corrcoef[..., 1] * solutions[..., 1] 220 | 221 | ## solve residuals 222 | num_code = corrcoef.shape[-1] - 1 223 | for k in range(1, num_code): 224 | lambda_value = ( 225 | -solutions[..., :k + 1] 226 | * torch.flip(corrcoef[..., 1:k + 2], dims=[-1]) 227 | ).sum(dim=-1) / extra.clamp_min(1e-7) 228 | aug = F.pad(solutions, [0, 1]) 229 | solutions = aug + lambda_value[..., None] * torch.flip(aug, dims=[-1]) 230 | extra = (1. - lambda_value ** 2) * extra 231 | 232 | return solutions[..., 1:] -------------------------------------------------------------------------------- /audio/model/diffusion_mel.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import random 4 | import numpy as np 5 | from torch.nn import functional as F 6 | 7 | from model.base import BaseModule 8 | from model.diffusion_module import * 9 | 10 | 11 | class GradLogPEstimator(BaseModule): 12 | def __init__(self, dim_base, dim_cond, dim_mults=(1, 2, 4)): 13 | super(GradLogPEstimator, self).__init__() 14 | 15 | dims = [2 + dim_cond, *map(lambda m: dim_base * m, dim_mults)] 16 | in_out = list(zip(dims[:-1], dims[1:])) 17 | 18 | self.time_pos_emb = SinusoidalPosEmb(dim_base) 19 | self.mlp = torch.nn.Sequential(torch.nn.Linear(dim_base, dim_base * 4), 20 | Mish(), torch.nn.Linear(dim_base * 4, dim_base)) 21 | cond_total = dim_base + 256 22 | self.cond_block = torch.nn.Sequential(torch.nn.Linear(cond_total, 4 * dim_cond), 23 | Mish(), torch.nn.Linear(4 * dim_cond, dim_cond)) 24 | 25 | self.downs = torch.nn.ModuleList([]) 26 | self.ups = torch.nn.ModuleList([]) 27 | num_resolutions = len(in_out) 28 | 29 | for ind, (dim_in, dim_out) in enumerate(in_out): 30 | is_last = ind >= (num_resolutions - 1) 31 | self.downs.append(torch.nn.ModuleList([ 32 | ResnetBlock(dim_in, dim_out, time_emb_dim=dim_base), 33 | ResnetBlock(dim_out, dim_out, time_emb_dim=dim_base), 34 | Residual(Rezero(LinearAttention(dim_out))), 35 | Downsample(dim_out) if not is_last else torch.nn.Identity()])) 36 | 37 | mid_dim = dims[-1] 38 | self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim_base) 39 | self.mid_attn = Residual(Rezero(LinearAttention(mid_dim))) 40 | self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim_base) 41 | 42 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): 43 | self.ups.append(torch.nn.ModuleList([ 44 | ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim_base), 45 | ResnetBlock(dim_in, dim_in, time_emb_dim=dim_base), 46 | Residual(Rezero(LinearAttention(dim_in))), 47 | Upsample(dim_in)])) 48 | 49 | self.m_final_block = Block(dim_base, dim_base) 50 | self.m_final_conv = torch.nn.Conv2d(dim_base, 1, 1) 51 | 52 | self.z_final_block = Block(dim_base, dim_base) 53 | self.z_final_conv = torch.nn.Conv2d(dim_base, 1, 1) 54 | 55 | def forward(self, x, x_mask, enc_out, spk, t): 56 | condition = self.time_pos_emb(t) 57 | t = self.mlp(condition) 58 | 59 | x = torch.stack([enc_out, x], 1) 60 | x_mask = x_mask.unsqueeze(1) 61 | 62 | condition = torch.cat([condition, spk.squeeze(2)], 1) 63 | condition = self.cond_block(condition).unsqueeze(-1).unsqueeze(-1) 64 | 65 | condition = torch.cat(x.shape[2] * [condition], 2) 66 | condition = torch.cat(x.shape[3] * [condition], 3) 67 | x = torch.cat([x, condition], 1) 68 | 69 | hiddens = [] 70 | masks = [x_mask] 71 | 72 | for resnet1, resnet2, attn, downsample in self.downs: 73 | mask_down = masks[-1] 74 | x = resnet1(x, mask_down, t) 75 | x = resnet2(x, mask_down, t) 76 | x = attn(x) 77 | hiddens.append(x) 78 | x = downsample(x * mask_down) 79 | masks.append(mask_down[:, :, :, ::2]) 80 | 81 | masks = masks[:-1] 82 | mask_mid = masks[-1] 83 | x = self.mid_block1(x, mask_mid, t) 84 | x = self.mid_attn(x) 85 | x = self.mid_block2(x, mask_mid, t) 86 | 87 | for resnet1, resnet2, attn, upsample in self.ups: 88 | mask_up = masks.pop() 89 | x = torch.cat((x, hiddens.pop()), dim=1) 90 | x = resnet1(x, mask_up, t) 91 | x = resnet2(x, mask_up, t) 92 | x = attn(x) 93 | x = upsample(x * mask_up) 94 | 95 | m_x = self.m_final_block(x, x_mask) 96 | m_output = self.m_final_conv(m_x * x_mask) 97 | 98 | z_x = self.z_final_block(x, x_mask) 99 | z_output = self.z_final_conv(z_x * x_mask) 100 | 101 | return (m_output * x_mask).squeeze(1), (z_output * x_mask).squeeze(1) 102 | 103 | 104 | class Diffusion(BaseModule): 105 | def __init__(self, n_feats, dim_unet, dim_spk, beta_min, beta_max): 106 | super(Diffusion, self).__init__() 107 | self.estimator = GradLogPEstimator(dim_unet, dim_spk) 108 | 109 | self.n_feats = n_feats 110 | self.dim_unet = dim_unet 111 | self.dim_spk = dim_spk 112 | self.beta_min = beta_min 113 | self.beta_max = beta_max 114 | 115 | def get_beta(self, t): 116 | beta = self.beta_min + (self.beta_max - self.beta_min) * t 117 | return beta 118 | 119 | def get_gamma(self, s, t, p=1.0, use_torch=False): 120 | beta_integral = self.beta_min + 0.5 * (self.beta_max - self.beta_min) * (t + s) 121 | beta_integral *= (t - s) 122 | if use_torch: 123 | gamma = torch.exp(-0.5 * p * beta_integral).unsqueeze(-1).unsqueeze(-1) 124 | else: 125 | gamma = math.exp(-0.5 * p * beta_integral) 126 | return gamma 127 | 128 | def get_mu(self, s, t): 129 | a = self.get_gamma(s, t) 130 | b = 1.0 - self.get_gamma(0, s, p=2.0) 131 | c = 1.0 - self.get_gamma(0, t, p=2.0) 132 | return a * b / c 133 | 134 | def get_nu(self, s, t): 135 | a = self.get_gamma(0, s) 136 | b = 1.0 - self.get_gamma(s, t, p=2.0) 137 | c = 1.0 - self.get_gamma(0, t, p=2.0) 138 | return a * b / c 139 | 140 | def get_sigma(self, s, t): 141 | a = 1.0 - self.get_gamma(0, s, p=2.0) 142 | b = 1.0 - self.get_gamma(s, t, p=2.0) 143 | c = 1.0 - self.get_gamma(0, t, p=2.0) 144 | return math.sqrt(a * b / c) 145 | 146 | def compute_diffused_z_pr(self, x0, mask, z_pr, t, use_torch=False): 147 | x0_weight = self.get_gamma(0, t, use_torch=use_torch) 148 | z_pr_weight = 1.0 - x0_weight 149 | xt_z_pr = x0 * x0_weight + z_pr * z_pr_weight 150 | return xt_z_pr * mask 151 | 152 | 153 | @torch.no_grad() 154 | def reverse(self, z, mask, z_pr, spk, ts): 155 | h = 1.0 / ts 156 | xt = z * mask 157 | 158 | for i in range(ts): 159 | t = 1.0 - i * h 160 | time = t * torch.ones(z.shape[0], dtype=z.dtype, device=z.device) 161 | beta_t = self.get_beta(t) 162 | 163 | kappa = self.get_gamma(0, t - h) * (1.0 - self.get_gamma(t - h, t, p=2.0)) 164 | kappa /= (self.get_gamma(0, t) * beta_t * h) 165 | kappa -= 1.0 166 | omega = self.get_nu(t - h, t) / self.get_gamma(0, t) 167 | omega += self.get_mu(t - h, t) 168 | omega -= (0.5 * beta_t * h + 1.0) 169 | sigma = self.get_sigma(t - h, t) 170 | 171 | dxt = (z_pr - xt) * (0.5 * beta_t * h + omega) 172 | tmp, dxt_ = self.estimator(xt, mask, z_pr, spk, time) 173 | dxt -= dxt_ * (1.0 + kappa) * (beta_t * h) 174 | dxt += torch.randn_like(z, device=z.device) * sigma 175 | xt = (xt - dxt) * mask 176 | 177 | return xt 178 | 179 | @torch.no_grad() 180 | def forward(self, z, mask, enc_out, spk, n_timesteps, mode): 181 | return self.reverse_diffusion(z, mask, enc_out, spk, n_timesteps, mode) 182 | 183 | def random_masking(self, xt, num, frame): 184 | xt_mask = torch.ones_like(xt) 185 | x0_mask = torch.ones_like(xt) 186 | for _ in range(num): 187 | idx = random.randint(0, xt.size(1)-frame) 188 | xt[:, idx:idx+frame, :] = 0 189 | xt_mask[:, idx:idx+frame, :] = 0 190 | x0_mask -= xt_mask 191 | 192 | return xt, xt_mask, x0_mask 193 | 194 | 195 | def compute_diffused_z_pr(self, x0, mask, z_pr, t, use_torch=False): 196 | x0_weight = self.get_gamma(0, t, use_torch=use_torch) 197 | z_pr_weight = 1.0 - x0_weight 198 | xt_z_pr = x0 * x0_weight + z_pr * z_pr_weight 199 | return xt_z_pr * mask 200 | 201 | 202 | def forward_diffusion(self, x0, mask, enc_out, t): 203 | xt = self.compute_diffused_z_pr(x0, mask, enc_out, t, use_torch=True) 204 | variance = 1.0 - self.get_gamma(0, t, p=2.0, use_torch=True) 205 | z = torch.randn(x0.shape, dtype=x0.dtype, device=x0.device, requires_grad=False) 206 | xt = xt + z * torch.sqrt(variance) 207 | 208 | return xt * mask, z * mask 209 | 210 | def compute_loss(self, x0, mask, enc_out, spk, t): 211 | xt, z = self.forward_diffusion(x0, mask, enc_out, t) 212 | masked_xt, xt_mask, x0_mask = self.random_masking(xt, num=4, frame=8) 213 | 214 | m_estimation, z_estimation = self.estimator(masked_xt, mask, enc_out, spk, t) 215 | m_estimation *= torch.sqrt(1.0 - self.get_gamma(0, t, p=2.0, use_torch=True)) 216 | z_estimation *= torch.sqrt(1.0 - self.get_gamma(0, t, p=2.0, use_torch=True)) 217 | diff_loss = torch.sum((z_estimation*xt_mask + z) ** 2) / (torch.sum(mask) * self.n_feats) 218 | recon_loss = F.l1_loss(x0*x0_mask, m_estimation*x0_mask) 219 | 220 | return diff_loss, recon_loss 221 | 222 | def compute_t(self, x0, mask, enc_out, spk, offset=1e-5): 223 | b = x0.shape[0] 224 | t = torch.rand(b, dtype=x0.dtype, device=x0.device, requires_grad=False) 225 | t = torch.clamp(t, offset, 1.0 - offset) 226 | 227 | return self.compute_loss(x0, mask, enc_out, spk, t) 228 | -------------------------------------------------------------------------------- /audio/vocoder/hifigan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import vocoder.modules as modules 5 | 6 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d 7 | from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm 8 | from module.commons import * 9 | from torch.cuda.amp import autocast 10 | import torchaudio 11 | from einops import rearrange 12 | import typing as tp 13 | 14 | def get_2d_padding(kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)): 15 | return (((kernel_size[0] - 1) * dilation[0]) // 2, ((kernel_size[1] - 1) * dilation[1]) // 2) 16 | 17 | class Generator(torch.nn.Module): 18 | def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0): 19 | super(Generator, self).__init__() 20 | self.num_kernels = len(resblock_kernel_sizes) 21 | self.num_upsamples = len(upsample_rates) 22 | self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) 23 | resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2 24 | 25 | self.ups = nn.ModuleList() 26 | for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): 27 | self.ups.append(weight_norm( 28 | ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)), 29 | k, u, padding=(k-u)//2))) 30 | 31 | self.resblocks = nn.ModuleList() 32 | for i in range(len(self.ups)): 33 | ch = upsample_initial_channel//(2**(i+1)) 34 | for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): 35 | self.resblocks.append(resblock(ch, k, d)) 36 | 37 | self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) 38 | self.ups.apply(init_weights) 39 | 40 | if gin_channels != 0: 41 | self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) 42 | 43 | def forward(self, x, g=None): 44 | x = self.conv_pre(x) 45 | if g is not None: 46 | x = x + self.cond(g) 47 | 48 | for i in range(self.num_upsamples): 49 | x = F.leaky_relu(x, modules.LRELU_SLOPE) 50 | x = self.ups[i](x) 51 | xs = None 52 | for j in range(self.num_kernels): 53 | if xs is None: 54 | xs = self.resblocks[i*self.num_kernels+j](x) 55 | else: 56 | xs += self.resblocks[i*self.num_kernels+j](x) 57 | x = xs / self.num_kernels 58 | x = F.leaky_relu(x) 59 | x = self.conv_post(x) 60 | x = torch.tanh(x) 61 | 62 | return x 63 | 64 | def remove_weight_norm(self): 65 | print('Removing weight norm...') 66 | for l in self.ups: 67 | remove_weight_norm(l) 68 | for l in self.resblocks: 69 | l.remove_weight_norm() 70 | 71 | class DiscriminatorS(torch.nn.Module): 72 | def __init__(self, use_spectral_norm=False): 73 | super(DiscriminatorS, self).__init__() 74 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 75 | self.convs = nn.ModuleList([ 76 | norm_f(Conv1d(1, 16, 15, 1, padding=7)), 77 | norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), 78 | norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), 79 | norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), 80 | norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), 81 | norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), 82 | ]) 83 | self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) 84 | 85 | def forward(self, x): 86 | fmap = [] 87 | 88 | for l in self.convs: 89 | x = l(x) 90 | x = F.leaky_relu(x, modules.LRELU_SLOPE) 91 | fmap.append(x) 92 | x = self.conv_post(x) 93 | fmap.append(x) 94 | x = torch.flatten(x, 1, -1) 95 | 96 | return x, fmap 97 | 98 | class DiscriminatorP(torch.nn.Module): 99 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 100 | super(DiscriminatorP, self).__init__() 101 | self.period = period 102 | self.use_spectral_norm = use_spectral_norm 103 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 104 | self.convs = nn.ModuleList([ 105 | norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 106 | norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 107 | norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 108 | norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 109 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))), 110 | ]) 111 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 112 | 113 | def forward(self, x): 114 | fmap = [] 115 | 116 | # 1d to 2d 117 | b, c, t = x.shape 118 | if t % self.period != 0: # pad first 119 | n_pad = self.period - (t % self.period) 120 | x = F.pad(x, (0, n_pad), "reflect") 121 | t = t + n_pad 122 | x = x.view(b, c, t // self.period, self.period) 123 | 124 | for l in self.convs: 125 | x = l(x) 126 | x = F.leaky_relu(x, modules.LRELU_SLOPE) 127 | fmap.append(x) 128 | x = self.conv_post(x) 129 | fmap.append(x) 130 | x = torch.flatten(x, 1, -1) 131 | 132 | return x, fmap 133 | 134 | class DiscriminatorR(torch.nn.Module): 135 | def __init__(self, resolution, use_spectral_norm=False): 136 | super(DiscriminatorR, self).__init__() 137 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 138 | 139 | n_fft, hop_length, win_length = resolution 140 | self.spec_transform = torchaudio.transforms.Spectrogram( 141 | n_fft=n_fft, hop_length=hop_length, win_length=win_length, window_fn=torch.hann_window, 142 | normalized=True, center=False, pad_mode=None, power=None) 143 | 144 | self.convs = nn.ModuleList([ 145 | norm_f(nn.Conv2d(2, 32, (3, 9), padding=(1, 4))), 146 | norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), 147 | norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), dilation=(2,1), padding=(2, 4))), 148 | norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), dilation=(4,1), padding=(4, 4))), 149 | norm_f(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))), 150 | ]) 151 | self.conv_post = norm_f(nn.Conv2d(32, 1, (3, 3), padding=(1, 1))) 152 | 153 | def forward(self, y): 154 | fmap = [] 155 | 156 | x = self.spec_transform(y) # [B, 2, Freq, Frames, 2] 157 | x = torch.cat([x.real, x.imag], dim=1) 158 | x = rearrange(x, 'b c w t -> b c t w') 159 | 160 | for l in self.convs: 161 | x = l(x) 162 | x = F.leaky_relu(x, modules.LRELU_SLOPE) 163 | fmap.append(x) 164 | x = self.conv_post(x) 165 | fmap.append(x) 166 | x = torch.flatten(x, 1, -1) 167 | 168 | return x, fmap 169 | 170 | 171 | class MultiPeriodDiscriminator(torch.nn.Module): 172 | def __init__(self, use_spectral_norm=False): 173 | super(MultiPeriodDiscriminator, self).__init__() 174 | # periods = [2,3,5,7,11] 175 | # resolutions = [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]] 176 | resolutions = [[2048, 512, 2048], [1024, 256, 1024], [512, 128, 512], [256, 64, 256], [128, 32, 128]] 177 | 178 | discs = [DiscriminatorR(resolutions[i], use_spectral_norm=use_spectral_norm) for i in range(len(resolutions))] 179 | # discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] 180 | # discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods] 181 | self.discriminators = nn.ModuleList(discs) 182 | 183 | def forward(self, y, y_hat): 184 | y_d_rs = [] 185 | y_d_gs = [] 186 | fmap_rs = [] 187 | fmap_gs = [] 188 | for i, d in enumerate(self.discriminators): 189 | y_d_r, fmap_r = d(y) 190 | y_d_g, fmap_g = d(y_hat) 191 | y_d_rs.append(y_d_r) 192 | y_d_gs.append(y_d_g) 193 | fmap_rs.append(fmap_r) 194 | fmap_gs.append(fmap_g) 195 | 196 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 197 | 198 | class HiFi(nn.Module): 199 | """ 200 | Synthesizer for Training 201 | """ 202 | 203 | def __init__(self, 204 | 205 | spec_channels, 206 | segment_size, 207 | inter_channels, 208 | hidden_channels, 209 | filter_channels, 210 | n_heads, 211 | n_layers, 212 | kernel_size, 213 | p_dropout, 214 | resblock, 215 | resblock_kernel_sizes, 216 | resblock_dilation_sizes, 217 | upsample_rates, 218 | upsample_initial_channel, 219 | upsample_kernel_sizes, 220 | **kwargs): 221 | 222 | super().__init__() 223 | self.spec_channels = spec_channels 224 | self.inter_channels = inter_channels 225 | self.hidden_channels = hidden_channels 226 | self.filter_channels = filter_channels 227 | self.n_heads = n_heads 228 | self.n_layers = n_layers 229 | self.kernel_size = kernel_size 230 | self.p_dropout = p_dropout 231 | self.resblock = resblock 232 | self.resblock_kernel_sizes = resblock_kernel_sizes 233 | self.resblock_dilation_sizes = resblock_dilation_sizes 234 | self.upsample_rates = upsample_rates 235 | self.upsample_initial_channel = upsample_initial_channel 236 | self.upsample_kernel_sizes = upsample_kernel_sizes 237 | self.segment_size = segment_size 238 | 239 | self.dec = Generator(spec_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes) 240 | 241 | def forward(self, x): 242 | 243 | y = self.dec(x) 244 | return y 245 | 246 | def infer(self, x, max_len=None): 247 | 248 | o = self.dec(x[:,:,:max_len]) 249 | return o 250 | 251 | -------------------------------------------------------------------------------- /audio/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.nn import functional as F 4 | from torch.nn.parallel import DistributedDataParallel as DDP 5 | 6 | import torch.distributed as dist 7 | import torch.multiprocessing as mp 8 | from torch.cuda.amp import autocast, GradScaler 9 | 10 | from torch.utils.data import DataLoader 11 | from torch.utils.tensorboard import SummaryWriter 12 | from torch.utils.data.distributed import DistributedSampler 13 | 14 | import random 15 | import commons 16 | import utils 17 | 18 | from augmentation.aug import Augment 19 | from model.diffhiervc import Wav2vec2, DiffHierVC 20 | from data_loader import AudioDataset, MelSpectrogramFixed 21 | from vocoder.hifigan import HiFi 22 | from torch.utils.data import DataLoader 23 | 24 | torch.backends.cudnn.benchmark = True 25 | global_step = 0 26 | 27 | def get_param_num(model): 28 | num_param = sum(param.numel() for param in model.parameters()) 29 | return num_param 30 | 31 | def main(): 32 | """Assume Single Node Multi GPUs Training Only""" 33 | assert torch.cuda.is_available(), "CPU training is not allowed." 34 | 35 | n_gpus = torch.cuda.device_count() 36 | port = 50000 + random.randint(0, 100) 37 | os.environ['MASTER_ADDR'] = 'localhost' 38 | os.environ['MASTER_PORT'] = str(port) 39 | 40 | hps = utils.get_hparams() 41 | mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hps,)) 42 | 43 | def run(rank, n_gpus, hps): 44 | global global_step 45 | if rank == 0: 46 | logger = utils.get_logger(hps.model_dir) 47 | logger.info(hps) 48 | utils.check_git_hash(hps.model_dir) 49 | writer = SummaryWriter(log_dir=hps.model_dir) 50 | writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval")) 51 | 52 | dist.init_process_group(backend='nccl', init_method='env://', world_size=n_gpus, rank=rank) 53 | torch.manual_seed(hps.train.seed) 54 | torch.cuda.set_device(rank) 55 | 56 | mel_fn = MelSpectrogramFixed( 57 | sample_rate=hps.data.sampling_rate, 58 | n_fft=hps.data.filter_length, 59 | win_length=hps.data.win_length, 60 | hop_length=hps.data.hop_length, 61 | f_min=hps.data.mel_fmin, 62 | f_max=hps.data.mel_fmax, 63 | n_mels=hps.data.n_mel_channels, 64 | window_fn=torch.hann_window 65 | ).cuda(rank) 66 | 67 | train_dataset = AudioDataset(hps, training=True) 68 | train_sampler = DistributedSampler(train_dataset) if n_gpus > 1 else None 69 | train_loader = DataLoader( 70 | train_dataset, batch_size=hps.train.batch_size, num_workers=32, 71 | sampler=train_sampler, drop_last=True, persistent_workers=True, pin_memory=True 72 | ) 73 | 74 | if rank == 0: 75 | test_dataset = AudioDataset(hps, training=False) 76 | eval_loader = DataLoader(test_dataset, batch_size=1) 77 | 78 | w2v = Wav2vec2().cuda(rank) 79 | aug = Augment(hps).cuda(rank) 80 | 81 | model = DiffHierVC(hps.data.n_mel_channels, hps.diffusion.spk_dim, 82 | hps.diffusion.dec_dim, hps.diffusion.beta_min, hps.diffusion.beta_max, hps).cuda() 83 | 84 | net_v = HiFi( 85 | hps.data.n_mel_channels, 86 | hps.train.segment_size // hps.data.hop_length, 87 | **hps.model).cuda() 88 | path_ckpt = './vocoder/voc_hifigan.pth' 89 | 90 | utils.load_checkpoint(path_ckpt, net_v, None) 91 | net_v.eval() 92 | net_v.dec.remove_weight_norm() 93 | 94 | if rank == 0: 95 | num_param = get_param_num(model.encoder) 96 | print('[Encoder] number of Parameters:', num_param) 97 | num_param = get_param_num(model.f0_dec) 98 | print('[F0 Decoder] number of Parameters:', num_param) 99 | num_param = get_param_num(model.mel_dec) 100 | print('[Mel Decoder] number of Parameters:', num_param) 101 | 102 | optimizer = torch.optim.AdamW( 103 | model.parameters(), 104 | hps.train.learning_rate, 105 | betas=hps.train.betas, 106 | eps=hps.train.eps) 107 | 108 | model = DDP(model, device_ids=[rank]) 109 | 110 | try: 111 | _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), model, optimizer) 112 | global_step = (epoch_str - 1) * len(train_loader) 113 | except: 114 | epoch_str = 1 115 | global_step = 0 116 | 117 | scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) 118 | scaler = GradScaler(enabled=hps.train.fp16_run) 119 | 120 | for epoch in range(epoch_str, hps.train.epochs + 1): 121 | if rank == 0: 122 | train_and_evaluate(rank, epoch, hps, [model, mel_fn, w2v, aug, net_v], optimizer, 123 | scheduler_g, scaler, [train_loader, eval_loader], logger, [writer, writer_eval]) 124 | else: 125 | train_and_evaluate(rank, epoch, hps, [model, mel_fn, w2v, aug, net_v], optimizer, 126 | scheduler_g, scaler, [train_loader, None], None, None) 127 | scheduler_g.step() 128 | 129 | def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers): 130 | model, mel_fn, w2v, aug, net_v = nets 131 | optimizer = optims 132 | scheduler_g = schedulers 133 | train_loader, eval_loader = loaders 134 | 135 | if writers is not None: 136 | writer, writer_eval = writers 137 | global global_step 138 | 139 | train_loader.sampler.set_epoch(epoch) 140 | model.train() 141 | for batch_idx, (x, norm_f0, x_f0, length) in enumerate(train_loader): 142 | x = x.cuda(rank, non_blocking=True) 143 | norm_f0 = norm_f0.cuda(rank, non_blocking=True) 144 | x_f0 = x_f0.cuda(rank, non_blocking=True) 145 | length = length.cuda(rank, non_blocking=True).squeeze() 146 | 147 | mel_x = mel_fn(x) 148 | aug_x = aug(x) 149 | nan_x = torch.isnan(aug_x).any() 150 | x = x if nan_x else aug_x 151 | x_pad = F.pad(x, (40, 40), "reflect") 152 | 153 | w2v_x = w2v(x_pad) 154 | f0_x = torch.log(x_f0+1) 155 | 156 | optimizer.zero_grad() 157 | loss_mel_diff, loss_mel_diff_rec, loss_f0_diff, loss_mel, loss_f0 = model.module.compute_loss(mel_x, w2v_x, norm_f0, f0_x, length) 158 | loss_gen_all = loss_mel_diff + loss_mel_diff_rec + loss_f0_diff + loss_mel*hps.train.c_mel + loss_f0 159 | 160 | if hps.train.fp16_run: 161 | scaler.scale(loss_gen_all).backward() 162 | scaler.unscale_(optimizer) 163 | grad_norm_g = commons.clip_grad_value_(model.parameters(), None) 164 | scaler.step(optimizer) 165 | scaler.update() 166 | else: 167 | loss_gen_all.backward() 168 | grad_norm_g = commons.clip_grad_value_(model.parameters(), None) 169 | optimizer.step() 170 | 171 | if rank == 0: 172 | if global_step % hps.train.log_interval == 0: 173 | lr = optimizer.param_groups[0]['lr'] 174 | losses = [loss_mel_diff, loss_f0_diff] 175 | logger.info('Train Epoch: {} [{:.0f}%]'.format( 176 | epoch, 177 | 100. * batch_idx / len(train_loader))) 178 | logger.info([x.item() for x in losses] + [global_step, lr]) 179 | 180 | scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g} 181 | scalar_dict.update({"loss/g/diff": loss_mel_diff, "loss/g/diff_rec": loss_mel_diff_rec, "loss/g/f0_diff": loss_f0_diff, "loss/g/mel": loss_mel, "loss/g/f0": loss_f0}) 182 | 183 | utils.summarize( 184 | writer=writer, 185 | global_step=global_step, 186 | scalars=scalar_dict) 187 | 188 | if global_step % hps.train.eval_interval == 0: 189 | torch.cuda.empty_cache() 190 | evaluate(hps, model, mel_fn, w2v, net_v, eval_loader, writer_eval) 191 | 192 | if global_step % hps.train.save_interval == 0: 193 | utils.save_checkpoint(model, optimizer, hps.train.learning_rate, epoch, 194 | os.path.join(hps.model_dir, "G_{}.pth".format(global_step))) 195 | 196 | global_step += 1 197 | 198 | if rank == 0: 199 | logger.info('====> Epoch: {}'.format(epoch)) 200 | 201 | 202 | def evaluate(hps, model, mel_fn, w2v, net_v, eval_loader, writer_eval): 203 | model.eval() 204 | image_dict = {} 205 | audio_dict = {} 206 | mel_loss = 0 207 | enc_loss = 0 208 | enc_f0_loss = 0 209 | diff_f0_loss = 0 210 | 211 | with torch.no_grad(): 212 | for batch_idx, (y, norm_y_f0, y_f0) in enumerate(eval_loader): 213 | y = y.cuda(0) 214 | norm_y_f0 = norm_y_f0.cuda(0) 215 | y_f0 = y_f0.cuda(0) 216 | 217 | mel_y = mel_fn(y) 218 | f0_y = torch.log(y_f0+1) 219 | length = torch.LongTensor([mel_y.size(2)]).cuda(0) 220 | 221 | y_pad = F.pad(y, (40, 40), "reflect") 222 | w2v_y = w2v(y_pad) 223 | 224 | y_f0_hat, y_mel, o_f0, o_mel = model(mel_y, w2v_y, norm_y_f0, f0_y, length, n_timesteps=6, mode='ml') 225 | 226 | mel_loss += F.l1_loss(mel_y, o_mel).item() 227 | enc_loss += F.l1_loss(mel_y, y_mel).item() 228 | enc_f0_loss += F.l1_loss(f0_y, y_f0_hat).item() 229 | diff_f0_loss += F.l1_loss(f0_y, o_f0).item() 230 | 231 | if batch_idx > 100: 232 | break 233 | if batch_idx <= 4: 234 | y_hat = net_v(o_mel) 235 | enc_hat = net_v(y_mel) 236 | 237 | plot_mel = torch.cat([mel_y, o_mel, y_mel], dim=1) 238 | plot_mel = plot_mel.clip(min=-10, max=10) 239 | 240 | image_dict.update({ 241 | "gen/mel_{}".format(batch_idx): utils.plot_spectrogram_to_numpy(plot_mel.squeeze().cpu().numpy()), 242 | "F0/f0_{}".format(batch_idx): 243 | utils.plot_f0_contour_to_numpy(mel_y.repeat_interleave(repeats=4, dim=2).squeeze().cpu().numpy(), 244 | f0s= {'target_f0': y_f0.squeeze().cpu(), 245 | 'enc_f0': (torch.exp(y_f0_hat)-1).squeeze().cpu(), 246 | 'diff_6_f0': (torch.exp(o_f0)-1).squeeze().cpu() 247 | }) 248 | }) 249 | audio_dict.update({ 250 | "gen/audio_{}".format(batch_idx): y_hat.squeeze(), 251 | "gen/enc_audio_{}".format(batch_idx): enc_hat.squeeze() 252 | }) 253 | if global_step == 0: 254 | audio_dict.update({"gt/audio_{}".format(batch_idx): y.squeeze()}) 255 | 256 | mel_loss /= 100 257 | enc_loss /= 100 258 | enc_f0_loss /= 100 259 | diff_f0_loss /= 100 260 | 261 | scalar_dict = {"val/mel": mel_loss, "val/enc_mel": enc_loss, "val/enc_f0": enc_f0_loss, "val/diff_f0": diff_f0_loss} 262 | utils.summarize( 263 | writer=writer_eval, 264 | global_step=global_step, 265 | images=image_dict, 266 | audios=audio_dict, 267 | audio_sampling_rate=hps.data.sampling_rate, 268 | scalars=scalar_dict 269 | ) 270 | model.train() 271 | 272 | 273 | if __name__ == "__main__": 274 | main() 275 | -------------------------------------------------------------------------------- /audio/vocoder/bigvgan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d 6 | from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm 7 | from torch.cuda.amp import autocast 8 | import torchaudio 9 | from einops import rearrange 10 | 11 | from alias_free_torch import * 12 | from module.commons import init_weights, get_padding 13 | import vocoder.modules as modules 14 | import vocoder.activations as activations 15 | 16 | class AMPBlock1(torch.nn.Module): 17 | def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), activation=None): 18 | super(AMPBlock1, self).__init__() 19 | 20 | self.convs1 = nn.ModuleList([ 21 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 22 | padding=get_padding(kernel_size, dilation[0]))), 23 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 24 | padding=get_padding(kernel_size, dilation[1]))), 25 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], 26 | padding=get_padding(kernel_size, dilation[2]))) 27 | ]) 28 | self.convs1.apply(init_weights) 29 | 30 | self.convs2 = nn.ModuleList([ 31 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 32 | padding=get_padding(kernel_size, 1))), 33 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 34 | padding=get_padding(kernel_size, 1))), 35 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 36 | padding=get_padding(kernel_size, 1))) 37 | ]) 38 | self.convs2.apply(init_weights) 39 | 40 | self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers 41 | 42 | 43 | self.activations = nn.ModuleList([ 44 | Activation1d( 45 | activation=activations.SnakeBeta(channels, alpha_logscale=True)) 46 | for _ in range(self.num_layers) 47 | ]) 48 | 49 | def forward(self, x): 50 | acts1, acts2 = self.activations[::2], self.activations[1::2] 51 | for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2): 52 | xt = a1(x) 53 | xt = c1(xt) 54 | xt = a2(xt) 55 | xt = c2(xt) 56 | x = xt + x 57 | 58 | return x 59 | 60 | def remove_weight_norm(self): 61 | for l in self.convs1: 62 | remove_weight_norm(l) 63 | for l in self.convs2: 64 | remove_weight_norm(l) 65 | 66 | 67 | class AMPBlock2(torch.nn.Module): 68 | def __init__(self, channels, kernel_size=3, dilation=(1, 3), activation=None): 69 | super(AMPBlock2, self).__init__() 70 | 71 | 72 | self.convs = nn.ModuleList([ 73 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 74 | padding=get_padding(kernel_size, dilation[0]))), 75 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 76 | padding=get_padding(kernel_size, dilation[1]))) 77 | ]) 78 | self.convs.apply(init_weights) 79 | 80 | self.num_layers = len(self.convs) # total number of conv layers 81 | 82 | if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing 83 | self.activations = nn.ModuleList([ 84 | Activation1d( 85 | activation=activations.Snake(channels, alpha_logscale=True)) 86 | for _ in range(self.num_layers) 87 | ]) 88 | elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing 89 | self.activations = nn.ModuleList([ 90 | Activation1d( 91 | activation=activations.SnakeBeta(channels, alpha_logscale=True)) 92 | for _ in range(self.num_layers) 93 | ]) 94 | else: 95 | raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.") 96 | 97 | def forward(self, x): 98 | for c, a in zip (self.convs, self.activations): 99 | xt = a(x) 100 | xt = c(xt) 101 | x = xt + x 102 | 103 | return x 104 | 105 | def remove_weight_norm(self): 106 | for l in self.convs: 107 | remove_weight_norm(l) 108 | 109 | class Generator(torch.nn.Module): 110 | def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0): 111 | super(Generator, self).__init__() 112 | self.num_kernels = len(resblock_kernel_sizes) 113 | self.num_upsamples = len(upsample_rates) 114 | 115 | self.conv_pre = weight_norm(Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)) 116 | resblock = AMPBlock1 if resblock == '1' else AMPBlock2 117 | 118 | self.ups = nn.ModuleList() 119 | for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): 120 | self.ups.append(weight_norm( 121 | ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)), 122 | k, u, padding=(k-u)//2))) 123 | 124 | self.resblocks = nn.ModuleList() 125 | for i in range(len(self.ups)): 126 | ch = upsample_initial_channel//(2**(i+1)) 127 | for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): 128 | self.resblocks.append(resblock(ch, k, d, activation="snakebeta")) 129 | 130 | activation_post = activations.SnakeBeta(ch, alpha_logscale=True) 131 | self.activation_post = Activation1d(activation=activation_post) 132 | 133 | self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) 134 | self.ups.apply(init_weights) 135 | 136 | if gin_channels != 0: 137 | self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) 138 | 139 | def forward(self, x, g=None): 140 | x = self.conv_pre(x) 141 | if g is not None: 142 | x = x + self.cond(g) 143 | 144 | for i in range(self.num_upsamples): 145 | 146 | x = self.ups[i](x) 147 | xs = None 148 | for j in range(self.num_kernels): 149 | if xs is None: 150 | xs = self.resblocks[i*self.num_kernels+j](x) 151 | else: 152 | xs += self.resblocks[i*self.num_kernels+j](x) 153 | x = xs / self.num_kernels 154 | 155 | x = self.activation_post(x) 156 | x = self.conv_post(x) 157 | x = torch.tanh(x) 158 | 159 | return x 160 | 161 | def remove_weight_norm(self): 162 | print('Removing weight norm...') 163 | for l in self.ups: 164 | remove_weight_norm(l) 165 | for l in self.resblocks: 166 | l.remove_weight_norm() 167 | 168 | class DiscriminatorP(torch.nn.Module): 169 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 170 | super(DiscriminatorP, self).__init__() 171 | self.period = period 172 | self.use_spectral_norm = use_spectral_norm 173 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 174 | self.convs = nn.ModuleList([ 175 | norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 176 | norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 177 | norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 178 | norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 179 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))), 180 | ]) 181 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 182 | 183 | def forward(self, x): 184 | fmap = [] 185 | 186 | b, c, t = x.shape 187 | if t % self.period != 0: 188 | n_pad = self.period - (t % self.period) 189 | x = F.pad(x, (0, n_pad), "reflect") 190 | t = t + n_pad 191 | x = x.view(b, c, t // self.period, self.period) 192 | 193 | for l in self.convs: 194 | x = l(x) 195 | x = F.leaky_relu(x, modules.LRELU_SLOPE) 196 | fmap.append(x) 197 | x = self.conv_post(x) 198 | fmap.append(x) 199 | x = torch.flatten(x, 1, -1) 200 | 201 | return x, fmap 202 | 203 | class DiscriminatorR(torch.nn.Module): 204 | def __init__(self, resolution, use_spectral_norm=False): 205 | super(DiscriminatorR, self).__init__() 206 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 207 | 208 | n_fft, hop_length, win_length = resolution 209 | self.spec_transform = torchaudio.transforms.Spectrogram( 210 | n_fft=n_fft, hop_length=hop_length, win_length=win_length, window_fn=torch.hann_window, 211 | normalized=True, center=False, pad_mode=None, power=None) 212 | 213 | self.convs = nn.ModuleList([ 214 | norm_f(nn.Conv2d(2, 32, (3, 9), padding=(1, 4))), 215 | norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), 216 | norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), dilation=(2,1), padding=(2, 4))), 217 | norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), dilation=(4,1), padding=(4, 4))), 218 | norm_f(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))), 219 | ]) 220 | self.conv_post = norm_f(nn.Conv2d(32, 1, (3, 3), padding=(1, 1))) 221 | 222 | def forward(self, y): 223 | fmap = [] 224 | 225 | x = self.spec_transform(y) 226 | x = torch.cat([x.real, x.imag], dim=1) 227 | x = rearrange(x, 'b c w t -> b c t w') 228 | 229 | for l in self.convs: 230 | x = l(x) 231 | x = F.leaky_relu(x, modules.LRELU_SLOPE) 232 | fmap.append(x) 233 | x = self.conv_post(x) 234 | fmap.append(x) 235 | x = torch.flatten(x, 1, -1) 236 | 237 | return x, fmap 238 | 239 | 240 | class MultiPeriodDiscriminator(torch.nn.Module): 241 | def __init__(self, use_spectral_norm=False): 242 | super(MultiPeriodDiscriminator, self).__init__() 243 | periods = [2,3,5,7,11] 244 | resolutions = [[2048, 512, 2048], [1024, 256, 1024], [512, 128, 512], [256, 64, 256], [128, 32, 128]] 245 | 246 | discs = [DiscriminatorR(resolutions[i], use_spectral_norm=use_spectral_norm) for i in range(len(resolutions))] 247 | discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods] 248 | self.discriminators = nn.ModuleList(discs) 249 | 250 | def forward(self, y, y_hat): 251 | y_d_rs = [] 252 | y_d_gs = [] 253 | fmap_rs = [] 254 | fmap_gs = [] 255 | for i, d in enumerate(self.discriminators): 256 | y_d_r, fmap_r = d(y) 257 | y_d_g, fmap_g = d(y_hat) 258 | y_d_rs.append(y_d_r) 259 | y_d_gs.append(y_d_g) 260 | fmap_rs.append(fmap_r) 261 | fmap_gs.append(fmap_g) 262 | 263 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 264 | 265 | class BigvGAN(nn.Module): 266 | """ 267 | Synthesizer for Training 268 | """ 269 | 270 | def __init__(self, 271 | 272 | spec_channels, 273 | segment_size, 274 | inter_channels, 275 | hidden_channels, 276 | filter_channels, 277 | n_heads, 278 | n_layers, 279 | kernel_size, 280 | p_dropout, 281 | resblock, 282 | resblock_kernel_sizes, 283 | resblock_dilation_sizes, 284 | upsample_rates, 285 | upsample_initial_channel, 286 | upsample_kernel_sizes, 287 | **kwargs): 288 | 289 | super().__init__() 290 | self.spec_channels = spec_channels 291 | self.inter_channels = inter_channels 292 | self.hidden_channels = hidden_channels 293 | self.filter_channels = filter_channels 294 | self.n_heads = n_heads 295 | self.n_layers = n_layers 296 | self.kernel_size = kernel_size 297 | self.p_dropout = p_dropout 298 | self.resblock = resblock 299 | self.resblock_kernel_sizes = resblock_kernel_sizes 300 | self.resblock_dilation_sizes = resblock_dilation_sizes 301 | self.upsample_rates = upsample_rates 302 | self.upsample_initial_channel = upsample_initial_channel 303 | self.upsample_kernel_sizes = upsample_kernel_sizes 304 | self.segment_size = segment_size 305 | 306 | self.dec = Generator(spec_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes) 307 | 308 | def forward(self, x): 309 | 310 | y = self.dec(x) 311 | return y 312 | 313 | def infer(self, x, max_len=None): 314 | 315 | o = self.dec(x[:,:,:max_len]) 316 | return o 317 | 318 | -------------------------------------------------------------------------------- /audio/server.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import time 4 | import torch 5 | import numpy as np 6 | from torch.nn import functional as F 7 | import torchaudio 8 | import amfm_decompy.pYAAPT as pYAAPT 9 | import amfm_decompy.basic_tools as basic 10 | from vocoder.bigvgan import BigvGAN 11 | from vocoder.hifigan import HiFi 12 | from model.diffhiervc import DiffHierVC, Wav2vec2 13 | from utils.utils import MelSpectrogramFixed 14 | import utils.utils as utils 15 | from flask import Flask, request, jsonify 16 | import threading 17 | import queue 18 | from functools import lru_cache 19 | 20 | try: 21 | from torch.cuda.amp import autocast, GradScaler 22 | 23 | amp_available = True 24 | except ImportError: 25 | amp_available = False 26 | 27 | app = Flask(__name__) 28 | 29 | logging.basicConfig( 30 | filename='server_log.txt', 31 | level=logging.INFO, 32 | format='%(asctime)s - %(levelname)s - %(message)s', 33 | filemode='a' 34 | ) 35 | logging.info("Server started with optimized configuration.") 36 | 37 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 38 | if torch.cuda.is_available(): 39 | logging.info(f"CUDA available: {torch.cuda.get_device_name(0)}") 40 | torch.backends.cuda.matmul.allow_tf32 = True 41 | torch.backends.cudnn.allow_tf32 = True 42 | torch.backends.cudnn.benchmark = True 43 | torch.cuda.empty_cache() 44 | 45 | seed = 1234 46 | torch.manual_seed(seed) 47 | if torch.cuda.is_available(): 48 | torch.cuda.manual_seed(seed) 49 | np.random.seed(seed) 50 | 51 | current_vocoder_type = 'bigvgan' 52 | noise_floor_dict = { 53 | 'bigvgan': 0.005, 54 | 'hifigan': 0.0003 55 | } 56 | 57 | 58 | @lru_cache(maxsize=10) 59 | def precompute_target_features(audio_path): 60 | logging.info(f"Precomputing target audio features for: {audio_path}") 61 | target_waveform, sr = torchaudio.load(audio_path) 62 | if sr != 16000: 63 | target_waveform = torchaudio.transforms.Resample(sr, 16000)(target_waveform) 64 | target_waveform = target_waveform.to(device).half() 65 | target_mel = mel_fn(target_waveform) 66 | target_length = torch.LongTensor([target_mel.size(-1)]).to(device) 67 | return target_mel, target_length 68 | 69 | 70 | def load_models(vocoder_type='bigvgan'): 71 | global hps, model, net_v, w2v, mel_fn, target_mel, target_length, current_vocoder_type 72 | current_vocoder_type = vocoder_type 73 | config_path = './ckpt/config_bigvgan.json' if vocoder_type == 'bigvgan' else './ckpt/config.json' 74 | hps = utils.get_hparams_from_file(config_path) 75 | 76 | mel_fn = MelSpectrogramFixed( 77 | sample_rate=hps.data.sampling_rate, 78 | n_fft=hps.data.filter_length, 79 | win_length=hps.data.win_length, 80 | hop_length=hps.data.hop_length, 81 | f_min=hps.data.mel_fmin, 82 | f_max=hps.data.mel_fmax, 83 | n_mels=hps.data.n_mel_channels, 84 | window_fn=torch.hann_window 85 | ).to(device) 86 | 87 | w2v = Wav2vec2().to(device).half() 88 | model = DiffHierVC( 89 | hps.data.n_mel_channels, 90 | hps.diffusion.spk_dim, 91 | hps.diffusion.dec_dim, 92 | hps.diffusion.beta_min, 93 | hps.diffusion.beta_max, 94 | hps 95 | ).to(device).half() 96 | 97 | model.load_state_dict(torch.load('./ckpt/model_diffhier.pth', map_location=device)) 98 | 99 | try: 100 | model = torch.compile(model, mode='reduce-overhead') 101 | except Exception as e: 102 | logging.warning(f"Model compilation failed: {e}") 103 | 104 | model.eval() 105 | torch.set_grad_enabled(False) 106 | 107 | if vocoder_type == 'bigvgan': 108 | net_v = BigvGAN( 109 | hps.data.n_mel_channels, 110 | hps.train.segment_size // hps.data.hop_length, 111 | **hps.model 112 | ).to(device).half() 113 | utils.load_checkpoint('./vocoder/voc_bigvgan.pth', net_v, None) 114 | else: # hifigan 115 | net_v = HiFi( 116 | hps.data.n_mel_channels, 117 | hps.train.segment_size // hps.data.hop_length, 118 | **hps.model 119 | ).to(device).half() 120 | utils.load_checkpoint('./vocoder/voc_hifigan.pth', net_v, None) 121 | 122 | net_v.eval().dec.remove_weight_norm() 123 | 124 | try: 125 | net_v = torch.compile(net_v, mode='reduce-overhead') 126 | except Exception as e: 127 | logging.warning(f"Vocoder compilation failed: {e}") 128 | 129 | current_target_audio_path = './sample/tar_p239_022.wav' 130 | target_mel, target_length = precompute_target_features(current_target_audio_path) 131 | logging.info(f"Models loaded successfully with {vocoder_type} vocoder.") 132 | 133 | 134 | load_models('bigvgan') 135 | 136 | f0_cache = {} 137 | 138 | 139 | def get_yaapt_f0(audio, sr=16000, interp=False): 140 | audio_hash = hash(audio.tobytes()) 141 | if audio_hash in f0_cache: 142 | return f0_cache[audio_hash] 143 | 144 | to_pad = int(20.0 / 1000 * sr) // 2 145 | f0s = [] 146 | for y in audio.astype(np.float64): 147 | y_pad = np.pad(y.squeeze(), (to_pad, to_pad), "constant", constant_values=0) 148 | pitch = pYAAPT.yaapt(basic.SignalObj(y_pad, sr), 149 | **{'frame_length': 20.0, 'frame_space': 5.0, 150 | 'nccf_thresh1': 0.25, 'tda_frame_length': 25.0}) 151 | f0s.append(pitch.samp_interp[None, None, :] if interp else pitch.samp_values[None, None, :]) 152 | 153 | result = np.vstack(f0s) 154 | f0_cache[audio_hash] = result 155 | 156 | if len(f0_cache) > 100: 157 | f0_cache.pop(next(iter(f0_cache))) 158 | 159 | return result 160 | 161 | 162 | def get_adaptive_diff_params(audio_length): 163 | return 15, 15 164 | 165 | 166 | def process_audio(audio_chunk, sr=16000): 167 | start_total = time.time() 168 | 169 | times = { 170 | 'preprocessing': 0, 171 | 'inference': 0, 172 | 'postprocessing': 0 173 | } 174 | 175 | try: 176 | start_preprocessing = time.time() 177 | audio_tensor = torch.from_numpy(audio_chunk.copy()).float().half().unsqueeze(0).to(device) 178 | p_val = (audio_tensor.shape[-1] // 1280 + 1) * 1280 - audio_tensor.shape[-1] 179 | audio_tensor = F.pad(audio_tensor, (0, p_val)).to(device) 180 | 181 | src_mel = mel_fn(audio_tensor) 182 | src_length = torch.LongTensor([src_mel.size(-1)]).to(device) 183 | 184 | w2v_x = w2v(F.pad(audio_tensor, (40, 40), "reflect")) 185 | times['preprocessing'] = time.time() - start_preprocessing 186 | 187 | start_f0 = time.time() 188 | try: 189 | f0 = get_yaapt_f0(audio_tensor.cpu().numpy(), sr) 190 | except Exception as e: 191 | logging.error(f"F0 computation error: {e}") 192 | f0 = np.zeros((1, audio_tensor.shape[-1] // 80), dtype=np.float32) 193 | 194 | f0_x = f0.copy() 195 | f0_x = torch.log(torch.FloatTensor(f0_x + 1)).half().to(device) 196 | 197 | ii = f0 != 0 198 | if np.any(ii): 199 | f0[ii] = (f0[ii] - f0[ii].mean()) / f0[ii].std() 200 | 201 | f0_norm_x = torch.FloatTensor(f0).half().to(device) 202 | 203 | audio_length_sec = audio_tensor.shape[-1] / sr 204 | diffpitch_ts, diffvoice_ts = get_adaptive_diff_params(audio_length_sec) 205 | 206 | start_inference = time.time() 207 | with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.float16): 208 | c = model.infer_vc( 209 | src_mel, w2v_x, f0_norm_x, f0_x, src_length, 210 | target_mel, target_length, 211 | diffpitch_ts=diffpitch_ts, 212 | diffvoice_ts=diffvoice_ts 213 | ) 214 | converted_audio = net_v(c) 215 | 216 | times['inference'] = time.time() - start_inference 217 | 218 | start_postprocessing = time.time() 219 | converted_audio = converted_audio / torch.max(torch.abs(converted_audio)) 220 | noise_floor = noise_floor_dict.get(current_vocoder_type, 0.001) 221 | logging.info(f"Applying noise floor: {noise_floor} for vocoder: {current_vocoder_type}") 222 | converted_audio = torch.where( 223 | torch.abs(converted_audio) < noise_floor, 224 | torch.zeros_like(converted_audio), 225 | converted_audio 226 | ) 227 | times['postprocessing'] = time.time() - start_postprocessing 228 | 229 | total_time = time.time() - start_total 230 | 231 | logging.info(f""" 232 | Performance Breakdown: 233 | - Total Processing Time: {total_time:.3f}s 234 | - Preprocessing Time: {times['preprocessing']:.3f}s 235 | - F0 Computation Time: {time.time() - start_f0:.3f}s 236 | - Inference Time: {times['inference']:.3f}s 237 | - Postprocessing Time: {times['postprocessing']:.3f}s 238 | """) 239 | 240 | return converted_audio.squeeze().detach().cpu().numpy() 241 | 242 | except Exception as e: 243 | logging.error(f"Audio processing error: {e}") 244 | raise 245 | 246 | 247 | PROCESSING_THREADS = 12 248 | processing_queue = queue.Queue(maxsize=25) 249 | 250 | 251 | def worker(): 252 | while True: 253 | task = processing_queue.get() 254 | if task is None: 255 | break 256 | audio_chunk, sr, response_object = task 257 | try: 258 | processed_audio = process_audio(audio_chunk, sr) 259 | response_object['processed_audio'] = processed_audio.tolist() 260 | response_object['success'] = True 261 | except Exception as e: 262 | logging.error(f"Error in audio processing: {e}") 263 | response_object['error'] = str(e) 264 | response_object['success'] = False 265 | finally: 266 | processing_queue.task_done() 267 | 268 | 269 | worker_threads = [] 270 | for i in range(PROCESSING_THREADS): 271 | t = threading.Thread(target=worker, daemon=True) 272 | t.start() 273 | worker_threads.append(t) 274 | 275 | 276 | @app.route('/convert', methods=['POST']) 277 | def convert(): 278 | try: 279 | data = request.get_data() 280 | audio_chunk = np.frombuffer(data, dtype=np.float32) 281 | sr = 16000 282 | if len(audio_chunk) < 1000: 283 | processed_audio = process_audio(audio_chunk, sr) 284 | return jsonify({'processed_audio': processed_audio.tolist()}) 285 | response_object = {'success': False} 286 | if processing_queue.qsize() >= processing_queue.maxsize - 1: 287 | return jsonify({'error': 'Server overloaded, please try again later'}), 503 288 | processing_queue.put((audio_chunk, sr, response_object)) 289 | timeout = 10 290 | start_wait = time.time() 291 | while not response_object.get('success', False) and time.time() - start_wait < timeout: 292 | time.sleep(0.1) 293 | if response_object.get('success', False): 294 | return jsonify({'processed_audio': response_object['processed_audio']}) 295 | elif 'error' in response_object: 296 | return jsonify({'error': response_object['error']}), 500 297 | else: 298 | return jsonify({'error': 'Processing timeout'}), 504 299 | except Exception as e: 300 | logging.error(f"Error processing request: {e}") 301 | return jsonify({'error': str(e)}), 500 302 | 303 | 304 | @app.route('/update_target', methods=['POST']) 305 | def update_target(): 306 | global target_mel, target_length, current_target_audio_path 307 | try: 308 | data = request.get_json() 309 | if "target_filename" in data: 310 | new_target_path = os.path.join("./sample", data["target_filename"]) 311 | if not os.path.exists(new_target_path): 312 | return jsonify({'status': 'error', 'message': f'File not found: {new_target_path}'}), 404 313 | current_target_audio_path = new_target_path 314 | logging.info(f"Target audio path updated to: {current_target_audio_path}") 315 | target_mel, target_length = precompute_target_features(current_target_audio_path) 316 | logging.info("Target audio updated successfully.") 317 | return jsonify({'status': 'success', 'message': 'Target audio updated successfully'}), 200 318 | except Exception as e: 319 | logging.error(f"Error updating target audio: {e}") 320 | return jsonify({'status': 'error', 'message': str(e)}), 500 321 | 322 | 323 | @app.route('/switch_vocoder', methods=['POST']) 324 | def switch_vocoder(): 325 | try: 326 | data = request.get_json() 327 | vocoder_type = data.get('vocoder_type', 'bigvgan') 328 | 329 | if vocoder_type not in ['bigvgan', 'hifigan']: 330 | return jsonify({'status': 'error', 'message': 'Invalid vocoder type'}), 400 331 | 332 | load_models(vocoder_type) 333 | return jsonify({ 334 | 'status': 'success', 335 | 'message': f'Switched to {vocoder_type} vocoder', 336 | 'current_vocoder': current_vocoder_type, 337 | 'noise_floor': noise_floor_dict.get(vocoder_type, 0.007) 338 | }), 200 339 | except Exception as e: 340 | logging.error(f"Error switching vocoder: {e}") 341 | return jsonify({'status': 'error', 'message': str(e)}), 500 342 | 343 | 344 | @app.route('/health', methods=['GET']) 345 | def health(): 346 | return jsonify({"status": "ok"}), 200 347 | 348 | 349 | if __name__ == '__main__': 350 | app.run(host='0.0.0.0', port=5003, debug=False) 351 | 352 | -------------------------------------------------------------------------------- /video/modules/core.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | # single thread doubles cuda performance - needs to be set before torch import 4 | if any(arg.startswith('--execution-provider') for arg in sys.argv): 5 | os.environ['OMP_NUM_THREADS'] = '1' 6 | # reduce tensorflow log level 7 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 8 | import warnings 9 | from typing import List 10 | import platform 11 | import signal 12 | import shutil 13 | import argparse 14 | import torch 15 | import onnxruntime 16 | import tensorflow 17 | 18 | import modules.globals 19 | import modules.metadata 20 | import modules.ui as ui 21 | from modules.processors.frame.core import get_frame_processors_modules 22 | from modules.utilities import has_image_extension, is_image, is_video, detect_fps, create_video, extract_frames, get_temp_frame_paths, restore_audio, create_temp, move_temp, clean_temp, normalize_output_path 23 | 24 | if 'ROCMExecutionProvider' in modules.globals.execution_providers: 25 | del torch 26 | 27 | warnings.filterwarnings('ignore', category=FutureWarning, module='insightface') 28 | warnings.filterwarnings('ignore', category=UserWarning, module='torchvision') 29 | 30 | 31 | def parse_args() -> None: 32 | signal.signal(signal.SIGINT, lambda signal_number, frame: destroy()) 33 | program = argparse.ArgumentParser() 34 | program.add_argument('-s', '--source', help='select an source image', dest='source_path') 35 | program.add_argument('-t', '--target', help='select an target image or video', dest='target_path') 36 | program.add_argument('-o', '--output', help='select output file or directory', dest='output_path') 37 | program.add_argument('--frame-processor', help='pipeline of frame processors', dest='frame_processor', default=['face_swapper'], choices=['face_swapper', 'face_enhancer'], nargs='+') 38 | program.add_argument('--keep-fps', help='keep original fps', dest='keep_fps', action='store_true', default=False) 39 | program.add_argument('--keep-audio', help='keep original audio', dest='keep_audio', action='store_true', default=True) 40 | program.add_argument('--keep-frames', help='keep temporary frames', dest='keep_frames', action='store_true', default=False) 41 | program.add_argument('--many-faces', help='process every face', dest='many_faces', action='store_true', default=False) 42 | program.add_argument('--nsfw-filter', help='filter the NSFW image or video', dest='nsfw_filter', action='store_true', default=False) 43 | program.add_argument('--map-faces', help='map source target faces', dest='map_faces', action='store_true', default=False) 44 | program.add_argument('--mouth-mask', help='mask the mouth region', dest='mouth_mask', action='store_true', default=False) 45 | program.add_argument('--video-encoder', help='adjust output video encoder', dest='video_encoder', default='libx264', choices=['libx264', 'libx265', 'libvpx-vp9']) 46 | program.add_argument('--video-quality', help='adjust output video quality', dest='video_quality', type=int, default=18, choices=range(52), metavar='[0-51]') 47 | program.add_argument('-l', '--lang', help='Ui language', default="en") 48 | program.add_argument('--live-mirror', help='The live camera display as you see it in the front-facing camera frame', dest='live_mirror', action='store_true', default=False) 49 | program.add_argument('--live-resizable', help='The live camera frame is resizable', dest='live_resizable', action='store_true', default=False) 50 | program.add_argument('--max-memory', help='maximum amount of RAM in GB', dest='max_memory', type=int, default=suggest_max_memory()) 51 | program.add_argument('--execution-provider', help='execution provider', dest='execution_provider', default=['cpu'], choices=suggest_execution_providers(), nargs='+') 52 | program.add_argument('--execution-threads', help='number of execution threads', dest='execution_threads', type=int, default=suggest_execution_threads()) 53 | program.add_argument('-v', '--version', action='version', version=f'{modules.metadata.name} {modules.metadata.version}') 54 | 55 | # register deprecated args 56 | program.add_argument('-f', '--face', help=argparse.SUPPRESS, dest='source_path_deprecated') 57 | program.add_argument('--cpu-cores', help=argparse.SUPPRESS, dest='cpu_cores_deprecated', type=int) 58 | program.add_argument('--gpu-vendor', help=argparse.SUPPRESS, dest='gpu_vendor_deprecated') 59 | program.add_argument('--gpu-threads', help=argparse.SUPPRESS, dest='gpu_threads_deprecated', type=int) 60 | 61 | args = program.parse_args() 62 | 63 | modules.globals.source_path = args.source_path 64 | modules.globals.target_path = args.target_path 65 | modules.globals.output_path = normalize_output_path(modules.globals.source_path, modules.globals.target_path, args.output_path) 66 | modules.globals.frame_processors = args.frame_processor 67 | modules.globals.headless = args.source_path or args.target_path or args.output_path 68 | modules.globals.keep_fps = args.keep_fps 69 | modules.globals.keep_audio = args.keep_audio 70 | modules.globals.keep_frames = args.keep_frames 71 | modules.globals.many_faces = args.many_faces 72 | modules.globals.mouth_mask = args.mouth_mask 73 | modules.globals.nsfw_filter = args.nsfw_filter 74 | modules.globals.map_faces = args.map_faces 75 | modules.globals.video_encoder = args.video_encoder 76 | modules.globals.video_quality = args.video_quality 77 | modules.globals.live_mirror = args.live_mirror 78 | modules.globals.live_resizable = args.live_resizable 79 | modules.globals.max_memory = args.max_memory 80 | modules.globals.execution_providers = decode_execution_providers(args.execution_provider) 81 | modules.globals.execution_threads = args.execution_threads 82 | modules.globals.lang = args.lang 83 | 84 | #for ENHANCER tumbler: 85 | if 'face_enhancer' in args.frame_processor: 86 | modules.globals.fp_ui['face_enhancer'] = True 87 | else: 88 | modules.globals.fp_ui['face_enhancer'] = False 89 | 90 | # translate deprecated args 91 | if args.source_path_deprecated: 92 | print('\033[33mArgument -f and --face are deprecated. Use -s and --source instead.\033[0m') 93 | modules.globals.source_path = args.source_path_deprecated 94 | modules.globals.output_path = normalize_output_path(args.source_path_deprecated, modules.globals.target_path, args.output_path) 95 | if args.cpu_cores_deprecated: 96 | print('\033[33mArgument --cpu-cores is deprecated. Use --execution-threads instead.\033[0m') 97 | modules.globals.execution_threads = args.cpu_cores_deprecated 98 | if args.gpu_vendor_deprecated == 'apple': 99 | print('\033[33mArgument --gpu-vendor apple is deprecated. Use --execution-provider coreml instead.\033[0m') 100 | modules.globals.execution_providers = decode_execution_providers(['coreml']) 101 | if args.gpu_vendor_deprecated == 'nvidia': 102 | print('\033[33mArgument --gpu-vendor nvidia is deprecated. Use --execution-provider cuda instead.\033[0m') 103 | modules.globals.execution_providers = decode_execution_providers(['cuda']) 104 | if args.gpu_vendor_deprecated == 'amd': 105 | print('\033[33mArgument --gpu-vendor amd is deprecated. Use --execution-provider cuda instead.\033[0m') 106 | modules.globals.execution_providers = decode_execution_providers(['rocm']) 107 | if args.gpu_threads_deprecated: 108 | print('\033[33mArgument --gpu-threads is deprecated. Use --execution-threads instead.\033[0m') 109 | modules.globals.execution_threads = args.gpu_threads_deprecated 110 | 111 | 112 | def encode_execution_providers(execution_providers: List[str]) -> List[str]: 113 | return [execution_provider.replace('ExecutionProvider', '').lower() for execution_provider in execution_providers] 114 | 115 | 116 | def decode_execution_providers(execution_providers: List[str]) -> List[str]: 117 | return [provider for provider, encoded_execution_provider in zip(onnxruntime.get_available_providers(), encode_execution_providers(onnxruntime.get_available_providers())) 118 | if any(execution_provider in encoded_execution_provider for execution_provider in execution_providers)] 119 | 120 | 121 | def suggest_max_memory() -> int: 122 | if platform.system().lower() == 'darwin': 123 | return 4 124 | return 16 125 | 126 | 127 | def suggest_execution_providers() -> List[str]: 128 | return encode_execution_providers(onnxruntime.get_available_providers()) 129 | 130 | 131 | def suggest_execution_threads() -> int: 132 | if 'DmlExecutionProvider' in modules.globals.execution_providers: 133 | return 1 134 | if 'ROCMExecutionProvider' in modules.globals.execution_providers: 135 | return 1 136 | return 8 137 | 138 | 139 | def limit_resources() -> None: 140 | # prevent tensorflow memory leak 141 | gpus = tensorflow.config.experimental.list_physical_devices('GPU') 142 | for gpu in gpus: 143 | tensorflow.config.experimental.set_memory_growth(gpu, True) 144 | # limit memory usage 145 | if modules.globals.max_memory: 146 | memory = modules.globals.max_memory * 1024 ** 3 147 | if platform.system().lower() == 'darwin': 148 | memory = modules.globals.max_memory * 1024 ** 6 149 | if platform.system().lower() == 'windows': 150 | import ctypes 151 | kernel32 = ctypes.windll.kernel32 152 | kernel32.SetProcessWorkingSetSize(-1, ctypes.c_size_t(memory), ctypes.c_size_t(memory)) 153 | else: 154 | import resource 155 | resource.setrlimit(resource.RLIMIT_DATA, (memory, memory)) 156 | 157 | 158 | def release_resources() -> None: 159 | if 'CUDAExecutionProvider' in modules.globals.execution_providers: 160 | torch.cuda.empty_cache() 161 | 162 | 163 | def pre_check() -> bool: 164 | if sys.version_info < (3, 9): 165 | update_status('Python version is not supported - please upgrade to 3.9 or higher.') 166 | return False 167 | if not shutil.which('ffmpeg'): 168 | update_status('ffmpeg is not installed.') 169 | return False 170 | return True 171 | 172 | 173 | def update_status(message: str, scope: str = 'DLC.CORE') -> None: 174 | print(f'[{scope}] {message}') 175 | if not modules.globals.headless: 176 | ui.update_status(message) 177 | 178 | def start() -> None: 179 | for frame_processor in get_frame_processors_modules(modules.globals.frame_processors): 180 | if not frame_processor.pre_start(): 181 | return 182 | update_status('Processing...') 183 | # process image to image 184 | if has_image_extension(modules.globals.target_path): 185 | if modules.globals.nsfw_filter and ui.check_and_ignore_nsfw(modules.globals.target_path, destroy): 186 | return 187 | try: 188 | shutil.copy2(modules.globals.target_path, modules.globals.output_path) 189 | except Exception as e: 190 | print("Error copying file:", str(e)) 191 | for frame_processor in get_frame_processors_modules(modules.globals.frame_processors): 192 | update_status('Progressing...', frame_processor.NAME) 193 | frame_processor.process_image(modules.globals.source_path, modules.globals.output_path, modules.globals.output_path) 194 | release_resources() 195 | if is_image(modules.globals.target_path): 196 | update_status('Processing to image succeed!') 197 | else: 198 | update_status('Processing to image failed!') 199 | return 200 | # process image to videos 201 | if modules.globals.nsfw_filter and ui.check_and_ignore_nsfw(modules.globals.target_path, destroy): 202 | return 203 | 204 | if not modules.globals.map_faces: 205 | update_status('Creating temp resources...') 206 | create_temp(modules.globals.target_path) 207 | update_status('Extracting frames...') 208 | extract_frames(modules.globals.target_path) 209 | 210 | temp_frame_paths = get_temp_frame_paths(modules.globals.target_path) 211 | for frame_processor in get_frame_processors_modules(modules.globals.frame_processors): 212 | update_status('Progressing...', frame_processor.NAME) 213 | frame_processor.process_video(modules.globals.source_path, temp_frame_paths) 214 | release_resources() 215 | # handles fps 216 | if modules.globals.keep_fps: 217 | update_status('Detecting fps...') 218 | fps = detect_fps(modules.globals.target_path) 219 | update_status(f'Creating video with {fps} fps...') 220 | create_video(modules.globals.target_path, fps) 221 | else: 222 | update_status('Creating video with 30.0 fps...') 223 | create_video(modules.globals.target_path) 224 | # handle audio 225 | if modules.globals.keep_audio: 226 | if modules.globals.keep_fps: 227 | update_status('Restoring audio...') 228 | else: 229 | update_status('Restoring audio might cause issues as fps are not kept...') 230 | restore_audio(modules.globals.target_path, modules.globals.output_path) 231 | else: 232 | move_temp(modules.globals.target_path, modules.globals.output_path) 233 | # clean and validate 234 | clean_temp(modules.globals.target_path) 235 | if is_video(modules.globals.target_path): 236 | update_status('Processing to video succeed!') 237 | else: 238 | update_status('Processing to video failed!') 239 | 240 | 241 | def destroy(to_quit=True) -> None: 242 | if modules.globals.target_path: 243 | clean_temp(modules.globals.target_path) 244 | if to_quit: quit() 245 | 246 | 247 | def run() -> None: 248 | parse_args() 249 | if not pre_check(): 250 | return 251 | for frame_processor in get_frame_processors_modules(modules.globals.frame_processors): 252 | if not frame_processor.pre_check(): 253 | return 254 | limit_resources() 255 | if modules.globals.headless: 256 | start() 257 | else: 258 | window = ui.init(start, destroy, modules.globals.lang) 259 | window.mainloop() 260 | --------------------------------------------------------------------------------