├── .gitignore ├── LICENSE ├── README.md ├── app.py ├── config ├── denoiser.yaml ├── enhancer_stage1.yaml └── enhancer_stage2.yaml ├── packages.txt ├── pyproject.toml ├── requirements.txt ├── resemble_enhance ├── __init__.py ├── common.py ├── data │ ├── __init__.py │ ├── dataset.py │ ├── distorter │ │ ├── __init__.py │ │ ├── base.py │ │ ├── custom.py │ │ ├── distorter.py │ │ └── sox.py │ └── utils.py ├── denoiser │ ├── __init__.py │ ├── __main__.py │ ├── denoiser.py │ ├── hparams.py │ ├── inference.py │ ├── train.py │ └── unet.py ├── enhancer │ ├── __init__.py │ ├── __main__.py │ ├── download.py │ ├── enhancer.py │ ├── hparams.py │ ├── inference.py │ ├── lcfm │ │ ├── __init__.py │ │ ├── cfm.py │ │ ├── irmae.py │ │ ├── lcfm.py │ │ └── wn.py │ ├── train.py │ └── univnet │ │ ├── __init__.py │ │ ├── alias_free_torch │ │ ├── __init__.py │ │ ├── filter.py │ │ └── resample.py │ │ ├── amp.py │ │ ├── discriminator.py │ │ ├── lvcnet.py │ │ ├── mrstft.py │ │ └── univnet.py ├── hparams.py ├── inference.py ├── melspec.py └── utils │ ├── __init__.py │ ├── control.py │ ├── distributed.py │ ├── engine.py │ ├── logging.py │ ├── train_loop.py │ └── utils.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | /data 2 | /runs 3 | /scripts 4 | /dist 5 | /build 6 | /*.egg-info 7 | /flagged 8 | version.py 9 | __pycache__ 10 | model_repo 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Resemble AI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Resemble Enhance 2 | 3 | [![PyPI](https://img.shields.io/pypi/v/resemble-enhance.svg)](https://pypi.org/project/resemble-enhance/) 4 | [![Hugging Face Space](https://img.shields.io/badge/Hugging%20Face%20%F0%9F%A4%97-Space-yellow)](https://huggingface.co/spaces/ResembleAI/resemble-enhance) 5 | [![License](https://img.shields.io/github/license/resemble-ai/Resemble-Enhance.svg)](https://github.com/resemble-ai/resemble-enhance/blob/main/LICENSE) 6 | [![Webpage](https://img.shields.io/badge/Webpage-Online-brightgreen)](https://www.resemble.ai/enhance/) 7 | 8 | https://github.com/resemble-ai/resemble-enhance/assets/660224/bc3ec943-e795-4646-b119-cce327c810f1 9 | 10 | Resemble Enhance is an AI-powered tool that aims to improve the overall quality of speech by performing denoising and enhancement. It consists of two modules: a denoiser, which separates speech from a noisy audio, and an enhancer, which further boosts the perceptual audio quality by restoring audio distortions and extending the audio bandwidth. The two models are trained on high-quality 44.1kHz speech data that guarantees the enhancement of your speech with high quality. 11 | 12 | ## Usage 13 | 14 | ### Installation 15 | 16 | Install the stable version: 17 | 18 | ```bash 19 | pip install resemble-enhance --upgrade 20 | ``` 21 | 22 | Or try the latest pre-release version: 23 | 24 | ```bash 25 | pip install resemble-enhance --upgrade --pre 26 | ``` 27 | 28 | ### Enhance 29 | 30 | ``` 31 | resemble-enhance in_dir out_dir 32 | ``` 33 | 34 | ### Denoise only 35 | 36 | ``` 37 | resemble-enhance in_dir out_dir --denoise_only 38 | ``` 39 | 40 | ### Web Demo 41 | 42 | We provide a web demo built with Gradio, you can try it out [here](https://huggingface.co/spaces/ResembleAI/resemble-enhance), or also run it locally: 43 | 44 | ``` 45 | python app.py 46 | ``` 47 | 48 | ## Train your own model 49 | 50 | ### Data Preparation 51 | 52 | You need to prepare a foreground speech dataset and a background non-speech dataset. In addition, you need to prepare a RIR dataset ([examples](https://github.com/RoyJames/room-impulse-responses)). 53 | 54 | ```bash 55 | data 56 | ├── fg 57 | │   ├── 00001.wav 58 | │   └── ... 59 | ├── bg 60 | │   ├── 00001.wav 61 | │   └── ... 62 | └── rir 63 |    ├── 00001.npy 64 |    └── ... 65 | ``` 66 | 67 | ### Training 68 | 69 | #### Denoiser Warmup 70 | 71 | Though the denoiser is trained jointly with the enhancer, it is recommended for a warmup training first. 72 | 73 | ```bash 74 | python -m resemble_enhance.denoiser.train --yaml config/denoiser.yaml runs/denoiser 75 | ``` 76 | 77 | #### Enhancer 78 | 79 | Then, you can train the enhancer in two stages. The first stage is to train the autoencoder and vocoder. And the second stage is to train the latent conditional flow matching (CFM) model. 80 | 81 | ##### Stage 1 82 | 83 | ```bash 84 | python -m resemble_enhance.enhancer.train --yaml config/enhancer_stage1.yaml runs/enhancer_stage1 85 | ``` 86 | 87 | ##### Stage 2 88 | 89 | ```bash 90 | python -m resemble_enhance.enhancer.train --yaml config/enhancer_stage2.yaml runs/enhancer_stage2 91 | ``` 92 | 93 | ## Blog 94 | 95 | Learn more on our [website](https://www.resemble.ai/enhance/)! 96 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import torch 3 | import torchaudio 4 | 5 | from resemble_enhance.enhancer.inference import denoise, enhance 6 | 7 | if torch.cuda.is_available(): 8 | device = "cuda" 9 | else: 10 | device = "cpu" 11 | 12 | 13 | def _fn(path, solver, nfe, tau, denoising): 14 | if path is None: 15 | return None, None 16 | 17 | solver = solver.lower() 18 | nfe = int(nfe) 19 | lambd = 0.9 if denoising else 0.1 20 | 21 | dwav, sr = torchaudio.load(path) 22 | dwav = dwav.mean(dim=0) 23 | 24 | wav1, new_sr = denoise(dwav, sr, device) 25 | wav2, new_sr = enhance(dwav, sr, device, nfe=nfe, solver=solver, lambd=lambd, tau=tau) 26 | 27 | wav1 = wav1.cpu().numpy() 28 | wav2 = wav2.cpu().numpy() 29 | 30 | return (new_sr, wav1), (new_sr, wav2) 31 | 32 | 33 | def main(): 34 | inputs: list = [ 35 | gr.Audio(type="filepath", label="Input Audio"), 36 | gr.Dropdown(choices=["Midpoint", "RK4", "Euler"], value="Midpoint", label="CFM ODE Solver"), 37 | gr.Slider(minimum=1, maximum=128, value=64, step=1, label="CFM Number of Function Evaluations"), 38 | gr.Slider(minimum=0, maximum=1, value=0.5, step=0.01, label="CFM Prior Temperature"), 39 | gr.Checkbox(value=False, label="Denoise Before Enhancement"), 40 | ] 41 | 42 | outputs: list = [ 43 | gr.Audio(label="Output Denoised Audio"), 44 | gr.Audio(label="Output Enhanced Audio"), 45 | ] 46 | 47 | interface = gr.Interface( 48 | fn=_fn, 49 | title="Resemble Enhance", 50 | description="AI-driven audio enhancement for your audio files, powered by Resemble AI.", 51 | inputs=inputs, 52 | outputs=outputs, 53 | ) 54 | 55 | interface.launch() 56 | 57 | 58 | if __name__ == "__main__": 59 | main() 60 | -------------------------------------------------------------------------------- /config/denoiser.yaml: -------------------------------------------------------------------------------- 1 | batch_size_per_gpu: 32 2 | training_seconds: 3.0 3 | -------------------------------------------------------------------------------- /config/enhancer_stage1.yaml: -------------------------------------------------------------------------------- 1 | lcfm_training_mode: ae 2 | load_fg_only: true 3 | batch_size_per_gpu: 16 4 | denoiser_run_dir: runs/denoiser 5 | -------------------------------------------------------------------------------- /config/enhancer_stage2.yaml: -------------------------------------------------------------------------------- 1 | lcfm_training_mode: cfm 2 | batch_size_per_gpu: 32 3 | training_seconds: 3.0 4 | gan_training_start_step: null 5 | lcfm_z_scale: 6 6 | praat_augment_prob: 0.2 7 | denoiser_run_dir: runs/denoiser 8 | enhancer_stage1_run_dir: runs/enhancer_stage1 9 | -------------------------------------------------------------------------------- /packages.txt: -------------------------------------------------------------------------------- 1 | libsox-dev 2 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 120 3 | target-version = ['py310'] 4 | 5 | [tool.isort] 6 | line_length = 120 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | celluloid>=0.2.0 2 | deepspeed>=0.12.4 3 | librosa>=0.10.1 4 | matplotlib>=3.8.1 5 | numpy>=1.26.2 6 | omegaconf>=2.3.0 7 | pandas>=2.1.3 8 | ptflops>=0.7.1.2 9 | rich>=13.7.0 10 | scipy>=1.11.4 11 | soundfile>=0.12.1 12 | torch>=2.1.1 13 | torchaudio>=2.1.1 14 | torchvision>=0.16.1 15 | tqdm>=4.66.1 16 | resampy>=0.4.2 17 | tabulate>=0.8.10 18 | gradio>=4.8.0 19 | -------------------------------------------------------------------------------- /resemble_enhance/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/resemble-ai/resemble-enhance/8e978149bfe8abab3eb77d965d579a111afdb0ff/resemble_enhance/__init__.py -------------------------------------------------------------------------------- /resemble_enhance/common.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | from torch import Tensor, nn 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | class Normalizer(nn.Module): 10 | def __init__(self, momentum=0.01, eps=1e-9): 11 | super().__init__() 12 | self.momentum = momentum 13 | self.eps = eps 14 | self.running_mean_unsafe: Tensor 15 | self.running_var_unsafe: Tensor 16 | self.register_buffer("running_mean_unsafe", torch.full([], torch.nan)) 17 | self.register_buffer("running_var_unsafe", torch.full([], torch.nan)) 18 | 19 | @property 20 | def started(self): 21 | return not torch.isnan(self.running_mean_unsafe) 22 | 23 | @property 24 | def running_mean(self): 25 | if not self.started: 26 | return torch.zeros_like(self.running_mean_unsafe) 27 | return self.running_mean_unsafe 28 | 29 | @property 30 | def running_std(self): 31 | if not self.started: 32 | return torch.ones_like(self.running_var_unsafe) 33 | return (self.running_var_unsafe + self.eps).sqrt() 34 | 35 | @torch.no_grad() 36 | def _ema(self, a: Tensor, x: Tensor): 37 | return (1 - self.momentum) * a + self.momentum * x 38 | 39 | def update_(self, x): 40 | if not self.started: 41 | self.running_mean_unsafe = x.mean() 42 | self.running_var_unsafe = x.var() 43 | else: 44 | self.running_mean_unsafe = self._ema(self.running_mean_unsafe, x.mean()) 45 | self.running_var_unsafe = self._ema(self.running_var_unsafe, (x - self.running_mean).pow(2).mean()) 46 | 47 | def forward(self, x: Tensor, update=True): 48 | if self.training and update: 49 | self.update_(x) 50 | self.stats = dict(mean=self.running_mean.item(), std=self.running_std.item()) 51 | x = (x - self.running_mean) / self.running_std 52 | return x 53 | 54 | def inverse(self, x: Tensor): 55 | return x * self.running_std + self.running_mean 56 | -------------------------------------------------------------------------------- /resemble_enhance/data/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | 4 | from torch.utils.data import DataLoader 5 | 6 | from ..hparams import HParams 7 | from .dataset import Dataset 8 | from .utils import mix_fg_bg, rglob_audio_files 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def _create_datasets(hp: HParams, mode, val_size=10, seed=123): 14 | paths = rglob_audio_files(hp.fg_dir) 15 | logger.info(f"Found {len(paths)} audio files in {hp.fg_dir}") 16 | 17 | random.Random(seed).shuffle(paths) 18 | train_paths = paths[:-val_size] 19 | val_paths = paths[-val_size:] 20 | 21 | train_ds = Dataset(train_paths, hp, training=True, mode=mode) 22 | val_ds = Dataset(val_paths, hp, training=False, mode=mode) 23 | 24 | logger.info(f"Train set: {len(train_ds)} samples - Val set: {len(val_ds)} samples") 25 | 26 | return train_ds, val_ds 27 | 28 | 29 | def create_dataloaders(hp: HParams, mode): 30 | train_ds, val_ds = _create_datasets(hp=hp, mode=mode) 31 | 32 | train_dl = DataLoader( 33 | train_ds, 34 | batch_size=hp.batch_size_per_gpu, 35 | shuffle=True, 36 | num_workers=hp.nj, 37 | drop_last=True, 38 | collate_fn=train_ds.collate_fn, 39 | ) 40 | val_dl = DataLoader( 41 | val_ds, 42 | batch_size=1, 43 | shuffle=False, 44 | num_workers=hp.nj, 45 | drop_last=False, 46 | collate_fn=val_ds.collate_fn, 47 | ) 48 | return train_dl, val_dl 49 | -------------------------------------------------------------------------------- /resemble_enhance/data/dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import torch 7 | import torchaudio 8 | import torchaudio.functional as AF 9 | from torch.nn.utils.rnn import pad_sequence 10 | from torch.utils.data import Dataset as DatasetBase 11 | 12 | from ..hparams import HParams 13 | from .distorter import Distorter 14 | from .utils import rglob_audio_files 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | def _normalize(x): 20 | return x / (np.abs(x).max() + 1e-7) 21 | 22 | 23 | def _collate(batch, key, tensor=True, pad=True): 24 | l = [d[key] for d in batch] 25 | if l[0] is None: 26 | return None 27 | if tensor: 28 | l = [torch.from_numpy(x) for x in l] 29 | if pad: 30 | assert tensor, "Can't pad non-tensor" 31 | l = pad_sequence(l, batch_first=True) 32 | return l 33 | 34 | 35 | def praat_augment(wav, sr): 36 | try: 37 | import parselmouth 38 | except ImportError: 39 | raise ImportError("Please install parselmouth>=0.5.0 to use Praat augmentation") 40 | # "praat-parselmouth @ git+https://github.com/YannickJadoul/Parselmouth@0bbcca69705ed73322f3712b19d71bb3694b2540", 41 | # https://github.com/YannickJadoul/Parselmouth/issues/68 42 | # note that this function may hang if the praat version is 0.4.3 43 | assert wav.ndim == 1, f"wav.ndim must be 1 but got {wav.ndim}" 44 | sound = parselmouth.Sound(wav, sr) 45 | formant_shift_ratio = random.uniform(1.1, 1.5) 46 | pitch_range_factor = random.uniform(0.5, 2.0) 47 | sound = parselmouth.praat.call(sound, "Change gender", 75, 600, formant_shift_ratio, 0, pitch_range_factor, 1.0) 48 | wav = np.array(sound.values)[0].astype(np.float32) 49 | return wav 50 | 51 | 52 | class Dataset(DatasetBase): 53 | def __init__( 54 | self, 55 | fg_paths: list[Path], 56 | hp: HParams, 57 | training=True, 58 | max_retries=100, 59 | silent_fg_prob=0.01, 60 | mode=False, 61 | ): 62 | super().__init__() 63 | 64 | assert mode in ("enhancer", "denoiser"), f"Invalid mode: {mode}" 65 | 66 | self.hp = hp 67 | self.fg_paths = fg_paths 68 | self.bg_paths = rglob_audio_files(hp.bg_dir) 69 | 70 | if len(self.fg_paths) == 0: 71 | raise ValueError(f"No foreground audio files found in {hp.fg_dir}") 72 | 73 | if len(self.bg_paths) == 0: 74 | raise ValueError(f"No background audio files found in {hp.bg_dir}") 75 | 76 | logger.info(f"Found {len(self.fg_paths)} foreground files and {len(self.bg_paths)} background files") 77 | 78 | self.training = training 79 | self.max_retries = max_retries 80 | self.silent_fg_prob = silent_fg_prob 81 | 82 | self.mode = mode 83 | self.distorter = Distorter(hp, training=training, mode=mode) 84 | 85 | def _load_wav(self, path, length=None, random_crop=True): 86 | wav, sr = torchaudio.load(path) 87 | 88 | wav = AF.resample( 89 | waveform=wav, 90 | orig_freq=sr, 91 | new_freq=self.hp.wav_rate, 92 | lowpass_filter_width=64, 93 | rolloff=0.9475937167399596, 94 | resampling_method="sinc_interp_kaiser", 95 | beta=14.769656459379492, 96 | ) 97 | 98 | wav = wav.float().numpy() 99 | 100 | if wav.ndim == 2: 101 | wav = np.mean(wav, axis=0) 102 | 103 | if length is None and self.training: 104 | length = int(self.hp.training_seconds * self.hp.wav_rate) 105 | 106 | if length is not None: 107 | if random_crop: 108 | start = random.randint(0, max(0, len(wav) - length)) 109 | wav = wav[start : start + length] 110 | else: 111 | wav = wav[:length] 112 | 113 | if length is not None and len(wav) < length: 114 | wav = np.pad(wav, (0, length - len(wav))) 115 | 116 | wav = _normalize(wav) 117 | 118 | return wav 119 | 120 | def _getitem_unsafe(self, index: int): 121 | fg_path = self.fg_paths[index] 122 | 123 | if self.training and random.random() < self.silent_fg_prob: 124 | fg_wav = np.zeros(int(self.hp.training_seconds * self.hp.wav_rate), dtype=np.float32) 125 | else: 126 | fg_wav = self._load_wav(fg_path) 127 | if random.random() < self.hp.praat_augment_prob and self.training: 128 | fg_wav = praat_augment(fg_wav, self.hp.wav_rate) 129 | 130 | if self.hp.load_fg_only: 131 | bg_wav = None 132 | fg_dwav = None 133 | bg_dwav = None 134 | else: 135 | fg_dwav = _normalize(self.distorter(fg_wav, self.hp.wav_rate)).astype(np.float32) 136 | if self.training: 137 | bg_path = random.choice(self.bg_paths) 138 | else: 139 | # Deterministic for validation 140 | bg_path = self.bg_paths[index % len(self.bg_paths)] 141 | bg_wav = self._load_wav(bg_path, length=len(fg_wav), random_crop=self.training) 142 | bg_dwav = _normalize(self.distorter(bg_wav, self.hp.wav_rate)).astype(np.float32) 143 | 144 | return dict( 145 | fg_wav=fg_wav, 146 | bg_wav=bg_wav, 147 | fg_dwav=fg_dwav, 148 | bg_dwav=bg_dwav, 149 | ) 150 | 151 | def __getitem__(self, index: int): 152 | for i in range(self.max_retries): 153 | try: 154 | return self._getitem_unsafe(index) 155 | except Exception as e: 156 | if i == self.max_retries - 1: 157 | raise RuntimeError(f"Failed to load {self.fg_paths[index]} after {self.max_retries} retries") from e 158 | logger.debug(f"Error loading {self.fg_paths[index]}: {e}, skipping") 159 | index = np.random.randint(0, len(self)) 160 | 161 | def __len__(self): 162 | return len(self.fg_paths) 163 | 164 | @staticmethod 165 | def collate_fn(batch): 166 | return dict( 167 | fg_wavs=_collate(batch, "fg_wav"), 168 | bg_wavs=_collate(batch, "bg_wav"), 169 | fg_dwavs=_collate(batch, "fg_dwav"), 170 | bg_dwavs=_collate(batch, "bg_dwav"), 171 | ) 172 | -------------------------------------------------------------------------------- /resemble_enhance/data/distorter/__init__.py: -------------------------------------------------------------------------------- 1 | from .distorter import Distorter 2 | -------------------------------------------------------------------------------- /resemble_enhance/data/distorter/base.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import os 3 | import random 4 | import time 5 | import warnings 6 | 7 | import numpy as np 8 | 9 | _DEBUG = bool(os.environ.get("DEBUG", False)) 10 | 11 | 12 | class Effect: 13 | def apply(self, wav: np.ndarray, sr: int): 14 | """ 15 | Args: 16 | wav: (T) 17 | sr: sample rate 18 | Returns: 19 | wav: (T) with the same sample rate of `sr` 20 | """ 21 | raise NotImplementedError 22 | 23 | def __call__(self, wav: np.ndarray, sr: int): 24 | """ 25 | Args: 26 | wav: (T) 27 | sr: sample rate 28 | Returns: 29 | wav: (T) with the same sample rate of `sr` 30 | """ 31 | assert len(wav.shape) == 1, wav.shape 32 | 33 | if _DEBUG: 34 | start = time.time() 35 | else: 36 | start = None 37 | 38 | shape = wav.shape 39 | assert wav.ndim == 1, f"{self}: Expected wav.ndim == 1, got {wav.ndim}." 40 | wav = self.apply(wav, sr) 41 | assert shape == wav.shape, f"{self}: {shape} != {wav.shape}." 42 | 43 | if start is not None: 44 | end = time.time() 45 | print(f"{self.__class__.__name__}: {end - start:.3f} sec") 46 | 47 | return wav 48 | 49 | 50 | class Chain(Effect): 51 | def __init__(self, *effects): 52 | super().__init__() 53 | 54 | self.effects = effects 55 | 56 | def apply(self, wav, sr): 57 | for effect in self.effects: 58 | wav = effect(wav, sr) 59 | return wav 60 | 61 | 62 | class Maybe(Effect): 63 | def __init__(self, prob, effect): 64 | super().__init__() 65 | 66 | self.prob = prob 67 | self.effect = effect 68 | 69 | if _DEBUG: 70 | warnings.warn("DEBUG mode is on. Maybe -> Must.") 71 | self.prob = 1 72 | 73 | def apply(self, wav, sr): 74 | if random.random() > self.prob: 75 | return wav 76 | return self.effect(wav, sr) 77 | 78 | 79 | class Choice(Effect): 80 | def __init__(self, *effects, **kwargs): 81 | super().__init__() 82 | self.effects = effects 83 | self.kwargs = kwargs 84 | 85 | def apply(self, wav, sr): 86 | return np.random.choice(self.effects, **self.kwargs)(wav, sr) 87 | 88 | 89 | class Permutation(Effect): 90 | def __init__(self, *effects, n: int | None = None): 91 | super().__init__() 92 | self.effects = effects 93 | self.n = n 94 | 95 | def apply(self, wav, sr): 96 | if self.n is None: 97 | n = np.random.binomial(len(self.effects), 0.5) 98 | else: 99 | n = self.n 100 | if n == 0: 101 | return wav 102 | perms = itertools.permutations(self.effects, n) 103 | effects = random.choice(list(perms)) 104 | return Chain(*effects)(wav, sr) 105 | -------------------------------------------------------------------------------- /resemble_enhance/data/distorter/custom.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | from dataclasses import dataclass 4 | from functools import cached_property 5 | from pathlib import Path 6 | 7 | import librosa 8 | import numpy as np 9 | from scipy import signal 10 | 11 | from ..utils import walk_paths 12 | from .base import Effect 13 | 14 | _logger = logging.getLogger(__name__) 15 | 16 | 17 | @dataclass 18 | class RandomRIR(Effect): 19 | rir_dir: Path | None 20 | rir_rate: int = 44_000 21 | rir_suffix: str = ".npy" 22 | deterministic: bool = False 23 | 24 | @cached_property 25 | def rir_paths(self): 26 | if self.rir_dir is None: 27 | return [] 28 | return list(walk_paths(self.rir_dir, self.rir_suffix)) 29 | 30 | def _sample_rir(self): 31 | if len(self.rir_paths) == 0: 32 | return None 33 | 34 | if self.deterministic: 35 | rir_path = self.rir_paths[0] 36 | else: 37 | rir_path = random.choice(self.rir_paths) 38 | 39 | rir = np.squeeze(np.load(rir_path)) 40 | assert isinstance(rir, np.ndarray) 41 | 42 | return rir 43 | 44 | def apply(self, wav, sr): 45 | # ref: https://github.com/haoheliu/voicefixer_main/blob/b06e07c945ac1d309b8a57ddcd599ca376b98cd9/dataloaders/augmentation/magical_effects.py#L158 46 | 47 | if len(self.rir_paths) == 0: 48 | return wav 49 | 50 | length = len(wav) 51 | 52 | wav = librosa.resample(wav, orig_sr=sr, target_sr=self.rir_rate, res_type="kaiser_fast") 53 | rir = self._sample_rir() 54 | 55 | wav = signal.convolve(wav, rir, mode="same") 56 | 57 | actlev = np.max(np.abs(wav)) 58 | if actlev > 0.99: 59 | wav = (wav / actlev) * 0.98 60 | 61 | wav = librosa.resample(wav, orig_sr=self.rir_rate, target_sr=sr, res_type="kaiser_fast") 62 | 63 | if abs(length - len(wav)) > 10: 64 | _logger.warning(f"length mismatch: {length} vs {len(wav)}") 65 | 66 | if length > len(wav): 67 | wav = np.pad(wav, (0, length - len(wav))) 68 | elif length < len(wav): 69 | wav = wav[:length] 70 | 71 | return wav 72 | 73 | 74 | class RandomGaussianNoise(Effect): 75 | def __init__(self, alpha_range=(0.8, 1)): 76 | super().__init__() 77 | self.alpha_range = alpha_range 78 | 79 | def apply(self, wav, sr): 80 | noise = np.random.randn(*wav.shape) 81 | noise_energy = np.sum(noise**2) 82 | wav_energy = np.sum(wav**2) 83 | noise = noise * np.sqrt(wav_energy / noise_energy) 84 | alpha = random.uniform(*self.alpha_range) 85 | return wav * alpha + noise * (1 - alpha) 86 | -------------------------------------------------------------------------------- /resemble_enhance/data/distorter/distorter.py: -------------------------------------------------------------------------------- 1 | from ...hparams import HParams 2 | from .base import Chain, Choice, Permutation 3 | from .custom import RandomGaussianNoise, RandomRIR 4 | 5 | 6 | class Distorter(Chain): 7 | def __init__(self, hp: HParams, training: bool = False, mode: str = "enhancer"): 8 | # Lazy import 9 | from .sox import RandomBandpassDistorter, RandomEqualizer, RandomLowpassDistorter, RandomOverdrive, RandomReverb 10 | 11 | if training: 12 | permutation = Permutation( 13 | RandomRIR(hp.rir_dir), 14 | RandomReverb(), 15 | RandomGaussianNoise(), 16 | RandomOverdrive(), 17 | RandomEqualizer(), 18 | Choice( 19 | RandomLowpassDistorter(), 20 | RandomBandpassDistorter(), 21 | ), 22 | ) 23 | if mode == "denoiser": 24 | super().__init__(permutation) 25 | else: 26 | # 80%: distortion, 20%: clean 27 | super().__init__(Choice(permutation, Chain(), p=[0.8, 0.2])) 28 | else: 29 | super().__init__( 30 | RandomRIR(hp.rir_dir, deterministic=True), 31 | RandomReverb(deterministic=True), 32 | ) 33 | -------------------------------------------------------------------------------- /resemble_enhance/data/distorter/sox.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import random 4 | import warnings 5 | from functools import partial 6 | 7 | import numpy as np 8 | import torch 9 | 10 | try: 11 | import augment 12 | except ImportError: 13 | raise ImportError( 14 | "augment is not installed, please install it first using:" 15 | "\npip install git+https://github.com/facebookresearch/WavAugment@54afcdb00ccc852c2f030f239f8532c9562b550e" 16 | ) 17 | 18 | from .base import Effect 19 | 20 | _logger = logging.getLogger(__name__) 21 | _DEBUG = bool(os.environ.get("DEBUG", False)) 22 | 23 | 24 | class AttachableEffect(Effect): 25 | def attach(self, chain: augment.EffectChain) -> augment.EffectChain: 26 | raise NotImplementedError 27 | 28 | def apply(self, wav: np.ndarray, sr: int): 29 | chain = augment.EffectChain() 30 | chain = self.attach(chain) 31 | tensor = torch.from_numpy(wav)[None].float() # (1, T) 32 | tensor = chain.apply(tensor, src_info={"rate": sr}, target_info={"channels": 1, "rate": sr}) 33 | wav = tensor.numpy()[0] # (T,) 34 | return wav 35 | 36 | 37 | class SoxEffect(AttachableEffect): 38 | def __init__(self, effect_name: str, *args, **kwargs): 39 | self.effect_name = effect_name 40 | self.args = args 41 | self.kwargs = kwargs 42 | 43 | def attach(self, chain: augment.EffectChain) -> augment.EffectChain: 44 | _logger.debug(f"Attaching {self.effect_name} with {self.args} and {self.kwargs}") 45 | if not hasattr(chain, self.effect_name): 46 | raise ValueError(f"EffectChain has no attribute {self.effect_name}") 47 | return getattr(chain, self.effect_name)(*self.args, **self.kwargs) 48 | 49 | 50 | class Maybe(AttachableEffect): 51 | """ 52 | Attach an effect with a probability. 53 | """ 54 | 55 | def __init__(self, prob: float, effect: AttachableEffect): 56 | self.prob = prob 57 | self.effect = effect 58 | if _DEBUG: 59 | warnings.warn("DEBUG mode is on. Maybe -> Must.") 60 | self.prob = 1 61 | 62 | def attach(self, chain: augment.EffectChain) -> augment.EffectChain: 63 | if random.random() > self.prob: 64 | return chain 65 | return self.effect.attach(chain) 66 | 67 | 68 | class Chain(AttachableEffect): 69 | """ 70 | Attach a chain of effects. 71 | """ 72 | 73 | def __init__(self, *effects: AttachableEffect): 74 | self.effects = effects 75 | 76 | def attach(self, chain: augment.EffectChain) -> augment.EffectChain: 77 | for effect in self.effects: 78 | chain = effect.attach(chain) 79 | return chain 80 | 81 | 82 | class Choice(AttachableEffect): 83 | """ 84 | Attach one of the effects randomly. 85 | """ 86 | 87 | def __init__(self, *effects: AttachableEffect): 88 | self.effects = effects 89 | 90 | def attach(self, chain: augment.EffectChain) -> augment.EffectChain: 91 | return random.choice(self.effects).attach(chain) 92 | 93 | 94 | class Generator: 95 | def __call__(self) -> str: 96 | raise NotImplementedError 97 | 98 | 99 | class Uniform(Generator): 100 | def __init__(self, low, high): 101 | self.low = low 102 | self.high = high 103 | 104 | def __call__(self) -> str: 105 | return str(random.uniform(self.low, self.high)) 106 | 107 | 108 | class Randint(Generator): 109 | def __init__(self, low, high): 110 | self.low = low 111 | self.high = high 112 | 113 | def __call__(self) -> str: 114 | return str(random.randint(self.low, self.high)) 115 | 116 | 117 | class Concat(Generator): 118 | def __init__(self, *parts: Generator | str): 119 | self.parts = parts 120 | 121 | def __call__(self): 122 | return "".join([part if isinstance(part, str) else part() for part in self.parts]) 123 | 124 | 125 | class RandomLowpassDistorter(SoxEffect): 126 | def __init__(self, low=2000, high=16000): 127 | super().__init__("sinc", "-n", Randint(50, 200), Concat("-", Uniform(low, high))) 128 | 129 | 130 | class RandomBandpassDistorter(SoxEffect): 131 | def __init__(self, low=100, high=1000, min_width=2000, max_width=4000): 132 | super().__init__("sinc", "-n", Randint(50, 200), partial(self._fn, low, high, min_width, max_width)) 133 | 134 | @staticmethod 135 | def _fn(low, high, min_width, max_width): 136 | start = random.randint(low, high) 137 | stop = start + random.randint(min_width, max_width) 138 | return f"{start}-{stop}" 139 | 140 | 141 | class RandomEqualizer(SoxEffect): 142 | def __init__(self, low=100, high=4000, q_low=1, q_high=5, db_low: int = -30, db_high: int = 30): 143 | super().__init__( 144 | "equalizer", 145 | Uniform(low, high), 146 | lambda: f"{random.randint(q_low, q_high)}q", 147 | lambda: random.randint(db_low, db_high), 148 | ) 149 | 150 | 151 | class RandomOverdrive(SoxEffect): 152 | def __init__(self, gain_low=5, gain_high=40, colour_low=20, colour_high=80): 153 | super().__init__("overdrive", Uniform(gain_low, gain_high), Uniform(colour_low, colour_high)) 154 | 155 | 156 | class RandomReverb(Chain): 157 | def __init__(self, deterministic=False): 158 | super().__init__( 159 | SoxEffect( 160 | "reverb", 161 | Uniform(50, 50) if deterministic else Uniform(0, 100), 162 | Uniform(50, 50) if deterministic else Uniform(0, 100), 163 | Uniform(50, 50) if deterministic else Uniform(0, 100), 164 | ), 165 | SoxEffect("channels", 1), 166 | ) 167 | 168 | 169 | class Flanger(SoxEffect): 170 | def __init__(self): 171 | super().__init__("flanger") 172 | 173 | 174 | class Phaser(SoxEffect): 175 | def __init__(self): 176 | super().__init__("phaser") 177 | -------------------------------------------------------------------------------- /resemble_enhance/data/utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Callable 3 | 4 | from torch import Tensor 5 | 6 | 7 | def walk_paths(root, suffix): 8 | for path in Path(root).iterdir(): 9 | if path.is_dir(): 10 | yield from walk_paths(path, suffix) 11 | elif path.suffix == suffix: 12 | yield path 13 | 14 | 15 | def rglob_audio_files(path: Path): 16 | return list(walk_paths(path, ".wav")) + list(walk_paths(path, ".flac")) 17 | 18 | 19 | def mix_fg_bg(fg: Tensor, bg: Tensor, alpha: float | Callable[..., float] = 0.5, eps=1e-7): 20 | """ 21 | Args: 22 | fg: (b, t) 23 | bg: (b, t) 24 | """ 25 | assert bg.shape == fg.shape, f"bg.shape != fg.shape: {bg.shape} != {fg.shape}" 26 | fg = fg / (fg.abs().max(dim=-1, keepdim=True).values + eps) 27 | bg = bg / (bg.abs().max(dim=-1, keepdim=True).values + eps) 28 | 29 | fg_energy = fg.pow(2).sum(dim=-1, keepdim=True) 30 | bg_energy = bg.pow(2).sum(dim=-1, keepdim=True) 31 | 32 | fg = fg / (fg_energy + eps).sqrt() 33 | bg = bg / (bg_energy + eps).sqrt() 34 | 35 | if callable(alpha): 36 | alpha = alpha() 37 | 38 | assert 0 <= alpha <= 1, f"alpha must be between 0 and 1: {alpha}" 39 | 40 | mx = alpha * fg + (1 - alpha) * bg 41 | mx = mx / (mx.abs().max(dim=-1, keepdim=True).values + eps) 42 | 43 | return mx 44 | -------------------------------------------------------------------------------- /resemble_enhance/denoiser/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/resemble-ai/resemble-enhance/8e978149bfe8abab3eb77d965d579a111afdb0ff/resemble_enhance/denoiser/__init__.py -------------------------------------------------------------------------------- /resemble_enhance/denoiser/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import torch 5 | import torchaudio 6 | 7 | from .inference import denoise 8 | 9 | 10 | @torch.inference_mode() 11 | def main(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("in_dir", type=Path, help="Path to input audio folder") 14 | parser.add_argument("out_dir", type=Path, help="Output folder") 15 | parser.add_argument("--run_dir", type=Path, default="runs/denoiser", help="Path to run folder") 16 | parser.add_argument("--suffix", type=str, default=".wav", help="File suffix") 17 | parser.add_argument("--device", type=str, default="cuda", help="Device") 18 | args = parser.parse_args() 19 | 20 | for path in args.in_dir.glob(f"**/*{args.suffix}"): 21 | print(f"Processing {path} ..") 22 | dwav, sr = torchaudio.load(path) 23 | hwav, sr = denoise(dwav[0], sr, args.run_dir, args.device) 24 | out_path = args.out_dir / path.relative_to(args.in_dir) 25 | out_path.parent.mkdir(parents=True, exist_ok=True) 26 | torchaudio.save(out_path, hwav[None], sr) 27 | 28 | 29 | if __name__ == "__main__": 30 | main() 31 | -------------------------------------------------------------------------------- /resemble_enhance/denoiser/denoiser.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import Tensor, nn 6 | 7 | from ..melspec import MelSpectrogram 8 | from .hparams import HParams 9 | from .unet import UNet 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def _normalize(x: Tensor) -> Tensor: 15 | return x / (x.abs().max(dim=-1, keepdim=True).values + 1e-7) 16 | 17 | 18 | class Denoiser(nn.Module): 19 | @property 20 | def stft_cfg(self) -> dict: 21 | hop_size = self.hp.hop_size 22 | return dict(hop_length=hop_size, n_fft=hop_size * 4, win_length=hop_size * 4) 23 | 24 | @property 25 | def n_fft(self): 26 | return self.stft_cfg["n_fft"] 27 | 28 | @property 29 | def eps(self): 30 | return 1e-7 31 | 32 | def __init__(self, hp: HParams): 33 | super().__init__() 34 | self.hp = hp 35 | self.net = UNet(input_dim=3, output_dim=3) 36 | self.mel_fn = MelSpectrogram(hp) 37 | 38 | self.dummy: Tensor 39 | self.register_buffer("dummy", torch.zeros(1), persistent=False) 40 | 41 | def to_mel(self, x: Tensor, drop_last=True): 42 | """ 43 | Args: 44 | x: (b t), wavs 45 | Returns: 46 | o: (b c t), mels 47 | """ 48 | if drop_last: 49 | return self.mel_fn(x)[..., :-1] # (b d t) 50 | return self.mel_fn(x) 51 | 52 | def _stft(self, x): 53 | """ 54 | Args: 55 | x: (b t) 56 | Returns: 57 | mag: (b f t) in [0, inf) 58 | cos: (b f t) in [-1, 1] 59 | sin: (b f t) in [-1, 1] 60 | """ 61 | dtype = x.dtype 62 | device = x.device 63 | 64 | if x.is_mps: 65 | x = x.cpu() 66 | 67 | window = torch.hann_window(self.stft_cfg["win_length"], device=x.device) 68 | s = torch.stft(x.float(), **self.stft_cfg, window=window, return_complex=True) # (b f t+1) 69 | 70 | s = s[..., :-1] # (b f t) 71 | 72 | mag = s.abs() # (b f t) 73 | 74 | phi = s.angle() # (b f t) 75 | cos = phi.cos() # (b f t) 76 | sin = phi.sin() # (b f t) 77 | 78 | mag = mag.to(dtype=dtype, device=device) 79 | cos = cos.to(dtype=dtype, device=device) 80 | sin = sin.to(dtype=dtype, device=device) 81 | 82 | return mag, cos, sin 83 | 84 | def _istft(self, mag: Tensor, cos: Tensor, sin: Tensor): 85 | """ 86 | Args: 87 | mag: (b f t) in [0, inf) 88 | cos: (b f t) in [-1, 1] 89 | sin: (b f t) in [-1, 1] 90 | Returns: 91 | x: (b t) 92 | """ 93 | device = mag.device 94 | dtype = mag.dtype 95 | 96 | if mag.is_mps: 97 | mag = mag.cpu() 98 | cos = cos.cpu() 99 | sin = sin.cpu() 100 | 101 | real = mag * cos # (b f t) 102 | imag = mag * sin # (b f t) 103 | 104 | s = torch.complex(real, imag) # (b f t) 105 | 106 | if s.isnan().any(): 107 | logger.warning("NaN detected in ISTFT input.") 108 | 109 | s = F.pad(s, (0, 1), "replicate") # (b f t+1) 110 | 111 | window = torch.hann_window(self.stft_cfg["win_length"], device=s.device) 112 | x = torch.istft(s, **self.stft_cfg, window=window, return_complex=False) 113 | 114 | if x.isnan().any(): 115 | logger.warning("NaN detected in ISTFT output, set to zero.") 116 | x = torch.where(x.isnan(), torch.zeros_like(x), x) 117 | 118 | x = x.to(dtype=dtype, device=device) 119 | 120 | return x 121 | 122 | def _magphase(self, real, imag): 123 | mag = (real.pow(2) + imag.pow(2) + self.eps).sqrt() 124 | cos = real / mag 125 | sin = imag / mag 126 | return mag, cos, sin 127 | 128 | def _predict(self, mag: Tensor, cos: Tensor, sin: Tensor): 129 | """ 130 | Args: 131 | mag: (b f t) 132 | cos: (b f t) 133 | sin: (b f t) 134 | Returns: 135 | mag_mask: (b f t) in [0, 1], magnitude mask 136 | cos_res: (b f t) in [-1, 1], phase residual 137 | sin_res: (b f t) in [-1, 1], phase residual 138 | """ 139 | x = torch.stack([mag, cos, sin], dim=1) # (b 3 f t) 140 | mag_mask, real, imag = self.net(x).unbind(1) # (b 3 f t) 141 | mag_mask = mag_mask.sigmoid() # (b f t) 142 | real = real.tanh() # (b f t) 143 | imag = imag.tanh() # (b f t) 144 | _, cos_res, sin_res = self._magphase(real, imag) # (b f t) 145 | return mag_mask, sin_res, cos_res 146 | 147 | def _separate(self, mag, cos, sin, mag_mask, cos_res, sin_res): 148 | """Ref: https://audio-agi.github.io/Separate-Anything-You-Describe/AudioSep_arXiv.pdf""" 149 | sep_mag = F.relu(mag * mag_mask) 150 | sep_cos = cos * cos_res - sin * sin_res 151 | sep_sin = sin * cos_res + cos * sin_res 152 | return sep_mag, sep_cos, sep_sin 153 | 154 | def forward(self, x: Tensor, y: Tensor | None = None): 155 | """ 156 | Args: 157 | x: (b t), a mixed audio 158 | y: (b t), a fg audio 159 | """ 160 | assert x.dim() == 2, f"Expected (b t), got {x.size()}" 161 | x = x.to(self.dummy) 162 | x = _normalize(x) 163 | 164 | if y is not None: 165 | assert y.dim() == 2, f"Expected (b t), got {y.size()}" 166 | y = y.to(self.dummy) 167 | y = _normalize(y) 168 | 169 | mag, cos, sin = self._stft(x) # (b 2f t) 170 | mag_mask, sin_res, cos_res = self._predict(mag, cos, sin) 171 | sep_mag, sep_cos, sep_sin = self._separate(mag, cos, sin, mag_mask, cos_res, sin_res) 172 | 173 | o = self._istft(sep_mag, sep_cos, sep_sin) 174 | 175 | npad = x.shape[-1] - o.shape[-1] 176 | o = F.pad(o, (0, npad)) 177 | 178 | if y is not None: 179 | self.losses = dict(l1=F.l1_loss(o, y)) 180 | 181 | return o 182 | -------------------------------------------------------------------------------- /resemble_enhance/denoiser/hparams.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from ..hparams import HParams as HParamsBase 4 | 5 | 6 | @dataclass(frozen=True) 7 | class HParams(HParamsBase): 8 | batch_size_per_gpu: int = 128 9 | distort_prob: float = 0.5 10 | -------------------------------------------------------------------------------- /resemble_enhance/denoiser/inference.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from functools import cache 3 | 4 | import torch 5 | 6 | from ..inference import inference 7 | from .train import Denoiser, HParams 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | @cache 13 | def load_denoiser(run_dir, device): 14 | if run_dir is None: 15 | return Denoiser(HParams()) 16 | hp = HParams.load(run_dir) 17 | denoiser = Denoiser(hp) 18 | path = run_dir / "ds" / "G" / "default" / "mp_rank_00_model_states.pt" 19 | state_dict = torch.load(path, map_location="cpu")["module"] 20 | denoiser.load_state_dict(state_dict) 21 | denoiser.eval() 22 | denoiser.to(device) 23 | return denoiser 24 | 25 | 26 | @torch.inference_mode() 27 | def denoise(dwav, sr, run_dir, device): 28 | denoiser = load_denoiser(run_dir, device) 29 | return inference(model=denoiser, dwav=dwav, sr=sr, device=device) 30 | -------------------------------------------------------------------------------- /resemble_enhance/denoiser/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | from functools import partial 4 | from pathlib import Path 5 | 6 | import soundfile 7 | import torch 8 | from deepspeed import DeepSpeedConfig 9 | from torch import Tensor 10 | from tqdm import tqdm 11 | 12 | from ..data import create_dataloaders, mix_fg_bg 13 | from ..utils import Engine, TrainLoop, save_mels, setup_logging, tree_map 14 | from ..utils.distributed import is_local_leader 15 | from .denoiser import Denoiser 16 | from .hparams import HParams 17 | 18 | 19 | def load_G(run_dir: Path, hp: HParams | None = None, training=True): 20 | if hp is None: 21 | hp = HParams.load(run_dir) 22 | assert isinstance(hp, HParams) 23 | model = Denoiser(hp) 24 | engine = Engine(model=model, config_class=DeepSpeedConfig(hp.deepspeed_config), ckpt_dir=run_dir / "ds" / "G") 25 | if training: 26 | engine.load_checkpoint() 27 | else: 28 | engine.load_checkpoint(load_optimizer_states=False, load_lr_scheduler_states=False) 29 | return engine 30 | 31 | 32 | def save_wav(path: Path, wav: Tensor, rate: int): 33 | wav = wav.detach().cpu().numpy() 34 | soundfile.write(path, wav, samplerate=rate) 35 | 36 | 37 | def main(): 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument("run_dir", type=Path) 40 | parser.add_argument("--yaml", type=Path, default=None) 41 | parser.add_argument("--device", type=str, default="cuda") 42 | args = parser.parse_args() 43 | 44 | setup_logging(args.run_dir) 45 | hp = HParams.load(args.run_dir, yaml=args.yaml) 46 | 47 | if is_local_leader(): 48 | hp.save_if_not_exists(args.run_dir) 49 | hp.print() 50 | 51 | train_dl, val_dl = create_dataloaders(hp, mode="denoiser") 52 | 53 | def feed_G(engine: Engine, batch: dict[str, Tensor]): 54 | alpha_fn = lambda: random.uniform(*hp.mix_alpha_range) 55 | if random.random() < hp.distort_prob: 56 | fg_wavs = batch["fg_dwavs"] 57 | else: 58 | fg_wavs = batch["fg_wavs"] 59 | mx_dwavs = mix_fg_bg(fg_wavs, batch["bg_dwavs"], alpha=alpha_fn) 60 | pred = engine(mx_dwavs, fg_wavs) 61 | losses = engine.gather_attribute("losses", prefix="losses") 62 | return pred, losses 63 | 64 | @torch.no_grad() 65 | def eval_fn(engine: Engine, eval_dir, n_saved=10): 66 | model = engine.module 67 | model.eval() 68 | 69 | step = engine.global_step 70 | 71 | for i, batch in enumerate(tqdm(val_dl), 1): 72 | batch = tree_map(lambda x: x.to(args.device) if isinstance(x, Tensor) else x, batch) 73 | 74 | fg_dwavs = batch["fg_dwavs"] # 1 t 75 | mx_dwavs = mix_fg_bg(fg_dwavs, batch["bg_dwavs"]) 76 | pred_fg_dwavs = model(mx_dwavs) # 1 t 77 | 78 | mx_mels = model.to_mel(mx_dwavs) # 1 c t 79 | fg_mels = model.to_mel(fg_dwavs) # 1 c t 80 | pred_fg_mels = model.to_mel(pred_fg_dwavs) # 1 c t 81 | 82 | rate = model.hp.wav_rate 83 | get_path = lambda suffix: eval_dir / f"step_{step:08}_{i:03}{suffix}" 84 | 85 | save_wav(get_path("_input.wav"), mx_dwavs[0], rate=rate) 86 | save_wav(get_path("_predict.wav"), pred_fg_dwavs[0], rate=rate) 87 | save_wav(get_path("_target.wav"), fg_dwavs[0], rate=rate) 88 | 89 | save_mels( 90 | get_path(".png"), 91 | cond_mel=mx_mels[0].cpu().numpy(), 92 | pred_mel=pred_fg_mels[0].cpu().numpy(), 93 | targ_mel=fg_mels[0].cpu().numpy(), 94 | ) 95 | 96 | if i >= n_saved: 97 | break 98 | 99 | train_loop = TrainLoop( 100 | run_dir=args.run_dir, 101 | train_dl=train_dl, 102 | load_G=partial(load_G, hp=hp), 103 | device=args.device, 104 | feed_G=feed_G, 105 | eval_fn=eval_fn, 106 | ) 107 | 108 | train_loop.run(max_steps=hp.max_steps) 109 | 110 | 111 | if __name__ == "__main__": 112 | main() 113 | -------------------------------------------------------------------------------- /resemble_enhance/denoiser/unet.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from torch import nn 3 | 4 | 5 | class PreactResBlock(nn.Sequential): 6 | def __init__(self, dim): 7 | super().__init__( 8 | nn.GroupNorm(dim // 16, dim), 9 | nn.GELU(), 10 | nn.Conv2d(dim, dim, 3, padding=1), 11 | nn.GroupNorm(dim // 16, dim), 12 | nn.GELU(), 13 | nn.Conv2d(dim, dim, 3, padding=1), 14 | ) 15 | 16 | def forward(self, x): 17 | return x + super().forward(x) 18 | 19 | 20 | class UNetBlock(nn.Module): 21 | def __init__(self, input_dim, output_dim=None, scale_factor=1.0): 22 | super().__init__() 23 | if output_dim is None: 24 | output_dim = input_dim 25 | self.pre_conv = nn.Conv2d(input_dim, output_dim, 3, padding=1) 26 | self.res_block1 = PreactResBlock(output_dim) 27 | self.res_block2 = PreactResBlock(output_dim) 28 | self.downsample = self.upsample = nn.Identity() 29 | if scale_factor > 1: 30 | self.upsample = nn.Upsample(scale_factor=scale_factor) 31 | elif scale_factor < 1: 32 | self.downsample = nn.Upsample(scale_factor=scale_factor) 33 | 34 | def forward(self, x, h=None): 35 | """ 36 | Args: 37 | x: (b c h w), last output 38 | h: (b c h w), skip output 39 | Returns: 40 | o: (b c h w), output 41 | s: (b c h w), skip output 42 | """ 43 | x = self.upsample(x) 44 | if h is not None: 45 | assert x.shape == h.shape, f"{x.shape} != {h.shape}" 46 | x = x + h 47 | x = self.pre_conv(x) 48 | x = self.res_block1(x) 49 | x = self.res_block2(x) 50 | return self.downsample(x), x 51 | 52 | 53 | class UNet(nn.Module): 54 | def __init__(self, input_dim, output_dim, hidden_dim=16, num_blocks=4, num_middle_blocks=2): 55 | super().__init__() 56 | self.input_dim = input_dim 57 | self.output_dim = output_dim 58 | self.input_proj = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 59 | self.encoder_blocks = nn.ModuleList( 60 | [ 61 | UNetBlock(input_dim=hidden_dim * 2**i, output_dim=hidden_dim * 2 ** (i + 1), scale_factor=0.5) 62 | for i in range(num_blocks) 63 | ] 64 | ) 65 | self.middle_blocks = nn.ModuleList( 66 | [UNetBlock(input_dim=hidden_dim * 2**num_blocks) for _ in range(num_middle_blocks)] 67 | ) 68 | self.decoder_blocks = nn.ModuleList( 69 | [ 70 | UNetBlock(input_dim=hidden_dim * 2 ** (i + 1), output_dim=hidden_dim * 2**i, scale_factor=2) 71 | for i in reversed(range(num_blocks)) 72 | ] 73 | ) 74 | self.head = nn.Sequential( 75 | nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), 76 | nn.GELU(), 77 | nn.Conv2d(hidden_dim, output_dim, 1), 78 | ) 79 | 80 | @property 81 | def scale_factor(self): 82 | return 2 ** len(self.encoder_blocks) 83 | 84 | def pad_to_fit(self, x): 85 | """ 86 | Args: 87 | x: (b c h w), input 88 | Returns: 89 | x: (b c h' w'), padded input 90 | """ 91 | hpad = (self.scale_factor - x.shape[2] % self.scale_factor) % self.scale_factor 92 | wpad = (self.scale_factor - x.shape[3] % self.scale_factor) % self.scale_factor 93 | return F.pad(x, (0, wpad, 0, hpad)) 94 | 95 | def forward(self, x): 96 | """ 97 | Args: 98 | x: (b c h w), input 99 | Returns: 100 | o: (b c h w), output 101 | """ 102 | shape = x.shape 103 | 104 | x = self.pad_to_fit(x) 105 | x = self.input_proj(x) 106 | 107 | s_list = [] 108 | for block in self.encoder_blocks: 109 | x, s = block(x) 110 | s_list.append(s) 111 | 112 | for block in self.middle_blocks: 113 | x, _ = block(x) 114 | 115 | for block, s in zip(self.decoder_blocks, reversed(s_list)): 116 | x, _ = block(x, s) 117 | 118 | x = self.head(x) 119 | x = x[..., : shape[2], : shape[3]] 120 | 121 | return x 122 | 123 | def test(self, shape=(3, 512, 256)): 124 | import ptflops 125 | 126 | macs, params = ptflops.get_model_complexity_info( 127 | self, 128 | shape, 129 | as_strings=True, 130 | print_per_layer_stat=True, 131 | verbose=True, 132 | ) 133 | 134 | print(f"macs: {macs}") 135 | print(f"params: {params}") 136 | 137 | 138 | def main(): 139 | model = UNet(3, 3) 140 | model.test() 141 | 142 | 143 | if __name__ == "__main__": 144 | main() 145 | -------------------------------------------------------------------------------- /resemble_enhance/enhancer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/resemble-ai/resemble-enhance/8e978149bfe8abab3eb77d965d579a111afdb0ff/resemble_enhance/enhancer/__init__.py -------------------------------------------------------------------------------- /resemble_enhance/enhancer/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import time 4 | from pathlib import Path 5 | 6 | import torch 7 | import torchaudio 8 | from tqdm import tqdm 9 | 10 | from .inference import denoise, enhance 11 | 12 | 13 | @torch.inference_mode() 14 | def main(): 15 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 16 | parser.add_argument("in_dir", type=Path, help="Path to input audio folder") 17 | parser.add_argument("out_dir", type=Path, help="Output folder") 18 | parser.add_argument( 19 | "--run_dir", 20 | type=Path, 21 | default=None, 22 | help="Path to the enhancer run folder, if None, use the default model", 23 | ) 24 | parser.add_argument( 25 | "--suffix", 26 | type=str, 27 | default=".wav", 28 | help="Audio file suffix", 29 | ) 30 | parser.add_argument( 31 | "--device", 32 | type=str, 33 | default="cuda", 34 | help="Device to use for computation, recommended to use CUDA", 35 | ) 36 | parser.add_argument( 37 | "--denoise_only", 38 | action="store_true", 39 | help="Only apply denoising without enhancement", 40 | ) 41 | parser.add_argument( 42 | "--lambd", 43 | type=float, 44 | default=1.0, 45 | help="Denoise strength for enhancement (0.0 to 1.0)", 46 | ) 47 | parser.add_argument( 48 | "--tau", 49 | type=float, 50 | default=0.5, 51 | help="CFM prior temperature (0.0 to 1.0)", 52 | ) 53 | parser.add_argument( 54 | "--solver", 55 | type=str, 56 | default="midpoint", 57 | choices=["midpoint", "rk4", "euler"], 58 | help="Numerical solver to use", 59 | ) 60 | parser.add_argument( 61 | "--nfe", 62 | type=int, 63 | default=64, 64 | help="Number of function evaluations", 65 | ) 66 | parser.add_argument( 67 | "--parallel_mode", 68 | action="store_true", 69 | help="Shuffle the audio paths and skip the existing ones, enabling multiple jobs to run in parallel", 70 | ) 71 | 72 | args = parser.parse_args() 73 | 74 | device = args.device 75 | 76 | if device == "cuda" and not torch.cuda.is_available(): 77 | print("CUDA is not available but --device is set to cuda, using CPU instead") 78 | device = "cpu" 79 | 80 | start_time = time.perf_counter() 81 | 82 | run_dir = args.run_dir 83 | 84 | paths = sorted(args.in_dir.glob(f"**/*{args.suffix}")) 85 | 86 | if args.parallel_mode: 87 | random.shuffle(paths) 88 | 89 | if len(paths) == 0: 90 | print(f"No {args.suffix} files found in the following path: {args.in_dir}") 91 | return 92 | 93 | pbar = tqdm(paths) 94 | 95 | for path in pbar: 96 | out_path = args.out_dir / path.relative_to(args.in_dir) 97 | if args.parallel_mode and out_path.exists(): 98 | continue 99 | pbar.set_description(f"Processing {out_path}") 100 | dwav, sr = torchaudio.load(path) 101 | dwav = dwav.mean(0) 102 | if args.denoise_only: 103 | hwav, sr = denoise( 104 | dwav=dwav, 105 | sr=sr, 106 | device=device, 107 | run_dir=args.run_dir, 108 | ) 109 | else: 110 | hwav, sr = enhance( 111 | dwav=dwav, 112 | sr=sr, 113 | device=device, 114 | nfe=args.nfe, 115 | solver=args.solver, 116 | lambd=args.lambd, 117 | tau=args.tau, 118 | run_dir=run_dir, 119 | ) 120 | out_path.parent.mkdir(parents=True, exist_ok=True) 121 | torchaudio.save(out_path, hwav[None], sr) 122 | 123 | # Cool emoji effect saying the job is done 124 | elapsed_time = time.perf_counter() - start_time 125 | print(f"🌟 Enhancement done! {len(paths)} files processed in {elapsed_time:.2f}s") 126 | 127 | 128 | if __name__ == "__main__": 129 | main() 130 | -------------------------------------------------------------------------------- /resemble_enhance/enhancer/download.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | 4 | import torch 5 | 6 | RUN_NAME = "enhancer_stage2" 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def get_source_url(relpath): 12 | return f"https://huggingface.co/ResembleAI/resemble-enhance/resolve/main/{RUN_NAME}/{relpath}?download=true" 13 | 14 | 15 | def get_target_path(relpath: str | Path, run_dir: str | Path | None = None): 16 | if run_dir is None: 17 | run_dir = Path(__file__).parent.parent / "model_repo" / RUN_NAME 18 | return Path(run_dir) / relpath 19 | 20 | 21 | def download(run_dir: str | Path | None = None): 22 | relpaths = ["hparams.yaml", "ds/G/latest", "ds/G/default/mp_rank_00_model_states.pt"] 23 | for relpath in relpaths: 24 | path = get_target_path(relpath, run_dir=run_dir) 25 | if path.exists(): 26 | continue 27 | url = get_source_url(relpath) 28 | path.parent.mkdir(parents=True, exist_ok=True) 29 | torch.hub.download_url_to_file(url, str(path)) 30 | return get_target_path("", run_dir=run_dir) 31 | -------------------------------------------------------------------------------- /resemble_enhance/enhancer/enhancer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import matplotlib.pyplot as plt 4 | import pandas as pd 5 | import torch 6 | from torch import Tensor, nn 7 | from torch.distributions import Beta 8 | 9 | from ..common import Normalizer 10 | from ..denoiser.inference import load_denoiser 11 | from ..melspec import MelSpectrogram 12 | from ..utils.distributed import global_leader_only 13 | from ..utils.train_loop import TrainLoop 14 | from .hparams import HParams 15 | from .lcfm import CFM, IRMAE, LCFM 16 | from .univnet import UnivNet 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | def _maybe(fn): 22 | def _fn(*args): 23 | if args[0] is None: 24 | return None 25 | return fn(*args) 26 | 27 | return _fn 28 | 29 | 30 | def _normalize_wav(x: Tensor): 31 | return x / (x.abs().max(dim=-1, keepdim=True).values + 1e-7) 32 | 33 | 34 | class Enhancer(nn.Module): 35 | def __init__(self, hp: HParams): 36 | super().__init__() 37 | self.hp = hp 38 | 39 | n_mels = self.hp.num_mels 40 | vocoder_input_dim = n_mels + self.hp.vocoder_extra_dim 41 | latent_dim = self.hp.lcfm_latent_dim 42 | 43 | self.lcfm = LCFM( 44 | IRMAE( 45 | input_dim=n_mels, 46 | output_dim=vocoder_input_dim, 47 | latent_dim=latent_dim, 48 | ), 49 | CFM( 50 | cond_dim=n_mels, 51 | output_dim=self.hp.lcfm_latent_dim, 52 | solver_nfe=self.hp.cfm_solver_nfe, 53 | solver_method=self.hp.cfm_solver_method, 54 | time_mapping_divisor=self.hp.cfm_time_mapping_divisor, 55 | ), 56 | z_scale=self.hp.lcfm_z_scale, 57 | ) 58 | 59 | self.lcfm.set_mode_(self.hp.lcfm_training_mode) 60 | 61 | self.mel_fn = MelSpectrogram(hp) 62 | self.vocoder = UnivNet(self.hp, vocoder_input_dim) 63 | self.denoiser = load_denoiser(self.hp.denoiser_run_dir, "cpu") 64 | self.normalizer = Normalizer() 65 | 66 | self._eval_lambd = 0.0 67 | 68 | self.dummy: Tensor 69 | self.register_buffer("dummy", torch.zeros(1)) 70 | 71 | if self.hp.enhancer_stage1_run_dir is not None: 72 | pretrained_path = self.hp.enhancer_stage1_run_dir / "ds/G/default/mp_rank_00_model_states.pt" 73 | self._load_pretrained(pretrained_path) 74 | 75 | logger.info(f"{self.__class__.__name__} summary") 76 | logger.info(f"{self.summarize()}") 77 | 78 | def _load_pretrained(self, path): 79 | # Clone is necessary as otherwise it holds a reference to the original model 80 | cfm_state_dict = {k: v.clone() for k, v in self.lcfm.cfm.state_dict().items()} 81 | denoiser_state_dict = {k: v.clone() for k, v in self.denoiser.state_dict().items()} 82 | state_dict = torch.load(path, map_location="cpu")["module"] 83 | self.load_state_dict(state_dict, strict=False) 84 | self.lcfm.cfm.load_state_dict(cfm_state_dict) # Reset cfm 85 | self.denoiser.load_state_dict(denoiser_state_dict) # Reset denoiser 86 | logger.info(f"Loaded pretrained model from {path}") 87 | 88 | def summarize(self): 89 | npa_train = lambda m: sum(p.numel() for p in m.parameters() if p.requires_grad) 90 | npa = lambda m: sum(p.numel() for p in m.parameters()) 91 | rows = [] 92 | for name, module in self.named_children(): 93 | rows.append(dict(name=name, trainable=npa_train(module), total=npa(module))) 94 | rows.append(dict(name="total", trainable=npa_train(self), total=npa(self))) 95 | df = pd.DataFrame(rows) 96 | return df.to_markdown(index=False) 97 | 98 | def to_mel(self, x: Tensor, drop_last=True): 99 | """ 100 | Args: 101 | x: (b t), wavs 102 | Returns: 103 | o: (b c t), mels 104 | """ 105 | if drop_last: 106 | return self.mel_fn(x)[..., :-1] # (b d t) 107 | return self.mel_fn(x) 108 | 109 | @global_leader_only 110 | @torch.no_grad() 111 | def _visualize(self, original_mel, denoised_mel): 112 | loop = TrainLoop.get_running_loop() 113 | if loop is None or loop.global_step % 100 != 0: 114 | return 115 | 116 | plt.figure(figsize=(6, 6)) 117 | plt.subplot(211) 118 | plt.title("Original") 119 | plt.imshow(original_mel[0].cpu().numpy(), origin="lower", interpolation="none") 120 | plt.subplot(212) 121 | plt.title("Denoised") 122 | plt.imshow(denoised_mel[0].cpu().numpy(), origin="lower", interpolation="none") 123 | plt.tight_layout() 124 | 125 | path = loop.get_running_loop_viz_path("input", ".png") 126 | plt.savefig(path, dpi=300) 127 | 128 | def _may_denoise(self, x: Tensor, y: Tensor | None = None): 129 | if self.hp.lcfm_training_mode == "cfm": 130 | return self.denoiser(x, y) 131 | return x 132 | 133 | def configurate_(self, nfe, solver, lambd, tau): 134 | """ 135 | Args: 136 | nfe: number of function evaluations 137 | solver: solver method 138 | lambd: denoiser strength [0, 1] 139 | tau: prior temperature [0, 1] 140 | """ 141 | self.lcfm.cfm.solver.configurate_(nfe, solver) 142 | self.lcfm.eval_tau_(tau) 143 | self._eval_lambd = lambd 144 | 145 | def forward(self, x: Tensor, y: Tensor | None = None, z: Tensor | None = None): 146 | """ 147 | Args: 148 | x: (b t), mix wavs (fg + bg) 149 | y: (b t), fg clean wavs 150 | z: (b t), fg distorted wavs 151 | Returns: 152 | o: (b t), reconstructed wavs 153 | """ 154 | assert x.dim() == 2, f"Expected (b t), got {x.size()}" 155 | assert y is None or y.dim() == 2, f"Expected (b t), got {y.size()}" 156 | 157 | if self.hp.lcfm_training_mode == "cfm": 158 | self.normalizer.eval() 159 | 160 | x = _normalize_wav(x) 161 | y = _maybe(_normalize_wav)(y) 162 | z = _maybe(_normalize_wav)(z) 163 | 164 | x_mel_original = self.normalizer(self.to_mel(x), update=False) # (b d t) 165 | 166 | if self.hp.lcfm_training_mode == "cfm": 167 | if self.training: 168 | lambd = Beta(0.2, 0.2).sample(x.shape[:1]).to(x.device) 169 | lambd = lambd[:, None, None] 170 | x_mel_denoised = self.normalizer(self.to_mel(self._may_denoise(x, z)), update=False) 171 | x_mel_denoised = x_mel_denoised.detach() 172 | x_mel_denoised = lambd * x_mel_denoised + (1 - lambd) * x_mel_original 173 | self._visualize(x_mel_original, x_mel_denoised) 174 | else: 175 | lambd = self._eval_lambd 176 | if lambd == 0: 177 | x_mel_denoised = x_mel_original 178 | else: 179 | x_mel_denoised = self.normalizer(self.to_mel(self._may_denoise(x, z)), update=False) 180 | x_mel_denoised = x_mel_denoised.detach() 181 | x_mel_denoised = lambd * x_mel_denoised + (1 - lambd) * x_mel_original 182 | else: 183 | x_mel_denoised = x_mel_original 184 | 185 | y_mel = _maybe(self.to_mel)(y) # (b d t) 186 | y_mel = _maybe(self.normalizer)(y_mel) 187 | 188 | if self.hp.force_gaussian_prior: 189 | lcfm_decoded = self.lcfm(x_mel_denoised, y_mel, ψ0=None) # (b d t) 190 | else: 191 | lcfm_decoded = self.lcfm(x_mel_denoised, y_mel, ψ0=x_mel_original) # (b d t) 192 | 193 | if lcfm_decoded is None: 194 | o = None 195 | else: 196 | o = self.vocoder(lcfm_decoded, y) 197 | 198 | return o 199 | -------------------------------------------------------------------------------- /resemble_enhance/enhancer/hparams.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from pathlib import Path 3 | 4 | from ..hparams import HParams as HParamsBase 5 | 6 | 7 | @dataclass(frozen=True) 8 | class HParams(HParamsBase): 9 | cfm_solver_method: str = "midpoint" 10 | cfm_solver_nfe: int = 64 11 | cfm_time_mapping_divisor: int = 4 12 | univnet_nc: int = 96 13 | 14 | lcfm_latent_dim: int = 64 15 | lcfm_training_mode: str = "ae" 16 | # This value should be carefully tuned when training. Better estimate it from the latent vectors first 17 | lcfm_z_scale: float = 5 18 | 19 | vocoder_extra_dim: int = 32 20 | 21 | gan_training_start_step: int | None = 5_000 22 | enhancer_stage1_run_dir: Path | None = None 23 | 24 | denoiser_run_dir: Path | None = None 25 | 26 | # Enable this increases the training stability (but will also disable the change of eval_tau) 27 | force_gaussian_prior: bool = False 28 | -------------------------------------------------------------------------------- /resemble_enhance/enhancer/inference.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from functools import cache 3 | from pathlib import Path 4 | 5 | import torch 6 | 7 | from ..inference import inference 8 | from .download import download 9 | from .train import Enhancer, HParams 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | @cache 15 | def load_enhancer(run_dir: str | Path | None, device): 16 | run_dir = download(run_dir) 17 | hp = HParams.load(run_dir) 18 | enhancer = Enhancer(hp) 19 | path = run_dir / "ds" / "G" / "default" / "mp_rank_00_model_states.pt" 20 | state_dict = torch.load(path, map_location="cpu")["module"] 21 | enhancer.load_state_dict(state_dict) 22 | enhancer.eval() 23 | enhancer.to(device) 24 | return enhancer 25 | 26 | 27 | @torch.inference_mode() 28 | def denoise(dwav, sr, device, run_dir=None): 29 | enhancer = load_enhancer(run_dir, device) 30 | return inference(model=enhancer.denoiser, dwav=dwav, sr=sr, device=device) 31 | 32 | 33 | @torch.inference_mode() 34 | def enhance(dwav, sr, device, nfe=32, solver="midpoint", lambd=0.5, tau=0.5, run_dir=None): 35 | assert 0 < nfe <= 128, f"nfe must be in (0, 128], got {nfe}" 36 | assert solver in ("midpoint", "rk4", "euler"), f"solver must be in ('midpoint', 'rk4', 'euler'), got {solver}" 37 | assert 0 <= lambd <= 1, f"lambd must be in [0, 1], got {lambd}" 38 | assert 0 <= tau <= 1, f"tau must be in [0, 1], got {tau}" 39 | enhancer = load_enhancer(run_dir, device) 40 | enhancer.configurate_(nfe=nfe, solver=solver, lambd=lambd, tau=tau) 41 | return inference(model=enhancer, dwav=dwav, sr=sr, device=device) 42 | -------------------------------------------------------------------------------- /resemble_enhance/enhancer/lcfm/__init__.py: -------------------------------------------------------------------------------- 1 | from .irmae import IRMAE 2 | from .lcfm import CFM, LCFM 3 | -------------------------------------------------------------------------------- /resemble_enhance/enhancer/lcfm/cfm.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import dataclass 3 | from functools import partial 4 | from typing import Protocol 5 | 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import scipy 9 | import torch 10 | import torch.nn.functional as F 11 | from torch import Tensor, nn 12 | from tqdm import trange 13 | 14 | from .wn import WN 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | class VelocityField(Protocol): 20 | def __call__(self, *, t: Tensor, ψt: Tensor, dt: Tensor) -> Tensor: ... 21 | 22 | 23 | class Solver: 24 | def __init__( 25 | self, 26 | method="midpoint", 27 | nfe=32, 28 | viz_name="solver", 29 | viz_every=100, 30 | mel_fn=None, 31 | time_mapping_divisor=4, 32 | verbose=False, 33 | ): 34 | self.configurate_(nfe=nfe, method=method) 35 | 36 | self.verbose = verbose 37 | self.viz_every = viz_every 38 | self.viz_name = viz_name 39 | 40 | self._camera = None 41 | self._mel_fn = mel_fn 42 | self._time_mapping = partial(self.exponential_decay_mapping, n=time_mapping_divisor) 43 | 44 | def configurate_(self, nfe=None, method=None): 45 | if nfe is None: 46 | nfe = self.nfe 47 | 48 | if method is None: 49 | method = self.method 50 | 51 | if nfe == 1 and method in ("midpoint", "rk4"): 52 | logger.warning(f"1 NFE is not supported for {method}, using euler method instead.") 53 | method = "euler" 54 | 55 | self.nfe = nfe 56 | self.method = method 57 | 58 | @property 59 | def time_mapping(self): 60 | return self._time_mapping 61 | 62 | @staticmethod 63 | def exponential_decay_mapping(t, n=4): 64 | """ 65 | Args: 66 | n: target step 67 | """ 68 | 69 | def h(t, a): 70 | return (a**t - 1) / (a - 1) 71 | 72 | # Solve h(1/n) = 0.5 73 | a = float(scipy.optimize.fsolve(lambda a: h(1 / n, a) - 0.5, x0=0)) 74 | 75 | t = h(t, a=a) 76 | 77 | return t 78 | 79 | @torch.no_grad() 80 | def _maybe_camera_snap(self, *, ψt, t): 81 | camera = self._camera 82 | if camera is not None: 83 | if ψt.shape[1] == 1: 84 | # Waveform, b 1 t, plot every 100 samples 85 | plt.subplot(211) 86 | plt.plot(ψt.detach().cpu().numpy()[0, 0, ::100], color="blue") 87 | if self._mel_fn is not None: 88 | plt.subplot(212) 89 | mel = self._mel_fn(ψt.detach().cpu().numpy()[0, 0]) 90 | plt.imshow(mel, origin="lower", interpolation="none") 91 | elif ψt.shape[1] == 2: 92 | # Complex 93 | plt.subplot(121) 94 | plt.imshow( 95 | ψt.detach().cpu().numpy()[0, 0], 96 | origin="lower", 97 | interpolation="none", 98 | ) 99 | plt.subplot(122) 100 | plt.imshow( 101 | ψt.detach().cpu().numpy()[0, 1], 102 | origin="lower", 103 | interpolation="none", 104 | ) 105 | else: 106 | # Spectrogram, b c t 107 | plt.imshow(ψt.detach().cpu().numpy()[0], origin="lower", interpolation="none") 108 | ax = plt.gca() 109 | ax.text(0.5, 1.01, f"t={t:.2f}", transform=ax.transAxes, ha="center") 110 | camera.snap() 111 | 112 | @staticmethod 113 | def _euler_step(t, ψt, dt, f: VelocityField): 114 | return ψt + dt * f(t=t, ψt=ψt, dt=dt) 115 | 116 | @staticmethod 117 | def _midpoint_step(t, ψt, dt, f: VelocityField): 118 | return ψt + dt * f(t=t + dt / 2, ψt=ψt + dt * f(t=t, ψt=ψt, dt=dt) / 2, dt=dt) 119 | 120 | @staticmethod 121 | def _rk4_step(t, ψt, dt, f: VelocityField): 122 | k1 = f(t=t, ψt=ψt, dt=dt) 123 | k2 = f(t=t + dt / 2, ψt=ψt + dt * k1 / 2, dt=dt) 124 | k3 = f(t=t + dt / 2, ψt=ψt + dt * k2 / 2, dt=dt) 125 | k4 = f(t=t + dt, ψt=ψt + dt * k3, dt=dt) 126 | return ψt + dt * (k1 + 2 * k2 + 2 * k3 + k4) / 6 127 | 128 | @property 129 | def _step(self): 130 | if self.method == "euler": 131 | return self._euler_step 132 | elif self.method == "midpoint": 133 | return self._midpoint_step 134 | elif self.method == "rk4": 135 | return self._rk4_step 136 | else: 137 | raise ValueError(f"Unknown method: {self.method}") 138 | 139 | def get_running_train_loop(self): 140 | try: 141 | # Lazy import 142 | from ...utils.train_loop import TrainLoop 143 | 144 | return TrainLoop.get_running_loop() 145 | except ImportError: 146 | return None 147 | 148 | @property 149 | def visualizing(self): 150 | loop = self.get_running_train_loop() 151 | if loop is None: 152 | return 153 | out_path = loop.make_current_step_viz_path(self.viz_name, ".gif") 154 | return loop.global_step % self.viz_every == 0 and not out_path.exists() 155 | 156 | def _reset_camera(self): 157 | try: 158 | from celluloid import Camera 159 | 160 | self._camera = Camera(plt.figure()) 161 | except: 162 | pass 163 | 164 | def _maybe_dump_camera(self): 165 | camera = self._camera 166 | loop = self.get_running_train_loop() 167 | if camera is not None and loop is not None: 168 | animation = camera.animate() 169 | out_path = loop.make_current_step_viz_path(self.viz_name, ".gif") 170 | out_path.parent.mkdir(exist_ok=True, parents=True) 171 | animation.save(out_path, writer="pillow", fps=4) 172 | plt.close() 173 | self._camera = None 174 | 175 | @property 176 | def n_steps(self): 177 | n = self.nfe 178 | if self.method == "euler": 179 | pass 180 | elif self.method == "midpoint": 181 | n //= 2 182 | elif self.method == "rk4": 183 | n //= 4 184 | else: 185 | raise ValueError(f"Unknown method: {self.method}") 186 | return n 187 | 188 | def solve(self, f: VelocityField, ψ0: Tensor, t0=0.0, t1=1.0): 189 | ts = self._time_mapping(np.linspace(t0, t1, self.n_steps + 1)) 190 | 191 | if self.visualizing: 192 | self._reset_camera() 193 | 194 | if self.verbose: 195 | steps = trange(self.n_steps, desc="CFM inference") 196 | else: 197 | steps = range(self.n_steps) 198 | 199 | ψt = ψ0 200 | 201 | for i in steps: 202 | dt = ts[i + 1] - ts[i] 203 | t = ts[i] 204 | self._maybe_camera_snap(ψt=ψt, t=t) 205 | ψt = self._step(t=t, ψt=ψt, dt=dt, f=f) 206 | 207 | self._maybe_camera_snap(ψt=ψt, t=ts[-1]) 208 | 209 | ψ1 = ψt 210 | del ψt 211 | 212 | self._maybe_dump_camera() 213 | 214 | return ψ1 215 | 216 | def __call__(self, f: VelocityField, ψ0: Tensor, t0=0.0, t1=1.0): 217 | return self.solve(f=f, ψ0=ψ0, t0=t0, t1=t1) 218 | 219 | 220 | class SinusodialTimeEmbedding(nn.Module): 221 | def __init__(self, d_embed): 222 | super().__init__() 223 | self.d_embed = d_embed 224 | assert d_embed % 2 == 0 225 | 226 | def forward(self, t): 227 | t = t.unsqueeze(-1) # ... 1 228 | p = torch.linspace(0, 4, self.d_embed // 2).to(t) 229 | while p.dim() < t.dim(): 230 | p = p.unsqueeze(0) # ... d/2 231 | sin = torch.sin(t * 10**p) 232 | cos = torch.cos(t * 10**p) 233 | return torch.cat([sin, cos], dim=-1) 234 | 235 | 236 | @dataclass(eq=False) 237 | class CFM(nn.Module): 238 | """ 239 | This mixin is for general diffusion models. 240 | 241 | ψ0 stands for the gaussian noise, and ψ1 is the data point. 242 | 243 | Here we follow the CFM style: 244 | The generation process (reverse process) is from t=0 to t=1. 245 | The forward process is from t=1 to t=0. 246 | """ 247 | 248 | cond_dim: int 249 | output_dim: int 250 | time_emb_dim: int = 128 251 | viz_name: str = "cfm" 252 | solver_nfe: int = 32 253 | solver_method: str = "midpoint" 254 | time_mapping_divisor: int = 4 255 | 256 | def __post_init__(self): 257 | super().__init__() 258 | self.solver = Solver( 259 | viz_name=self.viz_name, 260 | viz_every=1, 261 | nfe=self.solver_nfe, 262 | method=self.solver_method, 263 | time_mapping_divisor=self.time_mapping_divisor, 264 | ) 265 | self.emb = SinusodialTimeEmbedding(self.time_emb_dim) 266 | self.net = WN( 267 | input_dim=self.output_dim, 268 | output_dim=self.output_dim, 269 | local_dim=self.cond_dim, 270 | global_dim=self.time_emb_dim, 271 | ) 272 | 273 | def _perturb(self, ψ1: Tensor, t: Tensor | None = None): 274 | """ 275 | Perturb ψ1 to ψt. 276 | """ 277 | raise NotImplementedError 278 | 279 | def _sample_ψ0(self, x: Tensor): 280 | """ 281 | Args: 282 | x: (b c t), which implies the shape of ψ0 283 | """ 284 | shape = list(x.shape) 285 | shape[1] = self.output_dim 286 | if self.training: 287 | g = None 288 | else: 289 | g = torch.Generator(device=x.device) 290 | g.manual_seed(0) # deterministic sampling during eval 291 | ψ0 = torch.randn(shape, device=x.device, dtype=x.dtype, generator=g) 292 | return ψ0 293 | 294 | @property 295 | def sigma(self): 296 | return 1e-4 297 | 298 | def _to_ψt(self, *, ψ1: Tensor, ψ0: Tensor, t: Tensor): 299 | """ 300 | Eq (22) 301 | """ 302 | while t.dim() < ψ1.dim(): 303 | t = t.unsqueeze(-1) 304 | μ = t * ψ1 + (1 - t) * ψ0 305 | return μ + torch.randn_like(μ) * self.sigma 306 | 307 | def _to_u(self, *, ψ1, ψ0: Tensor): 308 | """ 309 | Eq (21) 310 | """ 311 | return ψ1 - ψ0 312 | 313 | def _to_v(self, *, ψt, x, t: float | Tensor): 314 | """ 315 | Args: 316 | ψt: (b c t) 317 | x: (b c t) 318 | t: (b) 319 | Returns: 320 | v: (b c t) 321 | """ 322 | if isinstance(t, (float, int)): 323 | t = torch.full(ψt.shape[:1], t).to(ψt) 324 | t = t.clamp(0, 1) # [0, 1) 325 | g = self.emb(t) # (b d) 326 | v = self.net(ψt, l=x, g=g) 327 | return v 328 | 329 | def compute_losses(self, x, y, ψ0) -> dict: 330 | """ 331 | Args: 332 | x: (b c t) 333 | y: (b c t) 334 | Returns: 335 | losses: dict 336 | """ 337 | t = torch.rand(len(x), device=x.device, dtype=x.dtype) 338 | t = self.solver.time_mapping(t) 339 | 340 | if ψ0 is None: 341 | ψ0 = self._sample_ψ0(x) 342 | 343 | ψt = self._to_ψt(ψ1=y, t=t, ψ0=ψ0) 344 | 345 | v = self._to_v(ψt=ψt, t=t, x=x) 346 | u = self._to_u(ψ1=y, ψ0=ψ0) 347 | 348 | losses = dict(l1=F.l1_loss(v, u)) 349 | 350 | return losses 351 | 352 | @torch.inference_mode() 353 | def sample(self, x, ψ0=None, t0=0.0): 354 | """ 355 | Args: 356 | x: (b c t) 357 | Returns: 358 | y: (b ... t) 359 | """ 360 | if ψ0 is None: 361 | ψ0 = self._sample_ψ0(x) 362 | f = lambda t, ψt, dt: self._to_v(ψt=ψt, t=t, x=x) 363 | ψ1 = self.solver(f=f, ψ0=ψ0, t0=t0) 364 | return ψ1 365 | 366 | def forward(self, x: Tensor, y: Tensor | None = None, ψ0: Tensor | None = None, t0=0.0): 367 | if y is None: 368 | y = self.sample(x, ψ0=ψ0, t0=t0) 369 | else: 370 | self.losses = self.compute_losses(x, y, ψ0=ψ0) 371 | return y 372 | -------------------------------------------------------------------------------- /resemble_enhance/enhancer/lcfm/irmae.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import dataclass 3 | 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch import Tensor, nn 7 | from torch.nn.utils.parametrizations import weight_norm 8 | 9 | from ...common import Normalizer 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | @dataclass 15 | class IRMAEOutput: 16 | latent: Tensor # latent vector 17 | decoded: Tensor | None # decoder output, include extra dim 18 | 19 | 20 | class ResBlock(nn.Sequential): 21 | def __init__(self, channels, dilations=[1, 2, 4, 8]): 22 | wn = weight_norm 23 | super().__init__( 24 | nn.GroupNorm(32, channels), 25 | nn.GELU(), 26 | wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[0])), 27 | nn.GroupNorm(32, channels), 28 | nn.GELU(), 29 | wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[1])), 30 | nn.GroupNorm(32, channels), 31 | nn.GELU(), 32 | wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[2])), 33 | nn.GroupNorm(32, channels), 34 | nn.GELU(), 35 | wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[3])), 36 | ) 37 | 38 | def forward(self, x: Tensor): 39 | return x + super().forward(x) 40 | 41 | 42 | class IRMAE(nn.Module): 43 | def __init__( 44 | self, 45 | input_dim, 46 | output_dim, 47 | latent_dim, 48 | hidden_dim=1024, 49 | num_irms=4, 50 | ): 51 | """ 52 | Args: 53 | input_dim: input dimension 54 | output_dim: output dimension 55 | latent_dim: latent dimension 56 | hidden_dim: hidden layer dimension 57 | num_irm_matrics: number of implicit rank minimization matrices 58 | norm: normalization layer 59 | """ 60 | self.input_dim = input_dim 61 | super().__init__() 62 | 63 | self.encoder = nn.Sequential( 64 | nn.Conv1d(input_dim, hidden_dim, 3, padding="same"), 65 | *[ResBlock(hidden_dim) for _ in range(4)], 66 | # Try to obtain compact representation (https://proceedings.neurips.cc/paper/2020/file/a9078e8653368c9c291ae2f8b74012e7-Paper.pdf) 67 | *[nn.Conv1d(hidden_dim if i == 0 else latent_dim, latent_dim, 1, bias=False) for i in range(num_irms)], 68 | nn.Tanh(), 69 | ) 70 | 71 | self.decoder = nn.Sequential( 72 | nn.Conv1d(latent_dim, hidden_dim, 3, padding="same"), 73 | *[ResBlock(hidden_dim) for _ in range(4)], 74 | nn.Conv1d(hidden_dim, output_dim, 1), 75 | ) 76 | 77 | self.head = nn.Sequential( 78 | nn.Conv1d(output_dim, hidden_dim, 3, padding="same"), 79 | nn.GELU(), 80 | nn.Conv1d(hidden_dim, input_dim, 1), 81 | ) 82 | 83 | self.estimator = Normalizer() 84 | 85 | def encode(self, x): 86 | """ 87 | Args: 88 | x: (b c t) tensor 89 | """ 90 | z = self.encoder(x) # (b c t) 91 | _ = self.estimator(z) # Estimate the glboal mean and std of z 92 | self.stats = {} 93 | self.stats["z_mean"] = z.mean().item() 94 | self.stats["z_std"] = z.std().item() 95 | self.stats["z_abs_68"] = z.abs().quantile(0.6827).item() 96 | self.stats["z_abs_95"] = z.abs().quantile(0.9545).item() 97 | self.stats["z_abs_99"] = z.abs().quantile(0.9973).item() 98 | return z 99 | 100 | def decode(self, z): 101 | """ 102 | Args: 103 | z: (b c t) tensor 104 | """ 105 | return self.decoder(z) 106 | 107 | def forward(self, x, skip_decoding=False): 108 | """ 109 | Args: 110 | x: (b c t) tensor 111 | skip_decoding: if True, skip the decoding step 112 | """ 113 | z = self.encode(x) # q(z|x) 114 | 115 | if skip_decoding: 116 | # This speeds up the training in cfm only mode 117 | decoded = None 118 | else: 119 | decoded = self.decode(z) # p(x|z) 120 | predicted = self.head(decoded) 121 | self.losses = dict(mse=F.mse_loss(predicted, x)) 122 | 123 | return IRMAEOutput(latent=z, decoded=decoded) 124 | -------------------------------------------------------------------------------- /resemble_enhance/enhancer/lcfm/lcfm.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from enum import Enum 3 | 4 | import matplotlib.pyplot as plt 5 | import torch 6 | import torch.nn as nn 7 | from torch import Tensor, nn 8 | 9 | from .cfm import CFM 10 | from .irmae import IRMAE, IRMAEOutput 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def freeze_(module): 16 | for p in module.parameters(): 17 | p.requires_grad_(False) 18 | 19 | 20 | class LCFM(nn.Module): 21 | class Mode(Enum): 22 | AE = "ae" 23 | CFM = "cfm" 24 | 25 | def __init__(self, ae: IRMAE, cfm: CFM, z_scale: float = 1.0): 26 | super().__init__() 27 | self.ae = ae 28 | self.cfm = cfm 29 | self.z_scale = z_scale 30 | self._mode = None 31 | self._eval_tau = 0.5 32 | 33 | @property 34 | def mode(self): 35 | return self._mode 36 | 37 | def set_mode_(self, mode): 38 | mode = self.Mode(mode) 39 | self._mode = mode 40 | 41 | if mode == mode.AE: 42 | freeze_(self.cfm) 43 | logger.info("Freeze cfm") 44 | elif mode == mode.CFM: 45 | freeze_(self.ae) 46 | logger.info("Freeze ae (encoder and decoder)") 47 | else: 48 | raise ValueError(f"Unknown training mode: {mode}") 49 | 50 | def get_running_train_loop(self): 51 | try: 52 | # Lazy import 53 | from ...utils.train_loop import TrainLoop 54 | 55 | return TrainLoop.get_running_loop() 56 | except ImportError: 57 | return None 58 | 59 | @property 60 | def global_step(self): 61 | loop = self.get_running_train_loop() 62 | if loop is None: 63 | return None 64 | return loop.global_step 65 | 66 | @torch.no_grad() 67 | def _visualize(self, x, y, y_): 68 | loop = self.get_running_train_loop() 69 | if loop is None: 70 | return 71 | 72 | plt.subplot(221) 73 | plt.imshow(y[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none") 74 | plt.title("GT") 75 | 76 | plt.subplot(222) 77 | y_ = y_[:, : y.shape[1]] 78 | plt.imshow(y_[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none") 79 | plt.title("Posterior") 80 | 81 | plt.subplot(223) 82 | z_ = self.cfm(x) 83 | y__ = self.ae.decode(z_) 84 | y__ = y__[:, : y.shape[1]] 85 | plt.imshow(y__[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none") 86 | plt.title("C-Prior") 87 | del y__ 88 | 89 | plt.subplot(224) 90 | z_ = torch.randn_like(z_) 91 | y__ = self.ae.decode(z_) 92 | y__ = y__[:, : y.shape[1]] 93 | plt.imshow(y__[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none") 94 | plt.title("Prior") 95 | del z_, y__ 96 | 97 | path = loop.make_current_step_viz_path("recon", ".png") 98 | path.parent.mkdir(exist_ok=True, parents=True) 99 | plt.tight_layout() 100 | plt.savefig(path, dpi=500) 101 | plt.close() 102 | 103 | def _scale(self, z: Tensor): 104 | return z * self.z_scale 105 | 106 | def _unscale(self, z: Tensor): 107 | return z / self.z_scale 108 | 109 | def eval_tau_(self, tau): 110 | self._eval_tau = tau 111 | 112 | def forward(self, x, y: Tensor | None = None, ψ0: Tensor | None = None): 113 | """ 114 | Args: 115 | x: (b d t), condition mel 116 | y: (b d t), target mel 117 | ψ0: (b d t), starting mel 118 | """ 119 | if self.mode == self.Mode.CFM: 120 | self.ae.eval() # Always set to eval when training cfm 121 | 122 | if ψ0 is not None: 123 | ψ0 = self._scale(self.ae.encode(ψ0)) 124 | if self.training: 125 | tau = torch.rand_like(ψ0[:, :1, :1]) 126 | else: 127 | tau = self._eval_tau 128 | ψ0 = tau * torch.randn_like(ψ0) + (1 - tau) * ψ0 129 | 130 | if y is None: 131 | if self.mode == self.Mode.AE: 132 | with torch.no_grad(): 133 | training = self.ae.training 134 | self.ae.eval() 135 | z = self.ae.encode(x) 136 | self.ae.train(training) 137 | else: 138 | z = self._unscale(self.cfm(x, ψ0=ψ0)) 139 | 140 | h = self.ae.decode(z) 141 | else: 142 | ae_output: IRMAEOutput = self.ae(y, skip_decoding=self.mode == self.Mode.CFM) 143 | 144 | if self.mode == self.Mode.CFM: 145 | _ = self.cfm(x, self._scale(ae_output.latent.detach()), ψ0=ψ0) 146 | 147 | h = ae_output.decoded 148 | 149 | if h is not None and self.global_step is not None and self.global_step % 100 == 0: 150 | self._visualize(x[:1], y[:1], h[:1]) 151 | 152 | return h 153 | -------------------------------------------------------------------------------- /resemble_enhance/enhancer/lcfm/wn.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | @torch.jit.script 11 | def _fused_tanh_sigmoid(h): 12 | a, b = h.chunk(2, dim=1) 13 | h = a.tanh() * b.sigmoid() 14 | return h 15 | 16 | 17 | class WNLayer(nn.Module): 18 | """ 19 | A DiffWave-like WN 20 | """ 21 | 22 | def __init__(self, hidden_dim, local_dim, global_dim, kernel_size, dilation): 23 | super().__init__() 24 | 25 | local_output_dim = hidden_dim * 2 26 | 27 | if global_dim is not None: 28 | self.gconv = nn.Conv1d(global_dim, hidden_dim, 1) 29 | 30 | if local_dim is not None: 31 | self.lconv = nn.Conv1d(local_dim, local_output_dim, 1) 32 | 33 | self.dconv = nn.Conv1d(hidden_dim, local_output_dim, kernel_size, dilation=dilation, padding="same") 34 | 35 | self.out = nn.Conv1d(hidden_dim, 2 * hidden_dim, kernel_size=1) 36 | 37 | def forward(self, z, l, g): 38 | identity = z 39 | 40 | if g is not None: 41 | if g.dim() == 2: 42 | g = g.unsqueeze(-1) 43 | z = z + self.gconv(g) 44 | 45 | z = self.dconv(z) 46 | 47 | if l is not None: 48 | z = z + self.lconv(l) 49 | 50 | z = _fused_tanh_sigmoid(z) 51 | 52 | h = self.out(z) 53 | 54 | z, s = h.chunk(2, dim=1) 55 | 56 | o = (z + identity) / math.sqrt(2) 57 | 58 | return o, s 59 | 60 | 61 | class WN(nn.Module): 62 | def __init__( 63 | self, 64 | input_dim, 65 | output_dim, 66 | local_dim=None, 67 | global_dim=None, 68 | n_layers=30, 69 | kernel_size=3, 70 | dilation_cycle=5, 71 | hidden_dim=512, 72 | ): 73 | super().__init__() 74 | assert kernel_size % 2 == 1 75 | assert hidden_dim % 2 == 0 76 | 77 | self.input_dim = input_dim 78 | self.hidden_dim = hidden_dim 79 | self.local_dim = local_dim 80 | self.global_dim = global_dim 81 | 82 | self.start = nn.Conv1d(input_dim, hidden_dim, 1) 83 | if local_dim is not None: 84 | self.local_norm = nn.InstanceNorm1d(local_dim) 85 | 86 | self.layers = nn.ModuleList( 87 | [ 88 | WNLayer( 89 | hidden_dim=hidden_dim, 90 | local_dim=local_dim, 91 | global_dim=global_dim, 92 | kernel_size=kernel_size, 93 | dilation=2 ** (i % dilation_cycle), 94 | ) 95 | for i in range(n_layers) 96 | ] 97 | ) 98 | 99 | self.end = nn.Conv1d(hidden_dim, output_dim, 1) 100 | 101 | def forward(self, z, l=None, g=None): 102 | """ 103 | Args: 104 | z: input (b c t) 105 | l: local condition (b c t) 106 | g: global condition (b d) 107 | """ 108 | z = self.start(z) 109 | 110 | if l is not None: 111 | l = self.local_norm(l) 112 | 113 | # Skips 114 | s_list = [] 115 | 116 | for layer in self.layers: 117 | z, s = layer(z, l, g) 118 | s_list.append(s) 119 | 120 | s_list = torch.stack(s_list, dim=0).sum(dim=0) 121 | s_list = s_list / math.sqrt(len(self.layers)) 122 | 123 | o = self.end(s_list) 124 | 125 | return o 126 | 127 | def summarize(self, length=100): 128 | from ptflops import get_model_complexity_info 129 | 130 | x = torch.randn(1, self.input_dim, length) 131 | 132 | macs, params = get_model_complexity_info( 133 | self, 134 | (self.input_dim, length), 135 | as_strings=True, 136 | print_per_layer_stat=True, 137 | verbose=True, 138 | ) 139 | 140 | print(f"Input shape: {x.shape}") 141 | print(f"Computational complexity: {macs}") 142 | print(f"Number of parameters: {params}") 143 | 144 | 145 | if __name__ == "__main__": 146 | model = WN(input_dim=64, output_dim=64) 147 | model.summarize() 148 | -------------------------------------------------------------------------------- /resemble_enhance/enhancer/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | from functools import partial 4 | from pathlib import Path 5 | 6 | import soundfile 7 | import torch 8 | from deepspeed import DeepSpeedConfig 9 | from torch import Tensor 10 | from tqdm import tqdm 11 | 12 | from ..data import create_dataloaders, mix_fg_bg 13 | from ..utils import Engine, TrainLoop, save_mels, setup_logging, tree_map 14 | from ..utils.distributed import is_local_leader 15 | from .enhancer import Enhancer 16 | from .hparams import HParams 17 | from .univnet.discriminator import Discriminator 18 | 19 | 20 | def load_G(run_dir: Path, hp: HParams | None = None, training=True): 21 | if hp is None: 22 | hp = HParams.load(run_dir) 23 | assert isinstance(hp, HParams) 24 | model = Enhancer(hp) 25 | engine = Engine(model=model, config_class=DeepSpeedConfig(hp.deepspeed_config), ckpt_dir=run_dir / "ds" / "G") 26 | if training: 27 | engine.load_checkpoint() 28 | else: 29 | engine.load_checkpoint(load_optimizer_states=False, load_lr_scheduler_states=False) 30 | return engine 31 | 32 | 33 | def load_D(run_dir: Path, hp: HParams | None): 34 | if hp is None: 35 | hp = HParams.load(run_dir) 36 | assert isinstance(hp, HParams) 37 | model = Discriminator(hp) 38 | engine = Engine(model=model, config_class=DeepSpeedConfig(hp.deepspeed_config), ckpt_dir=run_dir / "ds" / "D") 39 | engine.load_checkpoint() 40 | return engine 41 | 42 | 43 | def save_wav(path: Path, wav: Tensor, rate: int): 44 | wav_numpy = wav.detach().cpu().numpy() 45 | soundfile.write(path, wav_numpy, samplerate=rate) 46 | 47 | 48 | def main(): 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument("run_dir", type=Path) 51 | parser.add_argument("--yaml", type=Path, default=None) 52 | parser.add_argument("--device", type=str, default="cuda") 53 | args = parser.parse_args() 54 | 55 | setup_logging(args.run_dir) 56 | hp = HParams.load(args.run_dir, yaml=args.yaml) 57 | 58 | if is_local_leader(): 59 | hp.save_if_not_exists(args.run_dir) 60 | hp.print() 61 | 62 | train_dl, val_dl = create_dataloaders(hp, mode="enhancer") 63 | 64 | def feed_G(engine: Engine, batch: dict[str, Tensor]): 65 | if hp.lcfm_training_mode == "ae": 66 | pred = engine(batch["fg_wavs"], batch["fg_wavs"]) 67 | elif hp.lcfm_training_mode == "cfm": 68 | alpha_fn = lambda: random.uniform(*hp.mix_alpha_range) 69 | mx_dwavs = mix_fg_bg(batch["fg_dwavs"], batch["bg_dwavs"], alpha=alpha_fn) 70 | pred = engine(mx_dwavs, batch["fg_wavs"], batch["fg_dwavs"]) 71 | else: 72 | raise ValueError(f"Unknown training mode: {hp.lcfm_training_mode}") 73 | losses = engine.gather_attribute("losses") 74 | return pred, losses 75 | 76 | def feed_D(engine: Engine, batch: dict | None, fake: Tensor): 77 | if batch is None: 78 | losses = engine(fake=fake) 79 | else: 80 | losses = engine(fake=fake, real=batch["fg_wavs"]) 81 | return losses 82 | 83 | @torch.no_grad() 84 | def eval_fn(engine: Engine, eval_dir, n_saved=10): 85 | assert isinstance(hp, HParams) 86 | 87 | model = engine.module 88 | model.eval() 89 | 90 | step = engine.global_step 91 | 92 | for i, batch in enumerate(tqdm(val_dl), 1): 93 | batch = tree_map(lambda x: x.to(args.device) if isinstance(x, Tensor) else x, batch) 94 | 95 | fg_wavs = batch["fg_wavs"] # 1 t 96 | 97 | if hp.lcfm_training_mode == "ae": 98 | in_dwavs = fg_wavs 99 | elif hp.lcfm_training_mode == "cfm": 100 | in_dwavs = mix_fg_bg(fg_wavs, batch["bg_dwavs"]) 101 | else: 102 | raise ValueError(f"Unknown training mode: {hp.lcfm_training_mode}") 103 | 104 | pred_fg_wavs = model(in_dwavs) # 1 t 105 | 106 | in_mels = model.to_mel(in_dwavs) # 1 c t 107 | fg_mels = model.to_mel(fg_wavs) # 1 c t 108 | pred_fg_mels = model.to_mel(pred_fg_wavs) # 1 c t 109 | 110 | rate = model.hp.wav_rate 111 | get_path = lambda suffix: eval_dir / f"step_{step:08}_{i:03}{suffix}" 112 | 113 | save_wav(get_path("_input.wav"), in_dwavs[0], rate=rate) 114 | save_wav(get_path("_predict.wav"), pred_fg_wavs[0], rate=rate) 115 | save_wav(get_path("_target.wav"), fg_wavs[0], rate=rate) 116 | 117 | save_mels( 118 | get_path(".png"), 119 | cond_mel=in_mels[0].cpu().numpy(), 120 | pred_mel=pred_fg_mels[0].cpu().numpy(), 121 | targ_mel=fg_mels[0].cpu().numpy(), 122 | ) 123 | 124 | if i >= n_saved: 125 | break 126 | 127 | train_loop = TrainLoop( 128 | run_dir=args.run_dir, 129 | train_dl=train_dl, 130 | load_G=partial(load_G, hp=hp), 131 | load_D=partial(load_D, hp=hp), 132 | device=args.device, 133 | feed_G=feed_G, 134 | feed_D=feed_D, 135 | eval_fn=eval_fn, 136 | gan_training_start_step=hp.gan_training_start_step, 137 | ) 138 | 139 | train_loop.run(max_steps=hp.max_steps) 140 | 141 | 142 | if __name__ == "__main__": 143 | main() 144 | -------------------------------------------------------------------------------- /resemble_enhance/enhancer/univnet/__init__.py: -------------------------------------------------------------------------------- 1 | from .univnet import UnivNet 2 | -------------------------------------------------------------------------------- /resemble_enhance/enhancer/univnet/alias_free_torch/__init__.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | from .filter import * 5 | from .resample import * 6 | -------------------------------------------------------------------------------- /resemble_enhance/enhancer/univnet/alias_free_torch/filter.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import math 8 | 9 | if 'sinc' in dir(torch): 10 | sinc = torch.sinc 11 | else: 12 | # This code is adopted from adefossez's julius.core.sinc under the MIT License 13 | # https://adefossez.github.io/julius/julius/core.html 14 | # LICENSE is in incl_licenses directory. 15 | def sinc(x: torch.Tensor): 16 | """ 17 | Implementation of sinc, i.e. sin(pi * x) / (pi * x) 18 | __Warning__: Different to julius.sinc, the input is multiplied by `pi`! 19 | """ 20 | return torch.where(x == 0, 21 | torch.tensor(1., device=x.device, dtype=x.dtype), 22 | torch.sin(math.pi * x) / math.pi / x) 23 | 24 | 25 | # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License 26 | # https://adefossez.github.io/julius/julius/lowpass.html 27 | # LICENSE is in incl_licenses directory. 28 | def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size] 29 | even = (kernel_size % 2 == 0) 30 | half_size = kernel_size // 2 31 | 32 | #For kaiser window 33 | delta_f = 4 * half_width 34 | A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 35 | if A > 50.: 36 | beta = 0.1102 * (A - 8.7) 37 | elif A >= 21.: 38 | beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.) 39 | else: 40 | beta = 0. 41 | window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) 42 | 43 | # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio 44 | if even: 45 | time = (torch.arange(-half_size, half_size) + 0.5) 46 | else: 47 | time = torch.arange(kernel_size) - half_size 48 | if cutoff == 0: 49 | filter_ = torch.zeros_like(time) 50 | else: 51 | filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) 52 | # Normalize filter to have sum = 1, otherwise we will have a small leakage 53 | # of the constant component in the input signal. 54 | filter_ /= filter_.sum() 55 | filter = filter_.view(1, 1, kernel_size) 56 | 57 | return filter 58 | 59 | 60 | class LowPassFilter1d(nn.Module): 61 | def __init__(self, 62 | cutoff=0.5, 63 | half_width=0.6, 64 | stride: int = 1, 65 | padding: bool = True, 66 | padding_mode: str = 'replicate', 67 | kernel_size: int = 12): 68 | # kernel_size should be even number for stylegan3 setup, 69 | # in this implementation, odd number is also possible. 70 | super().__init__() 71 | if cutoff < -0.: 72 | raise ValueError("Minimum cutoff must be larger than zero.") 73 | if cutoff > 0.5: 74 | raise ValueError("A cutoff above 0.5 does not make sense.") 75 | self.kernel_size = kernel_size 76 | self.even = (kernel_size % 2 == 0) 77 | self.pad_left = kernel_size // 2 - int(self.even) 78 | self.pad_right = kernel_size // 2 79 | self.stride = stride 80 | self.padding = padding 81 | self.padding_mode = padding_mode 82 | filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) 83 | self.register_buffer("filter", filter) 84 | 85 | #input [B, C, T] 86 | def forward(self, x): 87 | _, C, _ = x.shape 88 | 89 | if self.padding: 90 | x = F.pad(x, (self.pad_left, self.pad_right), 91 | mode=self.padding_mode) 92 | out = F.conv1d(x, self.filter.expand(C, -1, -1), 93 | stride=self.stride, groups=C) 94 | 95 | return out 96 | -------------------------------------------------------------------------------- /resemble_enhance/enhancer/univnet/alias_free_torch/resample.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | from .filter import LowPassFilter1d 7 | from .filter import kaiser_sinc_filter1d 8 | 9 | 10 | class UpSample1d(nn.Module): 11 | def __init__(self, ratio=2, kernel_size=None): 12 | super().__init__() 13 | self.ratio = ratio 14 | self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size 15 | self.stride = ratio 16 | self.pad = self.kernel_size // ratio - 1 17 | self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 18 | self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 19 | filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, 20 | half_width=0.6 / ratio, 21 | kernel_size=self.kernel_size) 22 | self.register_buffer("filter", filter) 23 | 24 | # x: [B, C, T] 25 | def forward(self, x): 26 | _, C, _ = x.shape 27 | 28 | x = F.pad(x, (self.pad, self.pad), mode='replicate') 29 | x = self.ratio * F.conv_transpose1d( 30 | x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) 31 | x = x[..., self.pad_left:-self.pad_right] 32 | 33 | return x 34 | 35 | 36 | class DownSample1d(nn.Module): 37 | def __init__(self, ratio=2, kernel_size=None): 38 | super().__init__() 39 | self.ratio = ratio 40 | self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size 41 | self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio, 42 | half_width=0.6 / ratio, 43 | stride=ratio, 44 | kernel_size=self.kernel_size) 45 | 46 | def forward(self, x): 47 | xx = self.lowpass(x) 48 | 49 | return xx 50 | -------------------------------------------------------------------------------- /resemble_enhance/enhancer/univnet/amp.py: -------------------------------------------------------------------------------- 1 | # Refer from https://github.com/NVIDIA/BigVGAN 2 | 3 | import math 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch import nn 8 | from torch.nn.utils.parametrizations import weight_norm 9 | 10 | from .alias_free_torch import DownSample1d, UpSample1d 11 | 12 | 13 | class SnakeBeta(nn.Module): 14 | """ 15 | A modified Snake function which uses separate parameters for the magnitude of the periodic components 16 | Shape: 17 | - Input: (B, C, T) 18 | - Output: (B, C, T), same shape as the input 19 | Parameters: 20 | - alpha - trainable parameter that controls frequency 21 | - beta - trainable parameter that controls magnitude 22 | References: 23 | - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: 24 | https://arxiv.org/abs/2006.08195 25 | Examples: 26 | >>> a1 = snakebeta(256) 27 | >>> x = torch.randn(256) 28 | >>> x = a1(x) 29 | """ 30 | 31 | def __init__(self, in_features, alpha=1.0, clamp=(1e-2, 50)): 32 | """ 33 | Initialization. 34 | INPUT: 35 | - in_features: shape of the input 36 | - alpha - trainable parameter that controls frequency 37 | - beta - trainable parameter that controls magnitude 38 | alpha is initialized to 1 by default, higher values = higher-frequency. 39 | beta is initialized to 1 by default, higher values = higher-magnitude. 40 | alpha will be trained along with the rest of your model. 41 | """ 42 | super().__init__() 43 | self.in_features = in_features 44 | self.log_alpha = nn.Parameter(torch.zeros(in_features) + math.log(alpha)) 45 | self.log_beta = nn.Parameter(torch.zeros(in_features) + math.log(alpha)) 46 | self.clamp = clamp 47 | 48 | def forward(self, x): 49 | """ 50 | Forward pass of the function. 51 | Applies the function to the input elementwise. 52 | SnakeBeta ∶= x + 1/b * sin^2 (xa) 53 | """ 54 | alpha = self.log_alpha.exp().clamp(*self.clamp) 55 | alpha = alpha[None, :, None] 56 | 57 | beta = self.log_beta.exp().clamp(*self.clamp) 58 | beta = beta[None, :, None] 59 | 60 | x = x + (1.0 / beta) * (x * alpha).sin().pow(2) 61 | 62 | return x 63 | 64 | 65 | class UpActDown(nn.Module): 66 | def __init__( 67 | self, 68 | act, 69 | up_ratio: int = 2, 70 | down_ratio: int = 2, 71 | up_kernel_size: int = 12, 72 | down_kernel_size: int = 12, 73 | ): 74 | super().__init__() 75 | self.up_ratio = up_ratio 76 | self.down_ratio = down_ratio 77 | self.act = act 78 | self.upsample = UpSample1d(up_ratio, up_kernel_size) 79 | self.downsample = DownSample1d(down_ratio, down_kernel_size) 80 | 81 | def forward(self, x): 82 | # x: [B,C,T] 83 | x = self.upsample(x) 84 | x = self.act(x) 85 | x = self.downsample(x) 86 | return x 87 | 88 | 89 | class AMPBlock(nn.Sequential): 90 | def __init__(self, channels, *, kernel_size=3, dilations=(1, 3, 5)): 91 | super().__init__(*(self._make_layer(channels, kernel_size, d) for d in dilations)) 92 | 93 | def _make_layer(self, channels, kernel_size, dilation): 94 | return nn.Sequential( 95 | weight_norm(nn.Conv1d(channels, channels, kernel_size, dilation=dilation, padding="same")), 96 | UpActDown(act=SnakeBeta(channels)), 97 | weight_norm(nn.Conv1d(channels, channels, kernel_size, padding="same")), 98 | ) 99 | 100 | def forward(self, x): 101 | return x + super().forward(x) 102 | -------------------------------------------------------------------------------- /resemble_enhance/enhancer/univnet/discriminator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import Tensor, nn 6 | from torch.nn.utils.parametrizations import weight_norm 7 | 8 | from ..hparams import HParams 9 | from .mrstft import get_stft_cfgs 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class PeriodNetwork(nn.Module): 15 | def __init__(self, period): 16 | super().__init__() 17 | self.period = period 18 | wn = weight_norm 19 | self.convs = nn.ModuleList( 20 | [ 21 | wn(nn.Conv2d(1, 64, (5, 1), (3, 1), padding=(2, 0))), 22 | wn(nn.Conv2d(64, 128, (5, 1), (3, 1), padding=(2, 0))), 23 | wn(nn.Conv2d(128, 256, (5, 1), (3, 1), padding=(2, 0))), 24 | wn(nn.Conv2d(256, 512, (5, 1), (3, 1), padding=(2, 0))), 25 | wn(nn.Conv2d(512, 1024, (5, 1), 1, padding=(2, 0))), 26 | ] 27 | ) 28 | self.conv_post = wn(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 29 | 30 | def forward(self, x): 31 | """ 32 | Args: 33 | x: [B, 1, T] 34 | """ 35 | assert x.dim() == 3, f"(B, 1, T) is expected, but got {x.shape}." 36 | 37 | # 1d to 2d 38 | b, c, t = x.shape 39 | if t % self.period != 0: # pad first 40 | n_pad = self.period - (t % self.period) 41 | x = F.pad(x, (0, n_pad), "reflect") 42 | t = t + n_pad 43 | x = x.view(b, c, t // self.period, self.period) 44 | 45 | for l in self.convs: 46 | x = l(x) 47 | x = F.leaky_relu(x, 0.2) 48 | x = self.conv_post(x) 49 | x = torch.flatten(x, 1, -1) 50 | 51 | return x 52 | 53 | 54 | class SpecNetwork(nn.Module): 55 | def __init__(self, stft_cfg: dict): 56 | super().__init__() 57 | wn = weight_norm 58 | self.stft_cfg = stft_cfg 59 | self.convs = nn.ModuleList( 60 | [ 61 | wn(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))), 62 | wn(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), 63 | wn(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), 64 | wn(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), 65 | wn(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))), 66 | ] 67 | ) 68 | self.conv_post = wn(nn.Conv2d(32, 1, (3, 3), padding=(1, 1))) 69 | 70 | def forward(self, x): 71 | """ 72 | Args: 73 | x: [B, 1, T] 74 | """ 75 | x = self.spectrogram(x) 76 | x = x.unsqueeze(1) 77 | for l in self.convs: 78 | x = l(x) 79 | x = F.leaky_relu(x, 0.2) 80 | x = self.conv_post(x) 81 | x = x.flatten(1, -1) 82 | return x 83 | 84 | def spectrogram(self, x): 85 | """ 86 | Args: 87 | x: [B, 1, T] 88 | """ 89 | x = x.squeeze(1) 90 | dtype = x.dtype 91 | stft_cfg = dict(self.stft_cfg) 92 | x = torch.stft(x.float(), center=False, return_complex=False, **stft_cfg) 93 | mag = x.norm(p=2, dim=-1) # [B, F, TT] 94 | mag = mag.to(dtype) # [B, F, TT] 95 | return mag 96 | 97 | 98 | class MD(nn.ModuleList): 99 | def __init__(self, l: list): 100 | super().__init__([self._create_network(x) for x in l]) 101 | self._loss_type = None 102 | 103 | def loss_type_(self, loss_type): 104 | self._loss_type = loss_type 105 | 106 | def _create_network(self, _): 107 | raise NotImplementedError 108 | 109 | def _forward_each(self, d, x, y): 110 | assert self._loss_type is not None, "loss_type is not set." 111 | loss_type = self._loss_type 112 | 113 | if loss_type == "hinge": 114 | if y == 0: 115 | # d(x) should be small -> -1 116 | loss = F.relu(1 + d(x)).mean() 117 | elif y == 1: 118 | # d(x) should be large -> 1 119 | loss = F.relu(1 - d(x)).mean() 120 | else: 121 | raise ValueError(f"Invalid y: {y}") 122 | elif loss_type == "wgan": 123 | if y == 0: 124 | loss = d(x).mean() 125 | elif y == 1: 126 | loss = -d(x).mean() 127 | else: 128 | raise ValueError(f"Invalid y: {y}") 129 | else: 130 | raise ValueError(f"Invalid loss_type: {loss_type}") 131 | 132 | return loss 133 | 134 | def forward(self, x, y) -> Tensor: 135 | losses = [self._forward_each(d, x, y) for d in self] 136 | return torch.stack(losses).mean() 137 | 138 | 139 | class MPD(MD): 140 | def __init__(self): 141 | super().__init__([2, 3, 7, 13, 17]) 142 | 143 | def _create_network(self, period): 144 | return PeriodNetwork(period) 145 | 146 | 147 | class MRD(MD): 148 | def __init__(self, stft_cfgs): 149 | super().__init__(stft_cfgs) 150 | 151 | def _create_network(self, stft_cfg): 152 | return SpecNetwork(stft_cfg) 153 | 154 | 155 | class Discriminator(nn.Module): 156 | @property 157 | def wav_rate(self): 158 | return self.hp.wav_rate 159 | 160 | def __init__(self, hp: HParams): 161 | super().__init__() 162 | self.hp = hp 163 | self.stft_cfgs = get_stft_cfgs(hp) 164 | self.mpd = MPD() 165 | self.mrd = MRD(self.stft_cfgs) 166 | self.dummy_float: Tensor 167 | self.register_buffer("dummy_float", torch.zeros(0), persistent=False) 168 | 169 | def loss_type_(self, loss_type): 170 | self.mpd.loss_type_(loss_type) 171 | self.mrd.loss_type_(loss_type) 172 | 173 | def forward(self, fake, real=None): 174 | """ 175 | Args: 176 | fake: [B T] 177 | real: [B T] 178 | """ 179 | fake = fake.to(self.dummy_float) 180 | 181 | if real is None: 182 | self.loss_type_("wgan") 183 | else: 184 | length_difference = (fake.shape[-1] - real.shape[-1]) / real.shape[-1] 185 | assert length_difference < 0.05, f"length_difference should be smaller than 5%" 186 | 187 | self.loss_type_("hinge") 188 | real = real.to(self.dummy_float) 189 | 190 | fake = fake[..., : real.shape[-1]] 191 | real = real[..., : fake.shape[-1]] 192 | 193 | losses = {} 194 | 195 | assert fake.dim() == 2, f"(B, T) is expected, but got {fake.shape}." 196 | assert real is None or real.dim() == 2, f"(B, T) is expected, but got {real.shape}." 197 | 198 | fake = fake.unsqueeze(1) 199 | 200 | if real is None: 201 | losses["mpd"] = self.mpd(fake, 1) 202 | losses["mrd"] = self.mrd(fake, 1) 203 | else: 204 | real = real.unsqueeze(1) 205 | losses["mpd_fake"] = self.mpd(fake, 0) 206 | losses["mpd_real"] = self.mpd(real, 1) 207 | losses["mrd_fake"] = self.mrd(fake, 0) 208 | losses["mrd_real"] = self.mrd(real, 1) 209 | 210 | return losses 211 | -------------------------------------------------------------------------------- /resemble_enhance/enhancer/univnet/lvcnet.py: -------------------------------------------------------------------------------- 1 | """ refer from https://github.com/zceng/LVCNet """ 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch import nn 7 | from torch.nn.utils.parametrizations import weight_norm 8 | 9 | from .amp import AMPBlock 10 | 11 | 12 | class KernelPredictor(torch.nn.Module): 13 | """Kernel predictor for the location-variable convolutions""" 14 | 15 | def __init__( 16 | self, 17 | cond_channels, 18 | conv_in_channels, 19 | conv_out_channels, 20 | conv_layers, 21 | conv_kernel_size=3, 22 | kpnet_hidden_channels=64, 23 | kpnet_conv_size=3, 24 | kpnet_dropout=0.0, 25 | kpnet_nonlinear_activation="LeakyReLU", 26 | kpnet_nonlinear_activation_params={"negative_slope": 0.1}, 27 | ): 28 | """ 29 | Args: 30 | cond_channels (int): number of channel for the conditioning sequence, 31 | conv_in_channels (int): number of channel for the input sequence, 32 | conv_out_channels (int): number of channel for the output sequence, 33 | conv_layers (int): number of layers 34 | """ 35 | super().__init__() 36 | 37 | self.conv_in_channels = conv_in_channels 38 | self.conv_out_channels = conv_out_channels 39 | self.conv_kernel_size = conv_kernel_size 40 | self.conv_layers = conv_layers 41 | 42 | kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w 43 | kpnet_bias_channels = conv_out_channels * conv_layers # l_b 44 | 45 | self.input_conv = nn.Sequential( 46 | weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)), 47 | getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), 48 | ) 49 | 50 | self.residual_convs = nn.ModuleList() 51 | padding = (kpnet_conv_size - 1) // 2 52 | for _ in range(3): 53 | self.residual_convs.append( 54 | nn.Sequential( 55 | nn.Dropout(kpnet_dropout), 56 | weight_norm( 57 | nn.Conv1d( 58 | kpnet_hidden_channels, 59 | kpnet_hidden_channels, 60 | kpnet_conv_size, 61 | padding=padding, 62 | bias=True, 63 | ) 64 | ), 65 | getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), 66 | weight_norm( 67 | nn.Conv1d( 68 | kpnet_hidden_channels, 69 | kpnet_hidden_channels, 70 | kpnet_conv_size, 71 | padding=padding, 72 | bias=True, 73 | ) 74 | ), 75 | getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), 76 | ) 77 | ) 78 | self.kernel_conv = weight_norm( 79 | nn.Conv1d( 80 | kpnet_hidden_channels, 81 | kpnet_kernel_channels, 82 | kpnet_conv_size, 83 | padding=padding, 84 | bias=True, 85 | ) 86 | ) 87 | self.bias_conv = weight_norm( 88 | nn.Conv1d( 89 | kpnet_hidden_channels, 90 | kpnet_bias_channels, 91 | kpnet_conv_size, 92 | padding=padding, 93 | bias=True, 94 | ) 95 | ) 96 | 97 | def forward(self, c): 98 | """ 99 | Args: 100 | c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) 101 | """ 102 | batch, _, cond_length = c.shape 103 | c = self.input_conv(c) 104 | for residual_conv in self.residual_convs: 105 | residual_conv.to(c.device) 106 | c = c + residual_conv(c) 107 | k = self.kernel_conv(c) 108 | b = self.bias_conv(c) 109 | kernels = k.contiguous().view( 110 | batch, 111 | self.conv_layers, 112 | self.conv_in_channels, 113 | self.conv_out_channels, 114 | self.conv_kernel_size, 115 | cond_length, 116 | ) 117 | bias = b.contiguous().view( 118 | batch, 119 | self.conv_layers, 120 | self.conv_out_channels, 121 | cond_length, 122 | ) 123 | 124 | return kernels, bias 125 | 126 | 127 | class LVCBlock(torch.nn.Module): 128 | """the location-variable convolutions""" 129 | 130 | def __init__( 131 | self, 132 | in_channels, 133 | cond_channels, 134 | stride, 135 | dilations=[1, 3, 9, 27], 136 | lReLU_slope=0.2, 137 | conv_kernel_size=3, 138 | cond_hop_length=256, 139 | kpnet_hidden_channels=64, 140 | kpnet_conv_size=3, 141 | kpnet_dropout=0.0, 142 | add_extra_noise=False, 143 | downsampling=False, 144 | ): 145 | super().__init__() 146 | 147 | self.add_extra_noise = add_extra_noise 148 | 149 | self.cond_hop_length = cond_hop_length 150 | self.conv_layers = len(dilations) 151 | self.conv_kernel_size = conv_kernel_size 152 | 153 | self.kernel_predictor = KernelPredictor( 154 | cond_channels=cond_channels, 155 | conv_in_channels=in_channels, 156 | conv_out_channels=2 * in_channels, 157 | conv_layers=len(dilations), 158 | conv_kernel_size=conv_kernel_size, 159 | kpnet_hidden_channels=kpnet_hidden_channels, 160 | kpnet_conv_size=kpnet_conv_size, 161 | kpnet_dropout=kpnet_dropout, 162 | kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope}, 163 | ) 164 | 165 | if downsampling: 166 | self.convt_pre = nn.Sequential( 167 | nn.LeakyReLU(lReLU_slope), 168 | weight_norm(nn.Conv1d(in_channels, in_channels, 2 * stride + 1, padding="same")), 169 | nn.AvgPool1d(stride, stride), 170 | ) 171 | else: 172 | if stride == 1: 173 | self.convt_pre = nn.Sequential( 174 | nn.LeakyReLU(lReLU_slope), 175 | weight_norm(nn.Conv1d(in_channels, in_channels, 1)), 176 | ) 177 | else: 178 | self.convt_pre = nn.Sequential( 179 | nn.LeakyReLU(lReLU_slope), 180 | weight_norm( 181 | nn.ConvTranspose1d( 182 | in_channels, 183 | in_channels, 184 | 2 * stride, 185 | stride=stride, 186 | padding=stride // 2 + stride % 2, 187 | output_padding=stride % 2, 188 | ) 189 | ), 190 | ) 191 | 192 | self.amp_block = AMPBlock(in_channels) 193 | 194 | self.conv_blocks = nn.ModuleList() 195 | for d in dilations: 196 | self.conv_blocks.append( 197 | nn.Sequential( 198 | nn.LeakyReLU(lReLU_slope), 199 | weight_norm(nn.Conv1d(in_channels, in_channels, conv_kernel_size, dilation=d, padding="same")), 200 | nn.LeakyReLU(lReLU_slope), 201 | ) 202 | ) 203 | 204 | def forward(self, x, c): 205 | """forward propagation of the location-variable convolutions. 206 | Args: 207 | x (Tensor): the input sequence (batch, in_channels, in_length) 208 | c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) 209 | 210 | Returns: 211 | Tensor: the output sequence (batch, in_channels, in_length) 212 | """ 213 | _, in_channels, _ = x.shape # (B, c_g, L') 214 | 215 | x = self.convt_pre(x) # (B, c_g, stride * L') 216 | 217 | # Add one amp block just after the upsampling 218 | x = self.amp_block(x) # (B, c_g, stride * L') 219 | 220 | kernels, bias = self.kernel_predictor(c) 221 | 222 | if self.add_extra_noise: 223 | # Add extra noise to part of the feature 224 | a, b = x.chunk(2, dim=1) 225 | b = b + torch.randn_like(b) * 0.1 226 | x = torch.cat([a, b], dim=1) 227 | 228 | for i, conv in enumerate(self.conv_blocks): 229 | output = conv(x) # (B, c_g, stride * L') 230 | 231 | k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length) 232 | b = bias[:, i, :, :] # (B, 2 * c_g, cond_length) 233 | 234 | output = self.location_variable_convolution( 235 | output, k, b, hop_size=self.cond_hop_length 236 | ) # (B, 2 * c_g, stride * L'): LVC 237 | x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh( 238 | output[:, in_channels:, :] 239 | ) # (B, c_g, stride * L'): GAU 240 | 241 | return x 242 | 243 | def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256): 244 | """perform location-variable convolution operation on the input sequence (x) using the local convolution kernl. 245 | Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100. 246 | Args: 247 | x (Tensor): the input sequence (batch, in_channels, in_length). 248 | kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length) 249 | bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length) 250 | dilation (int): the dilation of convolution. 251 | hop_size (int): the hop_size of the conditioning sequence. 252 | Returns: 253 | (Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length). 254 | """ 255 | batch, _, in_length = x.shape 256 | batch, _, out_channels, kernel_size, kernel_length = kernel.shape 257 | 258 | assert in_length == ( 259 | kernel_length * hop_size 260 | ), f"length of (x, kernel) is not matched, {in_length} != {kernel_length} * {hop_size}" 261 | 262 | padding = dilation * int((kernel_size - 1) / 2) 263 | x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding) 264 | x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding) 265 | 266 | if hop_size < dilation: 267 | x = F.pad(x, (0, dilation), "constant", 0) 268 | x = x.unfold( 269 | 3, dilation, dilation 270 | ) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation) 271 | x = x[:, :, :, :, :hop_size] 272 | x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation) 273 | x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size) 274 | 275 | o = torch.einsum("bildsk,biokl->bolsd", x, kernel) 276 | o = o.to(memory_format=torch.channels_last_3d) 277 | bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d) 278 | o = o + bias 279 | o = o.contiguous().view(batch, out_channels, -1) 280 | 281 | return o 282 | -------------------------------------------------------------------------------- /resemble_enhance/enhancer/univnet/mrstft.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2019 Tomoki Hayashi 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torch import nn 10 | 11 | from ..hparams import HParams 12 | 13 | 14 | def _make_stft_cfg(hop_length, win_length=None): 15 | if win_length is None: 16 | win_length = 4 * hop_length 17 | n_fft = 2 ** (win_length - 1).bit_length() 18 | return dict(n_fft=n_fft, hop_length=hop_length, win_length=win_length) 19 | 20 | 21 | def get_stft_cfgs(hp: HParams): 22 | assert hp.wav_rate == 44100, f"wav_rate must be 44100, got {hp.wav_rate}" 23 | return [_make_stft_cfg(h) for h in (100, 256, 512)] 24 | 25 | 26 | def stft(x, n_fft, hop_length, win_length, window): 27 | dtype = x.dtype 28 | x = torch.stft(x.float(), n_fft, hop_length, win_length, window, return_complex=True) 29 | x = x.abs().to(dtype) 30 | x = x.transpose(2, 1) # (b f t) -> (b t f) 31 | return x 32 | 33 | 34 | class SpectralConvergengeLoss(nn.Module): 35 | def forward(self, x_mag, y_mag): 36 | """Calculate forward propagation. 37 | Args: 38 | x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). 39 | y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). 40 | Returns: 41 | Tensor: Spectral convergence loss value. 42 | """ 43 | return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro") 44 | 45 | 46 | class LogSTFTMagnitudeLoss(nn.Module): 47 | def forward(self, x_mag, y_mag): 48 | """Calculate forward propagation. 49 | Args: 50 | x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). 51 | y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). 52 | Returns: 53 | Tensor: Log STFT magnitude loss value. 54 | """ 55 | return F.l1_loss(torch.log1p(x_mag), torch.log1p(y_mag)) 56 | 57 | 58 | class STFTLoss(nn.Module): 59 | def __init__(self, hp, stft_cfg: dict, window="hann_window"): 60 | super().__init__() 61 | self.hp = hp 62 | self.stft_cfg = stft_cfg 63 | self.spectral_convergenge_loss = SpectralConvergengeLoss() 64 | self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss() 65 | self.register_buffer("window", getattr(torch, window)(stft_cfg["win_length"]), persistent=False) 66 | 67 | def forward(self, x, y): 68 | """Calculate forward propagation. 69 | Args: 70 | x (Tensor): Predicted signal (B, T). 71 | y (Tensor): Groundtruth signal (B, T). 72 | Returns: 73 | Tensor: Spectral convergence loss value. 74 | Tensor: Log STFT magnitude loss value. 75 | """ 76 | stft_cfg = dict(self.stft_cfg) 77 | x_mag = stft(x, **stft_cfg, window=self.window) # (b t) -> (b t f) 78 | y_mag = stft(y, **stft_cfg, window=self.window) 79 | sc_loss = self.spectral_convergenge_loss(x_mag, y_mag) 80 | mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag) 81 | return dict(sc=sc_loss, mag=mag_loss) 82 | 83 | 84 | class MRSTFTLoss(nn.Module): 85 | def __init__(self, hp: HParams, window="hann_window"): 86 | """Initialize Multi resolution STFT loss module. 87 | Args: 88 | resolutions (list): List of (FFT size, hop size, window length). 89 | window (str): Window function type. 90 | """ 91 | super().__init__() 92 | stft_cfgs = get_stft_cfgs(hp) 93 | self.stft_losses = nn.ModuleList() 94 | self.hp = hp 95 | for c in stft_cfgs: 96 | self.stft_losses += [STFTLoss(hp, c, window=window)] 97 | 98 | def forward(self, x, y): 99 | """Calculate forward propagation. 100 | Args: 101 | x (Tensor): Predicted signal (b t). 102 | y (Tensor): Groundtruth signal (b t). 103 | Returns: 104 | Tensor: Multi resolution spectral convergence loss value. 105 | Tensor: Multi resolution log STFT magnitude loss value. 106 | """ 107 | assert x.dim() == 2 and y.dim() == 2, f"(b t) is expected, but got {x.shape} and {y.shape}." 108 | 109 | dtype = x.dtype 110 | 111 | x = x.float() 112 | y = y.float() 113 | 114 | # Align length 115 | x = x[..., : y.shape[-1]] 116 | y = y[..., : x.shape[-1]] 117 | 118 | losses = {} 119 | 120 | for f in self.stft_losses: 121 | d = f(x, y) 122 | for k, v in d.items(): 123 | losses.setdefault(k, []).append(v) 124 | 125 | for k, v in losses.items(): 126 | losses[k] = torch.stack(v, dim=0).mean().to(dtype) 127 | 128 | return losses 129 | -------------------------------------------------------------------------------- /resemble_enhance/enhancer/univnet/univnet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import Tensor, nn 5 | from torch.nn.utils.parametrizations import weight_norm 6 | 7 | from ..hparams import HParams 8 | from .lvcnet import LVCBlock 9 | from .mrstft import MRSTFTLoss 10 | 11 | 12 | class UnivNet(nn.Module): 13 | @property 14 | def d_noise(self): 15 | return 128 16 | 17 | @property 18 | def strides(self): 19 | return [7, 5, 4, 3] 20 | 21 | @property 22 | def dilations(self): 23 | return [1, 3, 9, 27] 24 | 25 | @property 26 | def nc(self): 27 | return self.hp.univnet_nc 28 | 29 | @property 30 | def scale_factor(self) -> int: 31 | return self.hp.hop_size 32 | 33 | def __init__(self, hp: HParams, d_input): 34 | super().__init__() 35 | self.d_input = d_input 36 | 37 | self.hp = hp 38 | 39 | self.blocks = nn.ModuleList( 40 | [ 41 | LVCBlock( 42 | self.nc, 43 | d_input, 44 | stride=stride, 45 | dilations=self.dilations, 46 | cond_hop_length=hop_length, 47 | kpnet_conv_size=3, 48 | ) 49 | for stride, hop_length in zip(self.strides, np.cumprod(self.strides)) 50 | ] 51 | ) 52 | 53 | self.conv_pre = weight_norm(nn.Conv1d(self.d_noise, self.nc, 7, padding=3, padding_mode="reflect")) 54 | 55 | self.conv_post = nn.Sequential( 56 | nn.LeakyReLU(0.2), 57 | weight_norm(nn.Conv1d(self.nc, 1, 7, padding=3, padding_mode="reflect")), 58 | nn.Tanh(), 59 | ) 60 | 61 | self.mrstft = MRSTFTLoss(hp) 62 | 63 | @property 64 | def eps(self): 65 | return 1e-5 66 | 67 | def forward(self, x: Tensor, y: Tensor | None = None, npad=10): 68 | """ 69 | Args: 70 | x: (b c t), acoustic features 71 | y: (b t), waveform 72 | Returns: 73 | z: (b t), waveform 74 | """ 75 | assert x.ndim == 3, "x must be 3D tensor" 76 | assert y is None or y.ndim == 2, "y must be 2D tensor" 77 | assert x.shape[1] == self.d_input, f"x.shape[1] must be {self.d_input}, but got {x.shape}" 78 | assert npad >= 0, "npad must be positive or zero" 79 | 80 | x = F.pad(x, (0, npad), "constant", 0) 81 | z = torch.randn(x.shape[0], self.d_noise, x.shape[2]).to(x) 82 | z = self.conv_pre(z) # (b c t) 83 | 84 | for block in self.blocks: 85 | z = block(z, x) # (b c t) 86 | 87 | z = self.conv_post(z) # (b 1 t) 88 | z = z[..., : -self.scale_factor * npad] 89 | z = z.squeeze(1) # (b t) 90 | 91 | if y is not None: 92 | self.losses = self.mrstft(z, y) 93 | 94 | return z 95 | -------------------------------------------------------------------------------- /resemble_enhance/hparams.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import asdict, dataclass 3 | from pathlib import Path 4 | 5 | from omegaconf import OmegaConf 6 | from rich.console import Console 7 | from rich.panel import Panel 8 | from rich.table import Table 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | console = Console() 13 | 14 | 15 | def _make_stft_cfg(hop_length, win_length=None): 16 | if win_length is None: 17 | win_length = 4 * hop_length 18 | n_fft = 2 ** (win_length - 1).bit_length() 19 | return dict(n_fft=n_fft, hop_length=hop_length, win_length=win_length) 20 | 21 | 22 | def _build_rich_table(rows, columns, title=None): 23 | table = Table(title=title, header_style=None) 24 | for column in columns: 25 | table.add_column(column.capitalize(), justify="left") 26 | for row in rows: 27 | table.add_row(*map(str, row)) 28 | return Panel(table, expand=False) 29 | 30 | 31 | def _rich_print_dict(d, title="Config", key="Key", value="Value"): 32 | console.print(_build_rich_table(d.items(), [key, value], title)) 33 | 34 | 35 | @dataclass(frozen=True) 36 | class HParams: 37 | # Dataset 38 | fg_dir: Path = Path("data/fg") 39 | bg_dir: Path = Path("data/bg") 40 | rir_dir: Path = Path("data/rir") 41 | load_fg_only: bool = False 42 | praat_augment_prob: float = 0 43 | 44 | # Audio settings 45 | wav_rate: int = 44_100 46 | n_fft: int = 2048 47 | win_size: int = 2048 48 | hop_size: int = 420 # 9.5ms 49 | num_mels: int = 128 50 | stft_magnitude_min: float = 1e-4 51 | preemphasis: float = 0.97 52 | mix_alpha_range: tuple[float, float] = (0.2, 0.8) 53 | 54 | # Training 55 | nj: int = 64 56 | training_seconds: float = 1.0 57 | batch_size_per_gpu: int = 16 58 | min_lr: float = 1e-5 59 | max_lr: float = 1e-4 60 | warmup_steps: int = 1000 61 | max_steps: int = 1_000_000 62 | gradient_clipping: float = 1.0 63 | 64 | @property 65 | def deepspeed_config(self): 66 | return { 67 | "train_micro_batch_size_per_gpu": self.batch_size_per_gpu, 68 | "optimizer": { 69 | "type": "Adam", 70 | "params": {"lr": float(self.min_lr)}, 71 | }, 72 | "scheduler": { 73 | "type": "WarmupDecayLR", 74 | "params": { 75 | "warmup_min_lr": float(self.min_lr), 76 | "warmup_max_lr": float(self.max_lr), 77 | "warmup_num_steps": self.warmup_steps, 78 | "total_num_steps": self.max_steps, 79 | "warmup_type": "linear", 80 | }, 81 | }, 82 | "gradient_clipping": self.gradient_clipping, 83 | } 84 | 85 | @property 86 | def stft_cfgs(self): 87 | assert self.wav_rate == 44_100, f"wav_rate must be 44_100, got {self.wav_rate}" 88 | return [_make_stft_cfg(h) for h in (100, 256, 512)] 89 | 90 | @classmethod 91 | def from_yaml(cls, path: Path) -> "HParams": 92 | logger.info(f"Reading hparams from {path}") 93 | # First merge to fix types (e.g., str -> Path) 94 | return cls(**dict(OmegaConf.merge(cls(), OmegaConf.load(path)))) 95 | 96 | def save_if_not_exists(self, run_dir: Path): 97 | path = run_dir / "hparams.yaml" 98 | if path.exists(): 99 | logger.info(f"{path} already exists, not saving") 100 | return 101 | path.parent.mkdir(parents=True, exist_ok=True) 102 | OmegaConf.save(asdict(self), str(path)) 103 | 104 | @classmethod 105 | def load(cls, run_dir, yaml: Path | None = None): 106 | hps = [] 107 | 108 | if (run_dir / "hparams.yaml").exists(): 109 | hps.append(cls.from_yaml(run_dir / "hparams.yaml")) 110 | 111 | if yaml is not None: 112 | hps.append(cls.from_yaml(yaml)) 113 | 114 | if len(hps) == 0: 115 | hps.append(cls()) 116 | 117 | for hp in hps[1:]: 118 | if hp != hps[0]: 119 | errors = {} 120 | for k, v in asdict(hp).items(): 121 | if getattr(hps[0], k) != v: 122 | errors[k] = f"{getattr(hps[0], k)} != {v}" 123 | raise ValueError(f"Found inconsistent hparams: {errors}, consider deleting {run_dir}") 124 | 125 | return hps[0] 126 | 127 | def print(self): 128 | _rich_print_dict(asdict(self), title="HParams") 129 | -------------------------------------------------------------------------------- /resemble_enhance/inference.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.nn.utils.parametrize import remove_parametrizations 7 | from torchaudio.functional import resample 8 | from torchaudio.transforms import MelSpectrogram 9 | from tqdm import trange 10 | 11 | from .hparams import HParams 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | @torch.inference_mode() 17 | def inference_chunk(model, dwav, sr, device, npad=441): 18 | assert model.hp.wav_rate == sr, f"Expected {model.hp.wav_rate} Hz, got {sr} Hz" 19 | del sr 20 | 21 | length = dwav.shape[-1] 22 | abs_max = dwav.abs().max().clamp(min=1e-7) 23 | 24 | assert dwav.dim() == 1, f"Expected 1D waveform, got {dwav.dim()}D" 25 | dwav = dwav.to(device) 26 | dwav = dwav / abs_max # Normalize 27 | dwav = F.pad(dwav, (0, npad)) 28 | hwav = model(dwav[None])[0].cpu() # (T,) 29 | hwav = hwav[:length] # Trim padding 30 | hwav = hwav * abs_max # Unnormalize 31 | 32 | return hwav 33 | 34 | 35 | def compute_corr(x, y): 36 | return torch.fft.ifft(torch.fft.fft(x) * torch.fft.fft(y).conj()).abs() 37 | 38 | 39 | def compute_offset(chunk1, chunk2, sr=44100): 40 | """ 41 | Args: 42 | chunk1: (T,) 43 | chunk2: (T,) 44 | Returns: 45 | offset: int, offset in samples such that chunk1 ~= chunk2.roll(-offset) 46 | """ 47 | hop_length = sr // 200 # 5 ms resolution 48 | win_length = hop_length * 4 49 | n_fft = 2 ** (win_length - 1).bit_length() 50 | 51 | mel_fn = MelSpectrogram( 52 | sample_rate=sr, 53 | n_fft=n_fft, 54 | win_length=win_length, 55 | hop_length=hop_length, 56 | n_mels=80, 57 | f_min=0.0, 58 | f_max=sr // 2, 59 | ) 60 | 61 | spec1 = mel_fn(chunk1).log1p() 62 | spec2 = mel_fn(chunk2).log1p() 63 | 64 | corr = compute_corr(spec1, spec2) # (F, T) 65 | corr = corr.mean(dim=0) # (T,) 66 | 67 | argmax = corr.argmax().item() 68 | 69 | if argmax > len(corr) // 2: 70 | argmax -= len(corr) 71 | 72 | offset = -argmax * hop_length 73 | 74 | return offset 75 | 76 | 77 | def merge_chunks(chunks, chunk_length, hop_length, sr=44100, length=None): 78 | signal_length = (len(chunks) - 1) * hop_length + chunk_length 79 | overlap_length = chunk_length - hop_length 80 | signal = torch.zeros(signal_length, device=chunks[0].device) 81 | 82 | fadein = torch.linspace(0, 1, overlap_length, device=chunks[0].device) 83 | fadein = torch.cat([fadein, torch.ones(hop_length, device=chunks[0].device)]) 84 | fadeout = torch.linspace(1, 0, overlap_length, device=chunks[0].device) 85 | fadeout = torch.cat([torch.ones(hop_length, device=chunks[0].device), fadeout]) 86 | 87 | for i, chunk in enumerate(chunks): 88 | start = i * hop_length 89 | end = start + chunk_length 90 | 91 | if len(chunk) < chunk_length: 92 | chunk = F.pad(chunk, (0, chunk_length - len(chunk))) 93 | 94 | if i > 0: 95 | pre_region = chunks[i - 1][-overlap_length:] 96 | cur_region = chunk[:overlap_length] 97 | offset = compute_offset(pre_region, cur_region, sr=sr) 98 | start -= offset 99 | end -= offset 100 | 101 | if i == 0: 102 | chunk = chunk * fadeout 103 | elif i == len(chunks) - 1: 104 | chunk = chunk * fadein 105 | else: 106 | chunk = chunk * fadein * fadeout 107 | 108 | signal[start:end] += chunk[: len(signal[start:end])] 109 | 110 | signal = signal[:length] 111 | 112 | return signal 113 | 114 | 115 | def remove_weight_norm_recursively(module): 116 | for _, module in module.named_modules(): 117 | try: 118 | remove_parametrizations(module, "weight") 119 | except Exception: 120 | pass 121 | 122 | 123 | def inference(model, dwav, sr, device, chunk_seconds: float = 30.0, overlap_seconds: float = 1.0): 124 | remove_weight_norm_recursively(model) 125 | 126 | hp: HParams = model.hp 127 | 128 | dwav = resample( 129 | dwav, 130 | orig_freq=sr, 131 | new_freq=hp.wav_rate, 132 | lowpass_filter_width=64, 133 | rolloff=0.9475937167399596, 134 | resampling_method="sinc_interp_kaiser", 135 | beta=14.769656459379492, 136 | ) 137 | 138 | del sr # Everything is in hp.wav_rate now 139 | 140 | sr = hp.wav_rate 141 | 142 | if torch.cuda.is_available(): 143 | torch.cuda.synchronize() 144 | 145 | start_time = time.perf_counter() 146 | 147 | chunk_length = int(sr * chunk_seconds) 148 | overlap_length = int(sr * overlap_seconds) 149 | hop_length = chunk_length - overlap_length 150 | 151 | chunks = [] 152 | for start in trange(0, dwav.shape[-1], hop_length): 153 | chunks.append(inference_chunk(model, dwav[start : start + chunk_length], sr, device)) 154 | 155 | hwav = merge_chunks(chunks, chunk_length, hop_length, sr=sr, length=dwav.shape[-1]) 156 | 157 | if torch.cuda.is_available(): 158 | torch.cuda.synchronize() 159 | 160 | elapsed_time = time.perf_counter() - start_time 161 | logger.info(f"Elapsed time: {elapsed_time:.3f} s, {hwav.shape[-1] / elapsed_time / 1000:.3f} kHz") 162 | 163 | return hwav, sr 164 | -------------------------------------------------------------------------------- /resemble_enhance/melspec.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torchaudio.transforms import MelSpectrogram as TorchMelSpectrogram 5 | 6 | from .hparams import HParams 7 | 8 | 9 | class MelSpectrogram(nn.Module): 10 | def __init__(self, hp: HParams): 11 | """ 12 | Torch implementation of Resemble's mel extraction. 13 | Note that the values are NOT identical to librosa's implementation 14 | due to floating point precisions. 15 | """ 16 | super().__init__() 17 | self.hp = hp 18 | self.melspec = TorchMelSpectrogram( 19 | hp.wav_rate, 20 | n_fft=hp.n_fft, 21 | win_length=hp.win_size, 22 | hop_length=hp.hop_size, 23 | f_min=0, 24 | f_max=hp.wav_rate // 2, 25 | n_mels=hp.num_mels, 26 | power=1, 27 | normalized=False, 28 | # NOTE: Folowing librosa's default. 29 | pad_mode="constant", 30 | norm="slaney", 31 | mel_scale="slaney", 32 | ) 33 | self.register_buffer("stft_magnitude_min", torch.FloatTensor([hp.stft_magnitude_min])) 34 | self.min_level_db = 20 * np.log10(hp.stft_magnitude_min) 35 | self.preemphasis = hp.preemphasis 36 | self.hop_size = hp.hop_size 37 | 38 | def forward(self, wav, pad=True): 39 | """ 40 | Args: 41 | wav: [B, T] 42 | """ 43 | device = wav.device 44 | if wav.is_mps: 45 | wav = wav.cpu() 46 | self.to(wav.device) 47 | if self.preemphasis > 0: 48 | wav = torch.nn.functional.pad(wav, [1, 0], value=0) 49 | wav = wav[..., 1:] - self.preemphasis * wav[..., :-1] 50 | mel = self.melspec(wav) 51 | mel = self._amp_to_db(mel) 52 | mel_normed = self._normalize(mel) 53 | assert not pad or mel_normed.shape[-1] == 1 + wav.shape[-1] // self.hop_size # Sanity check 54 | mel_normed = mel_normed.to(device) 55 | return mel_normed # (M, T) 56 | 57 | def _normalize(self, s, headroom_db=15): 58 | return (s - self.min_level_db) / (-self.min_level_db + headroom_db) 59 | 60 | def _amp_to_db(self, x): 61 | return x.clamp_min(self.hp.stft_magnitude_min).log10() * 20 62 | -------------------------------------------------------------------------------- /resemble_enhance/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .distributed import global_leader_only 2 | from .engine import Engine, gather_attribute 3 | from .logging import setup_logging 4 | from .train_loop import TrainLoop, is_global_leader 5 | from .utils import save_mels, tree_map 6 | -------------------------------------------------------------------------------- /resemble_enhance/utils/control.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import selectors 3 | import sys 4 | from functools import cache 5 | 6 | from .distributed import global_leader_only 7 | 8 | _logger = logging.getLogger(__name__) 9 | 10 | 11 | @cache 12 | def _get_stdin_selector(): 13 | selector = selectors.DefaultSelector() 14 | selector.register(fileobj=sys.stdin, events=selectors.EVENT_READ) 15 | return selector 16 | 17 | 18 | @global_leader_only(boardcast_return=True) 19 | def non_blocking_input(): 20 | s = "" 21 | selector = _get_stdin_selector() 22 | events = selector.select(timeout=0) 23 | for key, _ in events: 24 | s: str = key.fileobj.readline().strip() 25 | _logger.info(f'Get stdin "{s}".') 26 | return s 27 | -------------------------------------------------------------------------------- /resemble_enhance/utils/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import socket 3 | from functools import cache, partial, wraps 4 | from typing import Callable 5 | 6 | import deepspeed 7 | import torch 8 | from deepspeed.accelerator import get_accelerator 9 | from torch.distributed import broadcast_object_list 10 | 11 | 12 | def get_free_port(): 13 | sock = socket.socket() 14 | sock.bind(("", 0)) 15 | return sock.getsockname()[1] 16 | 17 | 18 | @cache 19 | def fix_unset_envs(): 20 | envs = dict(RANK="0", WORLD_SIZE="1", MASTER_ADDR="localhost", MASTER_PORT=str(get_free_port()), LOCAL_RANK="0") 21 | 22 | for key in envs: 23 | value = os.getenv(key) 24 | if value is not None: 25 | return 26 | 27 | for key, value in envs.items(): 28 | os.environ[key] = value 29 | 30 | 31 | @cache 32 | def init_distributed(): 33 | fix_unset_envs() 34 | deepspeed.init_distributed(get_accelerator().communication_backend_name()) 35 | torch.cuda.set_device(local_rank()) 36 | 37 | 38 | def local_rank(): 39 | return int(os.getenv("LOCAL_RANK", 0)) 40 | 41 | 42 | def global_rank(): 43 | return int(os.getenv("RANK", 0)) 44 | 45 | 46 | def is_local_leader(): 47 | return local_rank() == 0 48 | 49 | 50 | def is_global_leader(): 51 | return global_rank() == 0 52 | 53 | 54 | def leader_only(leader_only_type, fn: Callable | None = None, boardcast_return=False) -> Callable: 55 | """ 56 | Args: 57 | fn: The function to decorate 58 | boardcast_return: Whether to boardcast the return value to all processes 59 | (may cause deadlock if the function calls another decorated function) 60 | """ 61 | 62 | def wrapper(fn): 63 | if hasattr(fn, "__leader_only_type__"): 64 | raise RuntimeError(f"Function {fn.__name__} has already been decorated with {fn.__leader_only_type__}") 65 | 66 | fn.__leader_only_type__ = leader_only_type 67 | 68 | if leader_only_type == "local": 69 | guard_fn = is_local_leader 70 | elif leader_only_type == "global": 71 | guard_fn = is_global_leader 72 | else: 73 | raise ValueError(f"Unknown leader_only_type: {leader_only_type}") 74 | 75 | @wraps(fn) 76 | def wrapped(*args, **kwargs): 77 | if boardcast_return: 78 | init_distributed() 79 | obj_list = [None] 80 | if guard_fn(): 81 | ret = fn(*args, **kwargs) 82 | obj_list[0] = ret 83 | if boardcast_return: 84 | broadcast_object_list(obj_list, src=0) 85 | return obj_list[0] 86 | 87 | return wrapped 88 | 89 | if fn is None: 90 | return wrapper 91 | 92 | return wrapper(fn) 93 | 94 | 95 | local_leader_only = partial(leader_only, "local") 96 | global_leader_only = partial(leader_only, "global") 97 | -------------------------------------------------------------------------------- /resemble_enhance/utils/engine.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | from functools import cache, partial 4 | from typing import Callable, TypeVar 5 | 6 | import deepspeed 7 | import pandas as pd 8 | from deepspeed.accelerator import get_accelerator 9 | from deepspeed.runtime.engine import DeepSpeedEngine 10 | from deepspeed.runtime.utils import clip_grad_norm_ 11 | from torch import nn 12 | 13 | from .distributed import fix_unset_envs 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | T = TypeVar("T") 18 | 19 | 20 | def flatten_dict(d): 21 | records = pd.json_normalize(d, sep="/").to_dict(orient="records") 22 | return records[0] if records else {} 23 | 24 | 25 | def _get_named_modules(module, attrname, sep="/"): 26 | for name, module in module.named_modules(): 27 | name = name.replace(".", sep) 28 | if hasattr(module, attrname): 29 | yield name, module 30 | 31 | 32 | def gather_attribute(module, attrname, delete=True, prefix=None): 33 | ret = {} 34 | for name, module in _get_named_modules(module, attrname): 35 | ret[name] = getattr(module, attrname) 36 | if delete: 37 | try: 38 | delattr(module, attrname) 39 | except Exception as e: 40 | raise RuntimeError(f"{name} {module} {attrname}") from e 41 | if prefix: 42 | ret = {prefix: ret} 43 | ret = flatten_dict(ret) 44 | # remove consecutive / 45 | ret = {re.sub(r"\/+", "/", k): v for k, v in ret.items()} 46 | return ret 47 | 48 | 49 | def dispatch_attribute(module, attrname, value, filter_fn: Callable[[nn.Module], bool] | None = None): 50 | for _, module in _get_named_modules(module, attrname): 51 | if filter_fn is None or filter_fn(module): 52 | setattr(module, attrname, value) 53 | 54 | 55 | @cache 56 | def update_deepspeed_logger(): 57 | logger = logging.getLogger("DeepSpeed") 58 | logger.setLevel(logging.WARNING) 59 | 60 | 61 | @cache 62 | def init_distributed(): 63 | update_deepspeed_logger() 64 | fix_unset_envs() 65 | deepspeed.init_distributed(get_accelerator().communication_backend_name()) 66 | 67 | 68 | def _try_each(*fns, e=None): 69 | if len(fns) == 0: 70 | raise RuntimeError("All functions failed") 71 | 72 | head, *tails = fns 73 | 74 | try: 75 | return head() 76 | except Exception as e: 77 | logger.warning(f"Tried {head} but failed: {e}, trying next") 78 | return _try_each(*tails) 79 | 80 | 81 | class Engine(DeepSpeedEngine): 82 | def __init__(self, *args, ckpt_dir, **kwargs): 83 | init_distributed() 84 | super().__init__(args=None, *args, **kwargs) 85 | self._ckpt_dir = ckpt_dir 86 | self._frozen_params = set() 87 | self._fp32_grad_norm = None 88 | 89 | @property 90 | def path(self): 91 | return self._ckpt_dir 92 | 93 | def freeze_(self): 94 | for p in self.module.parameters(): 95 | if p.requires_grad: 96 | p.requires_grad_(False) 97 | self._frozen_params.add(p) 98 | 99 | def unfreeze_(self): 100 | for p in self._frozen_params: 101 | p.requires_grad_(True) 102 | self._frozen_params.clear() 103 | 104 | @property 105 | def global_step(self): 106 | return self.global_steps 107 | 108 | def gather_attribute(self, *args, **kwargs): 109 | return gather_attribute(self.module, *args, **kwargs) 110 | 111 | def dispatch_attribute(self, *args, **kwargs): 112 | return dispatch_attribute(self.module, *args, **kwargs) 113 | 114 | def clip_fp32_gradients(self): 115 | self._fp32_grad_norm = clip_grad_norm_( 116 | parameters=self.module.parameters(), 117 | max_norm=self.gradient_clipping(), 118 | mpu=self.mpu, 119 | ) 120 | 121 | def get_grad_norm(self): 122 | grad_norm = self.get_global_grad_norm() 123 | if grad_norm is None: 124 | grad_norm = self._fp32_grad_norm 125 | return grad_norm 126 | 127 | def save_checkpoint(self, *args, **kwargs): 128 | if not self._ckpt_dir.exists(): 129 | self._ckpt_dir.mkdir(parents=True, exist_ok=True) 130 | super().save_checkpoint(save_dir=self._ckpt_dir, *args, **kwargs) 131 | logger.info(f"Saved checkpoint to {self._ckpt_dir}") 132 | 133 | def load_checkpoint(self, *args, **kwargs): 134 | fn = partial(super().load_checkpoint, *args, load_dir=self._ckpt_dir, **kwargs) 135 | return _try_each( 136 | lambda: fn(), 137 | lambda: fn(load_optimizer_states=False), 138 | lambda: fn(load_lr_scheduler_states=False), 139 | lambda: fn(load_optimizer_states=False, load_lr_scheduler_states=False), 140 | lambda: fn( 141 | load_optimizer_states=False, 142 | load_lr_scheduler_states=False, 143 | load_module_strict=False, 144 | ), 145 | ) 146 | -------------------------------------------------------------------------------- /resemble_enhance/utils/logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | 4 | from rich.logging import RichHandler 5 | 6 | from .distributed import global_leader_only 7 | 8 | 9 | @global_leader_only 10 | def setup_logging(run_dir): 11 | handlers = [] 12 | stdout_handler = RichHandler() 13 | stdout_handler.setLevel(logging.INFO) 14 | handlers.append(stdout_handler) 15 | 16 | if run_dir is not None: 17 | filename = Path(run_dir) / f"log.txt" 18 | filename.parent.mkdir(parents=True, exist_ok=True) 19 | file_handler = logging.FileHandler(filename, mode="a") 20 | file_handler.setLevel(logging.DEBUG) 21 | handlers.append(file_handler) 22 | 23 | # Update all existing loggers 24 | for name in ["DeepSpeed"]: 25 | logger = logging.getLogger(name) 26 | if isinstance(logger, logging.Logger): 27 | for handler in list(logger.handlers): 28 | logger.removeHandler(handler) 29 | for handler in handlers: 30 | logger.addHandler(handler) 31 | 32 | # Set the default logger 33 | logging.basicConfig( 34 | level=logging.getLevelName("INFO"), 35 | format="%(message)s", 36 | datefmt="[%X]", 37 | handlers=handlers, 38 | ) 39 | -------------------------------------------------------------------------------- /resemble_enhance/utils/train_loop.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import time 4 | from dataclasses import KW_ONLY, dataclass 5 | from pathlib import Path 6 | from typing import Protocol 7 | 8 | import torch 9 | from torch import Tensor 10 | from torch.utils.data import DataLoader 11 | 12 | from .control import non_blocking_input 13 | from .distributed import is_global_leader 14 | from .engine import Engine 15 | from .utils import tree_map 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class EvalFn(Protocol): 21 | def __call__(self, engine: Engine, eval_dir: Path) -> None: ... 22 | 23 | 24 | class EngineLoader(Protocol): 25 | def __call__(self, run_dir: Path) -> Engine: ... 26 | 27 | 28 | class GenFeeder(Protocol): 29 | def __call__(self, engine: Engine, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]: ... 30 | 31 | 32 | class DisFeeder(Protocol): 33 | def __call__(self, engine: Engine, batch: dict[str, Tensor] | None, fake: Tensor) -> dict[str, Tensor]: ... 34 | 35 | 36 | @dataclass 37 | class TrainLoop: 38 | _ = KW_ONLY 39 | 40 | run_dir: Path 41 | train_dl: DataLoader 42 | 43 | load_G: EngineLoader 44 | feed_G: GenFeeder 45 | load_D: EngineLoader | None = None 46 | feed_D: DisFeeder | None = None 47 | 48 | update_every: int = 5_000 49 | eval_every: int = 5_000 50 | backup_steps: tuple[int, ...] = (5_000, 100_000, 500_000) 51 | 52 | device: str = "cuda" 53 | eval_fn: EvalFn | None = None 54 | gan_training_start_step: int | None = None 55 | 56 | @property 57 | def global_step(self): 58 | return self.engine_G.global_step # How many steps have been completed? 59 | 60 | @property 61 | def eval_dir(self) -> Path | None: 62 | if self.eval_every != 0: 63 | eval_dir = self.run_dir.joinpath("eval") 64 | eval_dir.mkdir(exist_ok=True) 65 | else: 66 | eval_dir = None 67 | return eval_dir 68 | 69 | @property 70 | def viz_dir(self) -> Path: 71 | return Path(self.run_dir / "viz") 72 | 73 | def make_current_step_viz_path(self, name: str, suffix: str) -> Path: 74 | path = (self.viz_dir / name / f"{self.global_step}").with_suffix(suffix) 75 | path.parent.mkdir(exist_ok=True, parents=True) 76 | return path 77 | 78 | def __post_init__(self): 79 | engine_G = self.load_G(self.run_dir) 80 | if self.load_D is None: 81 | engine_D = None 82 | else: 83 | engine_D = self.load_D(self.run_dir) 84 | self.engine_G = engine_G 85 | self.engine_D = engine_D 86 | 87 | @property 88 | def model_G(self): 89 | return self.engine_G.module 90 | 91 | @property 92 | def model_D(self): 93 | if self.engine_D is None: 94 | return None 95 | return self.engine_D.module 96 | 97 | def save_checkpoint(self, tag="default"): 98 | engine_G = self.engine_G 99 | engine_D = self.engine_D 100 | engine_G.save_checkpoint(tag=tag) 101 | if engine_D is not None: 102 | engine_D.save_checkpoint(tag=tag) 103 | 104 | def run(self, max_steps: int = -1): 105 | self.set_running_loop_(self) 106 | 107 | train_dl = self.train_dl 108 | update_every = self.update_every 109 | eval_every = self.eval_every 110 | device = self.device 111 | eval_fn = self.eval_fn 112 | 113 | engine_G = self.engine_G 114 | engine_D = self.engine_D 115 | eval_dir = self.eval_dir 116 | 117 | init_step = self.global_step 118 | 119 | logger.info(f"\nTraining from step {init_step} to step {max_steps}") 120 | warmup_steps = {init_step + x for x in [50, 100, 500]} 121 | 122 | engine_G.train() 123 | 124 | if engine_D is not None: 125 | engine_D.train() 126 | 127 | gan_start_step = self.gan_training_start_step 128 | 129 | while True: 130 | loss_G = loss_D = 0 131 | 132 | for batch in train_dl: 133 | torch.cuda.synchronize() 134 | start_time = time.time() 135 | 136 | # What's the step after this batch? 137 | step = self.global_step + 1 138 | 139 | # Send data to the GPU 140 | batch = tree_map(lambda x: x.to(device) if isinstance(x, Tensor) else x, batch) 141 | 142 | stats = {"step": step} 143 | 144 | # Include step == 1 for sanity check 145 | gan_started = gan_start_step is not None and (step >= gan_start_step or step == 1) 146 | gan_started &= engine_D is not None 147 | 148 | # Generator step 149 | fake, losses = self.feed_G(engine=engine_G, batch=batch) 150 | 151 | # Train generator 152 | if gan_started: 153 | assert engine_D is not None 154 | assert self.feed_D is not None 155 | 156 | # Freeze the discriminator to let gradient go through fake 157 | engine_D.freeze_() 158 | losses |= self.feed_D(engine=engine_D, batch=None, fake=fake) 159 | 160 | loss_G = sum(losses.values()) 161 | stats |= {f"G/{k}": v.item() for k, v in losses.items()} 162 | stats |= {f"G/{k}": v for k, v in engine_G.gather_attribute("stats").items()} 163 | del losses 164 | 165 | assert isinstance(loss_G, Tensor) 166 | stats["G/loss"] = loss_G.item() 167 | stats["G/lr"] = engine_G.get_lr()[0] 168 | stats["G/grad_norm"] = engine_G.get_grad_norm() or 0 169 | 170 | if loss_G.isnan().item(): 171 | logger.error("Generator loss is NaN, skipping step") 172 | continue 173 | 174 | engine_G.backward(loss_G) 175 | engine_G.step() 176 | 177 | # Discriminator step 178 | if gan_started: 179 | assert engine_D is not None 180 | assert self.feed_D is not None 181 | 182 | engine_D.unfreeze_() 183 | losses = self.feed_D(engine=engine_D, batch=batch, fake=fake.detach()) 184 | del fake 185 | 186 | assert isinstance(losses, dict) 187 | loss_D = sum(losses.values()) 188 | assert isinstance(loss_D, Tensor) 189 | 190 | stats |= {f"D/{k}": v.item() for k, v in losses.items()} 191 | stats |= {f"D/{k}": v for k, v in engine_D.gather_attribute("stats").items()} 192 | del losses 193 | 194 | if loss_D.isnan().item(): 195 | logger.error("Discriminator loss is NaN, skipping step") 196 | continue 197 | 198 | engine_D.backward(loss_D) 199 | engine_D.step() 200 | 201 | stats["D/loss"] = loss_D.item() 202 | stats["D/lr"] = engine_D.get_lr()[0] 203 | stats["D/grad_norm"] = engine_D.get_grad_norm() or 0 204 | 205 | torch.cuda.synchronize() 206 | stats["elapsed_time"] = time.time() - start_time 207 | stats = tree_map(lambda x: float(f"{x:.4g}") if isinstance(x, float) else x, stats) 208 | logger.info(json.dumps(stats, indent=0)) 209 | 210 | command = non_blocking_input() 211 | 212 | evaling = step % eval_every == 0 or step in warmup_steps or command.strip() == "eval" 213 | if eval_fn is not None and is_global_leader() and eval_dir is not None and evaling: 214 | engine_G.eval() 215 | eval_fn(engine_G, eval_dir=eval_dir) 216 | engine_G.train() 217 | 218 | if command.strip() == "quit": 219 | logger.info("Training paused") 220 | self.save_checkpoint("default") 221 | return 222 | 223 | if command.strip() == "backup" or step in self.backup_steps: 224 | logger.info("Backing up") 225 | self.save_checkpoint(tag=f"backup_{step:07d}") 226 | 227 | if step % update_every == 0 or command.strip() == "save": 228 | self.save_checkpoint(tag="default") 229 | 230 | if step == max_steps: 231 | logger.info("Training finished") 232 | self.save_checkpoint(tag="default") 233 | return 234 | 235 | @classmethod 236 | def set_running_loop_(cls, loop): 237 | assert isinstance(loop, cls), f"Expected {cls}, got {type(loop)}" 238 | cls._running_loop = loop 239 | 240 | @classmethod 241 | def get_running_loop(cls) -> "TrainLoop | None": 242 | if hasattr(cls, "_running_loop"): 243 | assert isinstance(cls._running_loop, cls) 244 | return cls._running_loop 245 | return None 246 | 247 | @classmethod 248 | def get_running_loop_global_step(cls) -> int | None: 249 | if loop := cls.get_running_loop(): 250 | return loop.global_step 251 | return None 252 | 253 | @classmethod 254 | def get_running_loop_viz_path(cls, name: str, suffix: str) -> Path | None: 255 | if loop := cls.get_running_loop(): 256 | return loop.make_current_step_viz_path(name, suffix) 257 | return None 258 | -------------------------------------------------------------------------------- /resemble_enhance/utils/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, TypeVar, overload 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | 6 | 7 | def save_mels(path, *, targ_mel, pred_mel, cond_mel): 8 | n = 3 if cond_mel is None else 4 9 | 10 | plt.figure(figsize=(10, n * 4)) 11 | 12 | i = 1 13 | 14 | plt.subplot(n, 1, i) 15 | plt.imshow(pred_mel, origin="lower", interpolation="none") 16 | plt.title(f"Pred mel {pred_mel.shape}") 17 | i += 1 18 | 19 | plt.subplot(n, 1, i) 20 | plt.imshow(targ_mel, origin="lower", interpolation="none") 21 | plt.title(f"GT mel {targ_mel.shape}") 22 | i += 1 23 | 24 | plt.subplot(n, 1, i) 25 | pred_mel = pred_mel[:, : targ_mel.shape[1]] 26 | targ_mel = targ_mel[:, : pred_mel.shape[1]] 27 | plt.imshow(np.abs(pred_mel - targ_mel), origin="lower", interpolation="none") 28 | plt.title(f"Diff mel {pred_mel.shape}, mse={np.mean((pred_mel - targ_mel)**2):.4f}") 29 | i += 1 30 | 31 | if cond_mel is not None: 32 | plt.subplot(n, 1, i) 33 | plt.imshow(cond_mel, origin="lower", interpolation="none") 34 | plt.title(f"Cond mel {cond_mel.shape}") 35 | i += 1 36 | 37 | plt.savefig(path, dpi=480) 38 | plt.close() 39 | 40 | 41 | T = TypeVar("T") 42 | 43 | 44 | @overload 45 | def tree_map(fn: Callable, x: list[T]) -> list[T]: 46 | ... 47 | 48 | 49 | @overload 50 | def tree_map(fn: Callable, x: tuple[T]) -> tuple[T]: 51 | ... 52 | 53 | 54 | @overload 55 | def tree_map(fn: Callable, x: dict[str, T]) -> dict[str, T]: 56 | ... 57 | 58 | 59 | @overload 60 | def tree_map(fn: Callable, x: T) -> T: 61 | ... 62 | 63 | 64 | def tree_map(fn: Callable, x): 65 | if isinstance(x, list): 66 | x = [tree_map(fn, xi) for xi in x] 67 | elif isinstance(x, tuple): 68 | x = (tree_map(fn, xi) for xi in x) 69 | elif isinstance(x, dict): 70 | x = {k: tree_map(fn, v) for k, v in x.items()} 71 | else: 72 | x = fn(x) 73 | return x 74 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from datetime import datetime, timezone 3 | from pathlib import Path 4 | 5 | from setuptools import find_packages, setup 6 | 7 | 8 | def shell(*args): 9 | out = subprocess.check_output(args) 10 | return out.decode("ascii").strip() 11 | 12 | 13 | def write_version(version_core, pre_release=True): 14 | if pre_release: 15 | last_commit_time = shell("git", "log", "-1", "--format=%cd", "--date=iso-strict") 16 | last_commit_time = datetime.strptime(last_commit_time, "%Y-%m-%dT%H:%M:%S%z") 17 | last_commit_time = last_commit_time.astimezone(timezone.utc) 18 | last_commit_time = last_commit_time.strftime("%y%m%d%H%M%S") 19 | version = f"{version_core}-dev{last_commit_time}" 20 | else: 21 | version = version_core 22 | 23 | with open(Path("resemble_enhance", "version.py"), "w") as f: 24 | f.write('__version__ = "{}"\n'.format(version)) 25 | 26 | return version 27 | 28 | 29 | with open("README.md", "r") as f: 30 | long_description = f.read() 31 | 32 | 33 | with open("requirements.txt", "r") as f: 34 | requirements = f.read().splitlines() 35 | 36 | setup( 37 | name="resemble-enhance", 38 | python_requires=">=3.10", 39 | version=write_version("0.0.2", pre_release=True), 40 | description="Speech denoising and enhancement with deep learning", 41 | long_description=long_description, 42 | long_description_content_type="text/markdown", 43 | packages=find_packages(), 44 | install_requires=requirements, 45 | url="https://github.com/resemble-ai/resemble-enhance", 46 | author="Resemble AI", 47 | author_email="team@resemble.ai", 48 | entry_points={ 49 | "console_scripts": [ 50 | "resemble-enhance=resemble_enhance.enhancer.__main__:main", 51 | ] 52 | }, 53 | ) 54 | --------------------------------------------------------------------------------