├── README.md ├── sgmse ├── backbones │ ├── ncsnpp_utils │ │ ├── op │ │ │ ├── __init__.py │ │ │ ├── fused_bias_act.cpp │ │ │ ├── upfirdn2d.cpp │ │ │ ├── fused_act.py │ │ │ ├── fused_bias_act_kernel.cu │ │ │ ├── upfirdn2d.py │ │ │ └── upfirdn2d_kernel.cu │ │ ├── utils.py │ │ ├── normalization.py │ │ ├── up_or_down_sampling.py │ │ ├── layerspp.py │ │ └── layers.py │ ├── __init__.py │ ├── shared.py │ ├── convtasnet_utils │ │ └── utils.py │ └── convtasnet.py ├── util │ ├── tensors.py │ ├── registry.py │ ├── inference.py │ ├── graphics.py │ └── other.py ├── sampling │ ├── predictors.py │ ├── __init__.py │ └── correctors.py ├── odes.py └── data_module.py ├── preprocessing ├── utils.py ├── simulate_wind_noise.py ├── nonlinear_mixing.py └── create_data.py ├── evaluate.py ├── utils.py ├── enhancement.py ├── .gitignore └── train.py /README.md: -------------------------------------------------------------------------------- 1 | Final version -------------------------------------------------------------------------------- /sgmse/backbones/ncsnpp_utils/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /sgmse/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .shared import BackboneRegistry 2 | from .ncsnpp import AutoEncodeNCSNpp, NCSNpp, NCSNppLarge, NCSNpp12M, NCSNpp6M 3 | from .convtasnet import ConvTasNet 4 | from .gagnet import GaGNet 5 | 6 | __all__ = ['BackboneRegistry', 'AutoEncodeNCSNpp', 'NCSNpp', 'NCSNppLarge', 'NCSNpp12M', 'NCSNpp6M', 'ConvTasNet', 'GaGNet'] 7 | -------------------------------------------------------------------------------- /sgmse/util/tensors.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def batch_broadcast(a, x): 4 | """Broadcasts a over all dimensions of x, except the batch dimension, which must match.""" 5 | 6 | if len(a.shape) != 1: 7 | a = a.squeeze() 8 | if len(a.shape) != 1: 9 | raise ValueError( 10 | f"Don't know how to batch-broadcast tensor `a` with more than one effective dimension (shape {a.shape})" 11 | ) 12 | 13 | if a.shape[0] != x.shape[0] and a.shape[0] != 1: 14 | raise ValueError( 15 | f"Don't know how to batch-broadcast shape {a.shape} over {x.shape} as the batch dimension is not matching") 16 | 17 | out = a.view((x.shape[0], *(1 for _ in range(len(x.shape)-1)))) 18 | return out 19 | -------------------------------------------------------------------------------- /sgmse/backbones/ncsnpp_utils/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /sgmse/backbones/ncsnpp_utils/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /sgmse/util/registry.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Callable 3 | 4 | 5 | class Registry: 6 | def __init__(self, managed_thing: str): 7 | """ 8 | Create a new registry. 9 | 10 | Args: 11 | managed_thing: A string describing what type of thing is managed by this registry. Will be used for 12 | warnings and errors, so it's a good idea to keep this string globally unique and easily understood. 13 | """ 14 | self.managed_thing = managed_thing 15 | self._registry = {} 16 | 17 | def register(self, name: str) -> Callable: 18 | def inner_wrapper(wrapped_class) -> Callable: 19 | if name in self._registry: 20 | warnings.warn(f"{self.managed_thing} with name '{name}' doubly registered, old class will be replaced.") 21 | self._registry[name] = wrapped_class 22 | return wrapped_class 23 | return inner_wrapper 24 | 25 | def get_by_name(self, name: str): 26 | """Get a managed thing by name.""" 27 | if name in self._registry: 28 | return self._registry[name] 29 | else: 30 | raise ValueError(f"{self.managed_thing} with name '{name}' unknown.") 31 | 32 | def get_all_names(self): 33 | """Get the list of things' names registered to this registry.""" 34 | return list(self._registry.keys()) 35 | -------------------------------------------------------------------------------- /preprocessing/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/env/bin/python3 2 | 3 | import matplotlib 4 | matplotlib.use('Agg') 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import os 8 | import soundfile as sf 9 | from scipy.signal import resample_poly, fftconvolve, stft 10 | import glob 11 | import itertools, operator 12 | 13 | def obtain_noise_file(noise_dir, i_sample, channels, dataset, sample_rate, len_speech): 14 | 15 | nb_samples = len(os.listdir(noise_dir)) 16 | 17 | if dataset == "wham": 18 | noise_sample, noise_sr = sf.read(os.path.join(noise_dir, os.listdir(noise_dir)[i_sample%nb_samples])) 19 | 20 | if noise_sr != sample_rate: #Resample 21 | noise_sample = resample(noise_sample, noise_sr, sample_rate) 22 | if channels == 1: 23 | noise_sample = noise_sample[:, 0] 24 | 25 | elif dataset == "chime": 26 | noise_types = ["CAF", "PED", "STR", "BUS"] 27 | noise_type = noise_types[np.random.randint(len(noise_types))] 28 | noise_candidates = glob.glob(os.path.join(noise_dir, f"*_{noise_type}.CH1.wav")) 29 | noise_sample_basename = noise_candidates[np.random.randint(len(noise_candidates))][: -8] 30 | noise_sample_ch1, noise_sr = sf.read(noise_sample_basename + ".CH1.wav") 31 | 32 | if noise_sr != sample_rate: #Resample 33 | noise_sample_ch1 = resample(noise_sample_ch1, noise_sr, sample_rate) 34 | 35 | start = np.random.randint(noise_sample_ch1.shape[-1]-len_speech) 36 | noise_sample = np.stack([ sf.read(noise_sample_basename + f".CH{i_ch+1}.wav")[0].squeeze()[start: start+len_speech] 37 | for i_ch in range(channels) ]) 38 | 39 | if noise_sr != sample_rate: #Resample 40 | noise_sample_resampled = np.stack([ resample(noise_sample_ch, noise_sr, sample_rate) for noise_sample_ch in noise_sample ]) 41 | noise_sample = noise_sample_resampled 42 | 43 | elif dataset == "qut": 44 | raise NotImplementedError 45 | 46 | return noise_sample, noise_sr 47 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import os 3 | 4 | dataset_key = ["WSJ0+Chime3_low", "WSJ0+Reverb","VoiceBank-DEMAND" ] 5 | 6 | # Define the list of checkpoints, datasets, and N values 7 | datasets = {"WSJ0+Chime3_low":"/data/dataset/WSJ0-CHiME3_low_snr", 8 | "WSJ0+Reverb":"/data/dataset/WSJ0-CHiME3_derev", 9 | "VoiceBank-DEMAND": "/data/dataset/VCTK_corpus"} 10 | base_command = "CUDA_VISIBLE_DEVICES=2 python enhancement.py" 11 | 12 | N_values = [1, 2, 3,4,5,6,7,8,9,10,50] # Define the list of N values 13 | pretrained_path = "/data/baseline_interspeech2025/storm_pretrained/pretrained_model_by_author" 14 | models_path = os.listdir(pretrained_path) 15 | for model in models_path: 16 | if model =="ncsnpp": 17 | mode = "denoiser-only" 18 | 19 | elif model == "sgmse": 20 | mode = "score-only" 21 | elif model == "storm": 22 | mode = "storm" 23 | model_dataset_path = os.path.join(pretrained_path, model) 24 | model_dataset_list = os.listdir(model_dataset_path) 25 | for dataset_name in model_dataset_list: 26 | if dataset_name in dataset_key: 27 | try: 28 | ckpt_name = os.listdir(os.path.join(pretrained_path, model,dataset_name)) 29 | assert len(ckpt_name) == 1 30 | ckpt_name = ckpt_name[0] 31 | ckpt_path = os.path.join(pretrained_path, model, dataset_name, ckpt_name) 32 | 33 | except FileNotFoundError: 34 | pass 35 | for n in N_values: 36 | enhanced_dir = f"{dataset_name}_{model}_N_{n}" 37 | command = ( 38 | f"{base_command} " 39 | f"--test_dir {datasets[dataset_name]} " 40 | f"--enhanced_dir {enhanced_dir} " 41 | f"--ckpt {ckpt_path} " 42 | f"--mode {mode} " 43 | f"--N {n}" 44 | ) 45 | 46 | if model == "ncsnpp": 47 | if n>1: 48 | continue 49 | 50 | print(f"Running command: {command}") 51 | 52 | # Execute the command 53 | result = subprocess.run(command, shell=True, text=True) 54 | 55 | # Check for errors 56 | if result.returncode != 0: 57 | print(f"Error occurred while processing {dataset} with {checkpoint} and N={N}.") 58 | print(result.stderr) 59 | 60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /sgmse/util/inference.py: -------------------------------------------------------------------------------- 1 | import sre_compile 2 | import torch 3 | from torchaudio import load 4 | from sgmse.util.other import si_sdr, pad_spec 5 | from pesq import pesq 6 | from tqdm import tqdm 7 | from pystoi import stoi 8 | import numpy as np 9 | 10 | # Settings 11 | snr = 0.5 12 | N = 5 13 | corrector_steps = 1 14 | 15 | # Plotting settings 16 | MAX_VIS_SAMPLES = 10 17 | n_fft = 512 18 | hop_length = 128 19 | 20 | def evaluate_model(model, num_eval_files, spec=False, audio=False, discriminative=False): 21 | 22 | model.eval() 23 | _pesq, _si_sdr, _estoi = 0., 0., 0. 24 | if spec: 25 | noisy_spec_list, estimate_spec_list, clean_spec_list = [], [], [] 26 | if audio: 27 | noisy_audio_list, estimate_audio_list, clean_audio_list = [], [], [] 28 | 29 | for i in range(num_eval_files): 30 | # Load wavs 31 | x, y = model.data_module.valid_set.__getitem__(i, raw=True) #d,t 32 | norm_factor = y.abs().max().item() 33 | x_hat = model.enhance(y) 34 | 35 | if x_hat.ndim == 1: 36 | x_hat = x_hat.unsqueeze(0) 37 | 38 | if x.ndim == 1: 39 | x = x.unsqueeze(0).cpu().numpy() 40 | x_hat = x_hat.unsqueeze(0).cpu().numpy() 41 | y = y.unsqueeze(0).cpu().numpy() 42 | else: #eval only first channel 43 | x = x[0].unsqueeze(0).cpu().numpy() 44 | x_hat = x_hat[0].unsqueeze(0).cpu().numpy() 45 | y = y[0].unsqueeze(0).cpu().numpy() 46 | 47 | _si_sdr += si_sdr(x[0], x_hat[0]) 48 | _pesq += pesq(16000, x[0], x_hat[0], 'wb') 49 | _estoi += stoi(x[0], x_hat[0], 16000, extended=True) 50 | 51 | y, x_hat, x = torch.from_numpy(y), torch.from_numpy(x_hat), torch.from_numpy(x) 52 | if spec and i < MAX_VIS_SAMPLES: 53 | y_stft, x_hat_stft, x_stft = model._stft(y[0]), model._stft(x_hat[0]), model._stft(x[0]) 54 | noisy_spec_list.append(y_stft) 55 | estimate_spec_list.append(x_hat_stft) 56 | clean_spec_list.append(x_stft) 57 | 58 | if audio and i < MAX_VIS_SAMPLES: 59 | noisy_audio_list.append(y[0]) 60 | estimate_audio_list.append(x_hat[0]) 61 | clean_audio_list.append(x[0]) 62 | 63 | if spec: 64 | if audio: 65 | return _pesq/num_eval_files, _si_sdr/num_eval_files, _estoi/num_eval_files, [noisy_spec_list, estimate_spec_list, clean_spec_list], [noisy_audio_list, estimate_audio_list, clean_audio_list] 66 | else: 67 | return _pesq/num_eval_files, _si_sdr/num_eval_files, _estoi/num_eval_files, [noisy_spec_list, estimate_spec_list, clean_spec_list], None 68 | elif audio and not spec: 69 | return _pesq/num_eval_files, _si_sdr/num_eval_files, _estoi/num_eval_files, None, [noisy_audio_list, estimate_audio_list, clean_audio_list] 70 | else: 71 | return _pesq/num_eval_files, _si_sdr/num_eval_files, _estoi/num_eval_files, None, None 72 | 73 | -------------------------------------------------------------------------------- /preprocessing/simulate_wind_noise.py: -------------------------------------------------------------------------------- 1 | """Example to generate a wind noise signal.""" 2 | 3 | import numpy as np 4 | # Need access to the WindNoiseGenerator library (file: sc_wind_noise_generator.py) presented in D. Mirabilii et al. "Simulating wind noise with airflow speed-dependent characteristics,” in Int. Workshop on Acoustic Signal Enhancement, Sept. 2022" 5 | # Please ask the authors as we are not responsible for the distribution of their code 6 | from sc_wind_noise_generator import WindNoiseGenerator as wng 7 | import argparse 8 | import os 9 | import shutil 10 | import tqdm 11 | 12 | SEED = 100 # Seed for random sequence regeneration 13 | 14 | # Parameters 15 | wind_params = { 16 | "duration": 8, 17 | "fs": 16000, 18 | "gustiness_range": [1, 10], 19 | "wind_profile_magnitude_range": [200, 500], 20 | "wind_profile_acceptable_transition_threshold": 100 21 | } 22 | 23 | parser = argparse.ArgumentParser() 24 | 25 | parser.add_argument('--dir', type=str) 26 | parser.add_argument('--n', type=int, help="number of samples") 27 | parser.add_argument('--sr', default=16000, type=int) 28 | 29 | args = parser.parse_args() 30 | params = vars(args) 31 | params = {**wind_params, **params} 32 | 33 | if os.path.exists(args.dir): 34 | shutil.rmtree(args.dir) 35 | os.makedirs(args.dir, exist_ok=True) 36 | 37 | for i in tqdm.tqdm(range(args.n)): 38 | 39 | # Generate wind profile 40 | gustiness = np.random.uniform(wind_params["gustiness_range"][0], wind_params["gustiness_range"][1]) # Number of speed points. One yields constant wind. High values yields gusty wind. 41 | number_points_wind_profile = int(1.5 * gustiness) 42 | wind_profile = [np.random.uniform(wind_params["wind_profile_magnitude_range"][0], wind_params["wind_profile_magnitude_range"][1])] 43 | 44 | while len(wind_profile) < number_points_wind_profile: 45 | is_valid = False 46 | while not is_valid: 47 | new_point = np.random.uniform(wind_params["wind_profile_magnitude_range"][0], wind_params["wind_profile_magnitude_range"][1]) 48 | is_valid = new_point < wind_profile[-1] + wind_params["wind_profile_acceptable_transition_threshold"] and new_point > wind_profile[-1] - wind_params["wind_profile_acceptable_transition_threshold"] 49 | wind_profile.append(new_point) 50 | 51 | seed_sample = SEED + i 52 | # Generate wind noise 53 | wn = wng(fs=args.sr, duration=wind_params["duration"], generate=True, 54 | wind_profile=wind_profile, 55 | gustiness=gustiness, start_seed=seed_sample) 56 | wn_signal, wind_profile = wn.generate_wind_noise() 57 | 58 | # Save signal in .wav file 59 | wn.save_signal(wn_signal, filename=os.path.join(args.dir, f'simulated_{i}.wav'), num_ch=1, fs=args.sr) -------------------------------------------------------------------------------- /sgmse/sampling/predictors.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | import torch 4 | import numpy as np 5 | 6 | from sgmse.util.registry import Registry 7 | 8 | 9 | PredictorRegistry = Registry("Predictor") 10 | 11 | 12 | class Predictor(abc.ABC): 13 | """The abstract class for a predictor algorithm.""" 14 | 15 | def __init__(self, sde, score_fn, probability_flow=False): 16 | super().__init__() 17 | self.sde = sde 18 | self.rsde = sde.reverse(score_fn) 19 | self.score_fn = score_fn 20 | self.probability_flow = probability_flow 21 | 22 | @abc.abstractmethod 23 | def update_fn(self, x, t, *args): 24 | """One update of the predictor. 25 | 26 | Args: 27 | x: A PyTorch tensor representing the current state 28 | t: A Pytorch tensor representing the current time step. 29 | *args: Possibly additional arguments, in particular `y` for OU processes 30 | 31 | Returns: 32 | x: A PyTorch tensor of the next state. 33 | x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising. 34 | """ 35 | pass 36 | 37 | def debug_update_fn(self, x, t, *args): 38 | raise NotImplementedError(f"Debug update function not implemented for predictor {self}.") 39 | 40 | 41 | @PredictorRegistry.register('euler_maruyama') 42 | class EulerMaruyamaPredictor(Predictor): 43 | def __init__(self, sde, score_fn, probability_flow=False): 44 | super().__init__(sde, score_fn, probability_flow=probability_flow) 45 | 46 | def update_fn(self, x, t, *args, **kwargs): 47 | dt = -1. / self.rsde.N 48 | z = torch.randn_like(x) 49 | f, g = self.rsde.sde(x, t, *args, **kwargs) 50 | x_mean = x + f * dt 51 | if g.ndim < x.ndim: 52 | g = g.view( *g.size(), *((1,)*(x.ndim - g.ndim)) ) 53 | x = x_mean + g * np.sqrt(-dt) * z 54 | return x, x_mean 55 | 56 | 57 | @PredictorRegistry.register('reverse_diffusion') 58 | class ReverseDiffusionPredictor(Predictor): 59 | def __init__(self, sde, score_fn, probability_flow=False): 60 | super().__init__(sde, score_fn, probability_flow=probability_flow) 61 | 62 | def update_fn(self, x, t, *args, **kwargs): 63 | f, g = self.rsde.discretize(x, t, *args, **kwargs) 64 | z = torch.randn_like(x) 65 | x_mean = x - f 66 | if g.ndim < x.ndim: 67 | g = g.view( *g.size(), *((1,)*(x.ndim - g.ndim)) ) 68 | x = x_mean + g * z 69 | return x, x_mean 70 | 71 | 72 | @PredictorRegistry.register('none') 73 | class NonePredictor(Predictor): 74 | """An empty predictor that does nothing.""" 75 | 76 | def __init__(self, *args, **kwargs): 77 | pass 78 | 79 | def update_fn(self, x, t, *args, **kwargs): 80 | return x, x 81 | -------------------------------------------------------------------------------- /sgmse/backbones/ncsnpp_utils/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | fused = load( 12 | "fused", 13 | sources=[ 14 | os.path.join(module_path, "fused_bias_act.cpp"), 15 | os.path.join(module_path, "fused_bias_act_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class FusedLeakyReLUFunctionBackward(Function): 21 | @staticmethod 22 | def forward(ctx, grad_output, out, negative_slope, scale): 23 | ctx.save_for_backward(out) 24 | ctx.negative_slope = negative_slope 25 | ctx.scale = scale 26 | 27 | empty = grad_output.new_empty(0) 28 | 29 | grad_input = fused.fused_bias_act( 30 | grad_output, empty, out, 3, 1, negative_slope, scale 31 | ) 32 | 33 | dim = [0] 34 | 35 | if grad_input.ndim > 2: 36 | dim += list(range(2, grad_input.ndim)) 37 | 38 | grad_bias = grad_input.sum(dim).detach() 39 | 40 | return grad_input, grad_bias 41 | 42 | @staticmethod 43 | def backward(ctx, gradgrad_input, gradgrad_bias): 44 | out, = ctx.saved_tensors 45 | gradgrad_out = fused.fused_bias_act( 46 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 47 | ) 48 | 49 | return gradgrad_out, None, None, None 50 | 51 | 52 | class FusedLeakyReLUFunction(Function): 53 | @staticmethod 54 | def forward(ctx, input, bias, negative_slope, scale): 55 | empty = input.new_empty(0) 56 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 57 | ctx.save_for_backward(out) 58 | ctx.negative_slope = negative_slope 59 | ctx.scale = scale 60 | 61 | return out 62 | 63 | @staticmethod 64 | def backward(ctx, grad_output): 65 | out, = ctx.saved_tensors 66 | 67 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 68 | grad_output, out, ctx.negative_slope, ctx.scale 69 | ) 70 | 71 | return grad_input, grad_bias, None, None 72 | 73 | 74 | class FusedLeakyReLU(nn.Module): 75 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 76 | super().__init__() 77 | 78 | self.bias = nn.Parameter(torch.zeros(channel)) 79 | self.negative_slope = negative_slope 80 | self.scale = scale 81 | 82 | def forward(self, input): 83 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 84 | 85 | 86 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 87 | if input.device.type == "cpu": 88 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 89 | return ( 90 | F.leaky_relu( 91 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 92 | ) 93 | * scale 94 | ) 95 | 96 | else: 97 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 98 | -------------------------------------------------------------------------------- /sgmse/backbones/ncsnpp_utils/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /sgmse/util/graphics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchaudio import load 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import os 6 | import soundfile as sf 7 | import glob 8 | 9 | # Plotting settings 10 | EPS_graphics = 1e-10 11 | n_fft = 512 12 | hop_length = 128 13 | 14 | stft_kwargs = {"n_fft": n_fft, "hop_length": hop_length, "window": torch.hann_window(n_fft), "center": True, "return_complex": True} 15 | 16 | def visualize_example(mix, estimate, target, idx_sample=0, epoch=0, name="", sample_rate=16000, hop_len=128, return_fig=False): 17 | """Visualize training targets and estimates of the Neural Network 18 | Args: 19 | - mix: Tensor [F, T] 20 | - estimates/targets: Tensor [F, T] 21 | """ 22 | 23 | if isinstance(mix, torch.Tensor): 24 | mix = torch.abs(mix).detach().cpu() 25 | estimate = torch.abs(estimate).detach().cpu() 26 | target = torch.abs(target).detach().cpu() 27 | 28 | vmin, vmax = -60, 0 29 | 30 | fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(24, 8)) 31 | 32 | freqs = sample_rate/(2*mix.size(-2)) * torch.arange(mix.size(-2)) 33 | frames = hop_len/sample_rate * torch.arange(mix.size(-1)) 34 | 35 | ax = axes.flat[0] 36 | im = ax.pcolormesh(frames, freqs, 20*np.log10(.1*mix + EPS_graphics), vmin=vmin, vmax=vmax, shading="auto", cmap="magma") 37 | ax.set_xlabel('Time [s]') 38 | ax.set_ylabel('Frequency [Hz]') 39 | ax.set_title('Mixed Speech') 40 | 41 | ax = axes.flat[1] 42 | ax.pcolormesh(frames, freqs, 20*np.log10(.1*estimate + EPS_graphics), vmin=vmin, vmax=vmax, shading="auto", cmap="magma") 43 | ax.set_xlabel('Time [s]') 44 | ax.set_ylabel('Frequency [Hz]') 45 | ax.set_title('Anechoic estimate') 46 | 47 | ax = axes.flat[2] 48 | ax.pcolormesh(frames, freqs, 20*np.log10(.1*target + EPS_graphics), vmin=vmin, vmax=vmax, shading="auto", cmap="magma") 49 | ax.set_xlabel('Time [s]') 50 | ax.set_ylabel('Frequency [Hz]') 51 | ax.set_title('Anechoic target') 52 | 53 | fig.subplots_adjust(right=0.87) 54 | cbar_ax = fig.add_axes([0.9, 0.25, 0.005, 0.5]) 55 | fig.colorbar(im, cax=cbar_ax) 56 | 57 | if return_fig: 58 | return fig 59 | else: 60 | plt.savefig(os.path.join(spec_path, f"spectro_{idx_sample}_epoch{epoch}{name}.png"), bbox_inches="tight") 61 | plt.close() 62 | 63 | 64 | def visualize_one(estimate, spec_path, name="", sample_rate=16000, hop_len=128, raw=True): 65 | """Visualize training targets and estimates of the Neural Network 66 | Args: 67 | - mix: Tensor [F, T] 68 | - estimates/targets: Tensor [F, T] 69 | """ 70 | 71 | if isinstance(estimate, torch.Tensor): 72 | estimate = torch.abs(estimate).squeeze().detach().cpu() 73 | elif type(estimate) == str: 74 | estimate = np.squeeze(sf.read(estimate)[0]) 75 | norm_factor = 0.1/np.max(np.abs(estimate)) 76 | xmax = 6 77 | estimate = estimate[..., : xmax*sample_rate] 78 | estimate = torch.stft(torch.from_numpy(norm_factor*estimate), **stft_kwargs) 79 | 80 | vmin, vmax = -60, 0 81 | 82 | freqs = sample_rate/(2*estimate.size(-2)) * torch.arange(estimate.size(-2)) 83 | frames = hop_len/sample_rate * torch.arange(estimate.size(-1)) 84 | 85 | fig = plt.figure(figsize=(8, 8)) 86 | im = plt.pcolormesh(frames, freqs, 20*np.log10(estimate.abs() + EPS_graphics), vmin=vmin, vmax=vmax, shading="auto", cmap="magma") 87 | 88 | if raw: 89 | plt.yticks([]) 90 | plt.tick_params(left="off") 91 | plt.xticks([]) 92 | plt.tick_params(bottom="off") 93 | else: 94 | plt.xlabel('Time [s]') 95 | plt.ylabel('Frequency [Hz]') 96 | plt.title('Anechoic estimate') 97 | cbar_ax = fig.add_axes([0.93, 0.25, 0.03, 0.4]) 98 | fig.colorbar(im, cax=cbar_ax) 99 | 100 | plt.savefig(os.path.join(spec_path, name + ".png"), dpi=300, bbox_inches="tight") 101 | plt.close() 102 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.stats 3 | from scipy.signal import butter, sosfilt 4 | import torch 5 | import os 6 | from pesq import pesq 7 | from pystoi import stoi 8 | 9 | 10 | def si_sdr_components(s_hat, s, n): 11 | """ 12 | """ 13 | # s_target 14 | alpha_s = np.dot(s_hat, s) / np.linalg.norm(s)**2 15 | s_target = alpha_s * s 16 | 17 | # e_noise 18 | alpha_n = np.dot(s_hat, n) / np.linalg.norm(n)**2 19 | e_noise = alpha_n * n 20 | 21 | # e_art 22 | e_art = s_hat - s_target - e_noise 23 | 24 | return s_target, e_noise, e_art 25 | 26 | def energy_ratios(s_hat, s, n): 27 | """ 28 | """ 29 | s_target, e_noise, e_art = si_sdr_components(s_hat, s, n) 30 | 31 | si_sdr = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_noise + e_art)**2) 32 | si_sir = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_noise)**2) 33 | si_sar = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_art)**2) 34 | 35 | return si_sdr, si_sir, si_sar 36 | 37 | def mean_conf_int(data, confidence=0.95): 38 | a = 1.0 * np.array(data) 39 | n = len(a) 40 | m, se = np.mean(a), scipy.stats.sem(a) 41 | h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1) 42 | return m, h 43 | 44 | class Method(): 45 | def __init__(self, name, base_dir, metrics): 46 | self.name = name 47 | self.base_dir = base_dir 48 | self.metrics = {} 49 | 50 | for i in range(len(metrics)): 51 | metric = metrics[i] 52 | value = [] 53 | self.metrics[metric] = value 54 | 55 | def append(self, matric, value): 56 | self.metrics[matric].append(value) 57 | 58 | def get_mean_ci(self, metric): 59 | return mean_conf_int(np.array(self.metrics[metric])) 60 | 61 | def hp_filter(signal, cut_off=80, order=10, sr=16000): 62 | factor = cut_off /sr * 2 63 | sos = butter(order, factor, 'hp', output='sos') 64 | filtered = sosfilt(sos, signal) 65 | return filtered 66 | 67 | def si_sdr(s, s_hat): 68 | alpha = np.dot(s_hat, s)/np.linalg.norm(s)**2 69 | sdr = 10*np.log10(np.linalg.norm(alpha*s)**2/np.linalg.norm( 70 | alpha*s - s_hat)**2) 71 | return sdr 72 | 73 | def snr_dB(s,n): 74 | s_power = 1/len(s)*np.sum(s**2) 75 | n_power = 1/len(n)*np.sum(n**2) 76 | snr_dB = 10*np.log10(s_power/n_power) 77 | return snr_dB 78 | 79 | def pad_spec(Y): 80 | T = Y.size(3) 81 | if T%64 !=0: 82 | num_pad = 64-T%64 83 | else: 84 | num_pad = 0 85 | pad2d = torch.nn.ZeroPad2d((0, num_pad, 0,0)) 86 | return pad2d(Y) 87 | 88 | 89 | def ensure_dir(file_path): 90 | directory = os.path.dirname(file_path) 91 | if not os.path.exists(directory): 92 | os.makedirs(directory) 93 | 94 | 95 | def print_metrics(x, y, x_hat_list, labels, sr=16000): 96 | _si_sdr_mix = si_sdr(x, y) 97 | _pesq_mix = pesq(sr, x, y, 'wb') 98 | _estoi_mix = stoi(x, y, sr, extended=True) 99 | print(f'Mixture: PESQ: {_pesq_mix:.2f}, ESTOI: {_estoi_mix:.2f}, SI-SDR: {_si_sdr_mix:.2f}') 100 | for i, x_hat in enumerate(x_hat_list): 101 | _si_sdr = si_sdr(x, x_hat) 102 | _pesq = pesq(sr, x, x_hat, 'wb') 103 | _estoi = stoi(x, x_hat, sr, extended=True) 104 | print(f'{labels[i]}: {_pesq:.2f}, ESTOI: {_estoi:.2f}, SI-SDR: {_si_sdr:.2f}') 105 | 106 | def mean_std(data): 107 | data = data[~np.isnan(data)] 108 | mean = np.mean(data) 109 | std = np.std(data) 110 | return mean, std 111 | 112 | def print_mean_std(data, decimal=2): 113 | data = np.array(data) 114 | data = data[~np.isnan(data)] 115 | mean = np.mean(data) 116 | std = np.std(data) 117 | if decimal == 2: 118 | string = f'{mean:.2f} ± {std:.2f}' 119 | elif decimal == 1: 120 | string = f'{mean:.1f} ± {std:.1f}' 121 | return string -------------------------------------------------------------------------------- /sgmse/sampling/__init__.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/sampling.py 2 | """Various sampling methods.""" 3 | from scipy import integrate 4 | import torch 5 | import numpy as np 6 | 7 | from .predictors import Predictor, PredictorRegistry, ReverseDiffusionPredictor 8 | from .correctors import Corrector, CorrectorRegistry 9 | 10 | 11 | __all__ = [ 12 | 'PredictorRegistry', 'CorrectorRegistry', 'Predictor', 'Corrector', 13 | 'get_sampler' 14 | ] 15 | 16 | 17 | def to_flattened_numpy(x): 18 | """Flatten a torch tensor `x` and convert it to numpy.""" 19 | return x.detach().cpu().numpy().reshape((-1,)) 20 | 21 | 22 | def from_flattened_numpy(x, shape): 23 | """Form a torch tensor with the given `shape` from a flattened numpy array `x`.""" 24 | return torch.from_numpy(x.reshape(shape)) 25 | 26 | 27 | 28 | 29 | def get_ode_sampler( 30 | ode, score_fn, y, N, inverse_scaler=None, 31 | eps=3e-2, device='cuda', **kwargs 32 | ): 33 | """Probability flow ODE sampler with the black-box ODE solver. 34 | 35 | Args: 36 | sde: An `sdes.SDE` object representing the forward SDE. 37 | score_fn: A function (typically learned model) that predicts the score. 38 | y: A `torch.Tensor`, representing the (non-white-)noisy starting point(s) to condition the prior on. 39 | inverse_scaler: The inverse data normalizer. 40 | denoise: If `True`, add one-step denoising to final samples. 41 | rtol: A `float` number. The relative tolerance level of the ODE solver. 42 | atol: A `float` number. The absolute tolerance level of the ODE solver. 43 | method: A `str`. The algorithm used for the black-box ODE solver. 44 | See the documentation of `scipy.integrate.solve_ivp`. 45 | eps: A `float` number. The reverse-time SDE/ODE will be integrated to `eps` for numerical stability. 46 | device: PyTorch device. 47 | 48 | Returns: 49 | A sampling function that returns samples and the number of function evaluations during sampling. 50 | """ 51 | # predictor = ReverseDiffusionPredictor(sde, score_fn, probability_flow=False) 52 | # rsde = sde.reverse(score_fn, probability_flow=True) 53 | 54 | # def denoise_update_fn(x): 55 | # vec_eps = torch.ones(x.shape[0], device=x.device) * eps 56 | # _, x = predictor.update_fn(x, vec_eps, y) 57 | # return x 58 | 59 | # def drift_fn(x, t, y): 60 | # """Get the drift function of the reverse-time SDE.""" 61 | # return rsde.sde(x, t, y)[0] 62 | conditioning = kwargs["conditioning"] 63 | def ode_sampler(z=None, **kwargs): 64 | """The probability flow ODE sampler with black-box ODE solver. 65 | 66 | Args: 67 | model: A score model. 68 | z: If present, generate samples from latent code `z`. 69 | Returns: 70 | samples, number of function evaluations. 71 | """ 72 | with torch.no_grad(): 73 | # If not represent, sample the latent code from the prior distibution of the SDE. 74 | x,_ = ode.prior_sampling(y.shape, y) 75 | x = x.to(device) 76 | 77 | def ode_func(t, x): 78 | x = from_flattened_numpy(x, y.shape).to(device).type(torch.complex64) 79 | # print(type(x)) 80 | vec_t = torch.ones(y.shape[0], device=x.device) * t 81 | 82 | drift = score_fn(x, vec_t, conditioning, y) 83 | drift = drift.cpu() 84 | # print(type(drift)) 85 | return to_flattened_numpy(drift) 86 | 87 | # Black-box ODE solver for the probability flow ODE 88 | xt = to_flattened_numpy(x) 89 | timesteps = torch.linspace(ode.T, eps, N, device=y.device) 90 | for i in range(len(timesteps)): 91 | t = timesteps[i] 92 | if i == len(timesteps)-1: 93 | dt = 0-t 94 | else: 95 | dt = timesteps[i+1]-t 96 | # print(type(xt)) 97 | # print(type(dt)) 98 | dt = dt.cpu().numpy() 99 | xt = xt + dt * ode_func(t,xt) 100 | 101 | nfe = N 102 | # print(y.shape) 103 | x = torch.tensor(xt).reshape(y.shape).to(device).type(torch.complex64) 104 | 105 | # Denoising is equivalent to running one predictor step without adding noise 106 | # if denoise: 107 | # x = denoise_update_fn(x) 108 | 109 | # if inverse_scaler is not None: 110 | # x = inverse_scaler(x) 111 | return x, nfe 112 | 113 | return ode_sampler 114 | -------------------------------------------------------------------------------- /sgmse/sampling/correctors.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import torch 3 | 4 | from sgmse import odes 5 | from sgmse.util.registry import Registry 6 | 7 | 8 | CorrectorRegistry = Registry("Corrector") 9 | 10 | 11 | class Corrector(abc.ABC): 12 | """The abstract class for a corrector algorithm.""" 13 | 14 | def __init__(self, sde, score_fn, snr, n_steps): 15 | super().__init__() 16 | self.rsde = sde.reverse(score_fn) 17 | self.score_fn = score_fn 18 | self.snr = snr 19 | self.n_steps = n_steps 20 | 21 | @abc.abstractmethod 22 | def update_fn(self, x, t, *args): 23 | """One update of the corrector. 24 | 25 | Args: 26 | x: A PyTorch tensor representing the current state 27 | t: A PyTorch tensor representing the current time step. 28 | *args: Possibly additional arguments, in particular `y` for OU processes 29 | 30 | Returns: 31 | x: A PyTorch tensor of the next state. 32 | x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising. 33 | """ 34 | pass 35 | 36 | 37 | @CorrectorRegistry.register(name='langevin') 38 | class LangevinCorrector(Corrector): 39 | def __init__(self, sde, score_fn, snr, n_steps): 40 | super().__init__(sde, score_fn, snr, n_steps) 41 | self.score_fn = score_fn 42 | self.n_steps = n_steps 43 | self.snr = snr 44 | 45 | def update_fn(self, x, t, *args, **kwargs): 46 | target_snr = self.snr 47 | for _ in range(self.n_steps): 48 | if "conditioning" in kwargs.keys() and kwargs["conditioning"] is not None: 49 | grad = self.score_fn(x, t, score_conditioning=kwargs["conditioning"], sde_input=args[0]) 50 | else: 51 | grad = self.score_fn(x, t, *args) 52 | noise = torch.randn_like(x) 53 | grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean() 54 | noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean() 55 | step_size = ((target_snr * noise_norm / grad_norm) ** 2 * 2).unsqueeze(0) 56 | if step_size.ndim < x.ndim: 57 | step_size = step_size.view( *step_size.size(), *((1,)*(x.ndim - step_size.ndim)) ) 58 | x_mean = x + step_size * grad 59 | x = x_mean + noise * torch.sqrt(step_size * 2) 60 | 61 | return x, x_mean 62 | 63 | 64 | @CorrectorRegistry.register(name='ald') 65 | class AnnealedLangevinDynamics(Corrector): 66 | """The original annealed Langevin dynamics predictor in NCSN/NCSNv2.""" 67 | def __init__(self, sde, score_fn, snr, n_steps): 68 | super().__init__(sde, score_fn, snr, n_steps) 69 | if not isinstance(sde, (sdes.OUVESDE)): 70 | raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.") 71 | self.sde = sde 72 | self.score_fn = score_fn 73 | self.snr = snr 74 | self.n_steps = n_steps 75 | 76 | def update_fn(self, x, t, *args, **kwargs): 77 | n_steps = self.n_steps 78 | target_snr = self.snr 79 | std = self.sde.marginal_prob(x, t, *args)[1] 80 | 81 | for i in range(n_steps): 82 | if "conditioning" in kwargs.keys() and kwargs["conditioning"] is not None: 83 | grad = self.score_fn(x, t, score_conditioning=kwargs["conditioning"], sde_input=args[0]) 84 | else: 85 | grad = self.score_fn(x, t, *args) 86 | noise = torch.randn_like(x) 87 | step_size = (target_snr * std) ** 2 * 2 88 | if step_size.ndim < x.ndim: 89 | step_size = step_size.view( *step_size.size(), *((1,)*(x.ndim - step_size.ndim)) ) 90 | x_mean = x + step_size * grad 91 | x = x_mean + noise * torch.sqrt(step_size * 2) 92 | 93 | return x, x_mean 94 | 95 | 96 | @CorrectorRegistry.register(name='none') 97 | class NoneCorrector(Corrector): 98 | """An empty corrector that does nothing.""" 99 | 100 | def __init__(self, *args, **kwargs): 101 | self.snr = 0 102 | self.n_steps = 0 103 | pass 104 | 105 | def update_fn(self, x, t, *args, **kwargs): 106 | return x, x 107 | -------------------------------------------------------------------------------- /enhancement.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import glob 3 | from soundfile import read, write 4 | import pandas as pd 5 | from tensorboard import summary 6 | from tqdm import tqdm 7 | from torchaudio import load, save 8 | import torch 9 | import os 10 | from argparse import ArgumentParser 11 | import time 12 | 13 | from sgmse.backbones.shared import BackboneRegistry 14 | from sgmse.data_module import SpecsDataModule 15 | from sgmse.odes import ODERegistry 16 | from sgmse.model import StochasticRegenerationModel, ScoreModel, DiscriminativeModel 17 | 18 | from sgmse.util.other import * 19 | 20 | import matplotlib.pyplot as plt 21 | from os.path import join 22 | from utils import energy_ratios, ensure_dir, print_mean_std 23 | from pesq import pesq 24 | from pystoi import stoi 25 | EPS_LOG = 1e-10 26 | sr=16000 27 | # Tags 28 | base_parser = ArgumentParser(add_help=False) 29 | parser = ArgumentParser() 30 | for parser_ in (base_parser, parser): 31 | parser_.add_argument("--test_dir", type=str, required=True, help="Directory containing your corrupted files to enhance.") 32 | parser_.add_argument("--enhanced_dir", type=str, required=True, help="Where to write your cleaned files.") 33 | parser_.add_argument("--ckpt", type=str, required=True) 34 | parser_.add_argument("--mode", required=True, choices=["score-only", "denoiser-only", "storm"]) 35 | 36 | 37 | parser_.add_argument("--N", type=int, default=5, help="Number of reverse steps") 38 | 39 | args = parser.parse_args() 40 | 41 | # os.makedirs(args.enhanced_dir, exist_ok=True) 42 | 43 | #Checkpoint 44 | checkpoint_file = args.ckpt 45 | 46 | # Settings 47 | model_sr = 16000 48 | 49 | # Load score model 50 | if args.mode == "storm": 51 | model_cls = StochasticRegenerationModel 52 | elif args.mode == "score-only": 53 | model_cls = ScoreModel 54 | elif args.mode == "denoiser-only": 55 | model_cls = DiscriminativeModel 56 | 57 | model = model_cls.load_from_checkpoint( 58 | checkpoint_file, base_dir="", 59 | batch_size=1, num_workers=0, kwargs=dict(gpu=False) 60 | ) 61 | model.eval(no_ema=False) 62 | model.cuda() 63 | # num_parameters = count_parameters(model) 64 | # print(f"Total number of trainable parameters: {num_parameters}") 65 | 66 | 67 | clean_dir = join(args.test_dir, "test", "clean") 68 | noisy_dir = join(args.test_dir, "test", "noisy") 69 | target_dir = f"/data/{args.enhanced_dir}/" 70 | ensure_dir(target_dir + "files/") 71 | 72 | noisy_files = sorted(glob.glob('{}/*.wav'.format(noisy_dir))) 73 | # print(noisy_files) 74 | 75 | data = {"filename": [], "pesq": [], "estoi": [], "si_sdr": [], "si_sir": [], "si_sar": []} 76 | for cnt, noisy_file in tqdm.tqdm(enumerate(noisy_files)): 77 | filename = noisy_file.split('/')[-1] 78 | 79 | # Load wav 80 | x, _ = load(join(clean_dir, filename)) 81 | y, _ = load(noisy_file) 82 | 83 | 84 | x_hat = model.enhance(y, N=args.N) 85 | 86 | x_hat = x_hat.squeeze().cpu().numpy() 87 | 88 | 89 | x = x.squeeze().cpu().numpy() 90 | y = y.squeeze().cpu().numpy() 91 | n = y - x 92 | # Write enhanced wav file 93 | write(target_dir + "files/" + filename, x_hat, 16000) 94 | 95 | # Append metrics to data frame 96 | data["filename"].append(filename) 97 | try: 98 | p = pesq(sr, x, x_hat, 'wb') 99 | except: 100 | p = float("nan") 101 | data["pesq"].append(p) 102 | data["estoi"].append(stoi(x, x_hat, sr, extended=True)) 103 | data["si_sdr"].append(energy_ratios(x_hat, x, n)[0]) 104 | data["si_sir"].append(energy_ratios(x_hat, x, n)[1]) 105 | data["si_sar"].append(energy_ratios(x_hat, x, n)[2]) 106 | 107 | # Save results as DataFrame 108 | df = pd.DataFrame(data) 109 | df.to_csv(join(target_dir, "_results.csv"), index=False) 110 | 111 | # Save average results 112 | text_file = join(target_dir, "_avg_results.txt") 113 | with open(text_file, 'w') as file: 114 | file.write("PESQ: {} \n".format(print_mean_std(data["pesq"]))) 115 | file.write("ESTOI: {} \n".format(print_mean_std(data["estoi"]))) 116 | file.write("SI-SDR: {} \n".format(print_mean_std(data["si_sdr"]))) 117 | file.write("SI-SIR: {} \n".format(print_mean_std(data["si_sir"]))) 118 | file.write("SI-SAR: {} \n".format(print_mean_std(data["si_sar"]))) 119 | 120 | # Save settings 121 | text_file = join(target_dir, "_settings.txt") 122 | with open(text_file, 'w') as file: 123 | file.write("checkpoint file: {}\n".format(args.ckpt)) 124 | 125 | 126 | file.write("N: {}\n".format(args.N)) 127 | 128 | 129 | 130 | 131 | 132 | -------------------------------------------------------------------------------- /sgmse/backbones/shared.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | # from asteroid import complex_nn 7 | 8 | from sgmse.util.registry import Registry 9 | 10 | 11 | BackboneRegistry = Registry("Backbone") 12 | 13 | 14 | class GaussianFourierProjection(nn.Module): 15 | """Gaussian random features for encoding time steps.""" 16 | 17 | def __init__(self, embed_dim, scale=16, complex_valued=False): 18 | super().__init__() 19 | self.complex_valued = complex_valued 20 | if not complex_valued: 21 | # If the output is real-valued, we concatenate sin+cos of the features to avoid ambiguities. 22 | # Therefore, in this case the effective embed_dim is cut in half. For the complex-valued case, 23 | # we use complex numbers which each represent sin+cos directly, so the ambiguity is avoided directly, 24 | # and this halving is not necessary. 25 | embed_dim = embed_dim // 2 26 | # Randomly sample weights during initialization. These weights are fixed 27 | # during optimization and are not trainable. 28 | self.W = nn.Parameter(torch.randn(embed_dim) * scale, requires_grad=False) 29 | 30 | def forward(self, t): 31 | t_proj = t[:, None] * self.W[None, :] * 2*np.pi 32 | if self.complex_valued: 33 | return torch.exp(1j * t_proj) 34 | else: 35 | return torch.cat([torch.sin(t_proj), torch.cos(t_proj)], dim=-1) 36 | 37 | 38 | class DiffusionStepEmbedding(nn.Module): 39 | """Diffusion-Step embedding as in DiffWave / Vaswani et al. 2017.""" 40 | 41 | def __init__(self, embed_dim, complex_valued=False): 42 | super().__init__() 43 | self.complex_valued = complex_valued 44 | if not complex_valued: 45 | # If the output is real-valued, we concatenate sin+cos of the features to avoid ambiguities. 46 | # Therefore, in this case the effective embed_dim is cut in half. For the complex-valued case, 47 | # we use complex numbers which each represent sin+cos directly, so the ambiguity is avoided directly, 48 | # and this halving is not necessary. 49 | embed_dim = embed_dim // 2 50 | self.embed_dim = embed_dim 51 | 52 | def forward(self, t): 53 | fac = 10**(4*torch.arange(self.embed_dim, device=t.device) / (self.embed_dim-1)) 54 | inner = t[:, None] * fac[None, :] 55 | if self.complex_valued: 56 | return torch.exp(1j * inner) 57 | else: 58 | return torch.cat([torch.sin(inner), torch.cos(inner)], dim=-1) 59 | 60 | 61 | class ComplexLinear(nn.Module): 62 | """A potentially complex-valued linear layer. Reduces to a regular linear layer if `complex_valued=False`.""" 63 | def __init__(self, input_dim, output_dim, complex_valued): 64 | super().__init__() 65 | self.complex_valued = complex_valued 66 | if self.complex_valued: 67 | self.re = nn.Linear(input_dim, output_dim) 68 | self.im = nn.Linear(input_dim, output_dim) 69 | else: 70 | self.lin = nn.Linear(input_dim, output_dim) 71 | 72 | def forward(self, x): 73 | if self.complex_valued: 74 | return (self.re(x.real) - self.im(x.imag)) + 1j*(self.re(x.imag) + self.im(x.real)) 75 | else: 76 | return self.lin(x) 77 | 78 | 79 | class FeatureMapDense(nn.Module): 80 | """A fully connected layer that reshapes outputs to feature maps.""" 81 | 82 | def __init__(self, input_dim, output_dim, complex_valued=False): 83 | super().__init__() 84 | self.complex_valued = complex_valued 85 | self.dense = ComplexLinear(input_dim, output_dim, complex_valued=complex_valued) 86 | 87 | def forward(self, x): 88 | return self.dense(x)[..., None, None] 89 | 90 | 91 | # class ArgsComplexMultiplicationWrapper(nn.Module): 92 | # """Adapted from `asteroid`'s `complex_nn.py`, allowing args/kwargs to be passed through forward(). 93 | 94 | # Make a complex-valued module `F` from a real-valued module `f` by applying 95 | # complex multiplication rules: 96 | 97 | # F(a + i b) = f1(a) - f1(b) + i (f2(b) + f2(a)) 98 | 99 | # where `f1`, `f2` are instances of `f` that do *not* share weights. 100 | 101 | # Args: 102 | # module_cls (callable): A class or function that returns a Torch module/functional. 103 | # Constructor of `f` in the formula above. Called 2x with `*args`, `**kwargs`, 104 | # to construct the real and imaginary component modules. 105 | # """ 106 | 107 | # def __init__(self, module_cls, *args, **kwargs): 108 | # super().__init__() 109 | # self.re_module = module_cls(*args, **kwargs) 110 | # self.im_module = module_cls(*args, **kwargs) 111 | 112 | # def forward(self, x: complex_nn.ComplexTensor, *args, **kwargs) -> complex_nn.ComplexTensor: 113 | # return complex_nn.torch_complex_from_reim( 114 | # self.re_module(x.real, *args, **kwargs) - self.im_module(x.imag, *args, **kwargs), 115 | # self.re_module(x.imag, *args, **kwargs) + self.im_module(x.real, *args, **kwargs), 116 | # ) 117 | 118 | 119 | # ComplexConv2d = functools.partial(ArgsComplexMultiplicationWrapper, nn.Conv2d) 120 | # ComplexConvTranspose2d = functools.partial(ArgsComplexMultiplicationWrapper, nn.ConvTranspose2d) 121 | -------------------------------------------------------------------------------- /sgmse/util/other.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.stats 3 | import torch 4 | import csv 5 | import os 6 | import glob 7 | import tqdm 8 | import torchaudio 9 | import matplotlib.pyplot as plt 10 | import time 11 | from pydub import AudioSegment 12 | import scipy.signal as ss 13 | 14 | stft_kwargs = {"n_fft": 510, "hop_length": 128, "window": torch.hann_window(510), "return_complex": True} 15 | 16 | def lsd(s_hat, s, eps=1e-10): 17 | S_hat, S = torch.stft(torch.from_numpy(s_hat), **stft_kwargs), torch.stft(torch.from_numpy(s), **stft_kwargs) 18 | logPowerS_hat, logPowerS = 2*torch.log(eps + torch.abs(S_hat)), 2*torch.log(eps + torch.abs(S)) 19 | return torch.mean( torch.sqrt(torch.mean(torch.abs( logPowerS_hat - logPowerS ))) ).item() 20 | 21 | def si_sdr_components(s_hat, s, n, eps=1e-10): 22 | # s_target 23 | alpha_s = np.dot(s_hat, s) / (eps + np.linalg.norm(s)**2) 24 | s_target = alpha_s * s 25 | 26 | # e_noise 27 | alpha_n = np.dot(s_hat, n) / (eps + np.linalg.norm(n)**2) 28 | e_noise = alpha_n * n 29 | 30 | # e_art 31 | e_art = s_hat - s_target - e_noise 32 | 33 | return s_target, e_noise, e_art 34 | 35 | def energy_ratios(s_hat, s, n, eps=1e-10): 36 | """ 37 | """ 38 | s_target, e_noise, e_art = si_sdr_components(s_hat, s, n) 39 | 40 | si_sdr = 10*np.log10(eps + np.linalg.norm(s_target)**2 / (eps + np.linalg.norm(e_noise + e_art)**2)) 41 | si_sir = 10*np.log10(eps + np.linalg.norm(s_target)**2 / (eps + np.linalg.norm(e_noise)**2)) 42 | si_sar = 10*np.log10(eps + np.linalg.norm(s_target)**2 / (eps + np.linalg.norm(e_art)**2)) 43 | 44 | return si_sdr, si_sir, si_sar 45 | 46 | def mean_conf_int(data, confidence=0.95): 47 | a = 1.0 * np.array(data) 48 | n = len(a) 49 | m, se = np.mean(a), scipy.stats.sem(a) 50 | h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1) 51 | return m, h 52 | 53 | def mean_std(data): 54 | data = data[~np.isnan(data)] 55 | mean = np.mean(data) 56 | std = np.std(data) 57 | return mean, std 58 | 59 | class Method(): 60 | def __init__(self, name, base_dir, metrics): 61 | self.name = name 62 | self.base_dir = base_dir 63 | self.metrics = {} 64 | 65 | for i in range(len(metrics)): 66 | metric = metrics[i] 67 | value = [] 68 | self.metrics[metric] = value 69 | 70 | def append(self, matric, value): 71 | self.metrics[matric].append(value) 72 | 73 | def get_mean_ci(self, metric): 74 | return mean_conf_int(np.array(self.metrics[metric])) 75 | 76 | def hp_filter(signal, cut_off=80, order=10, sr=16000): 77 | factor = cut_off /sr * 2 78 | sos = ss.butter(order, factor, 'hp', output='sos') 79 | filtered = ss.sosfilt(sos, signal) 80 | return filtered 81 | 82 | def si_sdr(s, s_hat): 83 | alpha = np.dot(s_hat, s)/np.linalg.norm(s)**2 84 | sdr = 10*np.log10(np.linalg.norm(alpha*s)**2/np.linalg.norm( 85 | alpha*s - s_hat)**2) 86 | return sdr 87 | 88 | def si_sdr_torch(s, s_hat): 89 | min_len = min(s.size(-1), s_hat.size(-1)) 90 | s, s_hat = s[..., : min_len], s_hat[..., : min_len] 91 | alpha = torch.dot(s_hat, s)/torch.norm(s)**2 92 | sdr = 10*torch.log10(1e-10 + torch.norm(alpha*s)**2/(1e-10 + torch.norm( 93 | alpha*s - s_hat)**2)) 94 | return sdr 95 | 96 | def snr_dB(s,n): 97 | s_power = 1/len(s)*np.sum(s**2) 98 | n_power = 1/len(n)*np.sum(n**2) 99 | snr_dB = 10*np.log10(s_power/n_power) 100 | return snr_dB 101 | 102 | def pad_spec(Y): 103 | T = Y.size(3) 104 | if T%64 !=0: 105 | num_pad = 64-T%64 106 | else: 107 | num_pad = 0 108 | pad2d = torch.nn.ZeroPad2d((0, num_pad, 0,0)) 109 | return pad2d(Y) 110 | 111 | # def pad_time(Y): 112 | # padding_target = 8320 113 | # T = Y.size(2) 114 | # if T%padding_target !=0: 115 | # num_pad = padding_target-T%padding_target 116 | # else: 117 | # num_pad = 0 118 | # pad2d = torch.nn.ZeroPad2d((0, num_pad, 0, 0)) 119 | # return pad2d(Y) 120 | 121 | def mean_std(data): 122 | data = data[~np.isnan(data)] 123 | mean = np.mean(data) 124 | std = np.std(data) 125 | return mean, std 126 | 127 | 128 | 129 | def init_exp_csv_samples(output_path, tag_metric): 130 | with open(output_path, 'w', newline='') as csv_file: 131 | writer = csv.writer(csv_file, delimiter=",") 132 | fieldnames = ["Filename", "Length", "T60", "iSNR"] + tag_metric 133 | writer.writerow(fieldnames) 134 | csv_file.close() 135 | 136 | def snr_scale_factor(speech, noise, snr): 137 | noise_var = np.var(noise) 138 | speech_var = np.var(speech) 139 | 140 | factor = np.sqrt(speech_var / (noise_var * 10. ** (snr / 10.))) 141 | 142 | return factor 143 | 144 | def pydub_read(path, sr=16000): 145 | y = AudioSegment.from_file(path) 146 | y = y.set_frame_rate(sr) 147 | channel_sounds = y.split_to_mono() 148 | samples = [s.get_array_of_samples() for s in channel_sounds] 149 | fp_arr = np.array(samples).T.astype(np.float32) 150 | fp_arr /= np.iinfo(samples[0].typecode).max 151 | return fp_arr 152 | 153 | def align(y, ref): 154 | l = np.argmax(ss.fftconvolve(ref.squeeze(), np.flip(y.squeeze()))) - (ref.shape[0] - 1) 155 | if l: 156 | y = torch.from_numpy(np.roll(y, l, axis=-1)) 157 | return y 158 | 159 | def wer(r, h): 160 | ''' 161 | by zszyellow 162 | https://github.com/zszyellow/WER-in-python/blob/master/wer.py 163 | This function is to calculate the edit distance of reference sentence and the hypothesis sentence. 164 | Main algorithm used is dynamic programming. 165 | Attributes: 166 | r -> the list of words produced by splitting reference sentence. 167 | h -> the list of words produced by splitting hypothesis sentence. 168 | ''' 169 | d = np.zeros((len(r)+1)*(len(h)+1), dtype=np.uint8).reshape((len(r)+1, len(h)+1)) 170 | for i in range(len(r)+1): 171 | d[i][0] = i 172 | for j in range(len(h)+1): 173 | d[0][j] = j 174 | for i in range(1, len(r)+1): 175 | for j in range(1, len(h)+1): 176 | if r[i-1] == h[j-1]: 177 | d[i][j] = d[i-1][j-1] 178 | else: 179 | substitute = d[i-1][j-1] + 1 180 | insert = d[i][j-1] + 1 181 | delete = d[i-1][j] + 1 182 | d[i][j] = min(substitute, insert, delete) 183 | return float(d[len(r)][len(h)]) / len(r) -------------------------------------------------------------------------------- /sgmse/backbones/ncsnpp_utils/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """All functions and modules related to model definition. 17 | """ 18 | 19 | import torch 20 | #import sde_lib 21 | import numpy as np 22 | from ...sdes import OUVESDE, OUVPSDE 23 | 24 | 25 | _MODELS = {} 26 | 27 | 28 | def register_model(cls=None, *, name=None): 29 | """A decorator for registering model classes.""" 30 | 31 | def _register(cls): 32 | if name is None: 33 | local_name = cls.__name__ 34 | else: 35 | local_name = name 36 | if local_name in _MODELS: 37 | raise ValueError(f'Already registered model with name: {local_name}') 38 | _MODELS[local_name] = cls 39 | return cls 40 | 41 | if cls is None: 42 | return _register 43 | else: 44 | return _register(cls) 45 | 46 | 47 | def get_model(name): 48 | return _MODELS[name] 49 | 50 | 51 | def get_sigmas(sigma_min, sigma_max, num_scales): 52 | """Get sigmas --- the set of noise levels for SMLD from config files. 53 | Args: 54 | config: A ConfigDict object parsed from the config file 55 | Returns: 56 | sigmas: a jax numpy arrary of noise levels 57 | """ 58 | sigmas = np.exp( 59 | np.linspace(np.log(sigma_max), np.log(sigma_min), num_scales)) 60 | 61 | return sigmas 62 | 63 | 64 | def get_ddpm_params(config): 65 | """Get betas and alphas --- parameters used in the original DDPM paper.""" 66 | num_diffusion_timesteps = 1000 67 | # parameters need to be adapted if number of time steps differs from 1000 68 | beta_start = config.model.beta_min / config.model.num_scales 69 | beta_end = config.model.beta_max / config.model.num_scales 70 | betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) 71 | 72 | alphas = 1. - betas 73 | alphas_cumprod = np.cumprod(alphas, axis=0) 74 | sqrt_alphas_cumprod = np.sqrt(alphas_cumprod) 75 | sqrt_1m_alphas_cumprod = np.sqrt(1. - alphas_cumprod) 76 | 77 | return { 78 | 'betas': betas, 79 | 'alphas': alphas, 80 | 'alphas_cumprod': alphas_cumprod, 81 | 'sqrt_alphas_cumprod': sqrt_alphas_cumprod, 82 | 'sqrt_1m_alphas_cumprod': sqrt_1m_alphas_cumprod, 83 | 'beta_min': beta_start * (num_diffusion_timesteps - 1), 84 | 'beta_max': beta_end * (num_diffusion_timesteps - 1), 85 | 'num_diffusion_timesteps': num_diffusion_timesteps 86 | } 87 | 88 | 89 | def create_model(config): 90 | """Create the score model.""" 91 | model_name = config.model.name 92 | score_model = get_model(model_name)(config) 93 | score_model = score_model.to(config.device) 94 | score_model = torch.nn.DataParallel(score_model) 95 | return score_model 96 | 97 | 98 | def get_model_fn(model, train=False): 99 | """Create a function to give the output of the score-based model. 100 | 101 | Args: 102 | model: The score model. 103 | train: `True` for training and `False` for evaluation. 104 | 105 | Returns: 106 | A model function. 107 | """ 108 | 109 | def model_fn(x, labels): 110 | """Compute the output of the score-based model. 111 | 112 | Args: 113 | x: A mini-batch of input data. 114 | labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently 115 | for different models. 116 | 117 | Returns: 118 | A tuple of (model output, new mutable states) 119 | """ 120 | if not train: 121 | model.eval() 122 | return model(x, labels) 123 | else: 124 | model.train() 125 | return model(x, labels) 126 | 127 | return model_fn 128 | 129 | 130 | def get_score_fn(sde, model, train=False, continuous=False): 131 | """Wraps `score_fn` so that the model output corresponds to a real time-dependent score function. 132 | 133 | Args: 134 | sde: An `sde_lib.SDE` object that represents the forward SDE. 135 | model: A score model. 136 | train: `True` for training and `False` for evaluation. 137 | continuous: If `True`, the score-based model is expected to directly take continuous time steps. 138 | 139 | Returns: 140 | A score function. 141 | """ 142 | model_fn = get_model_fn(model, train=train) 143 | 144 | #if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE): 145 | if isinstance(sde, OUVPSDE): 146 | def score_fn(x, t): 147 | # Scale neural network output by standard deviation and flip sign 148 | if continuous or isinstance(sde, sde_lib.subVPSDE): 149 | # For VP-trained models, t=0 corresponds to the lowest noise level 150 | # The maximum value of time embedding is assumed to 999 for 151 | # continuously-trained models. 152 | labels = t * 999 153 | score = model_fn(x, labels) 154 | std = sde.marginal_prob(torch.zeros_like(x), t)[1] 155 | else: 156 | # For VP-trained models, t=0 corresponds to the lowest noise level 157 | labels = t * (sde.N - 1) 158 | score = model_fn(x, labels) 159 | std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()] 160 | 161 | score = -score / std[:, None, None, None] 162 | return score 163 | 164 | #elif isinstance(sde, sde_lib.VESDE): 165 | elif isinstance(sde, OUVESDE): 166 | def score_fn(x, t): 167 | if continuous: 168 | labels = sde.marginal_prob(torch.zeros_like(x), t)[1] 169 | else: 170 | # For VE-trained models, t=0 corresponds to the highest noise level 171 | labels = sde.T - t 172 | labels *= sde.N - 1 173 | labels = torch.round(labels).long() 174 | 175 | score = model_fn(x, labels) 176 | return score 177 | 178 | else: 179 | raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.") 180 | 181 | return score_fn 182 | 183 | 184 | def to_flattened_numpy(x): 185 | """Flatten a torch tensor `x` and convert it to numpy.""" 186 | return x.detach().cpu().numpy().reshape((-1,)) 187 | 188 | 189 | def from_flattened_numpy(x, shape): 190 | """Form a torch tensor with the given `shape` from a flattened numpy array `x`.""" 191 | return torch.from_numpy(x.reshape(shape)) -------------------------------------------------------------------------------- /sgmse/backbones/ncsnpp_utils/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | from torch.autograd import Function 6 | from torch.utils.cpp_extension import load 7 | 8 | 9 | module_path = os.path.dirname(__file__) 10 | upfirdn2d_op = load( 11 | "upfirdn2d", 12 | sources=[ 13 | os.path.join(module_path, "upfirdn2d.cpp"), 14 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 15 | ], 16 | ) 17 | 18 | 19 | class UpFirDn2dBackward(Function): 20 | @staticmethod 21 | def forward( 22 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 23 | ): 24 | 25 | up_x, up_y = up 26 | down_x, down_y = down 27 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 28 | 29 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 30 | 31 | grad_input = upfirdn2d_op.upfirdn2d( 32 | grad_output, 33 | grad_kernel, 34 | down_x, 35 | down_y, 36 | up_x, 37 | up_y, 38 | g_pad_x0, 39 | g_pad_x1, 40 | g_pad_y0, 41 | g_pad_y1, 42 | ) 43 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 44 | 45 | ctx.save_for_backward(kernel) 46 | 47 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 48 | 49 | ctx.up_x = up_x 50 | ctx.up_y = up_y 51 | ctx.down_x = down_x 52 | ctx.down_y = down_y 53 | ctx.pad_x0 = pad_x0 54 | ctx.pad_x1 = pad_x1 55 | ctx.pad_y0 = pad_y0 56 | ctx.pad_y1 = pad_y1 57 | ctx.in_size = in_size 58 | ctx.out_size = out_size 59 | 60 | return grad_input 61 | 62 | @staticmethod 63 | def backward(ctx, gradgrad_input): 64 | kernel, = ctx.saved_tensors 65 | 66 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 67 | 68 | gradgrad_out = upfirdn2d_op.upfirdn2d( 69 | gradgrad_input, 70 | kernel, 71 | ctx.up_x, 72 | ctx.up_y, 73 | ctx.down_x, 74 | ctx.down_y, 75 | ctx.pad_x0, 76 | ctx.pad_x1, 77 | ctx.pad_y0, 78 | ctx.pad_y1, 79 | ) 80 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 81 | gradgrad_out = gradgrad_out.view( 82 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 83 | ) 84 | 85 | return gradgrad_out, None, None, None, None, None, None, None, None 86 | 87 | 88 | class UpFirDn2d(Function): 89 | @staticmethod 90 | def forward(ctx, input, kernel, up, down, pad): 91 | up_x, up_y = up 92 | down_x, down_y = down 93 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 94 | 95 | kernel_h, kernel_w = kernel.shape 96 | batch, channel, in_h, in_w = input.shape 97 | ctx.in_size = input.shape 98 | 99 | input = input.reshape(-1, in_h, in_w, 1) 100 | 101 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 102 | 103 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 104 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 105 | ctx.out_size = (out_h, out_w) 106 | 107 | ctx.up = (up_x, up_y) 108 | ctx.down = (down_x, down_y) 109 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 110 | 111 | g_pad_x0 = kernel_w - pad_x0 - 1 112 | g_pad_y0 = kernel_h - pad_y0 - 1 113 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 114 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 115 | 116 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 117 | 118 | out = upfirdn2d_op.upfirdn2d( 119 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 120 | ) 121 | # out = out.view(major, out_h, out_w, minor) 122 | out = out.view(-1, channel, out_h, out_w) 123 | 124 | return out 125 | 126 | @staticmethod 127 | def backward(ctx, grad_output): 128 | kernel, grad_kernel = ctx.saved_tensors 129 | 130 | grad_input = UpFirDn2dBackward.apply( 131 | grad_output, 132 | kernel, 133 | grad_kernel, 134 | ctx.up, 135 | ctx.down, 136 | ctx.pad, 137 | ctx.g_pad, 138 | ctx.in_size, 139 | ctx.out_size, 140 | ) 141 | 142 | return grad_input, None, None, None, None 143 | 144 | 145 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 146 | if input.device.type == "cpu": 147 | out = upfirdn2d_native( 148 | input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] 149 | ) 150 | 151 | else: 152 | out = UpFirDn2d.apply( 153 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 154 | ) 155 | 156 | return out 157 | 158 | 159 | def upfirdn2d_native( 160 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 161 | ): 162 | _, channel, in_h, in_w = input.shape 163 | input = input.reshape(-1, in_h, in_w, 1) 164 | 165 | _, in_h, in_w, minor = input.shape 166 | kernel_h, kernel_w = kernel.shape 167 | 168 | out = input.view(-1, in_h, 1, in_w, 1, minor) 169 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 170 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 171 | 172 | out = F.pad( 173 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 174 | ) 175 | out = out[ 176 | :, 177 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 178 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 179 | :, 180 | ] 181 | 182 | out = out.permute(0, 3, 1, 2) 183 | out = out.reshape( 184 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 185 | ) 186 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 187 | out = F.conv2d(out, w) 188 | out = out.reshape( 189 | -1, 190 | minor, 191 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 192 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 193 | ) 194 | out = out.permute(0, 2, 3, 1) 195 | out = out[:, ::down_y, ::down_x, :] 196 | 197 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 198 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 199 | 200 | return out.view(-1, channel, out_h, out_w) 201 | -------------------------------------------------------------------------------- /preprocessing/nonlinear_mixing.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | ffmpeg = "/usr/local/bin/ffmpeg" 4 | 5 | 6 | #!/usr/env/bin/python3 7 | 8 | import matplotlib 9 | matplotlib.use('Agg') 10 | import matplotlib.pyplot as plt 11 | import os 12 | from os.path import join 13 | import numpy as np 14 | import soundfile as sf 15 | import glob 16 | import argparse 17 | import time 18 | import json 19 | from tqdm import tqdm 20 | import shutil 21 | import scipy.signal as ss 22 | import io 23 | import scipy.io.wavfile 24 | import pyroomacoustics as pra 25 | 26 | from utils import obtain_noise_file 27 | 28 | SEED = 100 29 | np.random.seed(SEED) 30 | 31 | 32 | def buildFFmpegCommand(params): 33 | 34 | filter_commands = "" 35 | filter_commands += "[1:a]asplit=2[sc][mix];" 36 | filter_commands += "[0:a][sc]sidechaincompress=" + \ 37 | f"threshold={params['threshold']}:" + \ 38 | f"ratio={params['ratio']}:" + \ 39 | f"level_sc={params['sc_gain']}" + \ 40 | f":release={params['release']}" + \ 41 | f":attack={params['attack']}" + \ 42 | "[compr];" 43 | filter_commands += "[compr][mix]amix" 44 | 45 | commands_list = [ 46 | "ffmpeg", 47 | "-y", 48 | "-i", 49 | params["speech_path"], 50 | "-i", 51 | params["noise_path"], 52 | "-filter_complex", 53 | filter_commands, 54 | params["output_path"] 55 | ] 56 | 57 | # return (" ").join(commands_list) 58 | return commands_list 59 | 60 | 61 | 62 | 63 | params = { 64 | "snr_range": [-6, 14], 65 | "threshold_range": [0.1, 0.3], 66 | "ratio_range": [1, 20], 67 | "attack_range": [5, 100], 68 | "release_range": [5, 500], 69 | "sc_gain_range": [0.8, 1.2], 70 | "clipping_threshold_range": [0.85, 1.], 71 | "clipping_chance": .75, 72 | } 73 | 74 | # ROOT = "" ## put your root directory here 75 | ROOT = "/data/lemercier/databases" 76 | assert ROOT != "", "You need to have a root databases directory" 77 | 78 | parser = argparse.ArgumentParser() 79 | 80 | parser.add_argument('--speech_dir', type=str, help='Clean speech', default="/data/lemercier/databases/wsj0+chime3/audio/{}/clean") #Put the correct regexp for your paths here 81 | parser.add_argument('--noise_dir', type=str, help='Noise', default="/data/lemercier/databases/wind_noise_16k/{}/**/") #Put the correct regexp for your paths here 82 | parser.add_argument('--sr', type=int, default=16000) 83 | parser.add_argument('--dummy', action="store_true", help='Number of samples') 84 | 85 | args = parser.parse_args() 86 | 87 | output_dir = join(ROOT, "speech_in_noise_nonlinear") 88 | if os.path.exists(output_dir): 89 | shutil.rmtree(output_dir) 90 | os.makedirs(output_dir, exist_ok=True) 91 | log = open(join(output_dir, "log_stats.txt"), "w") 92 | log.write("Parameters \n ========== \n") 93 | for key, param in params.items(): 94 | log.write(key + " : " + str(param) + "\n") 95 | 96 | for i_split, split in enumerate(["cv", "tr", "tt"]): 97 | 98 | print("Processing split {}...".format(split)) 99 | 100 | speech_split = sorted(glob.glob(join(args.speech_dir.format(split), "*.wav"))) 101 | noise_split = sorted(glob.glob(join(args.noise_dir.format(split), "*.wav"))) 102 | 103 | clean_output_dir = join(output_dir, split, "clean") 104 | noisy_output_dir = join(output_dir, split, "noisy") 105 | os.makedirs(clean_output_dir, exist_ok=True) 106 | os.makedirs(noisy_output_dir, exist_ok=True) 107 | 108 | real_nb_samples = 5 if args.dummy else len(speech_split) 109 | 110 | os.makedirs(join(output_dir, ".cache_noise"), exist_ok=True) 111 | os.makedirs(join(output_dir, ".cache_output"), exist_ok=True) 112 | 113 | for i in tqdm(range(real_nb_samples)): 114 | 115 | speech, sr = sf.read(speech_split[i]) 116 | assert sr == args.sr, "Obtained an unexpected Sampling rate" 117 | i_noise = np.random.randint(len(noise_split)) 118 | noise, sr = sf.read(noise_split[i_noise]) 119 | assert sr == args.sr, "Obtained an unexpected Sampling rate" 120 | 121 | speech_scale = np.max(np.abs(speech)) 122 | noise_scale = np.max(np.abs(noise)) 123 | 124 | if noise.shape[0] < speech.shape[0]: 125 | noise = np.pad(noise, ((0, speech.shape[0] - noise.shape[0]))) 126 | else: 127 | offset = np.random.randint(noise.shape[0] - speech.shape[0]) 128 | noise = noise[offset: offset + speech.shape[0]] 129 | 130 | snr = np.random.uniform(params["snr_range"][0], params["snr_range"][1]) 131 | noise_power = 1/noise.shape[0]*np.sum(noise**2) 132 | speech_power = 1/speech.shape[0]*np.sum(speech**2) 133 | noise_power_target = speech_power*np.power(10, -snr/10) 134 | noise_scaling = np.sqrt(noise_power_target / noise_power) 135 | 136 | noise_tmp = join(output_dir, ".cache_noise", "noise_tmp.wav") 137 | sf.write(noise_tmp, noise*noise_scaling, sr) 138 | 139 | # Compressor 140 | threshold = np.random.uniform(params["threshold_range"][0], params["threshold_range"][1]) 141 | ratio = np.random.uniform(params["ratio_range"][0], params["ratio_range"][1]) 142 | attack = np.random.uniform(params["attack_range"][0], params["attack_range"][1]) 143 | release = np.random.uniform(params["release_range"][0], params["release_range"][1]) 144 | sc_gain = np.random.uniform(params["sc_gain_range"][0], params["sc_gain_range"][1]) 145 | 146 | output_tmp = join(output_dir, ".cache_output", "output_tmp.wav") 147 | 148 | commands = buildFFmpegCommand({ 149 | "speech_path": speech_split[i], 150 | "noise_path": noise_tmp, 151 | "output_path": output_tmp, 152 | "threshold": threshold, 153 | "ratio": ratio, 154 | "attack": attack, 155 | "release": release, 156 | "sc_gain": sc_gain 157 | }) 158 | 159 | print(commands) 160 | if subprocess.run(commands).returncode != 0: 161 | print ("There was an error running your FFmpeg script") 162 | 163 | # Clipper 164 | mix, sr = sf.read(output_tmp) 165 | if np.random.random() < params["clipping_chance"]: 166 | clipping_threshold = np.random.uniform(params["clipping_threshold_range"][0], params["clipping_threshold_range"][1]) 167 | mix = np.maximum(clipping_threshold * np.min(mix)*np.ones_like(mix), mix) 168 | mix = np.minimum(clipping_threshold * np.max(mix)*np.ones_like(mix), mix) 169 | 170 | # Export 171 | output = os.path.basename(speech_split[i])[: -4] + f"_{i}_snr={snr:.1f}.wav" 172 | sf.write(join(noisy_output_dir, output), mix, sr) 173 | sf.write(join(clean_output_dir, os.path.basename(speech_split[i])), speech, sr) 174 | 175 | shutil.rmtree(join(output_dir, ".cache_noise")) 176 | shutil.rmtree(join(output_dir, ".cache_output")) 177 | 178 | -------------------------------------------------------------------------------- /sgmse/odes.py: -------------------------------------------------------------------------------- 1 | """ 2 | Abstract SDE classes, Reverse SDE, and VE/VP SDEs. 3 | 4 | Taken and adapted from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/sde_lib.py 5 | """ 6 | import abc 7 | from email.policy import default 8 | import warnings 9 | 10 | import numpy as np 11 | from sgmse.util.tensors import batch_broadcast 12 | import torch 13 | 14 | from sgmse.util.registry import Registry 15 | import os 16 | 17 | ODERegistry = Registry("ODE") 18 | 19 | ODERegistry = Registry("ODE") 20 | class ODE(abc.ABC): 21 | """ODE abstract class. Functions are designed for a mini-batch of inputs.""" 22 | 23 | def __init__(self): 24 | super().__init__() 25 | 26 | 27 | @property 28 | @abc.abstractmethod 29 | def T(self): 30 | pass 31 | @abc.abstractmethod 32 | def ode(self, x, t, *args): 33 | pass 34 | 35 | @abc.abstractmethod 36 | def marginal_prob(self, x, t, *args): 37 | """Parameters to determine the marginal distribution of the SDE, $p_t(x|args)$.""" 38 | pass 39 | 40 | @abc.abstractmethod 41 | def prior_sampling(self, shape, *args): 42 | """Generate one sample from the prior distribution, $p_T(x|args)$ with shape `shape`.""" 43 | pass 44 | 45 | 46 | @staticmethod 47 | @abc.abstractmethod 48 | def add_argparse_args(parent_parser): 49 | """ 50 | Add the necessary arguments for instantiation of this SDE class to an argparse ArgumentParser. 51 | """ 52 | pass 53 | 54 | 55 | @abc.abstractmethod 56 | def copy(self): 57 | pass 58 | 59 | 60 | 61 | ######################여기 밑에 것이 학습할 대상임############## 62 | 63 | 64 | @ODERegistry.register("flowmatching") 65 | class FLOWMATCHING(ODE): 66 | #original flow matching 67 | #Yaron Lipman, Ricky T. Q. Chen, Heli Ben-Hamu, Maximilian Nickel, and Matt Le. Flow matching for generative modeling. International Conference on Learning Representations (ICLR), 2023. 68 | #mu_t = (1-t)x+ty, sigma_t = (1-t)sigma_min +t 69 | #t범위 0= 1 and kernel % 2 == 1 33 | self.weight = nn.Parameter(torch.zeros(out_ch, in_ch, kernel, kernel)) 34 | if kernel_init is not None: 35 | self.weight.data = kernel_init(self.weight.data.shape) 36 | if use_bias: 37 | self.bias = nn.Parameter(torch.zeros(out_ch)) 38 | 39 | self.up = up 40 | self.down = down 41 | self.resample_kernel = resample_kernel 42 | self.kernel = kernel 43 | self.use_bias = use_bias 44 | 45 | def forward(self, x): 46 | if self.up: 47 | x = upsample_conv_2d(x, self.weight, k=self.resample_kernel) 48 | elif self.down: 49 | x = conv_downsample_2d(x, self.weight, k=self.resample_kernel) 50 | else: 51 | x = F.conv2d(x, self.weight, stride=1, padding=self.kernel // 2) 52 | 53 | if self.use_bias: 54 | x = x + self.bias.reshape(1, -1, 1, 1) 55 | 56 | return x 57 | 58 | 59 | def naive_upsample_2d(x, factor=2): 60 | _N, C, H, W = x.shape 61 | x = torch.reshape(x, (-1, C, H, 1, W, 1)) 62 | x = x.repeat(1, 1, 1, factor, 1, factor) 63 | return torch.reshape(x, (-1, C, H * factor, W * factor)) 64 | 65 | 66 | def naive_downsample_2d(x, factor=2): 67 | _N, C, H, W = x.shape 68 | x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor)) 69 | return torch.mean(x, dim=(3, 5)) 70 | 71 | 72 | def upsample_conv_2d(x, w, k=None, factor=2, gain=1): 73 | """Fused `upsample_2d()` followed by `tf.nn.conv2d()`. 74 | 75 | Padding is performed only once at the beginning, not between the 76 | operations. 77 | The fused op is considerably more efficient than performing the same 78 | calculation 79 | using standard TensorFlow ops. It supports gradients of arbitrary order. 80 | Args: 81 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 82 | C]`. 83 | w: Weight tensor of the shape `[filterH, filterW, inChannels, 84 | outChannels]`. Grouped convolution can be performed by `inChannels = 85 | x.shape[0] // numGroups`. 86 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 87 | (separable). The default is `[1] * factor`, which corresponds to 88 | nearest-neighbor upsampling. 89 | factor: Integer upsampling factor (default: 2). 90 | gain: Scaling factor for signal magnitude (default: 1.0). 91 | 92 | Returns: 93 | Tensor of the shape `[N, C, H * factor, W * factor]` or 94 | `[N, H * factor, W * factor, C]`, and same datatype as `x`. 95 | """ 96 | 97 | assert isinstance(factor, int) and factor >= 1 98 | 99 | # Check weight shape. 100 | assert len(w.shape) == 4 101 | convH = w.shape[2] 102 | convW = w.shape[3] 103 | inC = w.shape[1] 104 | outC = w.shape[0] 105 | 106 | assert convW == convH 107 | 108 | # Setup filter kernel. 109 | if k is None: 110 | k = [1] * factor 111 | k = _setup_kernel(k) * (gain * (factor ** 2)) 112 | p = (k.shape[0] - factor) - (convW - 1) 113 | 114 | stride = (factor, factor) 115 | 116 | # Determine data dimensions. 117 | stride = [1, 1, factor, factor] 118 | output_shape = ((_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW) 119 | output_padding = (output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convH, 120 | output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convW) 121 | assert output_padding[0] >= 0 and output_padding[1] >= 0 122 | num_groups = _shape(x, 1) // inC 123 | 124 | # Transpose weights. 125 | w = torch.reshape(w, (num_groups, -1, inC, convH, convW)) 126 | w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4) 127 | w = torch.reshape(w, (num_groups * inC, -1, convH, convW)) 128 | 129 | x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0) 130 | ## Original TF code. 131 | # x = tf.nn.conv2d_transpose( 132 | # x, 133 | # w, 134 | # output_shape=output_shape, 135 | # strides=stride, 136 | # padding='VALID', 137 | # data_format=data_format) 138 | ## JAX equivalent 139 | 140 | return upfirdn2d(x, torch.tensor(k, device=x.device), 141 | pad=((p + 1) // 2 + factor - 1, p // 2 + 1)) 142 | 143 | 144 | def conv_downsample_2d(x, w, k=None, factor=2, gain=1): 145 | """Fused `tf.nn.conv2d()` followed by `downsample_2d()`. 146 | 147 | Padding is performed only once at the beginning, not between the operations. 148 | The fused op is considerably more efficient than performing the same 149 | calculation 150 | using standard TensorFlow ops. It supports gradients of arbitrary order. 151 | Args: 152 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 153 | C]`. 154 | w: Weight tensor of the shape `[filterH, filterW, inChannels, 155 | outChannels]`. Grouped convolution can be performed by `inChannels = 156 | x.shape[0] // numGroups`. 157 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 158 | (separable). The default is `[1] * factor`, which corresponds to 159 | average pooling. 160 | factor: Integer downsampling factor (default: 2). 161 | gain: Scaling factor for signal magnitude (default: 1.0). 162 | 163 | Returns: 164 | Tensor of the shape `[N, C, H // factor, W // factor]` or 165 | `[N, H // factor, W // factor, C]`, and same datatype as `x`. 166 | """ 167 | 168 | assert isinstance(factor, int) and factor >= 1 169 | _outC, _inC, convH, convW = w.shape 170 | assert convW == convH 171 | if k is None: 172 | k = [1] * factor 173 | k = _setup_kernel(k) * gain 174 | p = (k.shape[0] - factor) + (convW - 1) 175 | s = [factor, factor] 176 | x = upfirdn2d(x, torch.tensor(k, device=x.device), 177 | pad=((p + 1) // 2, p // 2)) 178 | return F.conv2d(x, w, stride=s, padding=0) 179 | 180 | 181 | def _setup_kernel(k): 182 | k = np.asarray(k, dtype=np.float32) 183 | if k.ndim == 1: 184 | k = np.outer(k, k) 185 | k /= np.sum(k) 186 | assert k.ndim == 2 187 | assert k.shape[0] == k.shape[1] 188 | return k 189 | 190 | 191 | def _shape(x, dim): 192 | return x.shape[dim] 193 | 194 | 195 | def upsample_2d(x, k=None, factor=2, gain=1): 196 | r"""Upsample a batch of 2D images with the given filter. 197 | 198 | Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` 199 | and upsamples each image with the given filter. The filter is normalized so 200 | that 201 | if the input pixels are constant, they will be scaled by the specified 202 | `gain`. 203 | Pixels outside the image are assumed to be zero, and the filter is padded 204 | with 205 | zeros so that its shape is a multiple of the upsampling factor. 206 | Args: 207 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 208 | C]`. 209 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 210 | (separable). The default is `[1] * factor`, which corresponds to 211 | nearest-neighbor upsampling. 212 | factor: Integer upsampling factor (default: 2). 213 | gain: Scaling factor for signal magnitude (default: 1.0). 214 | 215 | Returns: 216 | Tensor of the shape `[N, C, H * factor, W * factor]` 217 | """ 218 | assert isinstance(factor, int) and factor >= 1 219 | if k is None: 220 | k = [1] * factor 221 | k = _setup_kernel(k) * (gain * (factor ** 2)) 222 | p = k.shape[0] - factor 223 | return upfirdn2d(x, torch.tensor(k, device=x.device), 224 | up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)) 225 | 226 | 227 | def downsample_2d(x, k=None, factor=2, gain=1): 228 | r"""Downsample a batch of 2D images with the given filter. 229 | 230 | Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` 231 | and downsamples each image with the given filter. The filter is normalized 232 | so that 233 | if the input pixels are constant, they will be scaled by the specified 234 | `gain`. 235 | Pixels outside the image are assumed to be zero, and the filter is padded 236 | with 237 | zeros so that its shape is a multiple of the downsampling factor. 238 | Args: 239 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 240 | C]`. 241 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 242 | (separable). The default is `[1] * factor`, which corresponds to 243 | average pooling. 244 | factor: Integer downsampling factor (default: 2). 245 | gain: Scaling factor for signal magnitude (default: 1.0). 246 | 247 | Returns: 248 | Tensor of the shape `[N, C, H // factor, W // factor]` 249 | """ 250 | 251 | assert isinstance(factor, int) and factor >= 1 252 | if k is None: 253 | k = [1] * factor 254 | k = _setup_kernel(k) * gain 255 | p = k.shape[0] - factor 256 | return upfirdn2d(x, torch.tensor(k, device=x.device), 257 | down=factor, pad=((p + 1) // 2, p // 2)) 258 | -------------------------------------------------------------------------------- /sgmse/backbones/ncsnpp_utils/layerspp.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # pylint: skip-file 17 | """Layers for defining NCSN++. 18 | """ 19 | from . import layers 20 | from . import up_or_down_sampling 21 | import torch.nn as nn 22 | import torch 23 | import torch.nn.functional as F 24 | import numpy as np 25 | 26 | conv1x1 = layers.ddpm_conv1x1 27 | conv3x3 = layers.ddpm_conv3x3 28 | NIN = layers.NIN 29 | default_init = layers.default_init 30 | 31 | 32 | class GaussianFourierProjection(nn.Module): 33 | """Gaussian Fourier embeddings for noise levels.""" 34 | 35 | def __init__(self, embedding_size=256, scale=1.0): 36 | super().__init__() 37 | self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) 38 | 39 | def forward(self, x): 40 | x_proj = x[:, None] * self.W[None, :] * 2 * np.pi 41 | return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) 42 | 43 | 44 | class Combine(nn.Module): 45 | """Combine information from skip connections.""" 46 | 47 | def __init__(self, dim1, dim2, method='cat'): 48 | super().__init__() 49 | self.Conv_0 = conv1x1(dim1, dim2) 50 | self.method = method 51 | 52 | def forward(self, x, y): 53 | h = self.Conv_0(x) 54 | if self.method == 'cat': 55 | return torch.cat([h, y], dim=1) 56 | elif self.method == 'sum': 57 | return h + y 58 | else: 59 | raise ValueError(f'Method {self.method} not recognized.') 60 | 61 | 62 | class AttnBlockpp(nn.Module): 63 | """Channel-wise self-attention block. Modified from DDPM.""" 64 | 65 | def __init__(self, channels, skip_rescale=False, init_scale=0.): 66 | super().__init__() 67 | self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, 68 | eps=1e-6) 69 | self.NIN_0 = NIN(channels, channels) 70 | self.NIN_1 = NIN(channels, channels) 71 | self.NIN_2 = NIN(channels, channels) 72 | self.NIN_3 = NIN(channels, channels, init_scale=init_scale) 73 | self.skip_rescale = skip_rescale 74 | 75 | def forward(self, x): 76 | B, C, H, W = x.shape 77 | h = self.GroupNorm_0(x) 78 | q = self.NIN_0(h) 79 | k = self.NIN_1(h) 80 | v = self.NIN_2(h) 81 | 82 | w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5)) 83 | w = torch.reshape(w, (B, H, W, H * W)) 84 | w = F.softmax(w, dim=-1) 85 | w = torch.reshape(w, (B, H, W, H, W)) 86 | h = torch.einsum('bhwij,bcij->bchw', w, v) 87 | h = self.NIN_3(h) 88 | if not self.skip_rescale: 89 | return x + h 90 | else: 91 | return (x + h) / np.sqrt(2.) 92 | 93 | 94 | class Upsample(nn.Module): 95 | def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, 96 | fir_kernel=(1, 3, 3, 1)): 97 | super().__init__() 98 | out_ch = out_ch if out_ch else in_ch 99 | if not fir: 100 | if with_conv: 101 | self.Conv_0 = conv3x3(in_ch, out_ch) 102 | else: 103 | if with_conv: 104 | self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch, 105 | kernel=3, up=True, 106 | resample_kernel=fir_kernel, 107 | use_bias=True, 108 | kernel_init=default_init()) 109 | self.fir = fir 110 | self.with_conv = with_conv 111 | self.fir_kernel = fir_kernel 112 | self.out_ch = out_ch 113 | 114 | def forward(self, x): 115 | B, C, H, W = x.shape 116 | if not self.fir: 117 | h = F.interpolate(x, (H * 2, W * 2), 'nearest') 118 | if self.with_conv: 119 | h = self.Conv_0(h) 120 | else: 121 | if not self.with_conv: 122 | h = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2) 123 | else: 124 | h = self.Conv2d_0(x) 125 | 126 | return h 127 | 128 | 129 | class Downsample(nn.Module): 130 | def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, 131 | fir_kernel=(1, 3, 3, 1)): 132 | super().__init__() 133 | out_ch = out_ch if out_ch else in_ch 134 | if not fir: 135 | if with_conv: 136 | self.Conv_0 = conv3x3(in_ch, out_ch, stride=2, padding=0) 137 | else: 138 | if with_conv: 139 | self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch, 140 | kernel=3, down=True, 141 | resample_kernel=fir_kernel, 142 | use_bias=True, 143 | kernel_init=default_init()) 144 | self.fir = fir 145 | self.fir_kernel = fir_kernel 146 | self.with_conv = with_conv 147 | self.out_ch = out_ch 148 | 149 | def forward(self, x): 150 | B, C, H, W = x.shape 151 | if not self.fir: 152 | if self.with_conv: 153 | x = F.pad(x, (0, 1, 0, 1)) 154 | x = self.Conv_0(x) 155 | else: 156 | x = F.avg_pool2d(x, 2, stride=2) 157 | else: 158 | if not self.with_conv: 159 | x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2) 160 | else: 161 | x = self.Conv2d_0(x) 162 | 163 | return x 164 | 165 | 166 | class ResnetBlockDDPMpp(nn.Module): 167 | """ResBlock adapted from DDPM.""" 168 | 169 | def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False, 170 | dropout=0.1, skip_rescale=False, init_scale=0.): 171 | super().__init__() 172 | out_ch = out_ch if out_ch else in_ch 173 | self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) 174 | self.Conv_0 = conv3x3(in_ch, out_ch) 175 | if temb_dim is not None: 176 | self.Dense_0 = nn.Linear(temb_dim, out_ch) 177 | self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape) 178 | nn.init.zeros_(self.Dense_0.bias) 179 | self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) 180 | self.Dropout_0 = nn.Dropout(dropout) 181 | self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) 182 | if in_ch != out_ch: 183 | if conv_shortcut: 184 | self.Conv_2 = conv3x3(in_ch, out_ch) 185 | else: 186 | self.NIN_0 = NIN(in_ch, out_ch) 187 | 188 | self.skip_rescale = skip_rescale 189 | self.act = act 190 | self.out_ch = out_ch 191 | self.conv_shortcut = conv_shortcut 192 | 193 | def forward(self, x, temb=None): 194 | h = self.act(self.GroupNorm_0(x)) 195 | h = self.Conv_0(h) 196 | if temb is not None: 197 | h += self.Dense_0(self.act(temb))[:, :, None, None] 198 | h = self.act(self.GroupNorm_1(h)) 199 | h = self.Dropout_0(h) 200 | h = self.Conv_1(h) 201 | if x.shape[1] != self.out_ch: 202 | if self.conv_shortcut: 203 | x = self.Conv_2(x) 204 | else: 205 | x = self.NIN_0(x) 206 | if not self.skip_rescale: 207 | return x + h 208 | else: 209 | return (x + h) / np.sqrt(2.) 210 | 211 | 212 | class ResnetBlockBigGANpp(nn.Module): 213 | def __init__(self, act, in_ch, out_ch=None, temb_dim=None, up=False, down=False, 214 | dropout=0.1, fir=False, fir_kernel=(1, 3, 3, 1), 215 | skip_rescale=True, init_scale=0.): 216 | super().__init__() 217 | 218 | out_ch = out_ch if out_ch else in_ch 219 | self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) 220 | self.up = up 221 | self.down = down 222 | self.fir = fir 223 | self.fir_kernel = fir_kernel 224 | 225 | self.Conv_0 = conv3x3(in_ch, out_ch) 226 | if temb_dim is not None: 227 | self.Dense_0 = nn.Linear(temb_dim, out_ch) 228 | self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape) 229 | nn.init.zeros_(self.Dense_0.bias) 230 | 231 | self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) 232 | self.Dropout_0 = nn.Dropout(dropout) 233 | self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) 234 | if in_ch != out_ch or up or down: 235 | self.Conv_2 = conv1x1(in_ch, out_ch) 236 | 237 | self.skip_rescale = skip_rescale 238 | self.act = act 239 | self.in_ch = in_ch 240 | self.out_ch = out_ch 241 | 242 | def forward(self, x, temb=None): 243 | h = self.act(self.GroupNorm_0(x)) 244 | 245 | if self.up: 246 | if self.fir: 247 | h = up_or_down_sampling.upsample_2d(h, self.fir_kernel, factor=2) 248 | x = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2) 249 | else: 250 | h = up_or_down_sampling.naive_upsample_2d(h, factor=2) 251 | x = up_or_down_sampling.naive_upsample_2d(x, factor=2) 252 | elif self.down: 253 | if self.fir: 254 | h = up_or_down_sampling.downsample_2d(h, self.fir_kernel, factor=2) 255 | x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2) 256 | else: 257 | h = up_or_down_sampling.naive_downsample_2d(h, factor=2) 258 | x = up_or_down_sampling.naive_downsample_2d(x, factor=2) 259 | 260 | h = self.Conv_0(h) 261 | # Add bias to each feature map conditioned on the time embedding 262 | if temb is not None: 263 | h += self.Dense_0(self.act(temb))[:, :, None, None] 264 | h = self.act(self.GroupNorm_1(h)) 265 | h = self.Dropout_0(h) 266 | h = self.Conv_1(h) 267 | 268 | if self.in_ch != self.out_ch or self.up or self.down: 269 | x = self.Conv_2(x) 270 | 271 | if not self.skip_rescale: 272 | return x + h 273 | else: 274 | return (x + h) / np.sqrt(2.) 275 | -------------------------------------------------------------------------------- /preprocessing/create_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/env/bin/python3 2 | 3 | import matplotlib 4 | matplotlib.use('Agg') 5 | import matplotlib.pyplot as plt 6 | import os 7 | from os.path import join 8 | import numpy as np 9 | import soundfile as sf 10 | import glob 11 | import argparse 12 | import time 13 | import json 14 | from tqdm import tqdm 15 | import shutil 16 | import scipy.signal as ss 17 | import io 18 | import scipy.io.wavfile 19 | import pyroomacoustics as pra 20 | 21 | from utils import obtain_noise_file 22 | 23 | SEED = 100 24 | np.random.seed(SEED) 25 | 26 | bwe_params = { 27 | "scale_factors": [2, 4, 8], 28 | "scale_probas": [.33, .33, .34], 29 | "lp_types": ["bessel", "butter", "cheby2"], 30 | "lp_orders": [2, 4, 8] 31 | } 32 | 33 | # enh_params = { 34 | # "snr_range": [0, 20] #weak setting -> sgmse+, icassp2023 35 | # } 36 | enh_params = { 37 | "snr_range": [-6, 14] #hard setting -> julian, tasl2023 38 | } 39 | 40 | derev_params = { 41 | "t60_range": [0.4, 1.0], 42 | "dim_range": [5, 15, 5, 15, 2, 6], 43 | "min_distance_to_wall": 1. 44 | } 45 | 46 | # ROOT = "" ## put your root directory here 47 | ROOT = "/data/lemercier/databases" 48 | assert ROOT != "", "You need to have a root databases directory" 49 | 50 | parser = argparse.ArgumentParser() 51 | 52 | parser.add_argument('--task', type=str, choices=["enh", "derev", "derev+enh", "bwe"]) 53 | parser.add_argument('--speech', type=str, choices=["vctk", "wsj0", "dns", "timit"], default="wsj0", help='Clean speech') 54 | parser.add_argument('--noise', type=str, choices=["none", "chime", "qut", "wham"], default="chime", help='Noise') 55 | parser.add_argument('--sr', type=int, default=16000) 56 | parser.add_argument('--splits', type=str, default="cv,tr,tt", help='Split folders of the dataset') 57 | parser.add_argument('--corruption-per-sample', type=int, default=1) 58 | parser.add_argument('--dummy', action="store_true", help='Number of samples') 59 | parser.add_argument('--bwe-method', type=str, default="polyphase", choices=["decimate", "polyphase"]) 60 | 61 | args = parser.parse_args() 62 | splits = args.splits.strip().split(",") 63 | dic_splits = {"cv": "valid", "tt": "test", "tr": "train"} 64 | 65 | params = vars(args) 66 | if "enh" in args.task: 67 | params["noise_dirs"] = { 68 | "wham": {split:f"/data/lemercier/databases/whamr/wham_noise/{split}" for split in splits}, 69 | "chime": {split:f"/data/lemercier/databases/CHiME3/data/audio/16kHz/backgrounds" for split in splits}, 70 | "qut": {split:f"/data/lemercier/databases/dns_chime3/{dic_splits[split]}" for split in splits} 71 | } 72 | params = {**enh_params, **params} 73 | if "derev" in args.task: 74 | params = {**derev_params, **params} 75 | if "bwe" in args.task: 76 | params = {**bwe_params, **params} 77 | 78 | output_dir = join(ROOT, args.speech + "_" + args.task) 79 | if args.task == "enh": 80 | output_dir += "_" + args.noise 81 | 82 | t0 = time.time() 83 | 84 | if args.speech == "wsj0": 85 | dic_split = {"cv": "si_dt_05", "tr": "si_tr_s", "tt": "si_et_05"} 86 | speech_lists = {split:glob.glob(f"{ROOT}/WSJ0/wsj0/{dic_split[split]}/**/*.wav") for split in splits} 87 | elif args.speech == "vctk": 88 | speakers = sorted(os.listdir(f"{ROOT}/VCTK-Corpus/wav48/")) 89 | speakers.remove("p280") 90 | speakers.remove("p315") 91 | ranges = {"tr": [0, 99], "cv": [97, 99], "tt": [99, 107]} 92 | speech_lists = {split:[] for split in splits} 93 | for split in splits: 94 | for spk_idx in range(*ranges[split]): 95 | speech_lists[split] += glob.glob(f"{ROOT}/VCTK-Corpus/wav48/{speakers[spk_idx]}/*.wav") 96 | elif args.speech == "timit": 97 | ranges = {"tr": [1, 7], "cv": [7, 8], "tt": [1, 8]} 98 | speech_lists = {split:[] for split in splits} 99 | transcription_lists = {split:[] for split in splits} 100 | for split in splits: 101 | splt_dr = "train" if split in ["cv", "tr"] else "test" 102 | for dr_idx in range(*ranges[split]): 103 | speech_lists[split] += sorted(glob.glob(f"{ROOT}/TIMIT/timit/{splt_dr}/dr{dr_idx}/**/*.wav")) 104 | transcription_lists[split] += sorted(glob.glob(f"{ROOT}/TIMIT/timit/{splt_dr}/dr{dr_idx}/**/*.txt")) 105 | 106 | if os.path.exists(output_dir): 107 | shutil.rmtree(output_dir) 108 | os.makedirs(output_dir, exist_ok=True) 109 | log = open(join(output_dir, "log_stats.txt"), "w") 110 | log.write("Parameters \n ========== \n") 111 | for key, param in params.items(): 112 | log.write(key + " : " + str(param) + "\n") 113 | 114 | for i_split, split in enumerate(splits): 115 | 116 | print("Processing split n° {}: {}...".format(i_split+1, split)) 117 | 118 | clean_output_dir = join(output_dir, "audio", split, "clean") 119 | noisy_output_dir = join(output_dir, "audio", split, "noisy") 120 | os.makedirs(clean_output_dir, exist_ok=True) 121 | os.makedirs(noisy_output_dir, exist_ok=True) 122 | if args.speech == "timit": 123 | transcription_output_dir = join(output_dir, "transcriptions", split) 124 | os.makedirs(transcription_output_dir, exist_ok=True) 125 | 126 | speech_list = speech_lists[split] 127 | speech_dir = None 128 | real_nb_samples = 5 if args.dummy else len(speech_list) 129 | nb_corruptions_per_sample = 1 if split == "tt" else args.corruption_per_sample 130 | 131 | for i_sample in tqdm(range(real_nb_samples)): 132 | 133 | speech_basename = os.path.basename(speech_list[i_sample]) 134 | speech, sr = sf.read(speech_list[i_sample]) 135 | assert sr == args.sr, "Obtained an unexpected Sampling rate" 136 | original_scale = np.max(np.abs(speech)) 137 | for ic in range(nb_corruptions_per_sample): 138 | 139 | lossy_speech = speech.copy() 140 | 141 | 142 | ### Dereverberation 143 | if "derev" in args.task: 144 | 145 | t60 = np.random.uniform(params["t60_range"][0], params["t60_range"][1]) #sample t60 146 | room_dim = np.array([ np.random.uniform(params["dim_range"][2*n], params["dim_range"][2*n+1]) for n in range(3) ]) #sample Dimensions 147 | center_mic_position = np.array([ np.random.uniform(params["min_distance_to_wall"], room_dim[n] - params["min_distance_to_wall"]) for n in range(3) ]) #sample microphone position 148 | source_position = np.array([ np.random.uniform(params["min_distance_to_wall"], room_dim[n] - params["min_distance_to_wall"]) for n in range(3) ]) #sample source position 149 | distance_source = 1/np.sqrt(center_mic_position.ndim)*np.linalg.norm(center_mic_position - source_position) 150 | mic_array_2d = pra.beamforming.circular_2D_array(center_mic_position[: -1], 1, phi0=0, radius=1.) # Compute microphone array 151 | mic_array = np.pad(mic_array_2d, ((0, 1), (0, 0)), mode="constant", constant_values=center_mic_position[-1]) 152 | 153 | ### Reverberant Room 154 | e_absorption, max_order = pra.inverse_sabine(t60, room_dim) #Compute absorption coeff 155 | reverberant_room = pra.ShoeBox( 156 | room_dim, fs=16000, materials=pra.Material(e_absorption), max_order=min(3, max_order), ray_tracing=True 157 | ) # Create room 158 | reverberant_room.set_ray_tracing() 159 | 160 | reverberant_room.add_microphone_array(mic_array) # Add microphone array 161 | 162 | # Generate reverberant room 163 | reverberant_room.add_source(source_position, signal=lossy_speech) 164 | reverberant_room.compute_rir() 165 | reverberant_room.simulate() 166 | t60_real = np.mean(reverberant_room.measure_rt60()).squeeze() 167 | lossy_speech = np.squeeze(np.array(reverberant_room.mic_array.signals)) 168 | 169 | #compute target 170 | e_absorption_dry = 0.99 171 | dry_room = pra.ShoeBox( 172 | room_dim, fs=16000, materials=pra.Material(e_absorption_dry), max_order=0 173 | ) # Create room 174 | dry_room.add_microphone_array(mic_array) # Add microphone array 175 | 176 | # Generate dry room 177 | dry_room.add_source(source_position, signal=speech) 178 | dry_room.compute_rir() 179 | dry_room.simulate() 180 | t60_real_dry = np.mean(dry_room.measure_rt60()).squeeze() 181 | speech = np.squeeze(np.array(dry_room.mic_array.signals)) 182 | noise_floor_snr = 50 183 | noise_floor_power = 1/speech.shape[0]*np.sum(speech**2)*np.power(10,-noise_floor_snr/10) 184 | noise_floor_signal = np.random.rand(int(.5*args.sr)) * np.sqrt(noise_floor_power) 185 | speech = np.concatenate([ speech, noise_floor_signal ]) 186 | 187 | min_length = min(lossy_speech.shape[0], speech.shape[0]) 188 | lossy_speech, speech = lossy_speech[: min_length], speech[: min_length] 189 | 190 | 191 | 192 | 193 | 194 | ### Enhancement 195 | 196 | if "enh" in args.task: 197 | 198 | noise, noise_sr = obtain_noise_file(params["noise_dirs"][args.noise][split], i_sample, 1, dataset=args.noise, sample_rate=args.sr, len_speech=speech.shape[0]) 199 | noise = np.squeeze(noise) 200 | if noise.shape[0] < speech.shape[0]: 201 | noise = np.pad(noise, ((0, speech.shape[0] - noise.shape[0]))) 202 | else: 203 | noise = noise[: speech.shape[0]] 204 | 205 | snr = np.random.uniform(params["snr_range"][0], params["snr_range"][1]) 206 | noise_power = 1/noise.shape[0]*np.sum(noise**2) 207 | speech_power = 1/speech.shape[0]*np.sum(speech**2) 208 | noise_power_target = speech_power*np.power(10,-snr/10) 209 | noise_scaling = np.sqrt(noise_power_target / noise_power) 210 | if "derev" in args.task: 211 | lossy_speech = lossy_speech + noise_scaling * noise 212 | else: 213 | lossy_speech = speech + noise_scaling * noise 214 | 215 | 216 | 217 | 218 | 219 | ### Bandwidth Reduction 220 | 221 | if "bwe" in args.task: 222 | scale_factor = np.random.choice(params["scale_factors"], p=params["scale_probas"]) 223 | lp_type = np.random.choice(params["lp_types"]) 224 | lp_order = np.random.choice(params["lp_orders"]) 225 | Wn = 1./(2*scale_factor) 226 | if lp_type == "bessel": 227 | kwargs = {} 228 | elif lp_type == "butter": 229 | kwargs = {} 230 | elif lp_type == "cheby2": 231 | kwargs = {"rs": 10. + 20*np.random.random()} 232 | 233 | if lp_order > 2: 234 | kwargs["output"] = "sos" 235 | lp_filter_coefs = getattr(scipy.signal, lp_type)(N=lp_order, Wn=Wn, fs=1, **kwargs) 236 | 237 | if args.bwe_method == "decimate": #method used by HiFI++ and VoiceFixer 238 | z, p, k = ss.sos2zpk(lp_filter_coefs) if lp_order > 2 else ss.tf2zpk(*lp_filter_coefs) 239 | filter_instance = ss.dlti(z,p,k) 240 | lossy_speech_subsampled = ss.decimate(lossy_speech, q=scale_factor, ftype=filter_instance) 241 | lossy_speech = ss.resample_poly(lossy_speech_subsampled, up=scale_factor, down=1) 242 | elif args.bwe_method == "polyphase": #method used by NVSR 243 | sos = lp_filter_coefs if lp_order > 2 else ss.tf2sos(*lp_filter_coefs) 244 | lossy_speech_filtered = ss.sosfilt(sos, lossy_speech) 245 | lossy_speech_subsampled = ss.resample_poly(lossy_speech_filtered, down=scale_factor, up=1) 246 | lossy_speech = ss.resample_poly(lossy_speech_subsampled, up=args.sr, down=sr/scale_factor) 247 | 248 | 249 | 250 | filename = speech_basename[: -4] + f"_{i_sample*args.corruption_per_sample + ic}" 251 | if "derev" in args.task: 252 | filename += f"_t60={t60_real:.2f}" 253 | if "enh" in args.task: 254 | filename += f"_snr={snr:.1f}" 255 | if "bwe" in args.task: 256 | filename += f"_down={scale_factor}" 257 | filename += ".wav" 258 | 259 | ### Export 260 | sf.write(join(clean_output_dir, filename), speech, args.sr) 261 | sf.write(join(noisy_output_dir, filename), lossy_speech, args.sr) 262 | if args.speech == "timit": 263 | shutil.copy(transcription_lists[split][i_sample], join(transcription_output_dir, filename[: -4] + ".txt")) -------------------------------------------------------------------------------- /sgmse/backbones/convtasnet_utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | 8 | class cLN(nn.Module): 9 | def __init__(self, dimension, eps = 1e-8, trainable=True): 10 | super(cLN, self).__init__() 11 | 12 | self.eps = eps 13 | if trainable: 14 | self.gain = nn.Parameter(torch.ones(1, dimension, 1)) 15 | self.bias = nn.Parameter(torch.zeros(1, dimension, 1)) 16 | else: 17 | self.gain = Variable(torch.ones(1, dimension, 1), requires_grad=False) 18 | self.bias = Variable(torch.zeros(1, dimension, 1), requires_grad=False) 19 | 20 | def forward(self, input): 21 | # input size: (Batch, Freq, Time) 22 | # cumulative mean for each time step 23 | 24 | batch_size = input.size(0) 25 | channel = input.size(1) 26 | time_step = input.size(2) 27 | 28 | step_sum = input.sum(1) # B, T 29 | step_pow_sum = input.pow(2).sum(1) # B, T 30 | cum_sum = torch.cumsum(step_sum, dim=1) # B, T 31 | cum_pow_sum = torch.cumsum(step_pow_sum, dim=1) # B, T 32 | 33 | entry_cnt = np.arange(channel, channel*(time_step+1), channel) 34 | entry_cnt = torch.from_numpy(entry_cnt).type(input.type()) 35 | entry_cnt = entry_cnt.view(1, -1).expand_as(cum_sum) 36 | 37 | cum_mean = cum_sum / entry_cnt # B, T 38 | cum_var = (cum_pow_sum - 2*cum_mean*cum_sum) / entry_cnt + cum_mean.pow(2) # B, T 39 | cum_std = (cum_var + self.eps).sqrt() # B, T 40 | 41 | cum_mean = cum_mean.unsqueeze(1) 42 | cum_std = cum_std.unsqueeze(1) 43 | 44 | x = (input - cum_mean.expand_as(input)) / cum_std.expand_as(input) 45 | return x * self.gain.expand_as(x).type(x.type()) + self.bias.expand_as(x).type(x.type()) 46 | 47 | def repackage_hidden(h): 48 | """ 49 | Wraps hidden states in new Variables, to detach them from their history. 50 | """ 51 | 52 | if type(h) == Variable: 53 | return Variable(h.data) 54 | else: 55 | return tuple(repackage_hidden(v) for v in h) 56 | 57 | class MultiRNN(nn.Module): 58 | """ 59 | Container module for multiple stacked RNN layers. 60 | 61 | args: 62 | rnn_type: string, select from 'RNN', 'LSTM' and 'GRU'. 63 | input_size: int, dimension of the input feature. The input should have shape 64 | (batch, seq_len, input_size). 65 | hidden_size: int, dimension of the hidden state. The corresponding output should 66 | have shape (batch, seq_len, hidden_size). 67 | num_layers: int, number of stacked RNN layers. Default is 1. 68 | bidirectional: bool, whether the RNN layers are bidirectional. Default is False. 69 | """ 70 | 71 | def __init__(self, rnn_type, input_size, hidden_size, dropout=0, num_layers=1, bidirectional=False): 72 | super(MultiRNN, self).__init__() 73 | 74 | self.rnn = getattr(nn, rnn_type)(input_size, hidden_size, num_layers, dropout=dropout, 75 | batch_first=True, bidirectional=bidirectional) 76 | 77 | 78 | 79 | self.rnn_type = rnn_type 80 | self.hidden_size = hidden_size 81 | self.num_layers = num_layers 82 | self.num_direction = int(bidirectional) + 1 83 | 84 | def forward(self, input): 85 | hidden = self.init_hidden(input.size(0)) 86 | self.rnn.flatten_parameters() 87 | return self.rnn(input, hidden) 88 | 89 | def init_hidden(self, batch_size): 90 | weight = next(self.parameters()).data 91 | if self.rnn_type == 'LSTM': 92 | return (Variable(weight.new(self.num_layers*self.num_direction, batch_size, self.hidden_size).zero_()), 93 | Variable(weight.new(self.num_layers*self.num_direction, batch_size, self.hidden_size).zero_())) 94 | else: 95 | return Variable(weight.new(self.num_layers*self.num_direction, batch_size, self.hidden_size).zero_()) 96 | 97 | 98 | class FCLayer(nn.Module): 99 | """ 100 | Container module for a fully-connected layer. 101 | 102 | args: 103 | input_size: int, dimension of the input feature. The input should have shape 104 | (batch, input_size). 105 | hidden_size: int, dimension of the output. The corresponding output should 106 | have shape (batch, hidden_size). 107 | nonlinearity: string, the nonlinearity applied to the transformation. Default is None. 108 | """ 109 | 110 | def __init__(self, input_size, hidden_size, bias=True, nonlinearity=None): 111 | super(FCLayer, self).__init__() 112 | 113 | self.input_size = input_size 114 | self.hidden_size = hidden_size 115 | self.bias = bias 116 | self.FC = nn.Linear(self.input_size, self.hidden_size, bias=bias) 117 | if nonlinearity: 118 | self.nonlinearity = getattr(F, nonlinearity) 119 | else: 120 | self.nonlinearity = None 121 | 122 | self.init_hidden() 123 | 124 | def forward(self, input): 125 | if self.nonlinearity is not None: 126 | return self.nonlinearity(self.FC(input)) 127 | else: 128 | return self.FC(input) 129 | 130 | def init_hidden(self): 131 | initrange = 1. / np.sqrt(self.input_size * self.hidden_size) 132 | self.FC.weight.data.uniform_(-initrange, initrange) 133 | if self.bias: 134 | self.FC.bias.data.fill_(0) 135 | 136 | 137 | class DepthConv1d(nn.Module): 138 | 139 | def __init__(self, input_channel, hidden_channel, kernel, padding, dilation=1, skip=True, causal=False): 140 | super(DepthConv1d, self).__init__() 141 | 142 | self.causal = causal 143 | self.skip = skip 144 | 145 | self.conv1d = nn.Conv1d(input_channel, hidden_channel, 1) 146 | if self.causal: 147 | self.padding = (kernel - 1) * dilation 148 | else: 149 | self.padding = padding 150 | self.dconv1d = nn.Conv1d(hidden_channel, hidden_channel, kernel, dilation=dilation, 151 | groups=hidden_channel, 152 | padding=self.padding) 153 | self.res_out = nn.Conv1d(hidden_channel, input_channel, 1) 154 | self.nonlinearity1 = nn.PReLU() 155 | self.nonlinearity2 = nn.PReLU() 156 | if self.causal: 157 | self.reg1 = cLN(hidden_channel, eps=1e-08) 158 | self.reg2 = cLN(hidden_channel, eps=1e-08) 159 | else: 160 | self.reg1 = nn.GroupNorm(1, hidden_channel, eps=1e-08) 161 | self.reg2 = nn.GroupNorm(1, hidden_channel, eps=1e-08) 162 | 163 | if self.skip: 164 | self.skip_out = nn.Conv1d(hidden_channel, input_channel, 1) 165 | 166 | def forward(self, input): 167 | output = self.reg1(self.nonlinearity1(self.conv1d(input))) 168 | if self.causal: 169 | output = self.reg2(self.nonlinearity2(self.dconv1d(output)[:,:,:-self.padding])) 170 | else: 171 | output = self.reg2(self.nonlinearity2(self.dconv1d(output))) 172 | residual = self.res_out(output) 173 | if self.skip: 174 | skip = self.skip_out(output) 175 | return residual, skip 176 | else: 177 | return residual 178 | 179 | class TCN(nn.Module): 180 | def __init__(self, input_dim, output_dim, BN_dim, hidden_dim, 181 | layer, stack, kernel=3, skip=True, 182 | causal=False, dilated=True): 183 | super(TCN, self).__init__() 184 | 185 | # input is a sequence of features of shape (B, N, L) 186 | 187 | # normalization 188 | if not causal: 189 | self.LN = nn.GroupNorm(1, input_dim, eps=1e-8) 190 | else: 191 | self.LN = cLN(input_dim, eps=1e-8) 192 | 193 | self.BN = nn.Conv1d(input_dim, BN_dim, 1) 194 | 195 | # TCN for feature extraction 196 | self.receptive_field = 0 197 | self.dilated = dilated 198 | 199 | self.TCN = nn.ModuleList([]) 200 | for s in range(stack): 201 | for i in range(layer): 202 | if self.dilated: 203 | self.TCN.append(DepthConv1d(BN_dim, hidden_dim, kernel, dilation=2**i, padding=2**i, skip=skip, causal=causal)) 204 | else: 205 | self.TCN.append(DepthConv1d(BN_dim, hidden_dim, kernel, dilation=1, padding=1, skip=skip, causal=causal)) 206 | if i == 0 and s == 0: 207 | self.receptive_field += kernel 208 | else: 209 | if self.dilated: 210 | self.receptive_field += (kernel - 1) * 2**i 211 | else: 212 | self.receptive_field += (kernel - 1) 213 | 214 | #print("Receptive field: {:3d} frames.".format(self.receptive_field)) 215 | 216 | # output layer 217 | 218 | self.output = nn.Sequential(nn.PReLU(), 219 | nn.Conv1d(BN_dim, output_dim, 1) 220 | ) 221 | 222 | self.skip = skip 223 | 224 | def forward(self, input): 225 | 226 | # input shape: (B, N, L) 227 | 228 | # normalization 229 | output = self.BN(self.LN(input)) 230 | 231 | # pass to TCN 232 | if self.skip: 233 | skip_connection = 0. 234 | for i in range(len(self.TCN)): 235 | residual, skip = self.TCN[i](output) 236 | output = output + residual 237 | skip_connection = skip_connection + skip 238 | else: 239 | for i in range(len(self.TCN)): 240 | residual = self.TCN[i](output) 241 | output = output + residual 242 | 243 | # output layer 244 | if self.skip: 245 | output = self.output(skip_connection) 246 | else: 247 | output = self.output(output) 248 | 249 | return output 250 | 251 | 252 | 253 | 254 | class TCNKoyama33(nn.Module): 255 | def __init__(self, input_dim, output_dim, BN_dim, hidden_dim, 256 | layer, stack, kernel=3, skip=True, **ignored_kwargs): 257 | 258 | super(TCNKoyama33, self).__init__() 259 | # input is a sequence of features of shape (B, N, L) 260 | 261 | # normalization 262 | self.LN = cLN(input_dim, eps=1e-8) 263 | self.BN = nn.Conv1d(input_dim, BN_dim, 1) 264 | 265 | # TCN for feature extraction 266 | self.receptive_field = 0 267 | self.dilated = True 268 | 269 | self.TCN = nn.ModuleList([]) 270 | for s in range(stack): 271 | for i in range(layer): 272 | self.TCN.append(DepthConv1d(BN_dim, hidden_dim, kernel, dilation=2**i, padding=2**i, skip=skip, causal=(i>4) )) 273 | if i == 0 and s == 0: 274 | self.receptive_field += kernel 275 | else: 276 | self.receptive_field += (kernel - 1) * 2**i 277 | 278 | # print("Receptive field: {:3d} frames.".format(self.receptive_field)) 279 | # output layer 280 | self.output = nn.Sequential(nn.PReLU(), 281 | nn.Conv1d(BN_dim, output_dim, 1) 282 | ) 283 | 284 | self.skip = skip 285 | 286 | 287 | def forward(self, input): 288 | 289 | # input shape: (B, N, L) 290 | 291 | # normalization 292 | output = self.BN(self.LN(input)) 293 | 294 | # pass to TCN 295 | if self.skip: 296 | skip_connection = 0. 297 | for i in range(len(self.TCN)): 298 | residual, skip = self.TCN[i](output) 299 | output = output + residual 300 | skip_connection = skip_connection + skip 301 | else: 302 | for i in range(len(self.TCN)): 303 | residual = self.TCN[i](output) 304 | output = output + residual 305 | 306 | # output layer 307 | if self.skip: 308 | output = self.output(skip_connection) 309 | else: 310 | output = self.output(output) 311 | 312 | return output 313 | -------------------------------------------------------------------------------- /sgmse/data_module.py: -------------------------------------------------------------------------------- 1 | 2 | from os.path import join 3 | import os 4 | import torch 5 | import pytorch_lightning as pl 6 | from torch.utils.data import Dataset 7 | from torch.utils.data import DataLoader 8 | from glob import glob 9 | from torchaudio import load 10 | import numpy as np 11 | import torch.nn.functional as F 12 | import h5py 13 | import json 14 | from sgmse.util.other import snr_scale_factor, pydub_read, align 15 | 16 | SEED = 10 17 | np.random.seed(SEED) 18 | 19 | def get_window(window_type, window_length): 20 | if window_type == 'sqrthann': 21 | return torch.sqrt(torch.hann_window(window_length, periodic=True)) 22 | elif window_type == 'hann': 23 | return torch.hann_window(window_length, periodic=True) 24 | else: 25 | raise NotImplementedError(f"Window type {window_type} not implemented!") 26 | 27 | class Specs(Dataset): 28 | def __init__( 29 | self, data_dir, subset, dummy, shuffle_spec, num_frames, format, 30 | normalize_audio=True, spec_transform=None, stft_kwargs=None, spatial_channels=1, 31 | return_time=False, 32 | **ignored_kwargs 33 | ): 34 | self.data_dir = data_dir 35 | self.subset = subset 36 | self.format = format 37 | self.spatial_channels = spatial_channels 38 | self.return_time = return_time 39 | 40 | dic_correspondence_subsets = {"train": "train", "valid": "valid", "test": "test"} 41 | self.clean_files = sorted(glob(join(data_dir, dic_correspondence_subsets[subset]) + '/clean/*.wav')) 42 | self.noisy_files = sorted(glob(join(data_dir, dic_correspondence_subsets[subset]) + '/noisy/*.wav')) 43 | 44 | 45 | self.dummy = dummy 46 | self.num_frames = num_frames 47 | self.shuffle_spec = shuffle_spec 48 | self.normalize_audio = normalize_audio 49 | self.spec_transform = spec_transform 50 | 51 | assert all(k in stft_kwargs.keys() for k in ["n_fft", "hop_length", "center", "window"]), "misconfigured STFT kwargs" 52 | self.stft_kwargs = stft_kwargs 53 | self.hop_length = self.stft_kwargs["hop_length"] 54 | assert self.stft_kwargs.get("center", None) == True, "'center' must be True for current implementation" 55 | 56 | def _open_hdf5(self): 57 | self.meta_data = json.load(open(sorted(glob(join(self.data_dir, f"*.json")))[-1], "r")) 58 | self.prep_file = h5py.File(sorted(glob(join(self.data_dir, f"*.hdf5")))[-1], 'r') 59 | 60 | def __getitem__(self, i, raw=False): 61 | x, sr = load(self.clean_files[i]) 62 | y, sr = load(self.noisy_files[i]) 63 | 64 | min_len = min(x.size(-1), y.size(-1)) 65 | x, y = x[..., : min_len], y[..., : min_len] 66 | 67 | if x.ndimension() == 2 and self.spatial_channels == 1: 68 | x, y = x[0].unsqueeze(0), y[0].unsqueeze(0) #Select first channel 69 | # Select channels 70 | assert self.spatial_channels <= x.size(0), f"You asked too many channels ({self.spatial_channels}) for the given dataset ({x.size(0)})" 71 | x, y = x[: self.spatial_channels], y[: self.spatial_channels] 72 | 73 | if raw: 74 | return x, y 75 | 76 | normfac = y.abs().max() 77 | 78 | # formula applies for center=True 79 | target_len = (self.num_frames - 1) * self.hop_length 80 | current_len = x.size(-1) 81 | pad = max(target_len - current_len, 0) 82 | if pad == 0: 83 | # extract random part of the audio file 84 | if self.shuffle_spec: 85 | start = int(np.random.uniform(0, current_len-target_len)) 86 | else: 87 | start = int((current_len-target_len)/2) 88 | x = x[..., start:start+target_len] 89 | y = y[..., start:start+target_len] 90 | else: 91 | # pad audio if the length T is smaller than num_frames 92 | x = F.pad(x, (pad//2, pad//2+(pad%2)), mode='constant') 93 | y = F.pad(y, (pad//2, pad//2+(pad%2)), mode='constant') 94 | 95 | if self.normalize_audio: 96 | # normalize both based on noisy speech, to ensure same clean signal power in x and y. 97 | x = x / normfac 98 | y = y / normfac 99 | 100 | if self.return_time: 101 | return x, y 102 | 103 | X = torch.stft(x, **self.stft_kwargs) 104 | Y = torch.stft(y, **self.stft_kwargs) 105 | 106 | X, Y = self.spec_transform(X), self.spec_transform(Y) 107 | 108 | return X, Y 109 | 110 | def __len__(self): 111 | if self.dummy: 112 | # for debugging shrink the data set sizer 113 | return int(len(self.clean_files)/10) 114 | else: 115 | if self.format == "vctk": 116 | return len(self.clean_files)//2 117 | else: 118 | return len(self.clean_files) 119 | 120 | 121 | 122 | 123 | 124 | class SpecsDataModule(pl.LightningDataModule): 125 | def __init__( 126 | self, base_dir="", format="wsj0", spatial_channels=1, batch_size=8, 127 | n_fft=510, hop_length=128, num_frames=256, window="hann", 128 | num_workers=8, dummy=False, spec_factor=0.15, spec_abs_exponent=0.5, 129 | gpu=True, return_time=False, **kwargs 130 | ): 131 | super().__init__() 132 | self.base_dir = base_dir 133 | self.format = format 134 | self.spatial_channels = spatial_channels 135 | self.batch_size = batch_size 136 | self.n_fft = n_fft 137 | self.hop_length = hop_length 138 | self.num_frames = num_frames 139 | self.window = get_window(window, self.n_fft) 140 | self.windows = {} 141 | self.num_workers = num_workers 142 | self.dummy = dummy 143 | self.spec_factor = spec_factor 144 | self.spec_abs_exponent = spec_abs_exponent 145 | self.gpu = gpu 146 | self.return_time = return_time 147 | self.kwargs = kwargs 148 | 149 | def setup(self, stage=None): 150 | specs_kwargs = dict( 151 | stft_kwargs=self.stft_kwargs, num_frames=self.num_frames, spec_transform=self.spec_fwd, 152 | **self.stft_kwargs, **self.kwargs 153 | ) 154 | if stage == 'fit' or stage is None: 155 | self.train_set = Specs(self.base_dir, 'train', self.dummy, True, 156 | format=self.format, spatial_channels=self.spatial_channels, 157 | return_time=self.return_time, **specs_kwargs) 158 | self.valid_set = Specs(self.base_dir, 'valid', self.dummy, False, 159 | format=self.format, spatial_channels=self.spatial_channels, 160 | return_time=self.return_time, **specs_kwargs) 161 | if stage == 'test' or stage is None: 162 | self.test_set = Specs(self.base_dir, 'test', self.dummy, False, 163 | format=self.format, spatial_channels=self.spatial_channels, 164 | return_time=self.return_time, **specs_kwargs) 165 | 166 | def spec_fwd(self, spec): 167 | if self.spec_abs_exponent != 1: 168 | e = self.spec_abs_exponent 169 | spec = spec.abs()**e * torch.exp(1j * spec.angle()) 170 | return spec * self.spec_factor 171 | 172 | def spec_back(self, spec): 173 | spec = spec / self.spec_factor 174 | if self.spec_abs_exponent != 1: 175 | e = self.spec_abs_exponent 176 | spec = spec.abs()**(1/e) * torch.exp(1j * spec.angle()) 177 | return spec 178 | 179 | @property 180 | def stft_kwargs(self): 181 | return {**self.istft_kwargs, "return_complex": True} 182 | 183 | @property 184 | def istft_kwargs(self): 185 | return dict( 186 | n_fft=self.n_fft, hop_length=self.hop_length, 187 | window=self.window, center=True 188 | ) 189 | 190 | def _get_window(self, x): 191 | """ 192 | Retrieve an appropriate window for the given tensor x, matching the device. 193 | Caches the retrieved windows so that only one window tensor will be allocated per device. 194 | """ 195 | window = self.windows.get(x.device, None) 196 | if window is None: 197 | window = self.window.to(x.device) 198 | self.windows[x.device] = window 199 | return window 200 | 201 | def stft(self, sig): 202 | window = self._get_window(sig) 203 | return torch.stft(sig, **{**self.stft_kwargs, "window": window}) 204 | 205 | def istft(self, spec, length=None): 206 | window = self._get_window(spec) 207 | return torch.istft(spec, **{**self.istft_kwargs, "window": window, "length": length}) 208 | 209 | @staticmethod 210 | def add_argparse_args(parser): 211 | parser.add_argument("--format", default="wsj0", required=True, choices=["wsj0_low","wsj0_high", "wsj0_reverb",'vb_dmd'], help="File paths follow the DNS data description.") 212 | parser.add_argument("--base_dir", type=str, default="/data/lemercier/databases/wsj0+chime_julian/audio", 213 | help="The base directory of the dataset. Should contain `train`, `valid` and `test` subdirectories, " 214 | "each of which contain `clean` and `noisy` subdirectories.") 215 | parser.add_argument("--batch_size", type=int, default=8, help="The batch size. 32 by default.") 216 | parser.add_argument("--n_fft", type=int, default=510, help="Number of FFT bins. 510 by default.") # to assure 256 freq bins 217 | parser.add_argument("--hop_length", type=int, default=128, help="Window hop length. 128 by default.") 218 | parser.add_argument("--num_frames", type=int, default=256, help="Number of frames for the dataset. 256 by default.") 219 | parser.add_argument("--window", type=str, choices=("sqrthann", "hann"), default="hann", help="The window function to use for the STFT. 'sqrthann' by default.") 220 | parser.add_argument("--num_workers", type=int, default=8, help="Number of workers to use for DataLoaders. 4 by default.") 221 | parser.add_argument("--dummy", action="store_true", help="Use reduced dummy dataset for prototyping.") 222 | parser.add_argument("--spec_factor", type=float, default=0.33, help="Factor to multiply complex STFT coefficients by.") ##### In Simon's current impl, this is 0.15 ! 223 | parser.add_argument("--spec_abs_exponent", type=float, default=0.5, 224 | help="Exponent e for the transformation abs(z)**e * exp(1j*angle(z)). " 225 | "1 by default; set to values < 1 to bring out quieter features.") 226 | parser.add_argument("--return_time", action="store_true", help="Return the waveform instead of the STFT") 227 | 228 | return parser 229 | 230 | def train_dataloader(self): 231 | # return DataLoader( 232 | return DataLoader( 233 | self.train_set, batch_size=self.batch_size, 234 | num_workers=self.num_workers, pin_memory=self.gpu, shuffle=True 235 | ) 236 | 237 | def val_dataloader(self): 238 | # return DataLoader( 239 | return DataLoader( 240 | self.valid_set, batch_size=self.batch_size, 241 | num_workers=self.num_workers, pin_memory=self.gpu, shuffle=False 242 | ) 243 | 244 | def test_dataloader(self): 245 | # return DataLoader( 246 | return DataLoader( 247 | self.test_set, batch_size=self.batch_size, 248 | num_workers=self.num_workers, pin_memory=self.gpu, shuffle=False 249 | ) 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | class SpecsAndTranscriptions(Specs): 265 | 266 | def __init__( 267 | self, data_dir, subset, dummy, shuffle_spec, num_frames, format, 268 | **kwargs 269 | ): 270 | super().__init__(data_dir, subset, dummy, shuffle_spec, num_frames, format, **kwargs) 271 | if format == "timit": 272 | dic_correspondence_subsets = {"train": "tr", "valid": "cv", "test": "tt"} 273 | self.clean_files = sorted(glob(join(data_dir, "audio", dic_correspondence_subsets[subset]) + '/clean/*.wav')) 274 | self.noisy_files = sorted(glob(join(data_dir, "audio", dic_correspondence_subsets[subset]) + '/noisy/*.wav')) 275 | self.transcriptions = sorted(glob(join(data_dir, "transcriptions", dic_correspondence_subsets[subset]) + '/*.txt')) 276 | else: 277 | raise NotImplementedError 278 | 279 | def __getitem__(self, i, raw=False): 280 | X, Y = super().__getitem__(i, raw=raw) 281 | transcription = open(self.transcriptions[i], "r").read() 282 | if self.format == "timit": #remove the number at the beginning 283 | transcription = " ".join(transcription.split(" ")[2: ]) 284 | 285 | return X, Y, transcription 286 | 287 | def __len__(self): 288 | if self.dummy: 289 | return int(len(self.clean_files)/10) 290 | else: 291 | return len(self.clean_files) 292 | 293 | class SpecsAndTranscriptionsDataModule(SpecsDataModule): 294 | 295 | def setup(self, stage=None): 296 | specs_kwargs = dict( 297 | stft_kwargs=self.stft_kwargs, num_frames=self.num_frames, spec_transform=self.spec_fwd, 298 | **self.stft_kwargs, **self.kwargs 299 | ) 300 | if stage == 'fit' or stage is None: 301 | raise NotImplementedError 302 | if stage == 'test' or stage is None: 303 | self.test_set = SpecsAndTranscriptions(self.base_dir, 'test', self.dummy, False, 304 | format=self.format, **specs_kwargs) 305 | 306 | 307 | @staticmethod 308 | def add_argparse_args(parser): 309 | parser.add_argument("--format", type=str, default="reverb_wsj0", choices=["wsj0", "vctk", "dns", "reverb_wsj0"], help="File paths follow the DNS data description.") 310 | parser.add_argument("--base-dir", type=str, default="/data/lemercier/databases/reverb_wsj0+chime/audio") 311 | parser.add_argument("--batch-size", type=int, default=8, help="The batch size.") 312 | parser.add_argument("--num-workers", type=int, default=8, help="Number of workers to use for DataLoaders.") 313 | parser.add_argument("--dummy", action="store_true", help="Use reduced dummy dataset for prototyping.") 314 | return parser -------------------------------------------------------------------------------- /sgmse/backbones/convtasnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | import torch 4 | from torch.autograd import Variable 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import functools 8 | import numpy as np 9 | 10 | from .shared import BackboneRegistry 11 | 12 | @BackboneRegistry.register("convtasnet") 13 | class ConvTasNet(torch.nn.Module): 14 | def __init__(self, 15 | fs=16000, 16 | win=2, 17 | enc_dim=256, 18 | feature_dim=128, 19 | layer=8, 20 | stack=3, 21 | kernel=3, 22 | causal=False, 23 | **kwargs): 24 | super(ConvTasNet, self).__init__() 25 | 26 | self.num_spk = 1 27 | 28 | self.FORCE_STFT_OUT = True 29 | 30 | #Encoder 31 | self.enc_dim = enc_dim 32 | self.win = int(fs*win/1000) 33 | self.stride = self.win // 2 34 | 35 | self.encoder = torch.nn.Conv1d(1, self.enc_dim, self.win, bias=False, stride=self.stride) 36 | 37 | #TCN 38 | self.feature_dim = feature_dim 39 | self.layer = layer 40 | self.stack = stack 41 | self.kernel = kernel 42 | self.causal = causal 43 | self.TCN = TCN(self.enc_dim, self.num_spk*self.enc_dim, self.feature_dim, self.feature_dim*4, 44 | self.layer, self.stack, self.kernel, causal=self.causal) 45 | self.total_receptive_field = self.stride * self.TCN.receptive_field #Take encoder into account 46 | 47 | #Learnt Decoder 48 | self.decoder = torch.nn.ConvTranspose1d(self.enc_dim, 1, self.win, bias=False, stride=self.stride) 49 | 50 | @staticmethod 51 | def add_argparse_args(parser): 52 | parser.add_argument("--causal", action="store_true", default=False) 53 | return parser 54 | 55 | def forward(self, input, *args, **ignored_kwargs): 56 | 57 | # padding 58 | output, rest = self.pad_signal(input) 59 | batch_size = output.size(0) 60 | 61 | # Encoder 62 | enc_output = self.encoder(output) # B, N, L 63 | 64 | # Separator 65 | masks = torch.sigmoid(self.TCN(enc_output)).view(batch_size, self.num_spk, self.enc_dim, -1) # B, C, N, L 66 | masked_output = enc_output.unsqueeze(1) * masks # B, C, N, L 67 | 68 | # Decoder 69 | output = self.decoder(masked_output.view(batch_size*self.num_spk, self.enc_dim, -1)) # B*C, 1, L 70 | output = output.squeeze(1) # B, T 71 | 72 | return output 73 | 74 | 75 | def pad_signal(self, input): 76 | 77 | # input is the waveforms: (B, T) or (B, 1, T) 78 | # reshape and padding 79 | if input.dim() not in [2, 3]: 80 | raise RuntimeError("Input can only be 2 or 3 dimensional.") 81 | 82 | if input.dim() == 2: 83 | input = input.unsqueeze(1) 84 | batch_size = input.size(0) 85 | nsample = input.size(2) 86 | rest = self.win - (self.stride + nsample % self.win) % self.win 87 | if rest > 0: 88 | pad = Variable(torch.zeros(batch_size, 1, rest)).type(input.type()) 89 | input = torch.cat([input, pad], 2) 90 | 91 | pad_aux = Variable(torch.zeros(batch_size, 1, self.stride)).type(input.type()) 92 | input = torch.cat([pad_aux, input, pad_aux], 2) 93 | 94 | return input, rest 95 | 96 | 97 | 98 | class cLN(nn.Module): 99 | def __init__(self, dimension, eps = 1e-8, trainable=True): 100 | super(cLN, self).__init__() 101 | 102 | self.eps = eps 103 | if trainable: 104 | self.gain = nn.Parameter(torch.ones(1, dimension, 1)) 105 | self.bias = nn.Parameter(torch.zeros(1, dimension, 1)) 106 | else: 107 | self.gain = Variable(torch.ones(1, dimension, 1), requires_grad=False) 108 | self.bias = Variable(torch.zeros(1, dimension, 1), requires_grad=False) 109 | 110 | def forward(self, input): 111 | # input size: (Batch, Freq, Time) 112 | # cumulative mean for each time step 113 | 114 | batch_size = input.size(0) 115 | channel = input.size(1) 116 | time_step = input.size(2) 117 | 118 | step_sum = input.sum(1) # B, T 119 | step_pow_sum = input.pow(2).sum(1) # B, T 120 | cum_sum = torch.cumsum(step_sum, dim=1) # B, T 121 | cum_pow_sum = torch.cumsum(step_pow_sum, dim=1) # B, T 122 | 123 | entry_cnt = np.arange(channel, channel*(time_step+1), channel) 124 | entry_cnt = torch.from_numpy(entry_cnt).type(input.type()) 125 | entry_cnt = entry_cnt.view(1, -1).expand_as(cum_sum) 126 | 127 | cum_mean = cum_sum / entry_cnt # B, T 128 | cum_var = (cum_pow_sum - 2*cum_mean*cum_sum) / entry_cnt + cum_mean.pow(2) # B, T 129 | cum_std = (cum_var + self.eps).sqrt() # B, T 130 | 131 | cum_mean = cum_mean.unsqueeze(1) 132 | cum_std = cum_std.unsqueeze(1) 133 | 134 | x = (input - cum_mean.expand_as(input)) / cum_std.expand_as(input) 135 | return x * self.gain.expand_as(x).type(x.type()) + self.bias.expand_as(x).type(x.type()) 136 | 137 | def repackage_hidden(h): 138 | """ 139 | Wraps hidden states in new Variables, to detach them from their history. 140 | """ 141 | 142 | if type(h) == Variable: 143 | return Variable(h.data) 144 | else: 145 | return tuple(repackage_hidden(v) for v in h) 146 | 147 | class MultiRNN(nn.Module): 148 | """ 149 | Container module for multiple stacked RNN layers. 150 | 151 | args: 152 | rnn_type: string, select from 'RNN', 'LSTM' and 'GRU'. 153 | input_size: int, dimension of the input feature. The input should have shape 154 | (batch, seq_len, input_size). 155 | hidden_size: int, dimension of the hidden state. The corresponding output should 156 | have shape (batch, seq_len, hidden_size). 157 | num_layers: int, number of stacked RNN layers. Default is 1. 158 | bidirectional: bool, whether the RNN layers are bidirectional. Default is False. 159 | """ 160 | 161 | def __init__(self, rnn_type, input_size, hidden_size, dropout=0, num_layers=1, bidirectional=False): 162 | super(MultiRNN, self).__init__() 163 | 164 | self.rnn = getattr(nn, rnn_type)(input_size, hidden_size, num_layers, dropout=dropout, 165 | batch_first=True, bidirectional=bidirectional) 166 | 167 | 168 | 169 | self.rnn_type = rnn_type 170 | self.hidden_size = hidden_size 171 | self.num_layers = num_layers 172 | self.num_direction = int(bidirectional) + 1 173 | 174 | def forward(self, input): 175 | hidden = self.init_hidden(input.size(0)) 176 | self.rnn.flatten_parameters() 177 | return self.rnn(input, hidden) 178 | 179 | def init_hidden(self, batch_size): 180 | weight = next(self.parameters()).data 181 | if self.rnn_type == 'LSTM': 182 | return (Variable(weight.new(self.num_layers*self.num_direction, batch_size, self.hidden_size).zero_()), 183 | Variable(weight.new(self.num_layers*self.num_direction, batch_size, self.hidden_size).zero_())) 184 | else: 185 | return Variable(weight.new(self.num_layers*self.num_direction, batch_size, self.hidden_size).zero_()) 186 | 187 | 188 | class FCLayer(nn.Module): 189 | """ 190 | Container module for a fully-connected layer. 191 | 192 | args: 193 | input_size: int, dimension of the input feature. The input should have shape 194 | (batch, input_size). 195 | hidden_size: int, dimension of the output. The corresponding output should 196 | have shape (batch, hidden_size). 197 | nonlinearity: string, the nonlinearity applied to the transformation. Default is None. 198 | """ 199 | 200 | def __init__(self, input_size, hidden_size, bias=True, nonlinearity=None): 201 | super(FCLayer, self).__init__() 202 | 203 | self.input_size = input_size 204 | self.hidden_size = hidden_size 205 | self.bias = bias 206 | self.FC = nn.Linear(self.input_size, self.hidden_size, bias=bias) 207 | if nonlinearity: 208 | self.nonlinearity = getattr(F, nonlinearity) 209 | else: 210 | self.nonlinearity = None 211 | 212 | self.init_hidden() 213 | 214 | def forward(self, input): 215 | if self.nonlinearity is not None: 216 | return self.nonlinearity(self.FC(input)) 217 | else: 218 | return self.FC(input) 219 | 220 | def init_hidden(self): 221 | initrange = 1. / np.sqrt(self.input_size * self.hidden_size) 222 | self.FC.weight.data.uniform_(-initrange, initrange) 223 | if self.bias: 224 | self.FC.bias.data.fill_(0) 225 | 226 | 227 | class DepthConv1d(nn.Module): 228 | 229 | def __init__(self, input_channel, hidden_channel, kernel, padding, dilation=1, skip=True, causal=False): 230 | super(DepthConv1d, self).__init__() 231 | 232 | self.causal = causal 233 | self.skip = skip 234 | 235 | self.conv1d = nn.Conv1d(input_channel, hidden_channel, 1) 236 | if self.causal: 237 | self.padding = (kernel - 1) * dilation 238 | else: 239 | self.padding = padding 240 | self.dconv1d = nn.Conv1d(hidden_channel, hidden_channel, kernel, dilation=dilation, 241 | groups=hidden_channel, 242 | padding=self.padding) 243 | self.res_out = nn.Conv1d(hidden_channel, input_channel, 1) 244 | self.nonlinearity1 = nn.PReLU() 245 | self.nonlinearity2 = nn.PReLU() 246 | if self.causal: 247 | self.reg1 = cLN(hidden_channel, eps=1e-08) 248 | self.reg2 = cLN(hidden_channel, eps=1e-08) 249 | else: 250 | self.reg1 = nn.GroupNorm(1, hidden_channel, eps=1e-08) 251 | self.reg2 = nn.GroupNorm(1, hidden_channel, eps=1e-08) 252 | 253 | if self.skip: 254 | self.skip_out = nn.Conv1d(hidden_channel, input_channel, 1) 255 | 256 | def forward(self, input): 257 | output = self.reg1(self.nonlinearity1(self.conv1d(input))) 258 | if self.causal: 259 | output = self.reg2(self.nonlinearity2(self.dconv1d(output)[:,:,:-self.padding])) 260 | else: 261 | output = self.reg2(self.nonlinearity2(self.dconv1d(output))) 262 | residual = self.res_out(output) 263 | if self.skip: 264 | skip = self.skip_out(output) 265 | return residual, skip 266 | else: 267 | return residual 268 | 269 | class TCN(nn.Module): 270 | def __init__(self, input_dim, output_dim, BN_dim, hidden_dim, 271 | layer, stack, kernel=3, skip=True, 272 | causal=False, dilated=True): 273 | super(TCN, self).__init__() 274 | 275 | # input is a sequence of features of shape (B, N, L) 276 | 277 | # normalization 278 | if not causal: 279 | self.LN = nn.GroupNorm(1, input_dim, eps=1e-8) 280 | else: 281 | self.LN = cLN(input_dim, eps=1e-8) 282 | 283 | self.BN = nn.Conv1d(input_dim, BN_dim, 1) 284 | 285 | # TCN for feature extraction 286 | self.receptive_field = 0 287 | self.dilated = dilated 288 | 289 | self.TCN = nn.ModuleList([]) 290 | for s in range(stack): 291 | for i in range(layer): 292 | if self.dilated: 293 | self.TCN.append(DepthConv1d(BN_dim, hidden_dim, kernel, dilation=2**i, padding=2**i, skip=skip, causal=causal)) 294 | else: 295 | self.TCN.append(DepthConv1d(BN_dim, hidden_dim, kernel, dilation=1, padding=1, skip=skip, causal=causal)) 296 | if i == 0 and s == 0: 297 | self.receptive_field += kernel 298 | else: 299 | if self.dilated: 300 | self.receptive_field += (kernel - 1) * 2**i 301 | else: 302 | self.receptive_field += (kernel - 1) 303 | 304 | #print("Receptive field: {:3d} frames.".format(self.receptive_field)) 305 | 306 | # output layer 307 | 308 | self.output = nn.Sequential(nn.PReLU(), 309 | nn.Conv1d(BN_dim, output_dim, 1) 310 | ) 311 | 312 | self.skip = skip 313 | 314 | def forward(self, input): 315 | 316 | # input shape: (B, N, L) 317 | 318 | # normalization 319 | output = self.BN(self.LN(input)) 320 | 321 | # pass to TCN 322 | if self.skip: 323 | skip_connection = 0. 324 | for i in range(len(self.TCN)): 325 | residual, skip = self.TCN[i](output) 326 | output = output + residual 327 | skip_connection = skip_connection + skip 328 | else: 329 | for i in range(len(self.TCN)): 330 | residual = self.TCN[i](output) 331 | output = output + residual 332 | 333 | # output layer 334 | if self.skip: 335 | output = self.output(skip_connection) 336 | else: 337 | output = self.output(output) 338 | 339 | return output 340 | 341 | 342 | 343 | 344 | class TCNKoyama33(nn.Module): 345 | def __init__(self, input_dim, output_dim, BN_dim, hidden_dim, 346 | layer, stack, kernel=3, skip=True, **ignored_kwargs): 347 | 348 | super(TCNKoyama33, self).__init__() 349 | # input is a sequence of features of shape (B, N, L) 350 | 351 | # normalization 352 | self.LN = cLN(input_dim, eps=1e-8) 353 | self.BN = nn.Conv1d(input_dim, BN_dim, 1) 354 | 355 | # TCN for feature extraction 356 | self.receptive_field = 0 357 | self.dilated = True 358 | 359 | self.TCN = nn.ModuleList([]) 360 | for s in range(stack): 361 | for i in range(layer): 362 | self.TCN.append(DepthConv1d(BN_dim, hidden_dim, kernel, dilation=2**i, padding=2**i, skip=skip, causal=(i>4) )) 363 | if i == 0 and s == 0: 364 | self.receptive_field += kernel 365 | else: 366 | self.receptive_field += (kernel - 1) * 2**i 367 | 368 | # print("Receptive field: {:3d} frames.".format(self.receptive_field)) 369 | # output layer 370 | self.output = nn.Sequential(nn.PReLU(), 371 | nn.Conv1d(BN_dim, output_dim, 1) 372 | ) 373 | 374 | self.skip = skip 375 | 376 | 377 | def forward(self, input): 378 | 379 | # input shape: (B, N, L) 380 | 381 | # normalization 382 | output = self.BN(self.LN(input)) 383 | 384 | # pass to TCN 385 | if self.skip: 386 | skip_connection = 0. 387 | for i in range(len(self.TCN)): 388 | residual, skip = self.TCN[i](output) 389 | output = output + residual 390 | skip_connection = skip_connection + skip 391 | else: 392 | for i in range(len(self.TCN)): 393 | residual = self.TCN[i](output) 394 | output = output + residual 395 | 396 | # output layer 397 | if self.skip: 398 | output = self.output(skip_connection) 399 | else: 400 | output = self.output(output) 401 | 402 | return output 403 | -------------------------------------------------------------------------------- /sgmse/backbones/ncsnpp_utils/op/upfirdn2d_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) { 18 | int c = a / b; 19 | 20 | if (c * b > a) { 21 | c--; 22 | } 23 | 24 | return c; 25 | } 26 | 27 | struct UpFirDn2DKernelParams { 28 | int up_x; 29 | int up_y; 30 | int down_x; 31 | int down_y; 32 | int pad_x0; 33 | int pad_x1; 34 | int pad_y0; 35 | int pad_y1; 36 | 37 | int major_dim; 38 | int in_h; 39 | int in_w; 40 | int minor_dim; 41 | int kernel_h; 42 | int kernel_w; 43 | int out_h; 44 | int out_w; 45 | int loop_major; 46 | int loop_x; 47 | }; 48 | 49 | template 50 | __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, 51 | const scalar_t *kernel, 52 | const UpFirDn2DKernelParams p) { 53 | int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; 54 | int out_y = minor_idx / p.minor_dim; 55 | minor_idx -= out_y * p.minor_dim; 56 | int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; 57 | int major_idx_base = blockIdx.z * p.loop_major; 58 | 59 | if (out_x_base >= p.out_w || out_y >= p.out_h || 60 | major_idx_base >= p.major_dim) { 61 | return; 62 | } 63 | 64 | int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; 65 | int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); 66 | int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; 67 | int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; 68 | 69 | for (int loop_major = 0, major_idx = major_idx_base; 70 | loop_major < p.loop_major && major_idx < p.major_dim; 71 | loop_major++, major_idx++) { 72 | for (int loop_x = 0, out_x = out_x_base; 73 | loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { 74 | int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; 75 | int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); 76 | int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; 77 | int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; 78 | 79 | const scalar_t *x_p = 80 | &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + 81 | minor_idx]; 82 | const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; 83 | int x_px = p.minor_dim; 84 | int k_px = -p.up_x; 85 | int x_py = p.in_w * p.minor_dim; 86 | int k_py = -p.up_y * p.kernel_w; 87 | 88 | scalar_t v = 0.0f; 89 | 90 | for (int y = 0; y < h; y++) { 91 | for (int x = 0; x < w; x++) { 92 | v += static_cast(*x_p) * static_cast(*k_p); 93 | x_p += x_px; 94 | k_p += k_px; 95 | } 96 | 97 | x_p += x_py - w * x_px; 98 | k_p += k_py - w * k_px; 99 | } 100 | 101 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 102 | minor_idx] = v; 103 | } 104 | } 105 | } 106 | 107 | template 109 | __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, 110 | const scalar_t *kernel, 111 | const UpFirDn2DKernelParams p) { 112 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; 113 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; 114 | 115 | __shared__ volatile float sk[kernel_h][kernel_w]; 116 | __shared__ volatile float sx[tile_in_h][tile_in_w]; 117 | 118 | int minor_idx = blockIdx.x; 119 | int tile_out_y = minor_idx / p.minor_dim; 120 | minor_idx -= tile_out_y * p.minor_dim; 121 | tile_out_y *= tile_out_h; 122 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; 123 | int major_idx_base = blockIdx.z * p.loop_major; 124 | 125 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | 126 | major_idx_base >= p.major_dim) { 127 | return; 128 | } 129 | 130 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; 131 | tap_idx += blockDim.x) { 132 | int ky = tap_idx / kernel_w; 133 | int kx = tap_idx - ky * kernel_w; 134 | scalar_t v = 0.0; 135 | 136 | if (kx < p.kernel_w & ky < p.kernel_h) { 137 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; 138 | } 139 | 140 | sk[ky][kx] = v; 141 | } 142 | 143 | for (int loop_major = 0, major_idx = major_idx_base; 144 | loop_major < p.loop_major & major_idx < p.major_dim; 145 | loop_major++, major_idx++) { 146 | for (int loop_x = 0, tile_out_x = tile_out_x_base; 147 | loop_x < p.loop_x & tile_out_x < p.out_w; 148 | loop_x++, tile_out_x += tile_out_w) { 149 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; 150 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; 151 | int tile_in_x = floor_div(tile_mid_x, up_x); 152 | int tile_in_y = floor_div(tile_mid_y, up_y); 153 | 154 | __syncthreads(); 155 | 156 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; 157 | in_idx += blockDim.x) { 158 | int rel_in_y = in_idx / tile_in_w; 159 | int rel_in_x = in_idx - rel_in_y * tile_in_w; 160 | int in_x = rel_in_x + tile_in_x; 161 | int in_y = rel_in_y + tile_in_y; 162 | 163 | scalar_t v = 0.0; 164 | 165 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { 166 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * 167 | p.minor_dim + 168 | minor_idx]; 169 | } 170 | 171 | sx[rel_in_y][rel_in_x] = v; 172 | } 173 | 174 | __syncthreads(); 175 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; 176 | out_idx += blockDim.x) { 177 | int rel_out_y = out_idx / tile_out_w; 178 | int rel_out_x = out_idx - rel_out_y * tile_out_w; 179 | int out_x = rel_out_x + tile_out_x; 180 | int out_y = rel_out_y + tile_out_y; 181 | 182 | int mid_x = tile_mid_x + rel_out_x * down_x; 183 | int mid_y = tile_mid_y + rel_out_y * down_y; 184 | int in_x = floor_div(mid_x, up_x); 185 | int in_y = floor_div(mid_y, up_y); 186 | int rel_in_x = in_x - tile_in_x; 187 | int rel_in_y = in_y - tile_in_y; 188 | int kernel_x = (in_x + 1) * up_x - mid_x - 1; 189 | int kernel_y = (in_y + 1) * up_y - mid_y - 1; 190 | 191 | scalar_t v = 0.0; 192 | 193 | #pragma unroll 194 | for (int y = 0; y < kernel_h / up_y; y++) 195 | #pragma unroll 196 | for (int x = 0; x < kernel_w / up_x; x++) 197 | v += sx[rel_in_y + y][rel_in_x + x] * 198 | sk[kernel_y + y * up_y][kernel_x + x * up_x]; 199 | 200 | if (out_x < p.out_w & out_y < p.out_h) { 201 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 202 | minor_idx] = v; 203 | } 204 | } 205 | } 206 | } 207 | } 208 | 209 | torch::Tensor upfirdn2d_op(const torch::Tensor &input, 210 | const torch::Tensor &kernel, int up_x, int up_y, 211 | int down_x, int down_y, int pad_x0, int pad_x1, 212 | int pad_y0, int pad_y1) { 213 | int curDevice = -1; 214 | cudaGetDevice(&curDevice); 215 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 216 | 217 | UpFirDn2DKernelParams p; 218 | 219 | auto x = input.contiguous(); 220 | auto k = kernel.contiguous(); 221 | 222 | p.major_dim = x.size(0); 223 | p.in_h = x.size(1); 224 | p.in_w = x.size(2); 225 | p.minor_dim = x.size(3); 226 | p.kernel_h = k.size(0); 227 | p.kernel_w = k.size(1); 228 | p.up_x = up_x; 229 | p.up_y = up_y; 230 | p.down_x = down_x; 231 | p.down_y = down_y; 232 | p.pad_x0 = pad_x0; 233 | p.pad_x1 = pad_x1; 234 | p.pad_y0 = pad_y0; 235 | p.pad_y1 = pad_y1; 236 | 237 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / 238 | p.down_y; 239 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / 240 | p.down_x; 241 | 242 | auto out = 243 | at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); 244 | 245 | int mode = -1; 246 | 247 | int tile_out_h = -1; 248 | int tile_out_w = -1; 249 | 250 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 251 | p.kernel_h <= 4 && p.kernel_w <= 4) { 252 | mode = 1; 253 | tile_out_h = 16; 254 | tile_out_w = 64; 255 | } 256 | 257 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 258 | p.kernel_h <= 3 && p.kernel_w <= 3) { 259 | mode = 2; 260 | tile_out_h = 16; 261 | tile_out_w = 64; 262 | } 263 | 264 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 265 | p.kernel_h <= 4 && p.kernel_w <= 4) { 266 | mode = 3; 267 | tile_out_h = 16; 268 | tile_out_w = 64; 269 | } 270 | 271 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 272 | p.kernel_h <= 2 && p.kernel_w <= 2) { 273 | mode = 4; 274 | tile_out_h = 16; 275 | tile_out_w = 64; 276 | } 277 | 278 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 279 | p.kernel_h <= 4 && p.kernel_w <= 4) { 280 | mode = 5; 281 | tile_out_h = 8; 282 | tile_out_w = 32; 283 | } 284 | 285 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 286 | p.kernel_h <= 2 && p.kernel_w <= 2) { 287 | mode = 6; 288 | tile_out_h = 8; 289 | tile_out_w = 32; 290 | } 291 | 292 | dim3 block_size; 293 | dim3 grid_size; 294 | 295 | if (tile_out_h > 0 && tile_out_w > 0) { 296 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 297 | p.loop_x = 1; 298 | block_size = dim3(32 * 8, 1, 1); 299 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, 300 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, 301 | (p.major_dim - 1) / p.loop_major + 1); 302 | } else { 303 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 304 | p.loop_x = 4; 305 | block_size = dim3(4, 32, 1); 306 | grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, 307 | (p.out_w - 1) / (p.loop_x * block_size.y) + 1, 308 | (p.major_dim - 1) / p.loop_major + 1); 309 | } 310 | 311 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { 312 | switch (mode) { 313 | case 1: 314 | upfirdn2d_kernel 315 | <<>>(out.data_ptr(), 316 | x.data_ptr(), 317 | k.data_ptr(), p); 318 | 319 | break; 320 | 321 | case 2: 322 | upfirdn2d_kernel 323 | <<>>(out.data_ptr(), 324 | x.data_ptr(), 325 | k.data_ptr(), p); 326 | 327 | break; 328 | 329 | case 3: 330 | upfirdn2d_kernel 331 | <<>>(out.data_ptr(), 332 | x.data_ptr(), 333 | k.data_ptr(), p); 334 | 335 | break; 336 | 337 | case 4: 338 | upfirdn2d_kernel 339 | <<>>(out.data_ptr(), 340 | x.data_ptr(), 341 | k.data_ptr(), p); 342 | 343 | break; 344 | 345 | case 5: 346 | upfirdn2d_kernel 347 | <<>>(out.data_ptr(), 348 | x.data_ptr(), 349 | k.data_ptr(), p); 350 | 351 | break; 352 | 353 | case 6: 354 | upfirdn2d_kernel 355 | <<>>(out.data_ptr(), 356 | x.data_ptr(), 357 | k.data_ptr(), p); 358 | 359 | break; 360 | 361 | default: 362 | upfirdn2d_kernel_large<<>>( 363 | out.data_ptr(), x.data_ptr(), 364 | k.data_ptr(), p); 365 | } 366 | }); 367 | 368 | return out; 369 | } -------------------------------------------------------------------------------- /sgmse/backbones/ncsnpp_utils/layers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # pylint: skip-file 17 | """Common layers for defining score networks. 18 | """ 19 | import math 20 | import string 21 | from functools import partial 22 | import torch.nn as nn 23 | import torch 24 | import torch.nn.functional as F 25 | import numpy as np 26 | from .normalization import ConditionalInstanceNorm2dPlus 27 | 28 | 29 | def get_act(config): 30 | """Get activation functions from the config file.""" 31 | 32 | if config == 'elu': 33 | return nn.ELU() 34 | elif config == 'relu': 35 | return nn.ReLU() 36 | elif config == 'lrelu': 37 | return nn.LeakyReLU(negative_slope=0.2) 38 | elif config == 'swish': 39 | return nn.SiLU() 40 | else: 41 | raise NotImplementedError('activation function does not exist!') 42 | 43 | 44 | def ncsn_conv1x1(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=0): 45 | """1x1 convolution. Same as NCSNv1/v2.""" 46 | conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias, dilation=dilation, 47 | padding=padding) 48 | init_scale = 1e-10 if init_scale == 0 else init_scale 49 | conv.weight.data *= init_scale 50 | conv.bias.data *= init_scale 51 | return conv 52 | 53 | 54 | def variance_scaling(scale, mode, distribution, 55 | in_axis=1, out_axis=0, 56 | dtype=torch.float32, 57 | device='cpu'): 58 | """Ported from JAX. """ 59 | 60 | def _compute_fans(shape, in_axis=1, out_axis=0): 61 | receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis] 62 | fan_in = shape[in_axis] * receptive_field_size 63 | fan_out = shape[out_axis] * receptive_field_size 64 | return fan_in, fan_out 65 | 66 | def init(shape, dtype=dtype, device=device): 67 | fan_in, fan_out = _compute_fans(shape, in_axis, out_axis) 68 | if mode == "fan_in": 69 | denominator = fan_in 70 | elif mode == "fan_out": 71 | denominator = fan_out 72 | elif mode == "fan_avg": 73 | denominator = (fan_in + fan_out) / 2 74 | else: 75 | raise ValueError( 76 | "invalid mode for variance scaling initializer: {}".format(mode)) 77 | variance = scale / denominator 78 | if distribution == "normal": 79 | return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance) 80 | elif distribution == "uniform": 81 | return (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance) 82 | else: 83 | raise ValueError("invalid distribution for variance scaling initializer") 84 | 85 | return init 86 | 87 | 88 | def default_init(scale=1.): 89 | """The same initialization used in DDPM.""" 90 | scale = 1e-10 if scale == 0 else scale 91 | return variance_scaling(scale, 'fan_avg', 'uniform') 92 | 93 | 94 | class Dense(nn.Module): 95 | """Linear layer with `default_init`.""" 96 | def __init__(self): 97 | super().__init__() 98 | 99 | 100 | def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1., padding=0): 101 | """1x1 convolution with DDPM initialization.""" 102 | conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias) 103 | conv.weight.data = default_init(init_scale)(conv.weight.data.shape) 104 | if bias: 105 | nn.init.zeros_(conv.bias) 106 | return conv 107 | 108 | 109 | def ncsn_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1): 110 | """3x3 convolution with PyTorch initialization. Same as NCSNv1/NCSNv2.""" 111 | init_scale = 1e-10 if init_scale == 0 else init_scale 112 | conv = nn.Conv2d(in_planes, out_planes, stride=stride, bias=bias, 113 | dilation=dilation, padding=padding, kernel_size=3) 114 | conv.weight.data *= init_scale 115 | conv.bias.data *= init_scale 116 | return conv 117 | 118 | 119 | def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1): 120 | """3x3 convolution with DDPM initialization.""" 121 | conv = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, 122 | dilation=dilation, bias=bias) 123 | conv.weight.data = default_init(init_scale)(conv.weight.data.shape) 124 | if bias: 125 | nn.init.zeros_(conv.bias) 126 | return conv 127 | 128 | ########################################################################### 129 | # Functions below are ported over from the NCSNv1/NCSNv2 codebase: 130 | # https://github.com/ermongroup/ncsn 131 | # https://github.com/ermongroup/ncsnv2 132 | ########################################################################### 133 | 134 | 135 | class CRPBlock(nn.Module): 136 | def __init__(self, features, n_stages, act=nn.ReLU(), maxpool=True): 137 | super().__init__() 138 | self.convs = nn.ModuleList() 139 | for i in range(n_stages): 140 | self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False)) 141 | self.n_stages = n_stages 142 | if maxpool: 143 | self.pool = nn.MaxPool2d(kernel_size=5, stride=1, padding=2) 144 | else: 145 | self.pool = nn.AvgPool2d(kernel_size=5, stride=1, padding=2) 146 | 147 | self.act = act 148 | 149 | def forward(self, x): 150 | x = self.act(x) 151 | path = x 152 | for i in range(self.n_stages): 153 | path = self.pool(path) 154 | path = self.convs[i](path) 155 | x = path + x 156 | return x 157 | 158 | 159 | class CondCRPBlock(nn.Module): 160 | def __init__(self, features, n_stages, num_classes, normalizer, act=nn.ReLU()): 161 | super().__init__() 162 | self.convs = nn.ModuleList() 163 | self.norms = nn.ModuleList() 164 | self.normalizer = normalizer 165 | for i in range(n_stages): 166 | self.norms.append(normalizer(features, num_classes, bias=True)) 167 | self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False)) 168 | 169 | self.n_stages = n_stages 170 | self.pool = nn.AvgPool2d(kernel_size=5, stride=1, padding=2) 171 | self.act = act 172 | 173 | def forward(self, x, y): 174 | x = self.act(x) 175 | path = x 176 | for i in range(self.n_stages): 177 | path = self.norms[i](path, y) 178 | path = self.pool(path) 179 | path = self.convs[i](path) 180 | 181 | x = path + x 182 | return x 183 | 184 | 185 | class RCUBlock(nn.Module): 186 | def __init__(self, features, n_blocks, n_stages, act=nn.ReLU()): 187 | super().__init__() 188 | 189 | for i in range(n_blocks): 190 | for j in range(n_stages): 191 | setattr(self, '{}_{}_conv'.format(i + 1, j + 1), ncsn_conv3x3(features, features, stride=1, bias=False)) 192 | 193 | self.stride = 1 194 | self.n_blocks = n_blocks 195 | self.n_stages = n_stages 196 | self.act = act 197 | 198 | def forward(self, x): 199 | for i in range(self.n_blocks): 200 | residual = x 201 | for j in range(self.n_stages): 202 | x = self.act(x) 203 | x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x) 204 | 205 | x += residual 206 | return x 207 | 208 | 209 | class CondRCUBlock(nn.Module): 210 | def __init__(self, features, n_blocks, n_stages, num_classes, normalizer, act=nn.ReLU()): 211 | super().__init__() 212 | 213 | for i in range(n_blocks): 214 | for j in range(n_stages): 215 | setattr(self, '{}_{}_norm'.format(i + 1, j + 1), normalizer(features, num_classes, bias=True)) 216 | setattr(self, '{}_{}_conv'.format(i + 1, j + 1), ncsn_conv3x3(features, features, stride=1, bias=False)) 217 | 218 | self.stride = 1 219 | self.n_blocks = n_blocks 220 | self.n_stages = n_stages 221 | self.act = act 222 | self.normalizer = normalizer 223 | 224 | def forward(self, x, y): 225 | for i in range(self.n_blocks): 226 | residual = x 227 | for j in range(self.n_stages): 228 | x = getattr(self, '{}_{}_norm'.format(i + 1, j + 1))(x, y) 229 | x = self.act(x) 230 | x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x) 231 | 232 | x += residual 233 | return x 234 | 235 | 236 | class MSFBlock(nn.Module): 237 | def __init__(self, in_planes, features): 238 | super().__init__() 239 | assert isinstance(in_planes, list) or isinstance(in_planes, tuple) 240 | self.convs = nn.ModuleList() 241 | self.features = features 242 | 243 | for i in range(len(in_planes)): 244 | self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True)) 245 | 246 | def forward(self, xs, shape): 247 | sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device) 248 | for i in range(len(self.convs)): 249 | h = self.convs[i](xs[i]) 250 | h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True) 251 | sums += h 252 | return sums 253 | 254 | 255 | class CondMSFBlock(nn.Module): 256 | def __init__(self, in_planes, features, num_classes, normalizer): 257 | super().__init__() 258 | assert isinstance(in_planes, list) or isinstance(in_planes, tuple) 259 | 260 | self.convs = nn.ModuleList() 261 | self.norms = nn.ModuleList() 262 | self.features = features 263 | self.normalizer = normalizer 264 | 265 | for i in range(len(in_planes)): 266 | self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True)) 267 | self.norms.append(normalizer(in_planes[i], num_classes, bias=True)) 268 | 269 | def forward(self, xs, y, shape): 270 | sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device) 271 | for i in range(len(self.convs)): 272 | h = self.norms[i](xs[i], y) 273 | h = self.convs[i](h) 274 | h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True) 275 | sums += h 276 | return sums 277 | 278 | 279 | class RefineBlock(nn.Module): 280 | def __init__(self, in_planes, features, act=nn.ReLU(), start=False, end=False, maxpool=True): 281 | super().__init__() 282 | 283 | assert isinstance(in_planes, tuple) or isinstance(in_planes, list) 284 | self.n_blocks = n_blocks = len(in_planes) 285 | 286 | self.adapt_convs = nn.ModuleList() 287 | for i in range(n_blocks): 288 | self.adapt_convs.append(RCUBlock(in_planes[i], 2, 2, act)) 289 | 290 | self.output_convs = RCUBlock(features, 3 if end else 1, 2, act) 291 | 292 | if not start: 293 | self.msf = MSFBlock(in_planes, features) 294 | 295 | self.crp = CRPBlock(features, 2, act, maxpool=maxpool) 296 | 297 | def forward(self, xs, output_shape): 298 | assert isinstance(xs, tuple) or isinstance(xs, list) 299 | hs = [] 300 | for i in range(len(xs)): 301 | h = self.adapt_convs[i](xs[i]) 302 | hs.append(h) 303 | 304 | if self.n_blocks > 1: 305 | h = self.msf(hs, output_shape) 306 | else: 307 | h = hs[0] 308 | 309 | h = self.crp(h) 310 | h = self.output_convs(h) 311 | 312 | return h 313 | 314 | 315 | class CondRefineBlock(nn.Module): 316 | def __init__(self, in_planes, features, num_classes, normalizer, act=nn.ReLU(), start=False, end=False): 317 | super().__init__() 318 | 319 | assert isinstance(in_planes, tuple) or isinstance(in_planes, list) 320 | self.n_blocks = n_blocks = len(in_planes) 321 | 322 | self.adapt_convs = nn.ModuleList() 323 | for i in range(n_blocks): 324 | self.adapt_convs.append( 325 | CondRCUBlock(in_planes[i], 2, 2, num_classes, normalizer, act) 326 | ) 327 | 328 | self.output_convs = CondRCUBlock(features, 3 if end else 1, 2, num_classes, normalizer, act) 329 | 330 | if not start: 331 | self.msf = CondMSFBlock(in_planes, features, num_classes, normalizer) 332 | 333 | self.crp = CondCRPBlock(features, 2, num_classes, normalizer, act) 334 | 335 | def forward(self, xs, y, output_shape): 336 | assert isinstance(xs, tuple) or isinstance(xs, list) 337 | hs = [] 338 | for i in range(len(xs)): 339 | h = self.adapt_convs[i](xs[i], y) 340 | hs.append(h) 341 | 342 | if self.n_blocks > 1: 343 | h = self.msf(hs, y, output_shape) 344 | else: 345 | h = hs[0] 346 | 347 | h = self.crp(h, y) 348 | h = self.output_convs(h, y) 349 | 350 | return h 351 | 352 | 353 | class ConvMeanPool(nn.Module): 354 | def __init__(self, input_dim, output_dim, kernel_size=3, biases=True, adjust_padding=False): 355 | super().__init__() 356 | if not adjust_padding: 357 | conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases) 358 | self.conv = conv 359 | else: 360 | conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases) 361 | 362 | self.conv = nn.Sequential( 363 | nn.ZeroPad2d((1, 0, 1, 0)), 364 | conv 365 | ) 366 | 367 | def forward(self, inputs): 368 | output = self.conv(inputs) 369 | output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2], 370 | output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4. 371 | return output 372 | 373 | 374 | class MeanPoolConv(nn.Module): 375 | def __init__(self, input_dim, output_dim, kernel_size=3, biases=True): 376 | super().__init__() 377 | self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases) 378 | 379 | def forward(self, inputs): 380 | output = inputs 381 | output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2], 382 | output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4. 383 | return self.conv(output) 384 | 385 | 386 | class UpsampleConv(nn.Module): 387 | def __init__(self, input_dim, output_dim, kernel_size=3, biases=True): 388 | super().__init__() 389 | self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases) 390 | self.pixelshuffle = nn.PixelShuffle(upscale_factor=2) 391 | 392 | def forward(self, inputs): 393 | output = inputs 394 | output = torch.cat([output, output, output, output], dim=1) 395 | output = self.pixelshuffle(output) 396 | return self.conv(output) 397 | 398 | 399 | class ConditionalResidualBlock(nn.Module): 400 | def __init__(self, input_dim, output_dim, num_classes, resample=1, act=nn.ELU(), 401 | normalization=ConditionalInstanceNorm2dPlus, adjust_padding=False, dilation=None): 402 | super().__init__() 403 | self.non_linearity = act 404 | self.input_dim = input_dim 405 | self.output_dim = output_dim 406 | self.resample = resample 407 | self.normalization = normalization 408 | if resample == 'down': 409 | if dilation > 1: 410 | self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation) 411 | self.normalize2 = normalization(input_dim, num_classes) 412 | self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation) 413 | conv_shortcut = partial(ncsn_conv3x3, dilation=dilation) 414 | else: 415 | self.conv1 = ncsn_conv3x3(input_dim, input_dim) 416 | self.normalize2 = normalization(input_dim, num_classes) 417 | self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding) 418 | conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding) 419 | 420 | elif resample is None: 421 | if dilation > 1: 422 | conv_shortcut = partial(ncsn_conv3x3, dilation=dilation) 423 | self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation) 424 | self.normalize2 = normalization(output_dim, num_classes) 425 | self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation) 426 | else: 427 | conv_shortcut = nn.Conv2d 428 | self.conv1 = ncsn_conv3x3(input_dim, output_dim) 429 | self.normalize2 = normalization(output_dim, num_classes) 430 | self.conv2 = ncsn_conv3x3(output_dim, output_dim) 431 | else: 432 | raise Exception('invalid resample value') 433 | 434 | if output_dim != input_dim or resample is not None: 435 | self.shortcut = conv_shortcut(input_dim, output_dim) 436 | 437 | self.normalize1 = normalization(input_dim, num_classes) 438 | 439 | def forward(self, x, y): 440 | output = self.normalize1(x, y) 441 | output = self.non_linearity(output) 442 | output = self.conv1(output) 443 | output = self.normalize2(output, y) 444 | output = self.non_linearity(output) 445 | output = self.conv2(output) 446 | 447 | if self.output_dim == self.input_dim and self.resample is None: 448 | shortcut = x 449 | else: 450 | shortcut = self.shortcut(x) 451 | 452 | return shortcut + output 453 | 454 | 455 | class ResidualBlock(nn.Module): 456 | def __init__(self, input_dim, output_dim, resample=None, act=nn.ELU(), 457 | normalization=nn.InstanceNorm2d, adjust_padding=False, dilation=1): 458 | super().__init__() 459 | self.non_linearity = act 460 | self.input_dim = input_dim 461 | self.output_dim = output_dim 462 | self.resample = resample 463 | self.normalization = normalization 464 | if resample == 'down': 465 | if dilation > 1: 466 | self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation) 467 | self.normalize2 = normalization(input_dim) 468 | self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation) 469 | conv_shortcut = partial(ncsn_conv3x3, dilation=dilation) 470 | else: 471 | self.conv1 = ncsn_conv3x3(input_dim, input_dim) 472 | self.normalize2 = normalization(input_dim) 473 | self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding) 474 | conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding) 475 | 476 | elif resample is None: 477 | if dilation > 1: 478 | conv_shortcut = partial(ncsn_conv3x3, dilation=dilation) 479 | self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation) 480 | self.normalize2 = normalization(output_dim) 481 | self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation) 482 | else: 483 | # conv_shortcut = nn.Conv2d ### Something wierd here. 484 | conv_shortcut = partial(ncsn_conv1x1) 485 | self.conv1 = ncsn_conv3x3(input_dim, output_dim) 486 | self.normalize2 = normalization(output_dim) 487 | self.conv2 = ncsn_conv3x3(output_dim, output_dim) 488 | else: 489 | raise Exception('invalid resample value') 490 | 491 | if output_dim != input_dim or resample is not None: 492 | self.shortcut = conv_shortcut(input_dim, output_dim) 493 | 494 | self.normalize1 = normalization(input_dim) 495 | 496 | def forward(self, x): 497 | output = self.normalize1(x) 498 | output = self.non_linearity(output) 499 | output = self.conv1(output) 500 | output = self.normalize2(output) 501 | output = self.non_linearity(output) 502 | output = self.conv2(output) 503 | 504 | if self.output_dim == self.input_dim and self.resample is None: 505 | shortcut = x 506 | else: 507 | shortcut = self.shortcut(x) 508 | 509 | return shortcut + output 510 | 511 | 512 | ########################################################################### 513 | # Functions below are ported over from the DDPM codebase: 514 | # https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py 515 | ########################################################################### 516 | 517 | def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000): 518 | assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32 519 | half_dim = embedding_dim // 2 520 | # magic number 10000 is from transformers 521 | emb = math.log(max_positions) / (half_dim - 1) 522 | # emb = math.log(2.) / (half_dim - 1) 523 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) 524 | # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :] 525 | # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :] 526 | emb = timesteps.float()[:, None] * emb[None, :] 527 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) 528 | if embedding_dim % 2 == 1: # zero pad 529 | emb = F.pad(emb, (0, 1), mode='constant') 530 | assert emb.shape == (timesteps.shape[0], embedding_dim) 531 | return emb 532 | 533 | 534 | def _einsum(a, b, c, x, y): 535 | einsum_str = '{},{}->{}'.format(''.join(a), ''.join(b), ''.join(c)) 536 | return torch.einsum(einsum_str, x, y) 537 | 538 | 539 | def contract_inner(x, y): 540 | """tensordot(x, y, 1).""" 541 | x_chars = list(string.ascii_lowercase[:len(x.shape)]) 542 | y_chars = list(string.ascii_lowercase[len(x.shape):len(y.shape) + len(x.shape)]) 543 | y_chars[0] = x_chars[-1] # first axis of y and last of x get summed 544 | out_chars = x_chars[:-1] + y_chars[1:] 545 | return _einsum(x_chars, y_chars, out_chars, x, y) 546 | 547 | 548 | class NIN(nn.Module): 549 | def __init__(self, in_dim, num_units, init_scale=0.1): 550 | super().__init__() 551 | self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True) 552 | self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True) 553 | 554 | def forward(self, x): 555 | x = x.permute(0, 2, 3, 1) 556 | y = contract_inner(x, self.W) + self.b 557 | return y.permute(0, 3, 1, 2) 558 | 559 | 560 | 561 | class AttnBlock(nn.Module): 562 | """Channel-wise self-attention block.""" 563 | def __init__(self, channels): 564 | super().__init__() 565 | self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6) 566 | self.NIN_0 = NIN(channels, channels) 567 | self.NIN_1 = NIN(channels, channels) 568 | self.NIN_2 = NIN(channels, channels) 569 | self.NIN_3 = NIN(channels, channels, init_scale=0.) 570 | 571 | def forward(self, x): 572 | B, C, H, W = x.shape 573 | h = self.GroupNorm_0(x) 574 | q = self.NIN_0(h) 575 | k = self.NIN_1(h) 576 | v = self.NIN_2(h) 577 | 578 | w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5)) 579 | w = torch.reshape(w, (B, H, W, H * W)) 580 | w = F.softmax(w, dim=-1) 581 | w = torch.reshape(w, (B, H, W, H, W)) 582 | h = torch.einsum('bhwij,bcij->bchw', w, v) 583 | h = self.NIN_3(h) 584 | return x + h 585 | 586 | 587 | class Upsample(nn.Module): 588 | def __init__(self, channels, with_conv=False): 589 | super().__init__() 590 | if with_conv: 591 | self.Conv_0 = ddpm_conv3x3(channels, channels) 592 | self.with_conv = with_conv 593 | 594 | def forward(self, x): 595 | B, C, H, W = x.shape 596 | h = F.interpolate(x, (H * 2, W * 2), mode='nearest') 597 | if self.with_conv: 598 | h = self.Conv_0(h) 599 | return h 600 | 601 | 602 | class Downsample(nn.Module): 603 | def __init__(self, channels, with_conv=False): 604 | super().__init__() 605 | if with_conv: 606 | self.Conv_0 = ddpm_conv3x3(channels, channels, stride=2, padding=0) 607 | self.with_conv = with_conv 608 | 609 | def forward(self, x): 610 | B, C, H, W = x.shape 611 | # Emulate 'SAME' padding 612 | if self.with_conv: 613 | x = F.pad(x, (0, 1, 0, 1)) 614 | x = self.Conv_0(x) 615 | else: 616 | x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=0) 617 | 618 | assert x.shape == (B, C, H // 2, W // 2) 619 | return x 620 | 621 | 622 | class ResnetBlockDDPM(nn.Module): 623 | """The ResNet Blocks used in DDPM.""" 624 | def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False, dropout=0.1): 625 | super().__init__() 626 | if out_ch is None: 627 | out_ch = in_ch 628 | self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=in_ch, eps=1e-6) 629 | self.act = act 630 | self.Conv_0 = ddpm_conv3x3(in_ch, out_ch) 631 | if temb_dim is not None: 632 | self.Dense_0 = nn.Linear(temb_dim, out_ch) 633 | self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape) 634 | nn.init.zeros_(self.Dense_0.bias) 635 | 636 | self.GroupNorm_1 = nn.GroupNorm(num_groups=32, num_channels=out_ch, eps=1e-6) 637 | self.Dropout_0 = nn.Dropout(dropout) 638 | self.Conv_1 = ddpm_conv3x3(out_ch, out_ch, init_scale=0.) 639 | if in_ch != out_ch: 640 | if conv_shortcut: 641 | self.Conv_2 = ddpm_conv3x3(in_ch, out_ch) 642 | else: 643 | self.NIN_0 = NIN(in_ch, out_ch) 644 | self.out_ch = out_ch 645 | self.in_ch = in_ch 646 | self.conv_shortcut = conv_shortcut 647 | 648 | def forward(self, x, temb=None): 649 | B, C, H, W = x.shape 650 | assert C == self.in_ch 651 | out_ch = self.out_ch if self.out_ch else self.in_ch 652 | h = self.act(self.GroupNorm_0(x)) 653 | h = self.Conv_0(h) 654 | # Add bias to each feature map conditioned on the time embedding 655 | if temb is not None: 656 | h += self.Dense_0(self.act(temb))[:, :, None, None] 657 | h = self.act(self.GroupNorm_1(h)) 658 | h = self.Dropout_0(h) 659 | h = self.Conv_1(h) 660 | if C != out_ch: 661 | if self.conv_shortcut: 662 | x = self.Conv_2(x) 663 | else: 664 | x = self.NIN_0(x) 665 | return x + h --------------------------------------------------------------------------------