├── dataset ├── __init__.py ├── config.py ├── prepare_saraga.py └── saraga.py ├── requirements.txt ├── ckpt └── README.md ├── utils ├── separation_eval.py ├── noam_schedule.py ├── phase_vocoder.py └── signal_processing.py ├── .gitignore ├── model ├── clustering.py ├── unet_utils.py ├── config.py ├── vad.py ├── __init__.py ├── unet.py └── estnoise_ms.py ├── LICENSE ├── config.py ├── README.md ├── separate.py └── train.py /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .saraga import SARAGA -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | librosa==0.7.0 2 | matplotlib==3.3.1 3 | numba==0.48 # for preventing librosa conflict 4 | numpy==1.19.1 5 | pydub==0.23.1 6 | tensorflow>=2.1.0 7 | tqdm==4.48.2 -------------------------------------------------------------------------------- /ckpt/README.md: -------------------------------------------------------------------------------- 1 | ## Model weights 2 | Download and store the model weights here. The folder structure, for a model named `saraga-8`, should look like that: 3 | * `.ckpt/saraga-8/saraga-8.json`: config file for `saraga-8` model. 4 | * `.ckpt/saraga-8/saraga-8/`: folder containing the checkpoint for the `saraga-8` model. -------------------------------------------------------------------------------- /utils/separation_eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def GlobalSDR(references, separations): 4 | """ Global SDR """ 5 | delta = 1e-7 # avoid numerical errors 6 | num = np.sum(np.square(references), axis=(1, 2)) 7 | den = np.sum(np.square(references - separations), axis=(1, 2)) 8 | num += delta 9 | den += delta 10 | return 10 * np.log10(num / den) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # python 2 | __pycache__ 3 | ckpt/__pycache__ 4 | dataset/__pycache__ 5 | model/__pycache__ 6 | utils/__pycache__ 7 | 8 | # mac osX 9 | .DS_Store 10 | ckpt/.DS_Store 11 | dataset/.DS_Store 12 | model/.DS_Store 13 | utils/.DS_Store 14 | 15 | # train 16 | ckpt/saraga-8/ 17 | ckpt/saraga-8.json 18 | log 19 | 20 | # test sample 21 | sample 22 | output 23 | 24 | # testing stuff for paper 25 | testing_files/ 26 | 27 | # wip for faster, cleaner, and better execution 28 | evaluate.py -------------------------------------------------------------------------------- /model/clustering.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from sklearn.cluster import KMeans 4 | 5 | def get_mask(normalized_feat, clusters, scheduler): 6 | kmeans = KMeans(n_clusters=clusters, random_state=0).fit(normalized_feat) 7 | centers = kmeans.cluster_centers_ 8 | original_means = np.mean(centers, axis=1) 9 | ordered_means = np.sort(np.mean(centers, axis=1)) 10 | means_and_pos = {} 11 | manual_weights = np.linspace(0, 1, clusters)**scheduler 12 | for idx, j in zip(manual_weights, ordered_means): 13 | means_and_pos[j] = idx 14 | label_and_dist = [] 15 | for j in original_means: 16 | label_and_dist.append(means_and_pos[j]) 17 | weights = [] 18 | for j in kmeans.labels_: 19 | weights.append(label_and_dist[j]) 20 | return np.array(weights, dtype=np.float32) / float(clusters-1) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 YoungJoong Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/noam_schedule.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class NoamScheduler(tf.keras.optimizers.schedules.LearningRateSchedule): 5 | """Noam learning rate scheduler from Vaswani et al., 2017. 6 | """ 7 | def __init__(self, learning_rate, warmup_steps, channels): 8 | """Initializer. 9 | Args: 10 | learning_rate: float, initial learning rate. 11 | warmup_steps: int, warmup steps. 12 | channels: int, base hidden size of the model. 13 | """ 14 | super(NoamScheduler, self).__init__() 15 | self.learning_rate = learning_rate 16 | self.warmup_steps = warmup_steps 17 | self.channels = channels 18 | 19 | def __call__(self, step): 20 | """Compute learning rate. 21 | """ 22 | return self.learning_rate * self.channels ** -0.5 * \ 23 | tf.minimum(step ** -0.5, step * self.warmup_steps ** -1.5) 24 | 25 | def get_config(self): 26 | """Serialize configurations. 27 | """ 28 | return { 29 | 'learning_rate': self.learning_rate, 30 | 'warmup_steps': self.warmup_steps, 31 | 'channels': self.channels, 32 | } 33 | -------------------------------------------------------------------------------- /dataset/config.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | class Config: 4 | """Configuration for dataset construction. 5 | """ 6 | def __init__(self): 7 | # audio config 8 | self.sr = 22050 9 | 10 | # stft 11 | self.hop = 256 12 | self.win = 1024 13 | self.fft = self.win 14 | self.win_fn = 'hann' 15 | 16 | # mel-scale filter bank 17 | self.mel = 80 18 | self.fmin = 0 19 | self.fmax = 8000 20 | 21 | self.eps = 1e-5 22 | 23 | # sample size 24 | self.frames = (self.hop + 6) * 128 # 16384 25 | self.batch = 8 26 | 27 | self.eval_tracks = [ 28 | 'kailasapathe', 'ragam_tanam_pallavi', 'gopi_gopala_bala', 'ananda_sagara' 29 | ] # four manually selected tracks for validation 30 | 31 | def window_fn(self): 32 | """Return window generator. 33 | Returns: 34 | Callable, window function of tf.signal 35 | , which corresponds to self.win_fn. 36 | """ 37 | mapper = { 38 | 'hann': tf.signal.hann_window, 39 | 'hamming': tf.signal.hamming_window 40 | } 41 | if self.win_fn in mapper: 42 | return mapper[self.win_fn] 43 | 44 | raise ValueError('invalid window function: ' + self.win_fn) 45 | -------------------------------------------------------------------------------- /model/unet_utils.py: -------------------------------------------------------------------------------- 1 | # nn.py 2 | # Source: https://github.com/hojonathanho/diffusion/blob/master/ 3 | # diffusion_tf/nn.py 4 | # Tensorflow 2.4.0 5 | # Windows/MacOS/Linux 6 | # Python 3.7 7 | 8 | 9 | import math 10 | import tensorflow as tf 11 | 12 | 13 | def default_init(scale): 14 | return tf.initializers.variance_scaling( 15 | scale=1e-10 if scale == 0 else scale, 16 | mode="fan_avg", 17 | distribution="uniform") 18 | 19 | 20 | def meanflat(x): 21 | return tf.math.reduce_mean(x, axis=list(range(1, len(x.shape)))) 22 | 23 | 24 | def get_timestep_embedding(timesteps, embedding_dim): 25 | # From fairseq. Build sinusoidal embeddings. This matches the 26 | # implementation in tensor2tensor, but differs slightly from the 27 | # description in Section 3.5 of "Attention Is All You Need". 28 | assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32 29 | 30 | half_dim = embedding_dim // 2 31 | emb = math.log(10000) / (half_dim - 1) 32 | emb = tf.math.exp(tf.range(half_dim, dtype=tf.float32) * -emb) 33 | # emb = tf.range(num_embeddings, dtype=tf.float32)[:, None] * emb[None, :] 34 | emb = tf.cast(timesteps, dtype=tf.float32)[:, None] * emb[None, :] 35 | emb = tf.concat([tf.math.sin(emb), tf.math.cos(emb)], axis=1) 36 | if embedding_dim % 2 == 1: # zero pad. 37 | # emb = tf.concat([emb, tf.zeros([num_embeddings, 1])], axis=1) 38 | emb = tf.pad(emb, [[0, 0], [0, 1]]) 39 | assert emb.shape == [timesteps.shape[0], embedding_dim] 40 | return emb -------------------------------------------------------------------------------- /dataset/prepare_saraga.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import glob 5 | import tqdm 6 | 7 | import numpy as np 8 | import torch.nn.functional as F 9 | import torchaudio as T 10 | 11 | SR = 22050 12 | 13 | 14 | def main(args): 15 | concert = glob.glob(os.path.join(args.saraga_dir, '*/')) 16 | 17 | for i in tqdm(concert): 18 | songs = glob.glob(os.path.join(args.saraga_dir, i, '*/')) 19 | for j in tqdm.tqdm(songs): 20 | song_name = j.split("/")[-2] 21 | mixture = os.path.join(j, song_name + ".mp3.mp3") 22 | vocals = os.path.join(j, song_name + ".multitrack-vocal.mp3") 23 | 24 | if os.path.exists(mixture): 25 | audio_mix, sr = T.load(mixture) 26 | audio_voc, _ = T.load(vocals) 27 | resampling = T.transforms.Resample(sr, SR) 28 | audio_mix = resampling(audio_mix) 29 | audio_voc = resampling(audio_voc) 30 | audio_mix = torch.mean(audio_mix, dim=0).unsqueeze(0) 31 | audio_mix = torch.clamp(audio_mix, -1.0, 1.0) 32 | audio_voc = torch.mean(audio_voc, dim=0).unsqueeze(0) 33 | audio_voc = torch.clamp(audio_voc, -1.0, 1.0) 34 | 35 | actual_len = audio_voc.shape 36 | for trim in np.arange(actual_len[1] // (args.sample_len*SR)): 37 | T.save( 38 | os.path.join( 39 | args.output_dir, song_name.lower().replace(" ", "_") + '_' + str(trim) + '_mixture.wav'), 40 | audio_mix[:, trim*args.sample_len*SR:(trim+1)*args.sample_len*SR].cpu(), 41 | sample_rate=sr, 42 | bits_per_sample=16) 43 | T.save( 44 | os.path.join( 45 | args.output_dir, song_name.lower().replace(" ", "_") + '_' + str(trim) + '_vocals.wav'), 46 | audio_voc[:, trim*args.sample_len*SR:(trim+1)*args.sample_len*SR].cpu(), 47 | sample_rate=sr, 48 | bits_per_sample=16) 49 | else: 50 | print("no file...") 51 | 52 | if __name__ == '__main__': 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument('--saraga-dir', default=None, type=str) 55 | parser.add_argument('--output-dir', default=None, type=str) 56 | parser.add_argument('--sample-len', default=6) 57 | parser.add_argument('--gpu', default=None) 58 | args = parser.parse_args() 59 | main(args) 60 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from dataset.config import Config as UnetDataConfig 2 | from model.config import Config as UnetModelConfig 3 | from utils.noam_schedule import NoamScheduler 4 | 5 | 6 | class UnetTrainConfig: 7 | """Configuration for training loop. 8 | """ 9 | def __init__(self): 10 | # optimizer 11 | self.lr_policy = 'fixed' 12 | self.learning_rate = 0.000002 13 | # self.lr_policy = 'noam' 14 | # self.learning_rate = 1 15 | # self.lr_params = { 16 | # 'warmup_steps': 4000, 17 | # 'channels': 64 18 | # } 19 | 20 | self.beta1 = 0.9 21 | self.beta2 = 0.98 22 | self.eps = 1e-9 23 | 24 | # 13000:100 25 | self.split = 3001*8 26 | # self.split = 50 27 | self.bufsiz = 50 28 | 29 | self.epoch = 10000 30 | 31 | # path config 32 | self.log = './log' 33 | self.ckpt = './ckpt' 34 | self.sounds = './sounds' 35 | 36 | # model name 37 | self.model_type = None 38 | self.name = 'saraga-8' 39 | 40 | # interval configuration 41 | self.eval_intval = 5000 42 | self.ckpt_intval = 10000 43 | def lr(self): 44 | """Generate proper learning rate scheduler. 45 | """ 46 | mapper = { 47 | 'noam': NoamScheduler 48 | } 49 | if self.lr_policy == 'fixed': 50 | return self.learning_rate 51 | if self.lr_policy in mapper: 52 | return mapper[self.lr_policy](self.learning_rate, **self.lr_params) 53 | raise ValueError('invalid lr_policy') 54 | 55 | class Config(): 56 | """Integrated configuration. 57 | """ 58 | def __init__(self): 59 | self.data = UnetDataConfig() 60 | self.model = UnetModelConfig() 61 | self.train = UnetTrainConfig() 62 | 63 | def dump(self): 64 | """Dump configurations into serializable dictionary. 65 | """ 66 | return {k: vars(v) for k, v in vars(self).items()} 67 | 68 | @staticmethod 69 | def load(dump_): 70 | """Load dumped configurations into new configuration. 71 | """ 72 | conf = Config() 73 | for k, v in dump_.items(): 74 | if hasattr(conf, k): 75 | obj = getattr(conf, k) 76 | load_state(obj, v) 77 | return conf 78 | 79 | 80 | def load_state(obj, dump_): 81 | """Load dictionary items to attributes. 82 | """ 83 | for k, v in dump_.items(): 84 | if hasattr(obj, k): 85 | setattr(obj, k, v) 86 | return obj 87 | -------------------------------------------------------------------------------- /model/config.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | class Config: 5 | """Configuration for DiffWave implementation. 6 | """ 7 | def __init__(self): 8 | self.model_type = None 9 | 10 | self.sr = 22050 11 | 12 | self.hop = 256 13 | self.win = 1024 14 | 15 | # mel-scale filter bank 16 | self.mel = 80 17 | self.fmin = 0 18 | self.fmax = 8000 19 | 20 | self.eps = 1e-5 21 | 22 | # sample size 23 | self.frames = (self.hop + 6) * 128 # 16384 24 | self.batch = 8 25 | 26 | # leaky relu coefficient 27 | self.leak = 0.4 28 | 29 | # embdding config 30 | self.embedding_size = 128 31 | self.embedding_proj = 512 32 | self.embedding_layers = 2 33 | self.embedding_factor = 4 34 | 35 | # upsampler config 36 | self.upsample_stride = [4, 1] 37 | self.upsample_kernel = [32, 3] 38 | self.upsample_layers = 4 39 | # computed hop size 40 | # block config 41 | self.channels = 64 42 | self.kernel_size = 3 43 | self.dilation_rate = 2 44 | self.num_layers = 30 45 | self.num_cycles = 3 46 | 47 | # noise schedule 48 | self.iter = 8 # 20, 40, 50 49 | self.noise_policy = 'linear' 50 | self.noise_start = 1e-4 51 | self.noise_end = 0.5 # 0.02 for 200 52 | 53 | def beta(self): 54 | """Generate beta-sequence. 55 | Returns: 56 | List[float], [iter], beta values. 57 | """ 58 | mapper = { 59 | 'linear': self._linear_sched, 60 | } 61 | if self.noise_policy not in mapper: 62 | raise ValueError('invalid beta policy') 63 | return mapper[self.noise_policy]() 64 | 65 | def _linear_sched(self): 66 | """Linearly generated noise. 67 | Returns: 68 | List[float], [iter], beta values. 69 | """ 70 | return np.linspace( 71 | self.noise_start, self.noise_end, self.iter, dtype=np.float32) 72 | 73 | def window_fn(self): 74 | """Return window generator. 75 | Returns: 76 | Callable, window function of tf.signal 77 | , which corresponds to self.win_fn. 78 | """ 79 | mapper = { 80 | 'hann': tf.signal.hann_window, 81 | 'hamming': tf.signal.hamming_window 82 | } 83 | if self.win_fn in mapper: 84 | return mapper[self.win_fn] 85 | 86 | raise ValueError('invalid window function: ' + self.win_fn) 87 | -------------------------------------------------------------------------------- /utils/phase_vocoder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | 5 | def phase_vocoder(D, hop_len=256, rate=0.8): 6 | """Phase vocoder. Given an STFT matrix D, speed up by a factor of `rate`. 7 | Based on implementation provided by: 8 | https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html#phase_vocoder 9 | :param D: tf.complex64([num_frames, num_bins]): the STFT tensor 10 | :param hop_len: float: the hop length param of the STFT 11 | :param rate: float > 0: the speed-up factor 12 | :return: D_stretched: tf.complex64([num_frames, num_bins]): the stretched STFT tensor 13 | """ 14 | # get shape 15 | sh = tf.shape(D, name="STFT_shape") 16 | frames = sh[0] 17 | fbins = sh[1] 18 | 19 | # time steps range 20 | t = tf.range(0.0, tf.cast(frames, tf.float32), rate, dtype=tf.float32, name="time_steps") 21 | 22 | # Expected phase advance in each bin 23 | dphi = tf.linspace(0.0, np.pi * hop_len, fbins, name="dphi_expected_phase_advance") 24 | phase_acc = tf.math.angle(D[0, :], name="phase_acc_init") 25 | 26 | # Pad 0 columns to simplify boundary logic 27 | D = tf.pad(D, [(0, 2), (0, 0)], mode='CONSTANT', name="padded_STFT") 28 | 29 | # def fn(previous_output, current_input): 30 | def _pvoc_mag_and_cum_phase(previous_output, current_input): 31 | # unpack prev phase 32 | _, prev = previous_output 33 | 34 | # grab the two current columns of the STFT 35 | i = tf.cast((tf.floor(current_input) + [0, 1]), tf.int32) 36 | bcols = tf.gather_nd(D, [[i[0]], [i[1]]]) 37 | 38 | # Weighting for linear magnitude interpolation 39 | t_dif = current_input - tf.floor(current_input) 40 | bmag = (1 - t_dif) * tf.abs(bcols[0, :]) + t_dif * (tf.abs(bcols[1, :])) 41 | 42 | # Compute phase advance 43 | dp = tf.math.angle(bcols[1, :]) - tf.math.angle(bcols[0, :]) - dphi 44 | dp = dp - 2 * np.pi * tf.round(dp / (2.0 * np.pi)) 45 | 46 | # return linear mag, accumulated phase 47 | return bmag, tf.squeeze(prev + dp + dphi) 48 | 49 | # initializer of zeros of correct shape for mag, and phase_acc for phase 50 | initializer = (tf.zeros(fbins, tf.float32), phase_acc) 51 | mag, phase = tf.scan(_pvoc_mag_and_cum_phase, t, initializer=initializer, 52 | parallel_iterations=10, back_prop=False, 53 | name="pvoc_cum_phase") 54 | 55 | # add the original phase_acc in 56 | phase2 = tf.concat([tf.expand_dims(phase_acc, 0), phase], 0)[:-1, :] 57 | D_stretched = tf.cast(mag, tf.complex64) * tf.exp(1.j * tf.cast(phase2, tf.complex64), name="stretched_STFT") 58 | 59 | return D_stretched -------------------------------------------------------------------------------- /utils/signal_processing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | 5 | def get_overlap_window(signal, boundary=None): 6 | window_out = np.ones(signal.shape) 7 | midpoint = window_out.shape[0] // 2 8 | if boundary == "start": 9 | window_out[midpoint:] = np.linspace(1, 0, window_out.shape[0]-midpoint) 10 | elif boundary == "end": 11 | window_out[:midpoint] = np.linspace(0, 1, window_out.shape[0]-midpoint) 12 | else: 13 | window_out[:midpoint] = np.linspace(0, 1, window_out.shape[0]-midpoint) 14 | window_out[midpoint:] = np.linspace(1, 0, window_out.shape[0]-midpoint) 15 | return window_out 16 | 17 | 18 | def compute_stft(signal, unet_config): 19 | signal_stft = check_shape_3d( 20 | check_shape_3d( 21 | tf.signal.stft( 22 | signal, 23 | frame_length=unet_config.model.win, 24 | frame_step=unet_config.model.hop, 25 | fft_length=unet_config.model.win, 26 | window_fn=tf.signal.hann_window), 1), 2) 27 | mag = tf.abs(signal_stft) 28 | phase = tf.math.angle(signal_stft) 29 | return mag, phase 30 | 31 | 32 | def compute_signal_from_stft(spec, phase, config): 33 | polar_spec = tf.complex(tf.multiply(spec, tf.math.cos(phase)), tf.zeros(spec.shape)) + \ 34 | tf.multiply(tf.complex(spec, tf.zeros(spec.shape)), tf.complex(tf.zeros(phase.shape), tf.math.sin(phase))) 35 | return tf.signal.inverse_stft( 36 | polar_spec, 37 | frame_length=config.model.win, 38 | frame_step=config.model.hop, 39 | window_fn=tf.signal.inverse_stft_window_fn( 40 | config.model.hop, 41 | forward_window_fn=tf.signal.hann_window)) 42 | 43 | 44 | def log2(x, base): 45 | return int(np.log(x) / np.log(base)) 46 | 47 | 48 | def next_power_of_2(n): 49 | # decrement `n` (to handle the case when `n` itself is a power of 2) 50 | n = n - 1 51 | # calculate the position of the last set bit of `n` 52 | lg = log2(n, 2) 53 | # next power of two will have a bit set at position `lg+1`. 54 | return 1 << lg #+ 1 55 | 56 | 57 | def check_shape_3d(data, dim): 58 | n = data.shape[dim] 59 | if n % 2 != 0: 60 | n = data.shape[dim] - 1 61 | if dim==0: 62 | return data[:n, :, :] 63 | if dim==1: 64 | return data[:, :n, :] 65 | if dim==2: 66 | return data[:, :, :n] 67 | 68 | 69 | def load_audio(paths): 70 | mixture = tf.io.read_file(paths[0]) 71 | vocals = tf.io.read_file(paths[1]) 72 | mixture_audio, _ = tf.audio.decode_wav(mixture, desired_channels=1) 73 | vocal_audio, _ = tf.audio.decode_wav(vocals, desired_channels=1) 74 | return tf.squeeze(mixture_audio, axis=-1), tf.squeeze(vocal_audio, axis=-1) 75 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Carnatic singing voice separation trained with in-domain data with leakage 2 | This is the official repository for: 3 | 4 | - Carnatic Singing Voice Separation Using Cold Diffusion on Training Data with Bleeding, G. Plaja-Roglans, M. Miron, A. Shankar and X. Serra, 2023 (accepted for presentation at ISMIR 2023, Milan, Italy). 5 | 6 | **IMPORTANT NOTE:** The code structure and an important part of the data loader and training code is an adaptation of the unofficial Tensorflow implementation of DiffWave (Zhifeng Kong et al., 2020). [Link to original repo](https://github.com/revsic/tf-diffwave). 7 | 8 | **ANOTHER IMPORTANT NOTE:** The model in this repo can also be used for hassle-free inference through the Python library [compIAM](https://github.com/MTG/compIAM), a centralized repository of tools, models, and datasets for the computational analysis of Carnatic and Hindustani Music. With a few commands, you can easily download and run the separation model. Refer to `compIAM` to use these model (and many others!) out-of-the-box. `compIAM` is installed with `pip install compiam``, make sure to install `v0.3.0` to use the separation model in this repo. 9 | 10 | ## Requirements 11 | 12 | The repository is based on Tensorflow 2. [See a complete list of requirements here](./requirements.txt). 13 | 14 | ## Run separation inference 15 | 16 | To run separation inference you can use `separate.py` file. 17 | 18 | ```bash 19 | python3 separate.py --input-signal /path/to/file.wav --clusters 5 --scheduler 4 20 | ``` 21 | 22 | Additional arguments can be passed to use a different model (`--model-name`), modify the batch size (i.e. chunk size processed by the model for optimized inference, `--batch-size`), and also specify to which GPU the process should be routed (`--gpu`). 23 | 24 | ## Train the model 25 | 26 | To train your own model, you should first prepare the data. See [how we do process Saraga](./dataset/prepare_saraga.py) before the training process detailed in the paper. The key idea is to have the chunked and aligned audio samples of the dataset with a naming like: `_.wav`, where `` corresponds to `mixture` and `vocals`. 27 | 28 | Then, run model training in [train.py](./train.py). Checkpoints will be stored every X training steps, X is defined by user in the [config.py](./config.py) file. 29 | 30 | To start to train from previous checkpoint, `--load-step` is available. 31 | 32 | ```bash 33 | python3 train.py --load-step 416 --config ./ckpt/.json 34 | ``` 35 | 36 | Download the pre-trained weights for the feature extraction U-Net [here](https://drive.google.com/uc?export=download&id=1yj9iHTY7nCh2qrIM2RIUOXhLXt1K8WcE). 37 | 38 | Unzip and store the weights into the [ckpt folder](./ckpt/). There should be .json file with the configuration, and a folder with the model weight checkpoint inside. Here's an example: 39 | 40 | ```py 41 | with open('./ckpt/saraga-8.json') as f: 42 | config = Config.load(json.load(f)) 43 | 44 | diffwave = DiffWave(config.model) 45 | diffwave.restore('./ckpt/saraga-8/saraga-8.ckpt-1').expect_partial() 46 | ``` 47 | 48 | [Write us](mailto:genis.plaja@upf.edu) or open an issue if you have any issues or questions! 49 | -------------------------------------------------------------------------------- /model/vad.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue May 1 20:43:28 2018 4 | @author: eesungkim 5 | """ 6 | 7 | import math 8 | import numpy as np 9 | from model.estnoise_ms import * 10 | 11 | 12 | def VAD(signal, sr, nFFT=512, win_length=0.025, hop_length=0.01, theshold=0.7): 13 | """Voice Activity Detector 14 | Parameters 15 | ---------- 16 | signal : audio time series 17 | sr : sampling rate of `signal` 18 | nFFT : length of the FFT window 19 | win_length : window size in sec 20 | hop_length : hop size in sec 21 | 22 | Returns 23 | ------- 24 | probRatio : frame-based voice activity probability sequence 25 | """ 26 | signal=signal.astype('float') 27 | 28 | maxPosteriorSNR= 1000 29 | minPosteriorSNR= 0.0001 30 | 31 | win_length_sample = round(win_length*sr) 32 | hop_length_sample = round(hop_length*sr) 33 | 34 | # the variance of the speech; lambda_x(k) 35 | _stft = stft(signal, n_fft=nFFT, win_length=win_length_sample, hop_length=hop_length_sample) 36 | pSpectrum = np.abs(_stft) ** 2 37 | 38 | # estimate the variance of the noise using minimum statistics noise PSD estimation ; lambda_d(k). 39 | estNoise = estnoisem(pSpectrum,hop_length) 40 | estNoise = estNoise 41 | 42 | aPosterioriSNR=pSpectrum/estNoise 43 | aPosterioriSNR=aPosterioriSNR 44 | aPosterioriSNR[aPosterioriSNR > maxPosteriorSNR] = maxPosteriorSNR 45 | aPosterioriSNR[aPosterioriSNR < minPosteriorSNR] = minPosteriorSNR 46 | 47 | a01=hop_length/0.05 # a01=P(signallence->speech) hop_length/mean signallence length (50 ms) 48 | a00=1-a01 # a00=P(signallence->signallence) 49 | a10=hop_length/0.1 # a10=P(speech->signallence) hop/mean talkspurt length (100 ms) 50 | a11=1-a10 # a11=P(speech->speech) 51 | 52 | b01=a01/a00 53 | b10=a11-a10*a01/a00 54 | 55 | smoothFactorDD=0.99 56 | previousGainedaPosSNR=1 57 | (nFrames,nFFT2) = pSpectrum.shape 58 | probRatio=np.zeros((nFrames,1)) 59 | logGamma_frame=0 60 | for i in range(nFrames): 61 | aPosterioriSNR_frame = aPosterioriSNR[i,:] 62 | 63 | #operator [2](52) 64 | oper=aPosterioriSNR_frame-1 65 | oper[oper < 0] = 0 66 | smoothed_a_priori_SNR = smoothFactorDD * previousGainedaPosSNR + (1-smoothFactorDD) * oper 67 | 68 | #V for MMSE estimate ([2](8)) 69 | V=0.1*smoothed_a_priori_SNR*aPosterioriSNR_frame/(1+smoothed_a_priori_SNR) 70 | 71 | #geometric mean of log likelihood ratios for individual frequency band [1](4) 72 | logLRforFreqBins=2*V-np.log(smoothed_a_priori_SNR+1) 73 | # logLRforFreqBins=np.exp(smoothed_a_priori_SNR*aPosterioriSNR_frame/(1+smoothed_a_priori_SNR))/(1+smoothed_a_priori_SNR) 74 | gMeanLogLRT=np.mean(logLRforFreqBins) 75 | logGamma_frame=np.log(a10/a01) + gMeanLogLRT + np.log(b01+b10/( a10+a00*np.exp(-logGamma_frame) ) ) 76 | probRatio[i]=1/(1+np.exp(-logGamma_frame)) 77 | 78 | #Calculate Gain function which results from the MMSE [2](7). 79 | gain = (math.gamma(1.5) * np.sqrt(V)) / aPosterioriSNR_frame * np.exp(-1 * V / 2) * ((1 + V) * bessel(0, V / 2) + V * bessel(1, V / 2)) 80 | 81 | previousGainedaPosSNR = (gain**2) * aPosterioriSNR_frame 82 | probRatio[probRatio>theshold]=1 83 | probRatio[probRatio0: 171 | if np.max(tf.squeeze(vocals, axis=0).numpy())>0: 172 | mix_mag, _ = self.compute_stft(mixture) 173 | _, voc_phase = self.compute_stft(vocals) 174 | 175 | pred = self.model(mix_mag) 176 | pred = self.compute_signal_from_stft(pred, voc_phase) 177 | mixture = mixture[:, :pred.shape[1]] 178 | vocals = vocals[:, :pred.shape[1]] 179 | pred = tf.transpose(pred, [1, 0]).numpy() 180 | vocals = tf.transpose(vocals, [1, 0]).numpy() 181 | 182 | ref = np.array([vocals]) 183 | est = np.array([pred]) 184 | 185 | scores = GlobalSDR(ref, est) 186 | voc_sdr.append(scores[0]) 187 | 188 | print('Median SDR:', np.median(voc_sdr)) 189 | print('Best model:', best_SDR) 190 | if np.median(voc_sdr) > best_SDR: 191 | print('Saving best new model with SDR:', np.median(voc_sdr)) 192 | self.model.write('{}_BEST-SDR.ckpt'.format(self.ckpt_path),self.optim) 193 | best_SDR = np.median(voc_sdr) 194 | return best_SDR, best_step 195 | 196 | def compute_stft(self, signal): 197 | signal_stft = check_shape_3d( 198 | check_shape_3d( 199 | tf.signal.stft( 200 | signal, 201 | frame_length=self.config.model.win, 202 | frame_step=self.config.model.hop, 203 | fft_length=self.config.model.win, 204 | window_fn=tf.signal.hann_window), 1), 2) 205 | mag = tf.abs(signal_stft) 206 | phase = tf.math.angle(signal_stft) 207 | return mag, phase 208 | 209 | def compute_signal_from_stft(self, spec, phase): 210 | polar_spec = tf.complex(tf.multiply(spec, tf.math.cos(phase)), tf.zeros(spec.shape)) + \ 211 | tf.multiply(tf.complex(spec, tf.zeros(spec.shape)), tf.complex(tf.zeros(phase.shape), tf.math.sin(phase))) 212 | return tf.signal.inverse_stft( 213 | polar_spec, 214 | frame_length=self.config.model.win, 215 | frame_step=self.config.model.hop, 216 | window_fn=tf.signal.inverse_stft_window_fn( 217 | self.config.model.hop, 218 | forward_window_fn=tf.signal.hann_window)) 219 | 220 | @staticmethod 221 | def load_audio(paths): 222 | mixture = tf.io.read_file(paths[0]) 223 | vocals = tf.io.read_file(paths[1]) 224 | mixture_audio, _ = tf.audio.decode_wav(mixture, desired_channels=1) 225 | vocal_audio, _ = tf.audio.decode_wav(vocals, desired_channels=1) 226 | return tf.squeeze(mixture_audio, axis=-1), tf.squeeze(vocal_audio, axis=-1) 227 | 228 | 229 | if __name__ == '__main__': 230 | parser = argparse.ArgumentParser() 231 | parser.add_argument('--config', default=None) 232 | parser.add_argument('--load-step', default=0, type=int) 233 | parser.add_argument('--data-dir', default=None) 234 | parser.add_argument('--gpu', default=None) 235 | args = parser.parse_args() 236 | 237 | # Activate CUDA if GPU id is given 238 | if args.gpu is not None: 239 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 240 | else: 241 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 242 | 243 | config = UnetConfig() 244 | 245 | if args.config is not None: 246 | print('[*] load config: ' + args.config) 247 | with open(args.config) as f: 248 | config = UnetConfig.load(json.load(f)) 249 | 250 | log_path = os.path.join(config.train.log, config.train.name) 251 | if not os.path.exists(log_path): 252 | os.makedirs(log_path) 253 | 254 | ckpt_path = os.path.join(config.train.ckpt, config.train.name) 255 | if not os.path.exists(ckpt_path): 256 | os.makedirs(ckpt_path) 257 | 258 | sounds_path = os.path.join(config.train.sounds, config.train.name) 259 | if not os.path.exists(sounds_path): 260 | os.makedirs(sounds_path) 261 | 262 | saraga = SARAGA(config.data, data_dir=args.data_dir) 263 | diffwave = DiffWave(config.model) 264 | trainer = Trainer(diffwave, saraga, config, data_dir=args.data_dir) 265 | 266 | if args.load_step > 0: 267 | super_path = os.path.join(config.train.ckpt, config.train.name) 268 | ckpt_path = os.path.join(super_path, '{}.ckpt-1'.format(config.train.name)) 269 | print('[*] load checkpoint: ' + ckpt_path) 270 | trainer.model.restore(ckpt_path, trainer.optim) 271 | print("Loaded!") 272 | 273 | with open(os.path.join(config.train.ckpt, config.train.name + '.json'), 'w') as f: 274 | json.dump(config.dump(), f) 275 | 276 | trainer.train(args.load_step) 277 | -------------------------------------------------------------------------------- /model/unet.py: -------------------------------------------------------------------------------- 1 | # unet.py 2 | # Source: https://github.com/hojonathanho/diffusion/blob/master/ 3 | # diffusion_tf/models/unet.py 4 | # Tensorflow 2.4.0 5 | # Windows/MacOS/Linux 6 | # Python 3.7 7 | 8 | 9 | from .unet_utils import get_timestep_embedding 10 | 11 | import tensorflow as tf 12 | import keras.backend as K 13 | import tensorflow_addons as tfa 14 | from tensorflow.keras import layers, models 15 | 16 | 17 | class TimestepEmbedding(layers.Layer): 18 | def __init__(self, dim): 19 | super(TimestepEmbedding, self).__init__() 20 | self.dim = dim 21 | 22 | def call(self, t): 23 | return get_timestep_embedding(t, self.dim) 24 | 25 | 26 | class Upsample(layers.Layer): 27 | def __init__(self, channels, with_conv=True): 28 | super(Upsample, self).__init__() 29 | self.channels = channels 30 | self.with_conv = with_conv 31 | self.conv = layers.Conv2DTranspose(self.channels, (3, 3), padding="same", strides=2) 32 | 33 | def call(self, inputs): 34 | batch_size, height, width, _ = inputs.shape 35 | x = self.conv(inputs) 36 | assert x.shape == [batch_size, height * 2, width * 2, self.channels] 37 | return x 38 | 39 | 40 | class Downsample(layers.Layer): 41 | def __init__(self, channels, with_conv=True): 42 | super(Downsample, self).__init__() 43 | self.with_conv = with_conv 44 | self.channels = channels 45 | self.conv = layers.Conv2D(self.channels, (3, 3), padding="same", strides=2) 46 | self.avg_pool = layers.AveragePooling2D(strides=2, padding="same") 47 | 48 | def call(self, inputs): 49 | batch_size, height, width, _ = inputs.shape 50 | if self.with_conv: 51 | x = self.conv(inputs) 52 | else: 53 | x = self.avg_pool(inputs) 54 | assert x.shape == [batch_size, height // 2, width // 2, self.channels] 55 | return x 56 | 57 | 58 | class ResNetBlock(layers.Layer): 59 | def __init__(self, in_ch, cond_track=None, out_ch=None, conv_shortcut=False, dropout=0.): 60 | super(ResNetBlock, self).__init__() 61 | self.in_ch = in_ch 62 | self.cond_track = cond_track 63 | self.out_ch = out_ch 64 | self.conv_shortcut = conv_shortcut 65 | self.dropout = dropout 66 | 67 | if self.out_ch is None: 68 | self.out_ch = self.in_ch 69 | self.c_not_out_ch = self.in_ch != self.out_ch 70 | 71 | # Layers. 72 | self.group_norm1 = tf.keras.layers.BatchNormalization() 73 | self.non_linear1 = layers.Activation("swish") 74 | self.conv1 = layers.Conv2D(self.out_ch, (3,3), padding="same") 75 | 76 | self.non_linear2 = layers.Activation("swish") 77 | self.dense2 = layers.Dense(self.out_ch) 78 | 79 | self.group_norm3 = tf.keras.layers.BatchNormalization() 80 | self.non_linear3 = layers.Activation("swish") 81 | self.dropout3 = layers.Dropout(self.dropout) 82 | self.conv3 = layers.Conv2D(self.out_ch, (3, 3), padding="same") 83 | if self.cond_track is not None: 84 | self.downsample_cond = layers.Conv2D(self.out_ch, (3, 3), padding="same", strides=2*self.cond_track) 85 | self.proj_cond = tf.keras.layers.Conv1D(self.out_ch, 1, padding="same") 86 | 87 | self.conv4 = layers.Conv2D(self.out_ch, (3, 3), padding="same") 88 | self.dense4 = layers.Dense(self.out_ch) 89 | 90 | 91 | def call(self, inputs, temb, cond=None): 92 | x = inputs 93 | 94 | x = self.group_norm1(x) 95 | x = self.non_linear1(x) 96 | x = self.conv1(x) 97 | 98 | # Add in timestep embedding. 99 | x += self.dense2(self.non_linear2(temb))[:, None, None, :] 100 | 101 | x = self.group_norm3(x) 102 | x = self.non_linear3(x) 103 | x = self.dropout3(x) 104 | 105 | if cond is not None: 106 | if self.cond_track is not None: 107 | cond = self.downsample_cond(cond) 108 | x = self.conv3(x) + self.proj_cond(cond) 109 | else: 110 | x = self.conv3(x) 111 | 112 | if self.c_not_out_ch: 113 | if self.conv_shortcut: 114 | inputs = self.conv4(inputs) 115 | else: 116 | inputs = self.dense4(inputs) 117 | assert x.shape == inputs.shape 118 | return inputs + x 119 | 120 | 121 | class AttentionBlock(layers.Layer): 122 | def __init__(self, channels): 123 | super(AttentionBlock, self).__init__() 124 | self.channels = channels 125 | 126 | self.avg_pool = tf.keras.layers.Lambda(lambda x: K.mean(x, axis=3, keepdims=True)) 127 | self.max_pool = tf.keras.layers.Lambda(lambda x: K.max(x, axis=3, keepdims=True)) 128 | self.cbam_feature = tf.keras.layers.Conv2D( 129 | filters=1, 130 | kernel_size=(3, 3), 131 | strides=1, 132 | padding='same', 133 | activation='sigmoid', 134 | kernel_initializer='he_normal', 135 | use_bias=False) 136 | 137 | def call(self, inputs): 138 | x = inputs 139 | avg_pool = self.avg_pool(x) 140 | max_pool = self.max_pool(x) 141 | cbam = tf.keras.layers.Concatenate(axis=3)([avg_pool, max_pool]) 142 | cbam = self.cbam_feature(cbam) 143 | return tf.keras.layers.multiply([inputs, cbam]) 144 | 145 | 146 | class UNet(models.Model): 147 | def __init__(self, config, num_res_blocks=2, attn_resolutions=(8, 16, 32), channels=16, 148 | ch_mult=(1, 2, 4, 8, 16, 32, 64), dropout=0.2, resample_with_conv=False): 149 | super(UNet, self).__init__() 150 | self.config = config 151 | self.num_res_blocks = num_res_blocks 152 | self.attn_resolutions = attn_resolutions 153 | self.channels = channels 154 | self.ch_mult = ch_mult 155 | self.dropout = dropout 156 | self.resample_with_conv = resample_with_conv 157 | self.num_resolutions = len(self.ch_mult) 158 | 159 | self.in_embed = [ 160 | TimestepEmbedding(self.channels), 161 | layers.Dense(self.channels*4), 162 | layers.Activation("swish"), 163 | layers.Dense(self.channels*4)] 164 | 165 | self.upsample_cond = [ 166 | tf.keras.layers.Conv2DTranspose(self.channels*2, [32, 3], [1, 1], padding='same') for _ in range(1)] 167 | # mel-downsampler 168 | self.downsample_cond = [ 169 | tf.keras.layers.Conv2D(1, [16, 3], [1, 8], padding='same') for _ in range(2)] 170 | 171 | # Downsampling. 172 | self.pre_process = layers.Conv2D(self.channels, (3, 3), padding="same") 173 | self.downsampling = [] 174 | cond_track = 1 175 | input_track = self.channels 176 | channel_track = self.channels 177 | for i_level in range(len(ch_mult)): 178 | downsampling_block = [] 179 | # Residual blocks for this resolution. 180 | for _ in range(self.num_res_blocks): 181 | if input_track in self.attn_resolutions: 182 | downsampling_block.append( 183 | ResNetBlock( 184 | in_ch=channel_track, 185 | cond_track=cond_track, 186 | out_ch=self.channels*self.ch_mult[i_level], 187 | dropout=self.dropout) 188 | ) 189 | else: 190 | downsampling_block.append( 191 | ResNetBlock( 192 | in_ch=channel_track, 193 | cond_track=cond_track, 194 | out_ch=self.channels*self.ch_mult[i_level], 195 | dropout=self.dropout)) 196 | if i_level != self.num_resolutions-1: 197 | downsampling_block.append( 198 | Downsample( 199 | channels=self.channels*self.ch_mult[i_level], 200 | with_conv=self.resample_with_conv)) 201 | cond_track *= 2 202 | input_track //= 2 203 | channel_track = self.channels*self.ch_mult[i_level] 204 | self.downsampling.append(downsampling_block) 205 | 206 | # Middle. 207 | self.middle = [ 208 | ResNetBlock(in_ch=channel_track, dropout=self.dropout), 209 | ResNetBlock(in_ch=channel_track, dropout=self.dropout) 210 | ] 211 | 212 | # Upsampling. 213 | self.upsampling = [] 214 | channel_track = self.channels*self.ch_mult[-1]*2 215 | for i_level in reversed(range(self.num_resolutions)): 216 | upsampling_block = [] 217 | # Residual blocks for this resolution. 218 | for _ in range(self.num_res_blocks + 1): 219 | if input_track in self.attn_resolutions: 220 | upsampling_block.append( 221 | ResNetBlock( 222 | in_ch=channel_track, 223 | cond_track=cond_track, 224 | out_ch=self.channels*self.ch_mult[i_level], 225 | dropout=0.2) 226 | ) 227 | else: 228 | upsampling_block.append( 229 | ResNetBlock( 230 | in_ch=channel_track, 231 | cond_track=cond_track, 232 | out_ch=self.channels*self.ch_mult[i_level], 233 | dropout=0.2)) 234 | # Upsample. 235 | if i_level != 0: 236 | upsampling_block.append( 237 | Upsample( 238 | channels=self.channels*self.ch_mult[i_level], 239 | with_conv=self.resample_with_conv)) 240 | cond_track //= 2 241 | input_track *= 2 242 | channel_track = self.channels*self.ch_mult[i_level] 243 | self.upsampling.append(upsampling_block) 244 | 245 | # End. 246 | self.end = [ 247 | layers.Conv2D(self.channels, (3, 3), padding="same"), 248 | layers.Conv2D(1, (3, 3), (1, 1), padding='same') 249 | ] 250 | 251 | 252 | def call(self, inputs, temb, cond=None): 253 | 254 | x = inputs[..., None] 255 | 256 | if cond is not None: 257 | cond = self.vectorize_layer(cond) 258 | cond = self.word_embedding(cond) 259 | if len(cond.shape) < 3: 260 | cond = cond[None] 261 | cond = tf.transpose(cond[..., None], [0, 2, 1, 3]) 262 | 263 | for upsample in self.upsample_cond: 264 | cond = tf.nn.leaky_relu(upsample(cond), 0.4) 265 | cond = tf.transpose(cond, [0, 3, 1, 2]) 266 | for downsample in self.downsample_cond: 267 | cond = tf.nn.leaky_relu(downsample(cond), 0.4) 268 | cond = tf.transpose(cond, [0, 2, 1, 3]) 269 | 270 | for lay in self.in_embed: 271 | temb = lay(temb) 272 | # Downsampling. 273 | hs = [self.pre_process(x)] 274 | for block in self.downsampling: 275 | for idx_block in range(self.num_res_blocks): 276 | if isinstance(block[idx_block], list): 277 | if cond is not None: 278 | h = block[idx_block][0](hs[-1], temb, cond) 279 | else: 280 | h = block[idx_block][0](hs[-1], temb) 281 | h = block[idx_block][1](h) 282 | hs.append(h) 283 | else: 284 | if cond is not None: 285 | h = block[idx_block](hs[-1], temb, cond) 286 | else: 287 | h = block[idx_block](hs[-1], temb) 288 | hs.append(h) 289 | if len(block) > self.num_res_blocks: 290 | for extra_lay in block[self.num_res_blocks:]: 291 | hs.append(extra_lay(hs[-1])) 292 | 293 | # Middle. 294 | h = hs[-1] 295 | for _, lay in enumerate(self.middle): 296 | h = lay(h, temb) 297 | 298 | # Upsampling. 299 | for block in self.upsampling: 300 | # Residual blocks for this resolution. 301 | for idx_block in range(self.num_res_blocks+1): 302 | if isinstance(block[idx_block], list): 303 | if cond is not None: 304 | h = block[idx_block][0](tf.concat([h, hs.pop()], axis=-1), temb, cond) 305 | else: 306 | h = block[idx_block][0](tf.concat([h, hs.pop()], axis=-1), temb) 307 | h = block[idx_block][1](h) 308 | else: 309 | if cond is not None: 310 | h = block[idx_block](tf.concat([h, hs.pop()], axis=-1), temb, cond) 311 | else: 312 | h = block[idx_block](tf.concat([h, hs.pop()], axis=-1), temb) 313 | # Upsample. 314 | if len(block) > self.num_res_blocks+1: 315 | for extra_lay in block[self.num_res_blocks+1:]: 316 | h = extra_lay(h) 317 | 318 | # End. 319 | for lay in self.end: 320 | h = lay(h) 321 | 322 | h = tf.keras.activations.sigmoid(h) 323 | h = tf.squeeze(h, axis=-1) 324 | 325 | return tf.multiply(inputs, h) 326 | 327 | -------------------------------------------------------------------------------- /model/estnoise_ms.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue May 1 20:43:28 2018 4 | @author: eesungkim 5 | """ 6 | 7 | import numpy as np 8 | from scipy.special import jv 9 | 10 | def bessel(v, X): 11 | return ((1j**(-v))*jv(v,1j*X)).real 12 | 13 | def stft(x, n_fft=512, win_length=400, hop_length=160, window='hamming'): 14 | if window == 'hanning': 15 | window = np.hanning(win_length) 16 | elif window == 'hamming': 17 | window = np.hamming(win_length) 18 | elif window == 'rectangle': 19 | window = np.ones(win_length) 20 | return np.array([np.fft.rfft(window*x[i:i+win_length],n_fft,axis=0) for i in range(0, len(x)-win_length, hop_length)]) 21 | 22 | def estnoisem(pSpectrum,hop_length): 23 | """ 24 | This is python implementation of [1],[2], and [3]. 25 | 26 | Refs: 27 | [1] Rainer Martin. 28 | Noise power spectral density estimation based on optimal smoothing and minimum statistics. 29 | IEEE Trans. Speech and Audio Processing, 9(5):504-512, July 2001. 30 | [2] Rainer Martin. 31 | Bias compensation methods for minimum statistics noise power spectral density estimation 32 | Signal Processing, 2006, 86, 1215-1229 33 | [3] Dirk Mauler and Rainer Martin 34 | Noise power spectral density estimation on highly correlated data 35 | Proc IWAENC, 2006 36 | 37 | Copyright (C) Mike Brookes 2008 38 | Version: $Id: estnoisem.m 1718 2012-03-31 16:40:41Z dmb $ 39 | 40 | VOICEBOX is a MATLAB toolbox for speech processing. 41 | Home page: http://www.ee.ic.ac.uk/hp/staff/dmb/voicebox/voicebox.html 42 | """ 43 | 44 | (nFrames,nFFT2)=np.shape(pSpectrum) # number of frames and freq bins 45 | x=np.array(np.zeros((nFrames,nFFT2)) ) # initialize output arrays 46 | xs=np.array(np.zeros((nFrames,nFFT2)) ) # will hold std error in the future 47 | 48 | # default algorithm constants 49 | taca= 0.0449 # smoothing time constant for alpha_c = -hop_length/log(0.7) in equ (11) 50 | tamax= 0.392 # max smoothing time constant in (3) = -hop_length/log(0.96) 51 | taminh= 0.0133 # min smoothing time constant (upper limit) in (3) = -hop_length/log(0.3) 52 | tpfall= 0.064 # time constant for P to fall (12) 53 | tbmax= 0.0717 # max smoothing time constant in (20) = -hop_length/log(0.8) 54 | qeqmin= 2.0 # minimum value of Qeq (23) 55 | qeqmax= 14.0 # max value of Qeq per frame 56 | av= 2.12 # fudge factor for bc calculation (23 + 13 lines) 57 | td= 1.536 # time to take minimum over 58 | nu= 8 # number of subwindows 59 | qith= np.array([0.03, 0.05, 0.06, np.Inf],dtype=float) # noise slope thresholds in dB/s 60 | nsmdb= np.array([47, 31.4, 15.7, 4.1],dtype=float) # maximum permitted +ve noise slope in dB/s 61 | 62 | 63 | # derived algorithm constants 64 | aca=np.exp(-hop_length/taca) # smoothing constant for alpha_c in equ (11) = 0.7 65 | acmax=aca # min value of alpha_c = 0.7 in equ (11) also = 0.7 66 | amax=np.exp(-hop_length/tamax) # max smoothing constant in (3) = 0.96 67 | aminh=np.exp(-hop_length/taminh) # min smoothing constant (upper limit) in (3) = 0.3 68 | bmax=np.exp(-hop_length/tbmax) # max smoothing constant in (20) = 0.8 69 | SNRexp = -hop_length/tpfall 70 | nv=round(td/(hop_length*nu)) # length of each subwindow in frames 71 | 72 | 73 | if nv<4: # algorithm doesn't work for miniscule frames 74 | nv=4 75 | nu=round(td/(hop_length*nv)) 76 | nd=nu*nv # length of total window in frames 77 | (md,hd,dd) = mhvals(nd) # calculate the constants M(D) and H(D) from Table III 78 | (mv,hv,dv) = mhvals(nv) # calculate the constants M(D) and H(D) from Table III 79 | nsms=np.array([10])**(nsmdb*nv*hop_length/10) # [8 4 2 1.2] in paper 80 | qeqimax=1/qeqmin # maximum value of Qeq inverse (23) 81 | qeqimin=1/qeqmax # minumum value of Qeq per frame inverse 82 | 83 | 84 | p=pSpectrum[0,:] # smoothed power spectrum 85 | ac=1 # correction factor (9) 86 | sn2=p # estimated noise power 87 | pb=p # smoothed noisy speech power (20) 88 | pb2=pb**2 89 | pminu=p 90 | actmin=np.array(np.ones(nFFT2) * np.Inf) # Running minimum estimate 91 | actminsub=np.array(np.ones(nFFT2) * np.Inf) # sub-window minimum estimate 92 | subwc=nv # force a buffer switch on first loop 93 | actbuf=np.array(np.ones((nu,nFFT2)) * np.Inf) # buffer to store subwindow minima 94 | ibuf=0 95 | lminflag=np.zeros(nFFT2) # flag to remember local minimum 96 | 97 | # loop for each frame 98 | for t in range(0,nFrames): # we use t instead of lambda in the paper 99 | pSpectrum_t=pSpectrum[t,:] # noise speech power spectrum 100 | acb=(1+(sum(p) / sum(pSpectrum_t)-1)**2)**(-1) # alpha_c-bar(t) (9) 101 | 102 | tmp=np.array([acb] ) 103 | tmp[tmp < acmax] = acmax 104 | #max_complex(np.array([acb] ),np.array([acmax] )) 105 | 106 | ac=aca*ac+(1-aca)*tmp # alpha_c(t) (10) 107 | 108 | ah=amax*ac*(1+(p/sn2-1)**2)**(-1) # alpha_hat: smoothing factor per frequency (11) 109 | SNR=sum(p)/sum(sn2) 110 | 111 | 112 | ah=max_complex(ah,min_complex(np.array([aminh] ),np.array([SNR**SNRexp] ))) # lower limit for alpha_hat (12) 113 | 114 | p=ah*p+(1-ah)*pSpectrum_t # smoothed noisy speech power (3) 115 | 116 | b=min_complex(ah**2,np.array([bmax] )) # smoothing constant for estimating periodogram variance (22 + 2 lines) 117 | pb=b*pb + (1-b)*p # smoothed periodogram (20) 118 | pb2=b*pb2 + (1-b)*p**2 # smoothed periodogram squared (21) 119 | 120 | qeqi=max_complex(min_complex((pb2-pb**2)/(2*sn2**2),np.array([qeqimax] )),np.array([qeqimin/(t+1)] )) # Qeq inverse (23) 121 | qiav=sum(qeqi)/nFFT2 # Average over all frequencies (23+12 lines) (ignore non-duplication of DC and nyquist terms) 122 | bc=1+av*np.sqrt(qiav) # bias correction factor (23+11 lines) 123 | bmind=1+2*(nd-1)*(1-md)/(qeqi**(-1)-2*md) # we use the signalmplified form (17) instead of (15) 124 | bminv=1+2*(nv-1)*(1-mv)/(qeqi**(-1)-2*mv) # same expressignalon but for sub windows 125 | kmod=(bc*p*bmind) < actmin # Frequency mask for new minimum 126 | 127 | if any(kmod): 128 | actmin[kmod]=bc*p[kmod]*bmind[kmod] 129 | actminsub[kmod]=bc*p[kmod]*bminv[kmod] 130 | 131 | if subwc>1 and subwc=nv: # end of buffer - do a buffer switch 137 | ibuf=1+(ibuf%nu) # increment actbuf storage pointer 138 | actbuf[ibuf-1,:]=actmin.copy() # save sub-window minimum 139 | pminu=min_complex_mat(actbuf) 140 | i=np.nonzero(np.array(qiav )pminu) 143 | if any(lmin): 144 | pminu[lmin]=actminsub[lmin] 145 | actbuf[:,lmin]= np.ones((nu,1)) * pminu[lmin] 146 | lminflag[:]=0 147 | actmin[:]=np.Inf 148 | subwc=0 149 | 150 | subwc=subwc+1 151 | x[t,:]=sn2.copy() 152 | qisq=np.sqrt(qeqi) 153 | # empirical formula for standard error based on Fig 15 of [2] 154 | xs[t,:]=sn2*np.sqrt(0.266*(nd+100*qisq)*qisq/(1+0.005*nd+6/nd)/(0.5*qeqi**(-1)+nd-1)) 155 | 156 | 157 | return x 158 | 159 | def mhvals(*args): 160 | """ 161 | This is python implementation of [1],[2], and [3]. 162 | 163 | Refs: 164 | [1] Rainer Martin. 165 | Noise power spectral density estimation based on optimal smoothing and minimum statistics. 166 | IEEE Trans. Speech and Audio Processing, 9(5):504-512, July 2001. 167 | [2] Rainer Martin. 168 | Bias compensation methods for minimum statistics noise power spectral density estimation 169 | Signal Processing, 2006, 86, 1215-1229 170 | [3] Dirk Mauler and Rainer Martin 171 | Noise power spectral density estimation on highly correlated data 172 | Proc IWAENC, 2006 173 | 174 | Copyright (C) Mike Brookes 2008 175 | Version: $Id: estnoisem.m 1718 2012-03-31 16:40:41Z dmb $ 176 | 177 | VOICEBOX is a MATLAB toolbox for speech processing. 178 | Home page: http://www.ee.ic.ac.uk/hp/staff/dmb/voicebox/voicebox.html 179 | """ 180 | nargin = len(args) 181 | 182 | dmh=np.array([ 183 | [1, 0, 0], 184 | [2, 0.26, 0.15], 185 | [5, 0.48, 0.48], 186 | [8, 0.58, 0.78], 187 | [10, 0.61, 0.98], 188 | [15, 0.668, 1.55], 189 | [20, 0.705, 2], 190 | [30, 0.762, 2.3], 191 | [40, 0.8, 2.52], 192 | [60, 0.841, 3.1], 193 | [80, 0.865, 3.38], 194 | [120, 0.89, 4.15], 195 | [140, 0.9, 4.35], 196 | [160, 0.91, 4.25], 197 | [180, 0.92, 3.9], 198 | [220, 0.93, 4.1], 199 | [260, 0.935, 4.7], 200 | [300, 0.94, 5] 201 | ],dtype=float) 202 | 203 | if nargin>=1: 204 | d=args[0] 205 | i=np.nonzero(d<=dmh[:,0]) 206 | if len(i)==0: 207 | i=np.shape(dmh)[0]-1 208 | j=i 209 | else: 210 | i=i[0][0] 211 | j=i-1 212 | if d==dmh[i,0]: 213 | m=dmh[i,1] 214 | h=dmh[i,2] 215 | else: 216 | qj=np.sqrt(dmh[i-1,0]) # interpolate usignalng sqrt(d) 217 | qi=np.sqrt(dmh[i,0]) 218 | q=np.sqrt(d) 219 | h=dmh[i,2]+(q-qi)*(dmh[j,2]-dmh[i,2])/(qj-qi) 220 | m=dmh[i,1]+(qi*qj/q-qj)*(dmh[j,1]-dmh[i,1])/(qi-qj) 221 | else: 222 | d=dmh[:,0].copy() 223 | m=dmh[:,1].copy() 224 | h=dmh[:,2].copy() 225 | 226 | return m,h,d 227 | 228 | 229 | def max_complex(a,b): 230 | """ 231 | This is python implementation of [1],[2], and [3]. 232 | 233 | Refs: 234 | [1] Rainer Martin. 235 | Noise power spectral density estimation based on optimal smoothing and minimum statistics. 236 | IEEE Trans. Speech and Audio Processing, 9(5):504-512, July 2001. 237 | [2] Rainer Martin. 238 | Bias compensation methods for minimum statistics noise power spectral density estimation 239 | Signal Processing, 2006, 86, 1215-1229 240 | [3] Dirk Mauler and Rainer Martin 241 | Noise power spectral density estimation on highly correlated data 242 | Proc IWAENC, 2006 243 | 244 | Copyright (C) Mike Brookes 2008 245 | Version: $Id: estnoisem.m 1718 2012-03-31 16:40:41Z dmb $ 246 | 247 | VOICEBOX is a MATLAB toolbox for speech processing. 248 | Home page: http://www.ee.ic.ac.uk/hp/staff/dmb/voicebox/voicebox.html 249 | """ 250 | if len(a)==1 and len(b)>1: 251 | a=np.tile(a,np.shape(b)) 252 | if len(b)==1 and len(a)>1: 253 | b=np.tile(b,np.shape(a)) 254 | 255 | i=np.logical_or(np.iscomplex(a),np.iscomplex(b)) 256 | 257 | aa = a.copy() 258 | bb = b.copy() 259 | 260 | if any(i): 261 | aa[i]=np.absolute(aa[i]) 262 | bb[i]=np.absolute(bb[i]) 263 | if a.dtype == 'complex' or b.dtype== 'complex': 264 | cc = np.array(np.zeros(np.shape(a)) ) 265 | else: 266 | cc = np.array(np.zeros(np.shape(a)),dtype=float) 267 | 268 | i=aa>bb 269 | cc[i]=a[i] 270 | cc[np.logical_not(i)] = b[np.logical_not(i)] 271 | 272 | return cc 273 | 274 | def min_complex(a,b): 275 | """ 276 | This is python implementation of [1],[2], and [3]. 277 | 278 | Refs: 279 | [1] Rainer Martin. 280 | Noise power spectral density estimation based on optimal smoothing and minimum statistics. 281 | IEEE Trans. Speech and Audio Processing, 9(5):504-512, July 2001. 282 | [2] Rainer Martin. 283 | Bias compensation methods for minimum statistics noise power spectral density estimation 284 | Signal Processing, 2006, 86, 1215-1229 285 | [3] Dirk Mauler and Rainer Martin 286 | Noise power spectral density estimation on highly correlated data 287 | Proc IWAENC, 2006 288 | 289 | Copyright (C) Mike Brookes 2008 290 | Version: $Id: estnoisem.m 1718 2012-03-31 16:40:41Z dmb $ 291 | 292 | VOICEBOX is a MATLAB toolbox for speech processing. 293 | Home page: http://www.ee.ic.ac.uk/hp/staff/dmb/voicebox/voicebox.html 294 | """ 295 | if len(a)==1 and len(b)>1: 296 | a=np.tile(a,np.shape(b)) 297 | if len(b)==1 and len(a)>1: 298 | b=np.tile(b,np.shape(a)) 299 | 300 | i=np.logical_or(np.iscomplex(a),np.iscomplex(b)) 301 | 302 | aa = a.copy() 303 | bb = b.copy() 304 | 305 | if any(i): 306 | aa[i]=np.absolute(aa[i]) 307 | bb[i]=np.absolute(bb[i]) 308 | 309 | if a.dtype == 'complex' or b.dtype== 'complex': 310 | cc = np.array(np.zeros(np.shape(a)) ) 311 | else: 312 | cc = np.array(np.zeros(np.shape(a)),dtype=float) 313 | 314 | i=aa