├── dataset ├── __init__.py ├── config.py └── musdb.py ├── requirements.txt ├── .gitignore ├── utils └── noam_schedule.py ├── model ├── config.py ├── __init__.py └── wavenet.py ├── config.py ├── README.md ├── augmentation_utils.py ├── process_musdb.py ├── separate.py ├── separate_musdb_track.py └── train.py /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .musdb import MUSDB -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | librosa==0.7.0 2 | matplotlib==3.3.1 3 | numba==0.48 4 | numpy==1.19.1 5 | tensorflow>=2.1.0 6 | tqdm==4.48.2 7 | torchaudio>0.7.0 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # python 2 | __pycache__ 3 | 4 | # train 5 | output 6 | ckpt 7 | log 8 | 9 | # test sample 10 | sample 11 | 12 | # additional files not needed or being fixed 13 | inf_sar.py 14 | medley.py 15 | inference.py 16 | run_inference.py 17 | run_inference_2.py 18 | perceptual.py 19 | OLD_README.md 20 | -------------------------------------------------------------------------------- /utils/noam_schedule.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class NoamScheduler(tf.keras.optimizers.schedules.LearningRateSchedule): 5 | """Noam learning rate scheduler from Vaswani et al., 2017. 6 | """ 7 | def __init__(self, learning_rate, warmup_steps, channels): 8 | """Initializer. 9 | Args: 10 | learning_rate: float, initial learning rate. 11 | warmup_steps: int, warmup steps. 12 | channels: int, base hidden size of the model. 13 | """ 14 | super(NoamScheduler, self).__init__() 15 | self.learning_rate = learning_rate 16 | self.warmup_steps = warmup_steps 17 | self.channels = channels 18 | 19 | def __call__(self, step): 20 | """Compute learning rate. 21 | """ 22 | return self.learning_rate * self.channels ** -0.5 * \ 23 | tf.minimum(step ** -0.5, step * self.warmup_steps ** -1.5) 24 | 25 | def get_config(self): 26 | """Serialize configurations. 27 | """ 28 | return { 29 | 'learning_rate': self.learning_rate, 30 | 'warmup_steps': self.warmup_steps, 31 | 'channels': self.channels, 32 | } 33 | -------------------------------------------------------------------------------- /dataset/config.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | class Config: 4 | """Configuration for dataset construction. 5 | """ 6 | def __init__(self): 7 | # audio config 8 | self.sr = 22050 9 | 10 | # stft 11 | self.hop = 256 12 | self.win = 1024 13 | self.fft = self.win 14 | self.win_fn = "hann" 15 | 16 | # mel-scale filter bank 17 | self.mel = 80 18 | self.fmin = 0 19 | self.fmax = 8000 20 | 21 | self.eps = 1e-5 22 | 23 | # sample size 24 | self.frames = self.hop * 32 # 8192 25 | self.batch = 8 26 | 27 | self.eval_tracks = ["20", "40", "93", "99"] # Used for the experiments 28 | 29 | def window_fn(self): 30 | """Return window generator. 31 | Returns: 32 | Callable, window function of tf.signal 33 | , which corresponds to self.win_fn. 34 | """ 35 | mapper = { 36 | "hann": tf.signal.hann_window, 37 | "hamming": tf.signal.hamming_window 38 | } 39 | if self.win_fn in mapper: 40 | return mapper[self.win_fn] 41 | 42 | raise ValueError("invalid window function: " + self.win_fn) 43 | -------------------------------------------------------------------------------- /model/config.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Config: 5 | """Configuration for DiffWave implementation. 6 | """ 7 | def __init__(self): 8 | self.sr = 22050 9 | 10 | # mel-scale filter bank 11 | self.mel = 80 12 | self.fmin = 0 13 | self.fmax = 8000 14 | 15 | self.eps = 1e-5 16 | 17 | # sample size 18 | self.frames = 256 * 32 # 8192 19 | self.batch = 8 20 | 21 | # leaky relu coefficient 22 | self.leak = 0.4 23 | 24 | # embdding config 25 | self.embedding_size = 128 26 | self.embedding_proj = 512 27 | self.embedding_layers = 2 28 | self.embedding_factor = 4 29 | 30 | # upsampler config 31 | self.upsample_stride = [4, 1] 32 | self.upsample_kernel = [32, 3] 33 | self.upsample_layers = 4 34 | # computed hop size 35 | self.hop = self.upsample_stride[0] ** self.upsample_layers 36 | 37 | # block config 38 | self.channels = 64 39 | self.kernel_size = 3 40 | self.dilation_rate = 2 41 | self.num_layers = 30 42 | self.num_cycles = 3 43 | 44 | # noise schedule 45 | self.iter = 20 # 20, 40, 50 46 | self.noise_policy = "linear" 47 | self.noise_start = 1e-4 48 | self.noise_end = 0.2 # 0.02 for 200 49 | 50 | def beta(self): 51 | """Generate beta-sequence. 52 | Returns: 53 | List[float], [iter], beta values. 54 | """ 55 | mapper = { 56 | 'linear': self._linear_sched, 57 | } 58 | if self.noise_policy not in mapper: 59 | raise ValueError('invalid beta policy') 60 | return mapper[self.noise_policy]() 61 | 62 | def _linear_sched(self): 63 | """Linearly generated noise. 64 | Returns: 65 | List[float], [iter], beta values. 66 | """ 67 | return np.linspace( 68 | self.noise_start, self.noise_end, self.iter, dtype=np.float32) 69 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from dataset.config import Config as DataConfig 2 | from model.config import Config as ModelConfig 3 | from utils.noam_schedule import NoamScheduler 4 | 5 | 6 | class TrainConfig: 7 | """Configuration for training loop. 8 | """ 9 | def __init__(self): 10 | # optimizer 11 | self.lr_policy = "fixed" 12 | self.learning_rate = 0.0002 13 | # self.lr_policy = "noam" 14 | # self.learning_rate = 1 15 | # self.lr_params = { 16 | # "warmup_steps": 4000, 17 | # "channels": 64 18 | # } 19 | 20 | self.beta1 = 0.9 21 | self.beta2 = 0.98 22 | self.eps = 1e-9 23 | 24 | # 13000:100 25 | self.split = 559*8 # Make sure this is the actual number of chunks 26 | self.bufsiz = 16 27 | 28 | self.epoch = 100000 # Select number of epocs to run 29 | 30 | # path config 31 | self.log = "./log" 32 | self.ckpt = "./ckpt" 33 | 34 | # model name 35 | self.name = "20-step-vocal" 36 | self.target = "vocals" 37 | self.sr = 22050 38 | 39 | # interval configuration 40 | self.ckpt_intval = 20000 41 | 42 | def lr(self): 43 | """Generate proper learning rate scheduler. 44 | """ 45 | mapper = { 46 | "noam": NoamScheduler 47 | } 48 | if self.lr_policy == "fixed": 49 | return self.learning_rate 50 | if self.lr_policy in mapper: 51 | return mapper[self.lr_policy](self.learning_rate, **self.lr_params) 52 | raise ValueError("invalid lr_policy") 53 | 54 | class Config: 55 | """Integrated configuration. 56 | """ 57 | def __init__(self): 58 | self.data = DataConfig() 59 | self.model = ModelConfig() 60 | self.train = TrainConfig() 61 | 62 | def dump(self): 63 | """Dump configurations into serializable dictionary. 64 | """ 65 | return {k: vars(v) for k, v in vars(self).items()} 66 | 67 | @staticmethod 68 | def load(dump_): 69 | """Load dumped configurations into new configuration. 70 | """ 71 | conf = Config() 72 | for k, v in dump_.items(): 73 | if hasattr(conf, k): 74 | obj = getattr(conf, k) 75 | load_state(obj, v) 76 | return conf 77 | 78 | 79 | def load_state(obj, dump_): 80 | """Load dictionary items to attributes. 81 | """ 82 | for k, v in dump_.items(): 83 | if hasattr(obj, k): 84 | setattr(obj, k, v) 85 | return obj 86 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from .wavenet import WaveNet 5 | 6 | class DiffWave(tf.keras.Model): 7 | """DiffWave: A Versatile Diffusion Model for Audio Synthesis. 8 | Zhifeng Kong et al., 2020. 9 | *** Slighly modified version of the original DiffWave code to a 10 | ccount for the singing voice separation approach. If re-using this 11 | code make sure to reference DiffWave! 12 | """ 13 | def __init__(self, config): 14 | """Initializer. 15 | Args: 16 | config: Config, model configuration. 17 | """ 18 | super(DiffWave, self).__init__() 19 | self.config = config 20 | self.wavenet = WaveNet(config) 21 | 22 | def call(self, signal): 23 | """Generate denoised audio. 24 | Args: 25 | signal: tf.Tensor, [B, T], starting signal for transformation. 26 | Returns: 27 | signal: tf.Tensor, [B, T], predicted output. 28 | """ 29 | alpha = 1 - self.config.beta() 30 | alpha_bar = np.cumprod(alpha) 31 | base = tf.ones([tf.shape(signal)[0]], dtype=tf.int32) 32 | for t in range(self.config.iter, 0, -1): 33 | eps = self.pred_noise(signal, base * t) 34 | mu = self.pred_signal(signal, eps, alpha[t - 1], alpha_bar[t - 1]) 35 | signal = mu 36 | return signal 37 | 38 | def diffusion(self, perturbation, target, pert_to_est, alpha_bar): 39 | if isinstance(alpha_bar, tf.Tensor): 40 | alpha_bar = alpha_bar[:, None] 41 | return tf.sqrt(alpha_bar) * target + \ 42 | tf.sqrt(1 - alpha_bar) * perturbation, pert_to_est 43 | 44 | def pred_noise(self, signal, timestep): 45 | """Predict noise from signal. 46 | Args: 47 | signal: tf.Tensor, [B, T], noised signal. 48 | timestep: tf.Tensor, [B], timesteps of current markov chain. 49 | Returns: 50 | tf.Tensor, [B, T], predicted noise. 51 | """ 52 | return self.wavenet(signal, timestep) 53 | 54 | def pred_signal(self, signal, eps, alpha, alpha_bar): 55 | """Compute mean of denoised signal. 56 | Args: 57 | signal: tf.Tensor, [B, T], noised signal. 58 | eps: tf.Tensor, [B, T], estimated noise. 59 | alpha: float, 1 - beta. 60 | alpha_bar: float, cumprod(1 - beta). 61 | Returns: 62 | tuple, 63 | mean: tf.Tensor, [B, T], estimated mean of denoised signal. 64 | """ 65 | signal = tf.dtypes.cast(signal, tf.float64) 66 | eps = tf.dtypes.cast(eps, tf.float64) 67 | 68 | # Compute mean (our estimation) from the original diffusion parametrization 69 | mean = (signal - (1 - alpha) / tf.dtypes.cast(tf.sqrt(1 - alpha_bar), tf.float64) * eps) \ 70 | / tf.dtypes.cast(tf.sqrt(alpha), tf.float64) 71 | return mean 72 | 73 | def write(self, path, optim=None): 74 | """Write checkpoint with `tf.train.Checkpoint`. 75 | Args: 76 | path: str, path to write. 77 | optim: Optional[tf.keras.optimizers.Optimizer] 78 | , optional optimizer. 79 | """ 80 | kwargs = {'model': self} 81 | if optim is not None: 82 | kwargs['optim'] = optim 83 | ckpt = tf.train.Checkpoint(**kwargs) 84 | ckpt.save(path) 85 | 86 | def restore(self, path, optim=None): 87 | """Restore checkpoint with `tf.train.Checkpoint`. 88 | Args: 89 | path: str, path to restore. 90 | optim: Optional[tf.keras.optimizers.Optimizer] 91 | , optional optimizer. 92 | """ 93 | kwargs = {'model': self} 94 | if optim is not None: 95 | kwargs['optim'] = optim 96 | ckpt = tf.train.Checkpoint(**kwargs) 97 | return ckpt.restore(path) 98 | -------------------------------------------------------------------------------- /dataset/musdb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import tensorflow as tf 4 | 5 | class MUSDB: 6 | """MUSDB dataset loader. 7 | Use other opensource vocoder settings, 16bit, sr: 22050. 8 | """ 9 | SR = 22050 10 | 11 | def __init__(self, config, data_dir=None): 12 | """Initializer. 13 | Args: 14 | config: Config, dataset configuration. 15 | data_dir: str, dataset directory 16 | , defaults to "~/tensorflow_datasets". 17 | download: bool, download dataset or not. 18 | from_tfds: bool, load from tfrecord generated by tfds or read raw audio. 19 | """ 20 | self.config = config 21 | self.rawset = self.load_data("train", data_dir) 22 | self.valset = self.load_data("val", data_dir) 23 | 24 | self.normalized = None 25 | 26 | def load_data(self, subset="train", data_dir=None): 27 | """Load dataset from tfrecord or raw audio files. 28 | Args: 29 | data_dir: str, dataset directory. 30 | Returns: 31 | tf.data.Dataset, data loader. 32 | """ 33 | if subset == "train": 34 | mixture_files = glob.glob(os.path.join(data_dir, "*mixture.wav")) 35 | for track in self.config.eval_tracks: 36 | mixture_files = [x for x in mixture_files if track + "_" not in x] 37 | else: 38 | mixture_files = [] 39 | for track in self.config.eval_tracks: 40 | mixture_files += [x for x in glob.glob(os.path.join(data_dir, "*mixture.wav")) if (track + "_" in x) and \ 41 | ("_" + track + "_" not in track)] 42 | mixture_files = [x for x in mixture_files if "silence" not in x] 43 | # generate file lists 44 | files = tf.data.Dataset.from_tensor_slices( 45 | [(mix, mix.replace("_mixture.", "_vocals."), mix.replace("_mixture.", "_accompaniment.")) for mix in mixture_files]) 46 | 47 | return files.map(MUSDB._load_audio) 48 | 49 | @staticmethod 50 | def _load_audio(paths): 51 | """Load audio with tf apis. 52 | Args: 53 | path: str, wavfile path to read. 54 | Returns: 55 | tf.Tensor, [T], mono audio in range (-1, 1). 56 | """ 57 | mixture_audio, _ = tf.audio.decode_wav(tf.io.read_file(paths[0]), desired_channels=1) 58 | vocal_audio, _ = tf.audio.decode_wav(tf.io.read_file(paths[1]), desired_channels=1) 59 | accomp_audio, _ = tf.audio.decode_wav(tf.io.read_file(paths[2]), desired_channels=1) 60 | return tf.squeeze(mixture_audio, axis=-1), tf.squeeze(vocal_audio, axis=-1), tf.squeeze(accomp_audio, axis=-1) 61 | 62 | def normalizer(self, frames=16000): 63 | """Create dataset normalizer, make fixed size segment in range(-1, 1). 64 | Args: 65 | frames: int, segment size, frame unit. 66 | from_tfds: bool, whether use tfds tfrecord or raw audio. 67 | Returns: 68 | Callable, normalizer. 69 | """ 70 | def normalize(mixture_signal, vocal_signal, accomp_signal): 71 | """Normalize datum. 72 | Args: 73 | mixture_signal: tf.Tensor, [T], mono audio in range (-1, 1). 74 | vocal_signal: tf.Tensor, [T], mono audio in range (-1, 1). 75 | accomp_signal: tf.Tensor, [T], mono audio in range (-1, 1). 76 | Returns: 77 | tf.Tensor, [frames], fixed size mixture signal in range (-1, 1). 78 | tf.Tensor, [frames], fixed size vocal signal in range (-1, 1). 79 | tf.Tensor, [frames], fixed size accomp signal in range (-1, 1). 80 | """ 81 | nonlocal frames 82 | frames = frames // self.config.hop * self.config.hop 83 | start = tf.random.uniform( 84 | (), 0, tf.shape(vocal_signal)[0] - frames, dtype=tf.int32) 85 | return mixture_signal[start:start + frames], vocal_signal[start:start + frames], accomp_signal[start:start + frames] 86 | return normalize 87 | 88 | def dataset(self): 89 | """Generate dataset. 90 | """ 91 | if self.normalized is None: 92 | self.normalized = self.rawset \ 93 | .map(self.normalizer(self.config.frames)) \ 94 | .batch(self.config.batch) 95 | return self.normalized 96 | 97 | def test_dataset(self): 98 | """Generate dataset. 99 | """ 100 | return self.valset \ 101 | .map(self.normalizer(self.config.frames)) \ 102 | .batch(self.config.batch) 103 | 104 | def validation(self): 105 | """Generate dataset. 106 | """ 107 | # Getting longer samples for evaluation 108 | return self.valset \ 109 | .map(self.normalizer(self.config.frames*4)) \ 110 | .batch(1) 111 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # diffusion-vocal-sep 2 | Code for training and inferencing using the method presented in "A diffusion-inspired training strategy for singing voice extraction in the waveform domain", presented in ISMIR 2022. 3 | 4 | This code is an adaptation of [DiffWave code](https://github.com/revsic/tf-diffwave). If re-using the implementation itself, please refer to 5 | the original repository of DiffWave. 6 | 7 | 8 | ## Quick reference 9 | ### Data pre-processing 10 | 11 | To enhance computation, the model is configured to be trained at ``22050Hz``. You may change that if desired. Bear in mind that at some parts of the code, 12 | this sampling rate is hard-coded. To train the dataset, you first need to pre-process MUSDB18HQ, to resample, create the accompaniments, and trim the 13 | recordings into chunks of 4 seconds, in order to enhance the training stage. You can do that by running: 14 | 15 | ```python 16 | python3 process_musdb.py --data-dir --output-dir --train 17 | ``` 18 | 19 | If ``--train`` is ``False``, in order to pre-process the testing data, the tracks will be resampled but not chunked, which is better for evaluation. 20 | 21 | 22 | ### Training 23 | 24 | To start the training, run: 25 | 26 | ```python 27 | python3 train.py --data-dir 28 | ``` 29 | 30 | If the training is interrupted, you can continue training by using: 31 | 32 | ```python 33 | python3 train.py --data-dir --config --load-step 34 | ``` 35 | 36 | The ``config.json`` file is stored in ``./ckpt/`` by default. The weights are stored, by default, in 37 | ``./ckpt//_.ckpt-1.data-00000-of-00001``. If the training process 38 | is stopped and you need to continue from where you left off, by setting the argument ``--load-step``, the training code takes the closest stored 39 | step. Note that you can manually set, in the ``config.py`` file, how often (in terms of steps) the weights are stored. 40 | 41 | We also store the model weights that obtain the best source separation metrics in the validation runs. This model is store in the 42 | following format: ``./ckpt//_BEST_MODEL_.ckpt-1.data-00000-of-00001``. 43 | 44 | Each model you train has a particular name, which is set in the ``config.py`` file, and will be used to relate the configuration, stored weights, 45 | and will be useful when inferencing. 46 | 47 | In the main configuration file, you can also set up the target you want to train the model for. Bear in mind that there are configuration files also 48 | for the dataloader and the model structure (which are found, respectively, in ``./dataset/`` and ``./model/``). 49 | 50 | 51 | ### Inference 52 | 53 | To run inference on a particular recording, run: 54 | 55 | ```python 56 | python3 separate.py --input-file --output-dir --model-name --batch --wiener 57 | ``` 58 | 59 | The ``separate.py`` function also takes a ``--ckpt`` parameter, which you can set manually. If ``--ckpt`` is not set, the path is built from the 60 | given ``--model-name`` (which is required), and the ``BEST_MODEL`` for the last available step is taken. During inference, the audio array to use 61 | as input is chunked used a pre-defined size of 20 seconds, but the user can select any duration of the chunks. This is done to prevent filling the 62 | available memory. 63 | 64 | We include an additional file to run inference on a MUSDB18HQ track. Basically, assuming that the folder structure is that of MUSDB18HQ, you can run 65 | the file as specified below, providing the path to the ``mixture.wav`` in MUSDB you would like to run inference on. The function will automatically 66 | take the references to compute, for this particular track, the SDR metric using the default implementation that is used in the MDX Challenge 2021-2023. 67 | 68 | ```python 69 | python3 separate_musdb_track.py --input-file --output-dir --model-name --batch --wiener 70 | ``` 71 | 72 | 73 | ### Citing 74 | 75 | ``` 76 | "A diffusion-inspired training strategy for singing voice extraction in the waveform domain" 77 | Genís Plaja-Roglans, Marius Miron, Xavier Serra 78 | in Proceedings of the International Society for Music Information Retrieval (ISMIR) Conference, 2022 (Bengaluru, India) 79 | ``` 80 | 81 | ``` 82 | @inproceedings{ 83 | Plaja-Roglans_2022, 84 | title={A diffusion-inspired training strategy for singing voice extraction in the waveform domain}, 85 | author={Plaja-Roglans, Genís and Miron, Marius and Serra, Xavier}, 86 | booktitle={International Society for Music Information Retrieval (ISMIR) Conference}, 87 | year={2022} 88 | } 89 | ``` 90 | 91 | Once again, this implementation is broadly based on the [TensorFlow implementation of DiffWave](https://github.com/revsic/tf-diffwave). If you are 92 | willing to use parts of the implementation itself, we kindly request that you refer to the TensorFlow DiffWave release and also cite it. 93 | -------------------------------------------------------------------------------- /augmentation_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | def _pitch_shift(single_audio, _shift): 5 | r_fft = tf.signal.rfft(single_audio) 6 | r_fft = tf.roll(r_fft, _shift, axis=0) 7 | zeros = tf.complex(tf.zeros([tf.abs(_shift)]), tf.zeros([tf.abs(_shift)])) 8 | if _shift < 0: 9 | r_fft = tf.concat([r_fft[:_shift], zeros], axis=0) 10 | else: 11 | r_fft = tf.concat([zeros, r_fft[_shift:]], axis=0) 12 | return tf.signal.irfft(r_fft) 13 | 14 | def _time_stretch(self, single_audio, _stretch): 15 | single_audio = tf.signal.stft( 16 | single_audio, 17 | frame_length=1024, 18 | frame_step=256, 19 | fft_length=1024, 20 | window_fn=tf.signal.hann_window) 21 | single_audio = phase_vocoder(single_audio, rate=_stretch) 22 | single_audio = tf.signal.inverse_stft( 23 | single_audio, 24 | frame_length=1024, 25 | frame_step=256, 26 | window_fn=tf.signal.inverse_stft_window_fn( 27 | 256, 28 | forward_window_fn=tf.signal.hann_window)) 29 | if single_audio.shape[0] > self.config.model.frames: 30 | single_audio = single_audio[:self.config.model.frames] 31 | if single_audio.shape[0] < self.config.model.frames: 32 | single_audio = tf.concat( 33 | [single_audio, tf.zeros([self.config.model.frames - single_audio.shape[0]])], 0) 34 | return single_audio 35 | 36 | def pitch_augment(self, mixture, vocal, accomp): 37 | """Compute conditions 38 | """ 39 | _shift = tf.random.uniform( 40 | shape=(self.config.model.frames,), minval=-4, maxval=4, dtype=tf.int64) 41 | augmentation = lambda x : _pitch_shift(x[0], x[1], x[2]) 42 | mixture = tf.map_fn( 43 | fn=augmentation, elems=[mixture, _shift], 44 | fn_output_signature=tf.float32) 45 | vocal = tf.map_fn( 46 | fn=augmentation, elems=[vocal, _shift], 47 | fn_output_signature=tf.float32) 48 | accomp = tf.map_fn( 49 | fn=augmentation, elems=[accomp, _shift], 50 | fn_output_signature=tf.float32) 51 | return mixture, vocal, accomp 52 | 53 | def time_augment(self, mixture, vocal, accomp): 54 | """Compute conditions 55 | """ 56 | _stretch = tf.random.uniform( 57 | shape=(self.config.model.frames,), minval=0.5, maxval=1.75, dtype=tf.float32) 58 | augmentation = lambda x : _time_stretch(x[0], x[1], x[2]) 59 | mixture = tf.map_fn( 60 | fn=augmentation, elems=[mixture, _stretch], 61 | fn_output_signature=tf.float32) 62 | vocal = tf.map_fn( 63 | fn=augmentation, elems=[vocal, _stretch], 64 | fn_output_signature=tf.float32) 65 | accomp = tf.map_fn( 66 | fn=augmentation, elems=[accomp, _stretch], 67 | fn_output_signature=tf.float32) 68 | return mixture, vocal, accomp 69 | 70 | def phase_vocoder(D, hop_len=256, rate=0.8): 71 | """Phase vocoder. Given an STFT matrix D, speed up by a factor of `rate`. 72 | Based on implementation provided by: 73 | https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html#phase_vocoder 74 | :param D: tf.complex64([num_frames, num_bins]): the STFT tensor 75 | :param hop_len: float: the hop length param of the STFT 76 | :param rate: float > 0: the speed-up factor 77 | :return: D_stretched: tf.complex64([num_frames, num_bins]): the stretched STFT tensor 78 | """ 79 | # get shape 80 | sh = tf.shape(D, name="STFT_shape") 81 | frames = sh[0] 82 | fbins = sh[1] 83 | 84 | # time steps range 85 | t = tf.range(0.0, tf.cast(frames, tf.float32), rate, dtype=tf.float32, name="time_steps") 86 | 87 | # Expected phase advance in each bin 88 | dphi = tf.linspace(0.0, np.pi * hop_len, fbins, name="dphi_expected_phase_advance") 89 | phase_acc = tf.math.angle(D[0, :], name="phase_acc_init") 90 | 91 | # Pad 0 columns to simplify boundary logic 92 | D = tf.pad(D, [(0, 2), (0, 0)], mode='CONSTANT', name="padded_STFT") 93 | 94 | # def fn(previous_output, current_input): 95 | def _pvoc_mag_and_cum_phase(previous_output, current_input): 96 | # unpack prev phase 97 | _, prev = previous_output 98 | 99 | # grab the two current columns of the STFT 100 | i = tf.cast((tf.floor(current_input) + [0, 1]), tf.int32) 101 | bcols = tf.gather_nd(D, [[i[0]], [i[1]]]) 102 | 103 | # Weighting for linear magnitude interpolation 104 | t_dif = current_input - tf.floor(current_input) 105 | bmag = (1 - t_dif) * tf.abs(bcols[0, :]) + t_dif * (tf.abs(bcols[1, :])) 106 | 107 | # Compute phase advance 108 | dp = tf.math.angle(bcols[1, :]) - tf.math.angle(bcols[0, :]) - dphi 109 | dp = dp - 2 * np.pi * tf.round(dp / (2.0 * np.pi)) 110 | 111 | # return linear mag, accumulated phase 112 | return bmag, tf.squeeze(prev + dp + dphi) 113 | 114 | # initializer of zeros of correct shape for mag, and phase_acc for phase 115 | initializer = (tf.zeros(fbins, tf.float32), phase_acc) 116 | mag, phase = tf.scan(_pvoc_mag_and_cum_phase, t, initializer=initializer, 117 | parallel_iterations=10, back_prop=False, 118 | name="pvoc_cum_phase") 119 | 120 | # add the original phase_acc in 121 | phase2 = tf.concat([tf.expand_dims(phase_acc, 0), phase], 0)[:-1, :] 122 | D_stretched = tf.cast(mag, tf.complex64) * tf.exp(1.j * tf.cast(phase2, tf.complex64), name="stretched_STFT") 123 | 124 | return D_stretched -------------------------------------------------------------------------------- /model/wavenet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | class DilatedConv1d(tf.keras.layers.Layer): 4 | """Custom implementation of dilated convolution 1D 5 | because of the issue https://github.com/tensorflow/tensorflow/issues/26797. 6 | """ 7 | def __init__(self, 8 | in_channels, 9 | out_channels, 10 | kernel_size, 11 | dilation_rate): 12 | """Initializer. 13 | Args: 14 | in_channels: int, input channels. 15 | out_channels: int, output channels. 16 | kernel_size: int, size of the kernel. 17 | dilation_rate: int, dilation rate. 18 | """ 19 | super(DilatedConv1d, self).__init__() 20 | self.dilations = dilation_rate 21 | 22 | init = tf.keras.initializers.GlorotUniform() 23 | self.kernel = tf.Variable( 24 | init([kernel_size, in_channels, out_channels], dtype=tf.float32), 25 | trainable=True) 26 | self.bias = tf.Variable( 27 | tf.zeros([1, 1, out_channels], dtype=tf.float32), 28 | trainable=True) 29 | 30 | def call(self, inputs): 31 | """Pass to dilated convolution 1d. 32 | Args: 33 | inputs: tf.Tensor, [B, T, Cin], input tensor. 34 | Returns: 35 | outputs: tf.Tensor, [B, T', Cout], output tensor. 36 | """ 37 | conv = tf.nn.conv1d( 38 | inputs, self.kernel, 1, padding='SAME', dilations=self.dilations) 39 | return conv + self.bias 40 | 41 | class Block(tf.keras.Model): 42 | """WaveNet Block. 43 | """ 44 | def __init__(self, channels, kernel_size, dilation, last=False): 45 | """Initializer. 46 | Args: 47 | channels: int, basic channel size. 48 | kernel_size: int, kernel size of the dilated convolution. 49 | dilation: int, dilation rate. 50 | last: bool, last block or not. 51 | """ 52 | super(Block, self).__init__() 53 | self.channels = channels 54 | self.last = last 55 | 56 | self.proj_embed = tf.keras.layers.Dense(channels) 57 | self.conv = DilatedConv1d( 58 | channels, channels * 2, kernel_size, dilation) 59 | self.proj_mel = tf.keras.layers.Conv1D(channels * 2, 1) 60 | if not last: 61 | self.proj_res = tf.keras.layers.Conv1D(channels, 1) 62 | self.proj_skip = tf.keras.layers.Conv1D(channels, 1) 63 | 64 | def call(self, inputs, embedding, cond=None): 65 | """Pass wavenet block. 66 | Args: 67 | inputs: tf.Tensor, [B, T, C(=channels)], input tensor. 68 | embedding: tf.Tensor, [B, E], embedding tensor for noise schedules. 69 | mel: tf.Tensor, [B, T // hop, M], mel-spectrogram conditions. 70 | Returns: 71 | residual: tf.Tensor, [B, T, C], output tensor for residual connection. 72 | skip: tf.Tensor, [B, T, C], output tensor for skip connection. 73 | """ 74 | # [B, C] 75 | embedding = self.proj_embed(embedding) 76 | # [B, T, C] 77 | x = inputs + embedding[:, None] 78 | # [B, T, Cx2] 79 | x = self.conv(x) 80 | # [B, T, C] 81 | context = tf.math.tanh(x[..., :self.channels]) 82 | gate = tf.math.sigmoid(x[..., self.channels:]) 83 | x = context * gate 84 | # [B, T, C] 85 | residual = (self.proj_res(x) + inputs) / 2 ** 0.5 if not self.last else None 86 | skip = self.proj_skip(x) 87 | return residual, skip 88 | 89 | class WaveNet(tf.keras.Model): 90 | """WaveNet structure. 91 | """ 92 | def __init__(self, config): 93 | """Initializer. 94 | Args: 95 | config: Config, model configuration. 96 | """ 97 | super(WaveNet, self).__init__() 98 | self.config = config 99 | # signal proj 100 | self.proj = tf.keras.layers.Conv1D(config.channels, 1) 101 | # embedding 102 | self.embed = self.embedding(config.iter) 103 | self.proj_embed = [ 104 | tf.keras.layers.Dense(config.embedding_proj) 105 | for _ in range(config.embedding_layers)] 106 | # wavenet blocks 107 | self.blocks = [] 108 | layers_per_cycle = config.num_layers // config.num_cycles 109 | for i in range(config.num_layers): 110 | dilation = config.dilation_rate ** (i % layers_per_cycle) 111 | self.blocks.append( 112 | Block( 113 | config.channels, 114 | config.kernel_size, 115 | dilation, 116 | last=i == config.num_layers - 1)) 117 | # for output 118 | self.proj_out = [ 119 | tf.keras.layers.Conv1D(config.channels, 1, activation=tf.nn.relu), 120 | tf.keras.layers.Conv1D(1, 1)] 121 | 122 | def call(self, signal, timestep): 123 | """Generate output signal. 124 | Args: 125 | signal: tf.Tensor, [B, T], noised signal. 126 | timestep: tf.Tensor, [B], int, timesteps of current markov chain. 127 | spec: tf.Tensor, TODO 128 | Returns: 129 | tf.Tensor, [B, T], generated. 130 | """ 131 | # [B, T, C(=channels)] 132 | x = tf.nn.relu(self.proj(signal[..., None])) 133 | # [B, E'] 134 | embed = tf.gather(self.embed, timestep - 1) 135 | # [B, E] 136 | for proj in self.proj_embed: 137 | embed = tf.nn.swish(proj(embed)) 138 | 139 | context = [] 140 | for block in self.blocks: 141 | # [B, T, C], [B, T, C] 142 | x, skip = block(x, embed) 143 | context.append(skip) 144 | # [B, T, C] 145 | scale = self.config.num_layers ** 0.5 146 | context = tf.reduce_sum(context, axis=0) / scale 147 | # [B, T, 1] 148 | for proj in self.proj_out: 149 | context = proj(context) 150 | # [B, T] 151 | return tf.squeeze(context, axis=-1) 152 | 153 | def embedding(self, iter): 154 | """Generate embedding. 155 | Args: 156 | iter: int, maximum iteration. 157 | Returns: 158 | tf.Tensor, [iter, E(=embedding_size)], embedding vectors. 159 | """ 160 | # [E // 2] 161 | logit = tf.linspace(0., 1., self.config.embedding_size // 2) 162 | exp = tf.pow(10, logit * self.config.embedding_factor) 163 | # [iter] 164 | timestep = tf.range(1, iter + 1) 165 | # [iter, E // 2] 166 | comp = exp[None] * tf.cast(timestep[:, None], tf.float32) 167 | # [iter, E] 168 | return tf.concat([tf.sin(comp), tf.cos(comp)], axis=-1) -------------------------------------------------------------------------------- /process_musdb.py: -------------------------------------------------------------------------------- 1 | """ 2 | To present more sparse and diverse data batches during training we comply with the training style of 3 | DiffWave and split the music recordings in chunks of 4 seconds. 4 | 5 | To get the proper encoding for TensorFlow to properly read the wav files, we use torchaudio which 6 | includes a very versatile way to get the audio files encoded as such. 7 | """ 8 | 9 | import os 10 | import math 11 | import torch 12 | import argparse 13 | 14 | import numpy as np 15 | import torchaudio as T 16 | 17 | from glob import glob 18 | from tqdm import tqdm 19 | 20 | def load_resample_downmix(path, new_sr=22050): 21 | # Loading 22 | audio, sr = T.load(path) 23 | # Resampling 24 | resampling = T.transforms.Resample(sr, new_sr) 25 | audio = resampling(audio) 26 | # Processing 27 | audio = torch.mean(audio, dim=0).unsqueeze(0) 28 | return audio 29 | 30 | 31 | def main_train(data_dir, output_dir, sample_len, sample_rate): 32 | 33 | if (data_dir is None) or (output_dir is None): 34 | raise ValueError("You must enter both directory of MUSDB18HQ and output directory") 35 | 36 | # Creating output dir if it does not exist 37 | if not os.path.exists(output_dir): 38 | os.mkdir(output_dir) 39 | 40 | # Get path list of songs 41 | musdb_songs = glob(os.path.join(data_dir, "*/")) 42 | 43 | for song_id, i in tqdm(enumerate(musdb_songs)): 44 | if T.__version__ > "0.7.0": 45 | audio_mix = load_resample_downmix(os.path.join(i, "mixture.wav"), sample_rate) 46 | audio_vocals = load_resample_downmix(os.path.join(i, "vocals.wav"), sample_rate) 47 | audio_bass = load_resample_downmix(os.path.join(i, "bass.wav"), sample_rate) 48 | audio_drums = load_resample_downmix(os.path.join(i, "drums.wav"), sample_rate) 49 | audio_other = load_resample_downmix(os.path.join(i, "other.wav"), sample_rate) 50 | 51 | # Get accomp track 52 | audio_accomp = audio_drums + audio_bass + audio_other 53 | 54 | audio_mix = torch.clamp(audio_mix, -1.0, 1.0) 55 | audio_vocals = torch.clamp(audio_vocals, -1.0, 1.0) 56 | audio_accomp = torch.clamp(audio_accomp, -1.0, 1.0) 57 | 58 | for trim in np.arange(math.floor((audio_mix.shape[1])/(sample_rate*sample_len))): 59 | audio_mix_trim = audio_mix[ 60 | :, trim*(sample_rate*sample_len):(trim+1)*(sample_rate*sample_len) 61 | ] 62 | audio_voc_trim = audio_vocals[ 63 | :, trim*(sample_rate*sample_len):(trim+1)*(sample_rate*sample_len) 64 | ] 65 | audio_accomp_trim = audio_accomp[ 66 | :, trim*(sample_rate*sample_len):(trim+1)*(sample_rate*sample_len) 67 | ] 68 | 69 | # Formatting filename 70 | if torch.max(audio_voc_trim[0]) == torch.tensor(0.0): 71 | track_id = "silence_" + song_id + "_" + str(trim) 72 | else: 73 | track_id = song_id + "_" + str(trim) 74 | 75 | # Saving 76 | T.save( 77 | os.path.join(output_dir, track_id + "_mixture.wav"), 78 | audio_mix_trim.cpu(), 79 | sample_rate=sample_rate, 80 | bits_per_sample=16 81 | ) 82 | T.save( 83 | os.path.join(output_dir, track_id + "_vocals.wav"), 84 | audio_voc_trim.cpu(), 85 | sample_rate=sample_rate, 86 | bits_per_sample=16 87 | ) 88 | T.save( 89 | os.path.join(output_dir, track_id + "_accompaniment.wav"), 90 | audio_accomp_trim.cpu(), 91 | sample_rate=sample_rate, 92 | bits_per_sample=16 93 | ) 94 | 95 | else: 96 | raise ModuleNotFoundError("Need a version > 0.7.0 for torchaudio!") 97 | 98 | 99 | def main_validation(data_dir, output_dir, sample_rate): 100 | 101 | if (data_dir is None) or (output_dir is None): 102 | raise ValueError("You must enter both directory of MUSDB18HQ test and output directory") 103 | 104 | # Creating output dir if it does not exist 105 | if not os.path.exists(output_dir): 106 | os.mkdir(output_dir) 107 | 108 | # Get path list of songs 109 | musdb_songs = glob(os.path.join(data_dir, "*/")) 110 | 111 | for i in tqdm(musdb_songs): 112 | if T.__version__ > "0.7.0": 113 | song_name = i.split("/")[-2] 114 | audio_mix = load_resample_downmix(os.path.join(i, "mixture.wav"), sample_rate) 115 | audio_vocals = load_resample_downmix(os.path.join(i, "vocals.wav"), sample_rate) 116 | audio_bass = load_resample_downmix(os.path.join(i, "bass.wav"), sample_rate) 117 | audio_drums = load_resample_downmix(os.path.join(i, "drums.wav"), sample_rate) 118 | audio_other = load_resample_downmix(os.path.join(i, "other.wav"), sample_rate) 119 | 120 | # Get accomp track 121 | audio_accomp = audio_drums + audio_bass + audio_other 122 | 123 | audio_mix = torch.clamp(audio_mix, -1.0, 1.0) 124 | audio_vocals = torch.clamp(audio_vocals, -1.0, 1.0) 125 | audio_accomp = torch.clamp(audio_accomp, -1.0, 1.0) 126 | 127 | # Saving 128 | T.save( 129 | os.path.join(output_dir, song_name, "mixture.wav"), 130 | audio_mix.cpu(), 131 | sample_rate=sample_rate, 132 | bits_per_sample=16 133 | ) 134 | T.save( 135 | os.path.join(output_dir, song_name, "vocals.wav"), 136 | audio_vocals.cpu(), 137 | sample_rate=sample_rate, 138 | bits_per_sample=16 139 | ) 140 | T.save( 141 | os.path.join(output_dir, song_name, "accompaniment.wav"), 142 | audio_accomp.cpu(), 143 | sample_rate=sample_rate, 144 | bits_per_sample=16 145 | ) 146 | 147 | else: 148 | raise ModuleNotFoundError("Need a version > 0.7.0 for torchaudio!") 149 | 150 | 151 | if __name__ == "__main__": 152 | parser = argparse.ArgumentParser() 153 | parser.add_argument("--data-dir", default=None) 154 | parser.add_argument("--output-dir", default=None) 155 | parser.add_argument("--sample-len", default=4) 156 | parser.add_argument("--sample-rate", default=22050) 157 | args = parser.parse_args() 158 | if args.train: 159 | main_train(args.data_dir, args.output_dir, args.sample_len, args.sample_rate) 160 | else: 161 | main_validation(args.data_dir, args.output_dir, args.sample_rate) -------------------------------------------------------------------------------- /separate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import glob 4 | import tqdm 5 | import json 6 | import norbert 7 | import librosa 8 | import argparse 9 | 10 | import numpy as np 11 | import soundfile as sf 12 | import tensorflow as tf 13 | 14 | from config import Config 15 | from model import DiffWave 16 | 17 | SAMPLING_RATE = 22050 18 | 19 | 20 | def get_window(signal, boundary=None): 21 | window_out = np.ones(signal.shape) 22 | midpoint = window_out.shape[0] // 2 23 | if boundary == "start": 24 | window_out[midpoint:] = np.linspace(1, 0, window_out.shape[0]-midpoint) 25 | elif boundary == "end": 26 | window_out[:midpoint] = np.linspace(0, 1, window_out.shape[0]-midpoint) 27 | else: 28 | window_out[:midpoint] = np.linspace(0, 1, window_out.shape[0]-midpoint) 29 | window_out[midpoint:] = np.linspace(1, 0, window_out.shape[0]-midpoint) 30 | return window_out 31 | 32 | def my_special_round(x, base): 33 | return math.ceil(base * round(float(x)/base)) 34 | 35 | 36 | def main(args): 37 | 38 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 39 | 40 | if not os.path.exists(os.path.join(".", "ckpt", args.model_name + ".json")): 41 | return ValueError("Please make sure the model exists and have a config file in ./ckpt") 42 | if not args.input_file: 43 | return ValueError("Please enter the input file through the --input-file argument") 44 | if not os.path.exists(args.input_file): 45 | return ValueError("Input file not found: please make sure the input file exists!") 46 | 47 | # prepare directory for samples 48 | if not os.path.exists(args.output_dir): 49 | os.makedirs(args.output_dir) 50 | 51 | with open(os.path.join(".", "ckpt", args.model_name + ".json"), "r") as f: 52 | config = Config.load(json.load(f)) 53 | 54 | # Create and initialize the model 55 | config = Config() 56 | diffwave = DiffWave(config.model) 57 | if args.ckpt is not None: 58 | diffwave.restore(args.ckpt).expect_partial() 59 | else: 60 | ckpts = glob.glob("./ckpt/" + args.model_name + "*.ckpt-1.data-00000-of-00001") 61 | ckpts = [x for x in ckpts if "BEST-MODEL" in x] 62 | latest_step = np.max([float(x.split("_")[-1].replace(".ckpt-1.data-00000-of-00001")) for x in ckpts]) 63 | ckpt_path = "./ckpt/" + args.model_name + "/" + args.model_name + "_BEST-MODEL_" + \ 64 | str(latest_step) + ".ckpt-1.data-00000-of-00001" 65 | diffwave.restore(ckpt_path).expect_partial() 66 | 67 | print("Separating: {}".format(args.input_file)) 68 | filename = args.input_file.split("/")[-1] 69 | mixture = tf.io.read_file(args.input_file) 70 | mixture, sr = tf.audio.decode_wav(mixture, desired_channels=1) 71 | 72 | if sr != SAMPLING_RATE: 73 | return ValueError("Please resample audio to {}Hz before running inference.".format(SAMPLING_RATE)) 74 | 75 | mixture = tf.squeeze(mixture, axis=-1) 76 | mixture_hopped_shape = math.ceil(mixture.shape[0] / config.data.hop) * config.data.hop 77 | output_vocals = np.zeros(mixture_hopped_shape) 78 | output_accomp = np.zeros(mixture_hopped_shape) 79 | hopsized_batch = ((int(args.batch)*SAMPLING_RATE) / 2) // config.data.hop * config.data.hop 80 | sec = math.floor(mixture_hopped_shape / hopsized_batch) 81 | 82 | for trim in tqdm.tqdm(np.arange(sec)): 83 | trim_low = int(trim*hopsized_batch) 84 | trim_high = int(trim_low + (hopsized_batch*2)) 85 | mixture_analyse = mixture[trim_low:trim_high] 86 | 87 | # Last batch (might be shorter than what expected) 88 | if mixture_analyse.shape[0] < hopsized_batch*2: 89 | padded_len = my_special_round(mixture_analyse.shape[0], base=config.data.hop) 90 | difference = int(padded_len - mixture_analyse.shape[0]) 91 | mixture_analyse = tf.concat([mixture_analyse, tf.zeros([difference])], axis=0) 92 | 93 | output_signal = diffwave(mixture_analyse[None]) 94 | pred_audio = tf.squeeze(output_signal, axis=0).numpy() 95 | 96 | mixture_analyse = mixture_analyse.numpy() 97 | pred_audio = pred_audio * (np.max(np.abs(mixture_analyse)) / np.max(np.abs(pred_audio))) 98 | pred_accomp = mixture_analyse - pred_audio 99 | 100 | if args.wiener: 101 | pred_audio = np.squeeze(pred_audio, axis=0) 102 | 103 | # Compute stft 104 | vocal_spec = np.transpose(librosa.stft(pred_audio), [1, 0]) 105 | accomp_spec = np.transpose(librosa.stft(pred_accomp), [1, 0]) 106 | 107 | # Separate mags and phases 108 | vocal_mag = np.abs(vocal_spec) 109 | vocal_phase = np.angle(vocal_spec) 110 | accomp_mag = np.abs(accomp_spec) 111 | accomp_phase = np.angle(accomp_spec) 112 | 113 | # Preparing inputs for wiener filtering 114 | mix_spec = np.transpose(librosa.stft(mixture_analyse), [1, 0]) 115 | sources = np.transpose(np.vstack([vocal_mag[None], accomp_mag[None]]), [1, 2, 0]) 116 | mix_spec = np.expand_dims(mix_spec, axis=-1) 117 | sources = np.expand_dims(sources, axis=2) 118 | 119 | # Wiener 120 | specs = norbert.wiener(sources, mix_spec) 121 | 122 | # Building output specs with filtered mags and original phases 123 | vocal_spec = np.abs(np.squeeze(specs[:, :, :, 0], axis=-1)) * np.exp(1j * vocal_phase) 124 | accomp_spec = np.abs(np.squeeze(specs[:, :, :, 1], axis=-1)) * np.exp(1j * accomp_phase) 125 | pred_audio = librosa.istft(np.transpose(vocal_spec, [1, 0])) 126 | pred_accomp = librosa.istft(np.transpose(accomp_spec, [1, 0])) 127 | pred_audio = np.squeeze(pred_audio, axis=0) 128 | pred_accomp = np.squeeze(pred_accomp, axis=0) 129 | 130 | # Get boundary 131 | boundary = None 132 | boundary = "start" if trim == 0 else None 133 | boundary = "end" if trim == sec-1 else None 134 | 135 | placehold_voc = np.zeros(output_vocals.shape) 136 | placehold_acc = np.zeros(output_accomp.shape) 137 | placehold_voc[trim_low:trim_high] = pred_audio * get_window(pred_audio, boundary=boundary) 138 | placehold_acc[trim_low:trim_high] = pred_accomp * get_window(pred_accomp, boundary=boundary) 139 | output_vocals += placehold_voc 140 | output_accomp += placehold_acc 141 | 142 | output_vocals = output_vocals[:mixture.shape[0]] 143 | output_accomp = output_accomp[:mixture.shape[0]] 144 | 145 | # Write output to file 146 | sf.write(os.path.join(args.output_dir, filename.replace(".wav", "_separated-vocals.wav"), output_vocals, SAMPLING_RATE)) 147 | sf.write(os.path.join(args.output_dir, filename.replace(".wav", "_separated-accompaniment.wav"), output_accomp, SAMPLING_RATE)) 148 | 149 | 150 | if __name__ == "__main__": 151 | parser = argparse.ArgumentParser() 152 | parser.add_argument("--input-file", default=None) 153 | parser.add_argument("--output-dir", default="./output/") 154 | parser.add_argument("--model-name", default="20-step-vocal") 155 | parser.add_argument("--ckpt", default=None) 156 | parser.add_argument("--batch", default=20) 157 | parser.add_argument("--wiener", default=False) 158 | parser.add_argument("--gpu", default=-1) 159 | args = parser.parse_args() 160 | main(args) 161 | -------------------------------------------------------------------------------- /separate_musdb_track.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file performs separation of a musdb file, assuming the structure of the dataset is correct. The user inputs the mixture that is going to 3 | be separated, and the corresponding singing voice and accompaniment are used to provide, in addition to the separated sources, the SDR metric 4 | that the model has achieved on the selected track. 5 | """ 6 | 7 | import os 8 | import math 9 | import glob 10 | import tqdm 11 | import json 12 | import norbert 13 | import librosa 14 | import argparse 15 | 16 | import numpy as np 17 | import soundfile as sf 18 | import tensorflow as tf 19 | 20 | from config import Config 21 | from model import DiffWave 22 | 23 | SAMPLING_RATE = 22050 24 | 25 | 26 | def get_window(signal, boundary=None): 27 | window_out = np.ones(signal.shape) 28 | midpoint = window_out.shape[0] // 2 29 | if boundary == "start": 30 | window_out[midpoint:] = np.linspace(1, 0, window_out.shape[0]-midpoint) 31 | elif boundary == "end": 32 | window_out[:midpoint] = np.linspace(0, 1, window_out.shape[0]-midpoint) 33 | else: 34 | window_out[:midpoint] = np.linspace(0, 1, window_out.shape[0]-midpoint) 35 | window_out[midpoint:] = np.linspace(1, 0, window_out.shape[0]-midpoint) 36 | return window_out 37 | 38 | def my_special_round(x, base): 39 | return math.ceil(base * round(float(x)/base)) 40 | 41 | def GlobalSDR(references, separations): 42 | """ Global SDR: main (or standard) metric from SiSEC 2021 and MDX""" 43 | delta = 1e-7 # avoid numerical errors 44 | num = np.sum(np.square(references), axis=(1, 2)) 45 | den = np.sum(np.square(references - separations), axis=(1, 2)) 46 | num += delta 47 | den += delta 48 | return 10 * np.log10(num / den) 49 | 50 | 51 | def main(args): 52 | 53 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 54 | 55 | if not os.path.exists(os.path.join(".", "ckpt", args.model_name + ".json")): 56 | return ValueError("Please make sure the model exists and have a config file in ./ckpt") 57 | if not args.input_file: 58 | return ValueError("Please enter the input file through the --input-file argument") 59 | if not os.path.exists(args.input_file): 60 | return ValueError("Input file not found: please make sure the input file exists!") 61 | 62 | # prepare directory for samples 63 | if not os.path.exists(args.output_dir): 64 | os.makedirs(args.output_dir) 65 | 66 | with open(os.path.join(".", "ckpt", args.model_name + ".json"), "r") as f: 67 | config = Config.load(json.load(f)) 68 | 69 | # Create and initialize the model 70 | config = Config() 71 | diffwave = DiffWave(config.model) 72 | if args.ckpt is not None: 73 | diffwave.restore(args.ckpt).expect_partial() 74 | else: 75 | ckpts = glob.glob(os.path.join(".", "ckpt", args.model_name + "*.ckpt-1.data-00000-of-00001")) 76 | ckpts = [x for x in ckpts if "BEST-MODEL" in x] 77 | latest_step = np.max([float(x.split("_")[-1].replace(".ckpt-1.data-00000-of-00001")) for x in ckpts]) 78 | ckpt_path = os.path.join(".", "ckpt", args.model_name, \ 79 | args.model_name + "_BEST-MODEL_" + str(latest_step) + ".ckpt-1.data-00000-of-00001") 80 | diffwave.restore(ckpt_path).expect_partial() 81 | 82 | print("Separating: {}".format(args.input_file)) 83 | filename = args.input_file.split("/")[-1] 84 | mixture = tf.io.read_file(args.input_file) 85 | mixture, sr = tf.audio.decode_wav(mixture, desired_channels=1) 86 | vocals = tf.io.read_file(args.input_file.replace("mixture.wav", "vocals.wav")) 87 | 88 | # Check if accompaniment is available, otherwise we create it 89 | if os.path.exists(args.input_file.replace("mixture.wav", "accompaniment.wav")): 90 | accomp = args.input_file.replace("mixture.wav", "accompaniment.wav") 91 | accomp, sr = tf.audio.decode_wav(accomp, desired_channels=1) 92 | else: 93 | bass = tf.audio.decode_wav(tf.io.read_file(args.input_file.replace("mixture.wav", "bass.wav"))) 94 | drums = tf.audio.decode_wav(tf.io.read_file(args.input_file.replace("mixture.wav", "drums.wav"))) 95 | other = tf.audio.decode_wav(tf.io.read_file(args.input_file.replace("mixture.wav", "other.wav"))) 96 | accomp = bass + drums + other 97 | 98 | vocals, sr = tf.audio.decode_wav(vocals, desired_channels=1) 99 | 100 | if sr != SAMPLING_RATE: 101 | return ValueError("Please resample MUSDB audio to {}Hz before running inference.".format(SAMPLING_RATE)) 102 | 103 | mixture = tf.squeeze(mixture, axis=-1) 104 | mixture_hopped_shape = math.ceil(mixture.shape[0] / config.data.hop) * config.data.hop 105 | output_vocals = np.zeros(mixture_hopped_shape) 106 | output_accomp = np.zeros(mixture_hopped_shape) 107 | hopsized_batch = ((int(args.batch)*SAMPLING_RATE) / 2) // config.data.hop * config.data.hop 108 | sec = math.floor(mixture_hopped_shape / hopsized_batch) 109 | 110 | for trim in tqdm.tqdm(np.arange(sec)): 111 | trim_low = int(trim*hopsized_batch) 112 | trim_high = int(trim_low + (hopsized_batch*2)) 113 | mixture_analyse = mixture[trim_low:trim_high] 114 | vocals_analyse = vocals[trim_low:trim_high] 115 | 116 | # Last batch (might be shorter than hopsized batch sized) 117 | if mixture_analyse.shape[0] < hopsized_batch*2: 118 | padded_len = my_special_round(mixture_analyse.shape[0], base=config.data.hop) 119 | difference = int(padded_len - mixture_analyse.shape[0]) 120 | mixture_analyse = tf.concat([mixture_analyse, tf.zeros([difference])], axis=0) 121 | 122 | output_signal = diffwave(mixture_analyse[None]) 123 | pred_audio = tf.squeeze(output_signal, axis=0).numpy() 124 | 125 | mixture_analyse = mixture_analyse.numpy() 126 | pred_audio = pred_audio * (np.max(np.abs(vocals_analyse)) / np.max(np.abs(pred_audio))) 127 | pred_accomp = mixture_analyse - pred_audio 128 | 129 | if args.wiener: 130 | pred_audio = np.squeeze(pred_audio, axis=0) 131 | 132 | # Compute stft 133 | vocal_spec = np.transpose(librosa.stft(pred_audio), [1, 0]) 134 | accomp_spec = np.transpose(librosa.stft(pred_accomp), [1, 0]) 135 | 136 | # Separate mags and phases 137 | vocal_mag = np.abs(vocal_spec) 138 | vocal_phase = np.angle(vocal_spec) 139 | accomp_mag = np.abs(accomp_spec) 140 | accomp_phase = np.angle(accomp_spec) 141 | 142 | # Preparing inputs for wiener filtering 143 | mix_spec = np.transpose(librosa.stft(mixture_analyse), [1, 0]) 144 | sources = np.transpose(np.vstack([vocal_mag[None], accomp_mag[None]]), [1, 2, 0]) 145 | mix_spec = np.expand_dims(mix_spec, axis=-1) 146 | sources = np.expand_dims(sources, axis=2) 147 | 148 | # Wiener 149 | specs = norbert.wiener(sources, mix_spec) 150 | 151 | # Building output specs with filtered mags and original phases 152 | vocal_spec = np.abs(np.squeeze(specs[:, :, :, 0], axis=-1)) * np.exp(1j * vocal_phase) 153 | accomp_spec = np.abs(np.squeeze(specs[:, :, :, 1], axis=-1)) * np.exp(1j * accomp_phase) 154 | pred_audio = librosa.istft(np.transpose(vocal_spec, [1, 0])) 155 | pred_accomp = librosa.istft(np.transpose(accomp_spec, [1, 0])) 156 | pred_audio = np.squeeze(pred_audio, axis=0) 157 | pred_accomp = np.squeeze(pred_accomp, axis=0) 158 | 159 | # Get boundary 160 | boundary = None 161 | boundary = "start" if trim == 0 else None 162 | boundary = "end" if trim == sec-1 else None 163 | 164 | placehold_voc = np.zeros(output_vocals.shape) 165 | placehold_acc = np.zeros(output_accomp.shape) 166 | placehold_voc[trim_low:trim_high] = pred_audio * get_window(pred_audio, boundary=boundary) 167 | placehold_acc[trim_low:trim_high] = pred_accomp * get_window(pred_accomp, boundary=boundary) 168 | output_vocals += placehold_voc 169 | output_accomp += placehold_acc 170 | 171 | output_vocals = output_vocals[:mixture.shape[0]] 172 | output_accomp = output_accomp[:mixture.shape[0]] 173 | 174 | scores = GlobalSDR(np.array([vocals, accomp]), np.array([output_vocals, output_accomp])[..., None]) 175 | print("VOCALS ==> SDR:", scores[0]) 176 | print("ACCOMP ==> SDR:", scores[1]) 177 | 178 | # Write output to file 179 | sf.write(os.path.join(args.output_dir, filename.replace(".wav", "_separated-vocals.wav"), output_vocals, SAMPLING_RATE)) 180 | sf.write(os.path.join(args.output_dir, filename.replace(".wav", "_separated-accompaniment.wav"), output_accomp, SAMPLING_RATE)) 181 | 182 | 183 | if __name__ == "__main__": 184 | parser = argparse.ArgumentParser() 185 | parser.add_argument("--input-file", default=None) 186 | parser.add_argument("--output-dir", default="./output/") 187 | parser.add_argument("--model-name", default="20-step-vocal") 188 | parser.add_argument("--ckpt", default=None) 189 | parser.add_argument("--batch", default=20) 190 | parser.add_argument("--wiener", default=False) 191 | parser.add_argument("--gpu", default=-1) 192 | args = parser.parse_args() 193 | main(args) 194 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import mir_eval 3 | import argparse 4 | import json 5 | import math 6 | import tqdm 7 | 8 | import numpy as np 9 | import tensorflow as tf 10 | 11 | #from augmentation_utils import time_augment, pitch_augment 12 | 13 | from config import Config 14 | from dataset import MUSDB 15 | from model import DiffWave 16 | 17 | ###os.environ["CUDA_VISIBLE_DEVICES"] = "1" # Manually add your GPU id here! 18 | 19 | class Trainer: 20 | """WaveGrad trainer. 21 | """ 22 | def __init__(self, model, dataset, config): 23 | """Initializer. 24 | Args: 25 | model: DiffWave, diffwave model. 26 | dataset: Dataset, input dataset to train the diffusion model 27 | which provides already batched and normalized speech dataset. 28 | config: Config, unified configurations. 29 | """ 30 | self.model = model 31 | self.dataset = dataset 32 | self.config = config 33 | 34 | self.split = config.train.split // config.data.batch 35 | self.trainset = self.dataset.dataset().take(self.split) \ 36 | .shuffle(config.train.bufsiz) \ 37 | .prefetch(tf.data.experimental.AUTOTUNE) 38 | self.testset = self.dataset.test_dataset() \ 39 | .prefetch(tf.data.experimental.AUTOTUNE) 40 | 41 | self.optim = tf.keras.optimizers.Adam( 42 | config.train.lr(), 43 | config.train.beta1, 44 | config.train.beta2, 45 | config.train.eps) 46 | 47 | self.ckpt_intval = config.train.ckpt_intval // config.data.batch 48 | 49 | self.train_log = tf.summary.create_file_writer( 50 | os.path.join(config.train.log, config.train.name, "train")) 51 | self.test_log = tf.summary.create_file_writer( 52 | os.path.join(config.train.log, config.train.name, "test")) 53 | 54 | self.ckpt_path = os.path.join( 55 | config.train.ckpt, config.train.name, config.train.name) 56 | 57 | self.alpha = 1 - config.model.beta() 58 | self.alpha_bar = np.cumprod(self.alpha) 59 | 60 | def compute_loss(self, mixture, vocals, accomp, target="vocals"): 61 | """Compute loss for noise estimation. 62 | Args: 63 | mixture: tf.Tensor, [B, T], raw audio signal mixture. 64 | vocals: tf.Tensor, [B, T], raw audio signal vocals. 65 | accomp: tf.Tensor, [B, T], raw audio signal accompaniment. 66 | target: str, indicating for which target the model is trained. 67 | Returns: 68 | loss: tf.Tensor, [], L1-loss between noise and estimation. 69 | """ 70 | bsize = tf.shape(vocals)[0] 71 | # [B] 72 | timesteps = tf.random.uniform( 73 | [bsize], 1, self.config.model.iter + 1, dtype=tf.int32) 74 | # [B] 75 | noise_level = tf.gather(self.alpha_bar, timesteps - 1) 76 | # [B, T], [B, T] 77 | if target == "vocals": 78 | noised, noise = self.model.diffusion(mixture, vocals, accomp, noise_level) 79 | else: 80 | noised, noise = self.model.diffusion(mixture, accomp, vocals, noise_level) 81 | # [B, T] 82 | eps = self.model.pred_noise(noised, timesteps) 83 | # [] 84 | loss = tf.reduce_mean(tf.abs(eps - noise)) 85 | return loss 86 | 87 | def train(self, step=0): 88 | """Train wavegrad. 89 | Args: 90 | step: int, starting step. 91 | ir_unit: int, log ir units. 92 | """ 93 | best_SDR = 0 94 | best_step = 0 95 | 96 | # Start training 97 | print("\n \n") ## Just to separate bars from Tensorflow-related warnings 98 | pbar_gen = tqdm.trange(step // self.split, self.config.train.epoch) 99 | for _ in pbar_gen: 100 | pbar_gen.set_description("General training process") 101 | train_loss = [] 102 | with tqdm.tqdm(total=self.split, leave=False) as pbar: 103 | pbar.set_description("Training epoch ({} steps)".format(self.split)) 104 | for mixture, vocal, accomp in self.trainset: 105 | 106 | # APPLY DATA AUGMENTATION HERE IF DESIRED 107 | #mixture, vocal, accomp = self.pitch_augment(mixture, vocal, accomp) 108 | #mixture, vocal, accomp = self.time_augment(mixture, vocal, accomp) 109 | 110 | with tf.GradientTape() as tape: 111 | tape.watch(self.model.trainable_variables) 112 | loss = self.compute_loss( 113 | mixture, 114 | vocal, 115 | accomp, 116 | target=self.config.train.target) 117 | train_loss.append(loss) 118 | 119 | grad = tape.gradient(loss, self.model.trainable_variables) 120 | self.optim.apply_gradients( 121 | zip(grad, self.model.trainable_variables)) 122 | 123 | norm = tf.reduce_mean([tf.norm(g) for g in grad]) 124 | del grad 125 | 126 | step += 1 127 | pbar.update() 128 | pbar.set_postfix( 129 | {"loss": loss.numpy().item(), 130 | "step": step, 131 | "grad": norm.numpy().item()}) 132 | 133 | if step % self.ckpt_intval == 0: 134 | self.model.write( 135 | "{}_{}.ckpt".format(self.ckpt_path, step), 136 | self.optim) 137 | 138 | train_loss = sum(train_loss) / len(train_loss) 139 | validation_loss = [] 140 | for mixture, vocal, accomp in self.testset: 141 | actual_loss = self.compute_loss( 142 | mixture, 143 | vocal, 144 | accomp, 145 | target=self.config.train.target).numpy().item() 146 | validation_loss.append(actual_loss) 147 | 148 | del vocal, accomp 149 | validation_loss = sum(validation_loss) / len(validation_loss) 150 | 151 | with self.test_log.as_default(): 152 | if step > 150000: 153 | best_SDR, best_step = self.eval( 154 | best_SDR, 155 | best_step, 156 | step, 157 | target=self.config.train.target) 158 | 159 | print("==> Current train loss: {}, and validation loss: {}".format(train_loss, validation_loss)) 160 | del train_loss, validation_loss 161 | 162 | 163 | def eval(self, best_SDR, best_step, step, target="vocals"): 164 | """Generate evaluation purpose audio. 165 | Returns: 166 | speech: np.ndarray, [T], ground truth. 167 | pred: np.ndarray, [T], predicted. 168 | ir: List[np.ndarray], config.model.iter x [T], 169 | intermediate representations. 170 | """ 171 | # [T] 172 | sdr_target = [] 173 | pbar_val = tqdm.tqdm(self.dataset.validation()) 174 | for mixture, vocals, accomp in pbar_val: 175 | pbar_val.set_description("Validating model") 176 | # Prepare data for eval 177 | hop = self.config.data.hop 178 | nearest_hop = hop * math.floor(mixture.shape[1]/hop) 179 | mixture_analyze = mixture[:, :nearest_hop] 180 | vocals_analyze = vocals[:, :nearest_hop] 181 | accomp_analyze = accomp[:, :nearest_hop] 182 | 183 | if target == "vocals": 184 | gt_target = vocals_analyze 185 | gt_rest = accomp_analyze 186 | else: 187 | gt_target = accomp_analyze 188 | gt_rest = vocals_analyze 189 | 190 | # Check vocal track is not silent 191 | if tf.reduce_max(gt_target).numpy() != 0.0: 192 | gt_target = tf.squeeze(gt_target, axis=0).numpy() 193 | gt_rest = tf.squeeze(gt_rest, axis=0).numpy() 194 | 195 | # Predict 196 | pred_target = self.model(mixture_analyze) 197 | 198 | # Get accompaniment by substraction 199 | mixture_analyze = tf.squeeze(mixture_analyze, axis=0).numpy() 200 | pred_target = tf.squeeze(pred_target, axis=0).numpy() 201 | pred_target = pred_target * (np.max(np.abs(gt_target)) / np.max(np.abs(pred_target))) 202 | pred_rest = mixture_analyze - pred_target 203 | 204 | # Evaluate 205 | ref = np.array([gt_target, gt_rest]) 206 | est = np.array([pred_target, pred_rest]) 207 | sdr, _, _, _, _ = mir_eval.separation.bss_eval_images( 208 | ref, est, compute_permutation=False) 209 | sdr_target.append(sdr[0]) 210 | 211 | # Updating best new model taking SDR into account 212 | if np.median(sdr_target) > best_SDR: 213 | print("Saving best new model with SDR: {}".format(str(np.median(sdr_target)))) 214 | self.model.write("{}_BEST-MODEL_{}.ckpt".format(self.ckpt_path, str(step)), self.optim) 215 | best_SDR = np.median(sdr_target) 216 | best_step = step 217 | else: 218 | print("Current best model: {} from step {}".format(str(best_SDR), str(best_step))) 219 | print("The median SDR of this evaluation step is: {}".format(np.median(sdr_target))) 220 | return best_SDR, best_step 221 | 222 | if __name__ == "__main__": 223 | parser = argparse.ArgumentParser() 224 | parser.add_argument("--config", default=None) 225 | parser.add_argument("--load-step", default=0, type=int) 226 | parser.add_argument("--data-dir", default=None) 227 | args = parser.parse_args() 228 | 229 | config = Config() 230 | if args.config is not None: 231 | print("[*] load config: " + args.config) 232 | with open(args.config) as f: 233 | config = Config.load(json.load(f)) 234 | 235 | log_path = os.path.join(config.train.log, config.train.name) 236 | if not os.path.exists(log_path): 237 | os.makedirs(log_path) 238 | 239 | ckpt_path = os.path.join(config.train.ckpt, config.train.name) 240 | if not os.path.exists(ckpt_path): 241 | os.makedirs(ckpt_path) 242 | 243 | dataset = MUSDB(config.data, data_dir=args.data_dir) 244 | diffwave = DiffWave(config.model) 245 | trainer = Trainer(diffwave, dataset, config) 246 | 247 | if args.load_step > 0: 248 | super_path = os.path.join(config.train.ckpt, config.train.name) 249 | ckpt_path = "{}_{}.ckpt".format(config.train.name, args.load_step) 250 | ckpt_path = next( 251 | name for name in os.listdir(super_path) 252 | if name.startswith(ckpt_path) and name.endswith(".index")) 253 | ckpt_path = os.path.join(super_path, ckpt_path[:-6]) 254 | 255 | print("[*] load checkpoint: " + ckpt_path) 256 | trainer.model.restore(ckpt_path, trainer.optim) 257 | 258 | with open(os.path.join(config.train.ckpt, config.train.name + ".json"), "w") as f: 259 | json.dump(config.dump(), f) 260 | 261 | trainer.train(args.load_step) 262 | --------------------------------------------------------------------------------