├── .gitignore ├── LICENSE ├── README.md ├── appendix └── plot_log.py ├── augment.py ├── inference.py ├── lib ├── __init__.py ├── dataset.py ├── layers.py ├── nets.py ├── spec_utils.py └── utils.py ├── models └── .gitkeep ├── pseudo.py ├── requirements.txt └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.npy 3 | *.npz 4 | *.pth 5 | *.json 6 | *.log 7 | 8 | *.jpg 9 | *.png 10 | 11 | *.wav 12 | *.m4a 13 | *.mp3 14 | *.mp4 15 | *.flac 16 | 17 | *~ 18 | .vscode/ 19 | venv/ 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 tsurumeso 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # vocal-remover 2 | 3 | [![Release](https://img.shields.io/github/release/tsurumeso/vocal-remover.svg)](https://github.com/tsurumeso/vocal-remover/releases/latest) 4 | [![Release](https://img.shields.io/github/downloads/tsurumeso/vocal-remover/total.svg)](https://github.com/tsurumeso/vocal-remover/releases) 5 | 6 | This is a deep-learning-based tool to extract instrumental track from your songs. 7 | 8 | ## Installation 9 | 10 | ### Getting vocal-remover 11 | Download the latest version from [here](https://github.com/tsurumeso/vocal-remover/releases). 12 | 13 | ### Install PyTorch 14 | **See**: [GET STARTED](https://pytorch.org/get-started/locally/) 15 | 16 | ### Install the other packages 17 | ``` 18 | cd vocal-remover 19 | pip install -r requirements.txt 20 | ``` 21 | 22 | ## Usage 23 | The following command separates the input into instrumental and vocal tracks. They are saved as `*_Instruments.wav` and `*_Vocals.wav`. 24 | 25 | ### Run on CPU 26 | ``` 27 | python inference.py --input path/to/an/audio/file 28 | ``` 29 | 30 | ### Run on GPU 31 | ``` 32 | python inference.py --input path/to/an/audio/file --gpu 0 33 | ``` 34 | 35 | ### Advanced options 36 | `--tta` option performs Test-Time-Augmentation to improve the separation quality. 37 | ``` 38 | python inference.py --input path/to/an/audio/file --tta --gpu 0 39 | ``` 40 | 41 | `--postprocess` option masks instrumental part based on the vocals volume to improve the separation quality. 42 | > [!WARNING] 43 | > This is an experimental feature. If you get any problems with this option, please disable it. 44 | ``` 45 | python inference.py --input path/to/an/audio/file --postprocess --gpu 0 46 | ``` 47 | 48 | ## Train your own model 49 | 50 | ### Place your dataset 51 | ``` 52 | path/to/dataset/ 53 | +- instruments/ 54 | | +- 01_foo_inst.wav 55 | | +- 02_bar_inst.mp3 56 | | +- ... 57 | +- mixtures/ 58 | +- 01_foo_mix.wav 59 | +- 02_bar_mix.mp3 60 | +- ... 61 | ``` 62 | 63 | ### Train a model 64 | ``` 65 | python train.py --dataset path/to/dataset --mixup_rate 0.5 --reduction_rate 0.5 --gpu 0 66 | ``` 67 | 68 | ## References 69 | - [1] Jansson et al., "Singing Voice Separation with Deep U-Net Convolutional Networks", https://ejhumphrey.com/assets/pdf/jansson2017singing.pdf 70 | - [2] Takahashi et al., "Multi-scale Multi-band DenseNets for Audio Source Separation", https://arxiv.org/pdf/1706.09588.pdf 71 | - [3] Takahashi et al., "MMDENSELSTM: AN EFFICIENT COMBINATION OF CONVOLUTIONAL AND RECURRENT NEURAL NETWORKS FOR AUDIO SOURCE SEPARATION", https://arxiv.org/pdf/1805.02410.pdf 72 | - [4] Choi et al., "PHASE-AWARE SPEECH ENHANCEMENT WITH DEEP COMPLEX U-NET", https://openreview.net/pdf?id=SkeRTsAcYm 73 | - [5] Jansson et al., "Learned complex masks for multi-instrument source separation", https://arxiv.org/pdf/2103.12864.pdf 74 | - [6] Liutkus et al., "The 2016 Signal Separation Evaluation Campaign", Latent Variable Analysis and Signal Separation - 12th International Conference 75 | -------------------------------------------------------------------------------- /appendix/plot_log.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | 7 | 8 | if __name__ == '__main__': 9 | with open(sys.argv[1], 'r', encoding='utf8') as f: 10 | log = np.asarray(json.load(f)) 11 | print(np.min(log, axis=0)) 12 | trn_loss = log[:, 0] 13 | val_loss = log[:, 1] 14 | 15 | plt.rcParams['font.size'] = 12 16 | plt.rcParams['legend.fontsize'] = 12 17 | 18 | x_val = np.arange(len(val_loss)) 19 | plt.plot(x_val, val_loss, label='validation loss', c='r') 20 | 21 | x_trn = np.arange(len(trn_loss)) 22 | plt.plot(x_trn, trn_loss, label='training loss', c='b') 23 | 24 | plt.grid(which='both', color='gray', linestyle='--') 25 | plt.xlabel('Epoch') 26 | plt.ylabel('Loss') 27 | plt.legend(edgecolor='white') 28 | plt.show() 29 | -------------------------------------------------------------------------------- /augment.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import subprocess 4 | 5 | import librosa 6 | import numpy as np 7 | import soundfile as sf 8 | from tqdm import tqdm 9 | 10 | from lib import dataset 11 | from lib import spec_utils 12 | 13 | 14 | if __name__ == '__main__': 15 | p = argparse.ArgumentParser() 16 | p.add_argument('--sr', '-r', type=int, default=44100) 17 | p.add_argument('--hop_length', '-l', type=int, default=1024) 18 | p.add_argument('--n_fft', '-f', type=int, default=2048) 19 | p.add_argument('--pitch', '-p', type=int, default=-1) 20 | p.add_argument('--mixtures', '-m', required=True) 21 | p.add_argument('--instruments', '-i', required=True) 22 | args = p.parse_args() 23 | 24 | input_i = 'input_i_{}.wav'.format(args.pitch) 25 | input_v = 'input_v_{}.wav'.format(args.pitch) 26 | output_i = 'output_i_{}.wav'.format(args.pitch) 27 | output_v = 'output_v_{}.wav'.format(args.pitch) 28 | cmd_i = 'soundstretch {} {} -pitch={}'.format(input_i, output_i, args.pitch) 29 | cmd_v = 'soundstretch {} {} -pitch={}'.format(input_v, output_v, args.pitch) 30 | cache_suffix = '_pitch{}.npy'.format(args.pitch) 31 | 32 | cache_dir = 'sr{}_hl{}_nf{}'.format(args.sr, args. hop_length, args.n_fft) 33 | mix_cache_dir = os.path.join(args.mixtures, cache_dir) 34 | inst_cache_dir = os.path.join(args.instruments, cache_dir) 35 | os.makedirs(mix_cache_dir, exist_ok=True) 36 | os.makedirs(inst_cache_dir, exist_ok=True) 37 | 38 | filelist = dataset.make_pair(args.mixtures, args.instruments) 39 | for mix_path, inst_path in tqdm(filelist): 40 | mix_basename = os.path.splitext(os.path.basename(mix_path))[0] 41 | mix_cache_path = os.path.join(mix_cache_dir, mix_basename + cache_suffix) 42 | 43 | inst_basename = os.path.splitext(os.path.basename(inst_path))[0] 44 | inst_cache_path = os.path.join(inst_cache_dir, inst_basename + cache_suffix) 45 | 46 | if os.path.exists(mix_cache_path) and os.path.exists(inst_cache_path): 47 | continue 48 | 49 | X, _ = librosa.load( 50 | mix_path, sr=args.sr, mono=False, dtype=np.float32, res_type='kaiser_fast') 51 | y, _ = librosa.load( 52 | inst_path, sr=args.sr, mono=False, dtype=np.float32, res_type='kaiser_fast') 53 | 54 | X, y = spec_utils.align_wave_head_and_tail(X, y, args.sr) 55 | v = X - y 56 | 57 | sf.write(input_i, y.T, args.sr) 58 | sf.write(input_v, v.T, args.sr) 59 | subprocess.call(cmd_i, stderr=subprocess.DEVNULL) 60 | subprocess.call(cmd_v, stderr=subprocess.DEVNULL) 61 | 62 | y, _ = librosa.load( 63 | output_i, sr=args.sr, mono=False, dtype=np.float32, res_type='kaiser_fast') 64 | v, _ = librosa.load( 65 | output_v, sr=args.sr, mono=False, dtype=np.float32, res_type='kaiser_fast') 66 | 67 | X = y + v 68 | 69 | spec = spec_utils.wave_to_spectrogram(X, args.hop_length, args.n_fft) 70 | np.save(mix_cache_path, spec) 71 | 72 | spec = spec_utils.wave_to_spectrogram(y, args.hop_length, args.n_fft) 73 | np.save(inst_cache_path, spec) 74 | 75 | os.remove(input_i) 76 | os.remove(input_v) 77 | os.remove(output_i) 78 | os.remove(output_v) 79 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import librosa 5 | import numpy as np 6 | import soundfile as sf 7 | import torch 8 | from tqdm import tqdm 9 | 10 | from lib import dataset 11 | from lib import nets 12 | from lib import spec_utils 13 | from lib import utils 14 | 15 | 16 | class Separator(object): 17 | 18 | def __init__(self, model, device=None, batchsize=1, cropsize=256, postprocess=False): 19 | self.model = model 20 | self.offset = model.offset 21 | self.device = device 22 | self.batchsize = batchsize 23 | self.cropsize = cropsize 24 | self.postprocess = postprocess 25 | 26 | def _postprocess(self, X_spec, mask): 27 | if self.postprocess: 28 | mask_mag = np.abs(mask) 29 | mask_mag = spec_utils.merge_artifacts(mask_mag) 30 | mask = mask_mag * np.exp(1.j * np.angle(mask)) 31 | 32 | X_mag = np.abs(X_spec) 33 | X_phase = np.angle(X_spec) 34 | 35 | y_spec = mask * X_mag * np.exp(1.j * X_phase) 36 | v_spec = (1 - mask) * X_mag * np.exp(1.j * X_phase) 37 | # y_spec = X_spec * mask 38 | # v_spec = X_spec - y_spec 39 | 40 | return y_spec, v_spec 41 | 42 | def _separate(self, X_spec_pad, roi_size): 43 | X_dataset = [] 44 | patches = (X_spec_pad.shape[2] - 2 * self.offset) // roi_size 45 | for i in range(patches): 46 | start = i * roi_size 47 | X_spec_crop = X_spec_pad[:, :, start:start + self.cropsize] 48 | X_dataset.append(X_spec_crop) 49 | 50 | X_dataset = np.asarray(X_dataset) 51 | 52 | self.model.eval() 53 | with torch.no_grad(): 54 | mask_list = [] 55 | # To reduce the overhead, dataloader is not used. 56 | for i in tqdm(range(0, patches, self.batchsize)): 57 | X_batch = X_dataset[i: i + self.batchsize] 58 | X_batch = torch.from_numpy(X_batch).to(self.device) 59 | 60 | mask = self.model.predict_mask(torch.abs(X_batch)) 61 | 62 | mask = mask.detach().cpu().numpy() 63 | mask = np.concatenate(mask, axis=2) 64 | mask_list.append(mask) 65 | 66 | mask = np.concatenate(mask_list, axis=2) 67 | 68 | return mask 69 | 70 | def separate(self, X_spec): 71 | n_frame = X_spec.shape[2] 72 | pad_l, pad_r, roi_size = dataset.make_padding(n_frame, self.cropsize, self.offset) 73 | X_spec_pad = np.pad(X_spec, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant') 74 | X_spec_pad /= np.abs(X_spec).max() 75 | 76 | mask = self._separate(X_spec_pad, roi_size) 77 | mask = mask[:, :, :n_frame] 78 | 79 | y_spec, v_spec = self._postprocess(X_spec, mask) 80 | 81 | return y_spec, v_spec 82 | 83 | def separate_tta(self, X_spec): 84 | n_frame = X_spec.shape[2] 85 | pad_l, pad_r, roi_size = dataset.make_padding(n_frame, self.cropsize, self.offset) 86 | X_spec_pad = np.pad(X_spec, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant') 87 | X_spec_pad /= X_spec_pad.max() 88 | 89 | mask = self._separate(X_spec_pad, roi_size) 90 | 91 | pad_l += roi_size // 2 92 | pad_r += roi_size // 2 93 | X_spec_pad = np.pad(X_spec, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant') 94 | X_spec_pad /= X_spec_pad.max() 95 | 96 | mask_tta = self._separate(X_spec_pad, roi_size) 97 | mask_tta = mask_tta[:, :, roi_size // 2:] 98 | mask = (mask[:, :, :n_frame] + mask_tta[:, :, :n_frame]) * 0.5 99 | 100 | y_spec, v_spec = self._postprocess(X_spec, mask) 101 | 102 | return y_spec, v_spec 103 | 104 | MODEL_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models') 105 | DEFAULT_MODEL_PATH = os.path.join(MODEL_DIR, 'baseline.pth') 106 | 107 | def main(): 108 | p = argparse.ArgumentParser() 109 | p.add_argument('--gpu', '-g', type=int, default=-1) 110 | p.add_argument('--pretrained_model', '-P', type=str, default=DEFAULT_MODEL_PATH) 111 | p.add_argument('--input', '-i', required=True) 112 | p.add_argument('--sr', '-r', type=int, default=44100) 113 | p.add_argument('--n_fft', '-f', type=int, default=2048) 114 | p.add_argument('--hop_length', '-H', type=int, default=1024) 115 | p.add_argument('--batchsize', '-B', type=int, default=4) 116 | p.add_argument('--cropsize', '-c', type=int, default=256) 117 | p.add_argument('--output_image', '-I', action='store_true') 118 | p.add_argument('--tta', '-t', action='store_true') 119 | p.add_argument('--postprocess', '-p', action='store_true') 120 | p.add_argument('--output_dir', '-o', type=str, default="") 121 | args = p.parse_args() 122 | 123 | print('loading model...', end=' ') 124 | device = torch.device('cpu') 125 | if args.gpu >= 0: 126 | if torch.cuda.is_available(): 127 | device = torch.device('cuda:{}'.format(args.gpu)) 128 | elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): 129 | device = torch.device('mps') 130 | model = nets.CascadedNet(args.n_fft, args.hop_length, 32, 128) 131 | model.load_state_dict(torch.load(args.pretrained_model, map_location='cpu')) 132 | model.to(device) 133 | print('done') 134 | 135 | print('loading wave source...', end=' ') 136 | X, sr = librosa.load( 137 | args.input, sr=args.sr, mono=False, dtype=np.float32, res_type='kaiser_fast' 138 | ) 139 | basename = os.path.splitext(os.path.basename(args.input))[0] 140 | print('done') 141 | 142 | if X.ndim == 1: 143 | # mono to stereo 144 | X = np.asarray([X, X]) 145 | 146 | print('stft of wave source...', end=' ') 147 | X_spec = spec_utils.wave_to_spectrogram(X, args.hop_length, args.n_fft) 148 | print('done') 149 | 150 | sp = Separator( 151 | model=model, 152 | device=device, 153 | batchsize=args.batchsize, 154 | cropsize=args.cropsize, 155 | postprocess=args.postprocess 156 | ) 157 | 158 | if args.tta: 159 | y_spec, v_spec = sp.separate_tta(X_spec) 160 | else: 161 | y_spec, v_spec = sp.separate(X_spec) 162 | 163 | print('validating output directory...', end=' ') 164 | output_dir = args.output_dir 165 | if output_dir != "": # modifies output_dir if theres an arg specified 166 | output_dir = output_dir.rstrip('/') + '/' 167 | os.makedirs(output_dir, exist_ok=True) 168 | print('done') 169 | 170 | print('inverse stft of instruments...', end=' ') 171 | wave = spec_utils.spectrogram_to_wave(y_spec, hop_length=args.hop_length) 172 | print('done') 173 | sf.write('{}{}_Instruments.wav'.format(output_dir, basename), wave.T, sr) 174 | 175 | print('inverse stft of vocals...', end=' ') 176 | wave = spec_utils.spectrogram_to_wave(v_spec, hop_length=args.hop_length) 177 | print('done') 178 | sf.write('{}{}_Vocals.wav'.format(output_dir, basename), wave.T, sr) 179 | 180 | if args.output_image: 181 | image = spec_utils.spectrogram_to_image(y_spec) 182 | utils.imwrite('{}{}_Instruments.jpg'.format(output_dir, basename), image) 183 | 184 | image = spec_utils.spectrogram_to_image(v_spec) 185 | utils.imwrite('{}{}_Vocals.jpg'.format(output_dir, basename), image) 186 | 187 | 188 | if __name__ == '__main__': 189 | main() 190 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsurumeso/vocal-remover/99f92fe4b6bfe37bf4ff5bf4110ce224007312e5/lib/__init__.py -------------------------------------------------------------------------------- /lib/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import numpy as np 5 | import torch 6 | import torch.utils.data 7 | from tqdm import tqdm 8 | 9 | try: 10 | from lib import spec_utils 11 | except ModuleNotFoundError: 12 | import spec_utils 13 | 14 | 15 | class VocalRemoverTrainingSet(torch.utils.data.Dataset): 16 | 17 | def __init__(self, training_set, cropsize, reduction_rate, reduction_weight, mixup_rate, mixup_alpha): 18 | self.training_set = training_set 19 | self.cropsize = cropsize 20 | self.reduction_rate = reduction_rate 21 | self.reduction_weight = reduction_weight 22 | self.mixup_rate = mixup_rate 23 | self.mixup_alpha = mixup_alpha 24 | 25 | def __len__(self): 26 | return len(self.training_set) 27 | 28 | def read_npy_shape(self, path): 29 | with open(path, 'rb') as fhandle: 30 | _, _ = np.lib.format.read_magic(fhandle) 31 | shape, _, _ = np.lib.format.read_array_header_1_0(fhandle) 32 | return shape 33 | 34 | def read_npy_chunk(self, path, start_row): 35 | with open(path, 'rb') as fhandle: 36 | _, _ = np.lib.format.read_magic(fhandle) 37 | shape, fortran, dtype = np.lib.format.read_array_header_1_0(fhandle) 38 | 39 | assert not fortran, 'Fortran order arrays are not supported' 40 | 41 | row_size = np.prod(shape[1:]) 42 | start_byte = start_row * row_size * dtype.itemsize 43 | fhandle.seek(start_byte, 1) 44 | n_items = row_size * self.cropsize 45 | flat = np.fromfile(fhandle, count=n_items, dtype=dtype) 46 | 47 | return flat.reshape((-1,) + shape[1:]) 48 | 49 | def aggressively_remove_vocal(self, X, y): 50 | X_mag = np.abs(X) 51 | y_mag = np.abs(y) 52 | v_mag = X_mag - y_mag 53 | v_mag *= v_mag > y_mag 54 | 55 | y_mag = np.clip(y_mag - v_mag * self.reduction_weight, 0, np.inf) 56 | 57 | return y_mag * np.exp(1.j * np.angle(y)) 58 | 59 | def do_crop(self, X_path, y_path): 60 | shape = self.read_npy_shape(X_path) 61 | start_row = np.random.randint(0, shape[0] - self.cropsize) 62 | 63 | X_crop = self.read_npy_chunk(X_path, start_row).transpose(1, 2, 0) 64 | y_crop = self.read_npy_chunk(y_path, start_row).transpose(1, 2, 0) 65 | 66 | return X_crop, y_crop 67 | 68 | def do_aug(self, X, y): 69 | if np.random.uniform() < self.reduction_rate: 70 | y = self.aggressively_remove_vocal(X, y) 71 | 72 | if np.random.uniform() < 0.5: 73 | # swap channel 74 | X = X[::-1].copy() 75 | y = y[::-1].copy() 76 | 77 | if np.random.uniform() < 0.01: 78 | # inst 79 | X = y.copy() 80 | 81 | # if np.random.uniform() < 0.01: 82 | # # mono 83 | # X[:] = X.mean(axis=0, keepdims=True) 84 | # y[:] = y.mean(axis=0, keepdims=True) 85 | 86 | return X, y 87 | 88 | def do_mixup(self, X, y): 89 | idx = np.random.randint(0, len(self)) 90 | X_path, y_path, coef = self.training_set[idx] 91 | 92 | X_i, y_i = self.do_crop(X_path, y_path) 93 | X_i /= coef 94 | y_i /= coef 95 | 96 | X_i, y_i = self.do_aug(X_i, y_i) 97 | 98 | lam = np.random.beta(self.mixup_alpha, self.mixup_alpha) 99 | X = lam * X + (1 - lam) * X_i 100 | y = lam * y + (1 - lam) * y_i 101 | 102 | return X, y 103 | 104 | def __getitem__(self, idx): 105 | X_path, y_path, coef = self.training_set[idx] 106 | 107 | X, y = self.do_crop(X_path, y_path) 108 | X /= coef 109 | y /= coef 110 | 111 | X, y = self.do_aug(X, y) 112 | 113 | if np.random.uniform() < self.mixup_rate: 114 | X, y = self.do_mixup(X, y) 115 | 116 | X_mag = np.abs(X) 117 | y_mag = np.abs(y) 118 | 119 | return X_mag, y_mag 120 | # return X, y 121 | 122 | 123 | class VocalRemoverValidationSet(torch.utils.data.Dataset): 124 | 125 | def __init__(self, patch_list): 126 | self.patch_list = patch_list 127 | 128 | def __len__(self): 129 | return len(self.patch_list) 130 | 131 | def __getitem__(self, idx): 132 | path = self.patch_list[idx] 133 | data = np.load(path) 134 | 135 | X, y = data['X'], data['y'] 136 | 137 | X_mag = np.abs(X) 138 | y_mag = np.abs(y) 139 | 140 | return X_mag, y_mag 141 | # return X, y 142 | 143 | 144 | def make_pair(mix_dir, inst_dir): 145 | input_exts = ['.wav', '.m4a', '.mp3', '.mp4', '.flac'] 146 | 147 | X_list = sorted([ 148 | os.path.join(mix_dir, fname) 149 | for fname in os.listdir(mix_dir) 150 | if os.path.splitext(fname)[1] in input_exts 151 | ]) 152 | y_list = sorted([ 153 | os.path.join(inst_dir, fname) 154 | for fname in os.listdir(inst_dir) 155 | if os.path.splitext(fname)[1] in input_exts 156 | ]) 157 | 158 | filelist = list(zip(X_list, y_list)) 159 | 160 | return filelist 161 | 162 | 163 | def train_val_split(dataset_dir, split_mode, val_rate, val_filelist): 164 | if split_mode == 'random': 165 | filelist = make_pair( 166 | os.path.join(dataset_dir, 'mixtures'), 167 | os.path.join(dataset_dir, 'instruments') 168 | ) 169 | 170 | random.shuffle(filelist) 171 | 172 | if len(val_filelist) == 0: 173 | val_size = int(len(filelist) * val_rate) 174 | train_filelist = filelist[:-val_size] 175 | val_filelist = filelist[-val_size:] 176 | else: 177 | train_filelist = [ 178 | pair for pair in filelist 179 | if list(pair) not in val_filelist 180 | ] 181 | elif split_mode == 'subdirs': 182 | if len(val_filelist) != 0: 183 | raise ValueError('`val_filelist` option is not available with `subdirs` mode') 184 | 185 | train_filelist = make_pair( 186 | os.path.join(dataset_dir, 'training/mixtures'), 187 | os.path.join(dataset_dir, 'training/instruments') 188 | ) 189 | 190 | val_filelist = make_pair( 191 | os.path.join(dataset_dir, 'validation/mixtures'), 192 | os.path.join(dataset_dir, 'validation/instruments') 193 | ) 194 | 195 | return train_filelist, val_filelist 196 | 197 | 198 | def make_padding(width, cropsize, offset): 199 | left = offset 200 | roi_size = cropsize - offset * 2 201 | if roi_size == 0: 202 | roi_size = cropsize 203 | right = roi_size - (width % roi_size) + left 204 | 205 | return left, right, roi_size 206 | 207 | 208 | def make_training_set(filelist, sr, hop_length, n_fft): 209 | ret = [] 210 | for X_path, y_path in tqdm(filelist): 211 | X, y, X_cache_path, y_cache_path = spec_utils.cache_or_load( 212 | X_path, y_path, sr, hop_length, n_fft 213 | ) 214 | coef = np.max([np.abs(X).max(), np.abs(y).max()]) 215 | ret.append([X_cache_path, y_cache_path, coef]) 216 | 217 | return ret 218 | 219 | 220 | def make_validation_set(filelist, cropsize, sr, hop_length, n_fft, offset): 221 | patch_list = [] 222 | patch_dir = 'cs{}_sr{}_hl{}_nf{}_of{}'.format(cropsize, sr, hop_length, n_fft, offset) 223 | os.makedirs(patch_dir, exist_ok=True) 224 | 225 | for X_path, y_path in tqdm(filelist): 226 | basename = os.path.splitext(os.path.basename(X_path))[0] 227 | 228 | X, y, _, _ = spec_utils.cache_or_load(X_path, y_path, sr, hop_length, n_fft) 229 | coef = np.max([np.abs(X).max(), np.abs(y).max()]) 230 | X, y = X / coef, y / coef 231 | 232 | l, r, roi_size = make_padding(X.shape[2], cropsize, offset) 233 | X_pad = np.pad(X, ((0, 0), (0, 0), (l, r)), mode='constant') 234 | y_pad = np.pad(y, ((0, 0), (0, 0), (l, r)), mode='constant') 235 | 236 | len_dataset = int(np.ceil(X.shape[2] / roi_size)) 237 | for j in range(len_dataset): 238 | outpath = os.path.join(patch_dir, '{}_p{}.npz'.format(basename, j)) 239 | start = j * roi_size 240 | if not os.path.exists(outpath): 241 | np.savez( 242 | outpath, 243 | X=X_pad[:, :, start:start + cropsize], 244 | y=y_pad[:, :, start:start + cropsize] 245 | ) 246 | patch_list.append(outpath) 247 | 248 | return patch_list 249 | 250 | 251 | def get_oracle_data(X, y, oracle_loss, oracle_rate, oracle_drop_rate): 252 | k = int(len(X) * oracle_rate * (1 / (1 - oracle_drop_rate))) 253 | n = int(len(X) * oracle_rate) 254 | indices = np.argsort(oracle_loss)[::-1][:k] 255 | indices = np.random.choice(indices, n, replace=False) 256 | oracle_X = X[indices].copy() 257 | oracle_y = y[indices].copy() 258 | 259 | return oracle_X, oracle_y, indices 260 | 261 | 262 | if __name__ == "__main__": 263 | import sys 264 | import utils 265 | 266 | mix_dir = sys.argv[1] 267 | inst_dir = sys.argv[2] 268 | outdir = sys.argv[3] 269 | 270 | os.makedirs(outdir, exist_ok=True) 271 | 272 | filelist = make_pair(mix_dir, inst_dir) 273 | for mix_path, inst_path in tqdm(filelist): 274 | mix_basename = os.path.splitext(os.path.basename(mix_path))[0] 275 | 276 | X_spec, y_spec, _, _ = spec_utils.cache_or_load( 277 | mix_path, inst_path, 44100, 1024, 2048 278 | ) 279 | 280 | X_mag = np.abs(X_spec) 281 | y_mag = np.abs(y_spec) 282 | v_mag = X_mag - y_mag 283 | v_mag *= v_mag > y_mag 284 | 285 | outpath = '{}/{}_Vocal.jpg'.format(outdir, mix_basename) 286 | v_image = spec_utils.spectrogram_to_image(v_mag) 287 | utils.imwrite(outpath, v_image) 288 | -------------------------------------------------------------------------------- /lib/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | from lib import spec_utils 6 | 7 | 8 | class Conv2DBNActiv(nn.Module): 9 | 10 | def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU): 11 | super(Conv2DBNActiv, self).__init__() 12 | self.conv = nn.Sequential( 13 | nn.Conv2d( 14 | nin, nout, 15 | kernel_size=ksize, 16 | stride=stride, 17 | padding=pad, 18 | dilation=dilation, 19 | bias=False 20 | ), 21 | nn.BatchNorm2d(nout), 22 | activ() 23 | ) 24 | 25 | def __call__(self, x): 26 | return self.conv(x) 27 | 28 | 29 | class Encoder(nn.Module): 30 | 31 | def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU): 32 | super(Encoder, self).__init__() 33 | self.conv1 = Conv2DBNActiv(nin, nout, ksize, stride, pad, activ=activ) 34 | self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ) 35 | 36 | def __call__(self, x): 37 | h = self.conv1(x) 38 | h = self.conv2(h) 39 | 40 | return h 41 | 42 | 43 | class Decoder(nn.Module): 44 | 45 | def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False): 46 | super(Decoder, self).__init__() 47 | self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ) 48 | # self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ) 49 | self.dropout = nn.Dropout2d(0.1) if dropout else None 50 | 51 | def __call__(self, x, skip=None): 52 | x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) 53 | 54 | if skip is not None: 55 | skip = spec_utils.crop_center(skip, x) 56 | x = torch.cat([x, skip], dim=1) 57 | 58 | h = self.conv1(x) 59 | # h = self.conv2(h) 60 | 61 | if self.dropout is not None: 62 | h = self.dropout(h) 63 | 64 | return h 65 | 66 | 67 | class ASPPModule(nn.Module): 68 | 69 | def __init__(self, nin, nout, dilations=(4, 8, 12), activ=nn.ReLU, dropout=False): 70 | super(ASPPModule, self).__init__() 71 | self.conv1 = nn.Sequential( 72 | nn.AdaptiveAvgPool2d((1, None)), 73 | Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ) 74 | ) 75 | self.conv2 = Conv2DBNActiv( 76 | nin, nout, 1, 1, 0, activ=activ 77 | ) 78 | self.conv3 = Conv2DBNActiv( 79 | nin, nout, 3, 1, dilations[0], dilations[0], activ=activ 80 | ) 81 | self.conv4 = Conv2DBNActiv( 82 | nin, nout, 3, 1, dilations[1], dilations[1], activ=activ 83 | ) 84 | self.conv5 = Conv2DBNActiv( 85 | nin, nout, 3, 1, dilations[2], dilations[2], activ=activ 86 | ) 87 | self.bottleneck = Conv2DBNActiv( 88 | nout * 5, nout, 1, 1, 0, activ=activ 89 | ) 90 | self.dropout = nn.Dropout2d(0.1) if dropout else None 91 | 92 | def forward(self, x): 93 | _, _, h, w = x.size() 94 | feat1 = F.interpolate(self.conv1(x), size=(h, w), mode='bilinear', align_corners=True) 95 | feat2 = self.conv2(x) 96 | feat3 = self.conv3(x) 97 | feat4 = self.conv4(x) 98 | feat5 = self.conv5(x) 99 | out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1) 100 | out = self.bottleneck(out) 101 | 102 | if self.dropout is not None: 103 | out = self.dropout(out) 104 | 105 | return out 106 | 107 | 108 | class LSTMModule(nn.Module): 109 | 110 | def __init__(self, nin_conv, nin_lstm, nout_lstm): 111 | super(LSTMModule, self).__init__() 112 | self.conv = Conv2DBNActiv(nin_conv, 1, 1, 1, 0) 113 | self.lstm = nn.LSTM( 114 | input_size=nin_lstm, 115 | hidden_size=nout_lstm // 2, 116 | bidirectional=True 117 | ) 118 | self.dense = nn.Sequential( 119 | nn.Linear(nout_lstm, nin_lstm), 120 | nn.BatchNorm1d(nin_lstm), 121 | nn.ReLU() 122 | ) 123 | 124 | def forward(self, x): 125 | N, _, nbins, nframes = x.size() 126 | h = self.conv(x)[:, 0] # N, nbins, nframes 127 | h = h.permute(2, 0, 1) # nframes, N, nbins 128 | h, _ = self.lstm(h) 129 | h = self.dense(h.reshape(-1, h.size()[-1])) # nframes * N, nbins 130 | h = h.reshape(nframes, N, 1, nbins) 131 | h = h.permute(1, 2, 3, 0) 132 | 133 | return h 134 | -------------------------------------------------------------------------------- /lib/nets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | from lib import layers 6 | 7 | 8 | class BaseNet(nn.Module): 9 | 10 | def __init__(self, nin, nout, nin_lstm, nout_lstm, dilations=((4, 2), (8, 4), (12, 6))): 11 | super(BaseNet, self).__init__() 12 | self.enc1 = layers.Conv2DBNActiv(nin, nout, 3, 1, 1) 13 | self.enc2 = layers.Encoder(nout, nout * 2, 3, 2, 1) 14 | self.enc3 = layers.Encoder(nout * 2, nout * 4, 3, 2, 1) 15 | self.enc4 = layers.Encoder(nout * 4, nout * 6, 3, 2, 1) 16 | self.enc5 = layers.Encoder(nout * 6, nout * 8, 3, 2, 1) 17 | 18 | self.aspp = layers.ASPPModule(nout * 8, nout * 8, dilations, dropout=True) 19 | 20 | self.dec4 = layers.Decoder(nout * (6 + 8), nout * 6, 3, 1, 1) 21 | self.dec3 = layers.Decoder(nout * (4 + 6), nout * 4, 3, 1, 1) 22 | self.dec2 = layers.Decoder(nout * (2 + 4), nout * 2, 3, 1, 1) 23 | self.lstm_dec2 = layers.LSTMModule(nout * 2, nin_lstm, nout_lstm) 24 | self.dec1 = layers.Decoder(nout * (1 + 2) + 1, nout * 1, 3, 1, 1) 25 | 26 | def __call__(self, x): 27 | e1 = self.enc1(x) 28 | e2 = self.enc2(e1) 29 | e3 = self.enc3(e2) 30 | e4 = self.enc4(e3) 31 | e5 = self.enc5(e4) 32 | 33 | h = self.aspp(e5) 34 | 35 | h = self.dec4(h, e4) 36 | h = self.dec3(h, e3) 37 | h = self.dec2(h, e2) 38 | h = torch.cat([h, self.lstm_dec2(h)], dim=1) 39 | h = self.dec1(h, e1) 40 | 41 | return h 42 | 43 | 44 | class CascadedNet(nn.Module): 45 | 46 | def __init__(self, n_fft, hop_length, nout=32, nout_lstm=128, is_complex=False): 47 | super(CascadedNet, self).__init__() 48 | self.n_fft = n_fft 49 | self.hop_length = hop_length 50 | self.is_complex = is_complex 51 | 52 | self.max_bin = n_fft // 2 53 | self.output_bin = n_fft // 2 + 1 54 | self.nin_lstm = self.max_bin // 2 55 | self.offset = 64 56 | 57 | nin = 4 if is_complex else 2 58 | 59 | self.stg1_low_band_net = nn.Sequential( 60 | BaseNet(nin, nout // 2, self.nin_lstm // 2, nout_lstm), 61 | layers.Conv2DBNActiv(nout // 2, nout // 4, 1, 1, 0) 62 | ) 63 | self.stg1_high_band_net = BaseNet( 64 | nin, nout // 4, self.nin_lstm // 2, nout_lstm // 2 65 | ) 66 | 67 | self.stg2_low_band_net = nn.Sequential( 68 | BaseNet(nout // 4 + nin, nout, self.nin_lstm // 2, nout_lstm), 69 | layers.Conv2DBNActiv(nout, nout // 2, 1, 1, 0) 70 | ) 71 | self.stg2_high_band_net = BaseNet( 72 | nout // 4 + nin, nout // 2, self.nin_lstm // 2, nout_lstm // 2 73 | ) 74 | 75 | self.stg3_full_band_net = BaseNet( 76 | 3 * nout // 4 + nin, nout, self.nin_lstm, nout_lstm 77 | ) 78 | 79 | self.out = nn.Conv2d(nout, nin, 1, bias=False) 80 | self.aux_out = nn.Conv2d(3 * nout // 4, nin, 1, bias=False) 81 | 82 | def forward(self, x): 83 | if self.is_complex: 84 | x = torch.cat([x.real, x.imag], dim=1) 85 | 86 | x = x[:, :, :self.max_bin] 87 | 88 | bandw = x.size()[2] // 2 89 | l1_in = x[:, :, :bandw] 90 | h1_in = x[:, :, bandw:] 91 | l1 = self.stg1_low_band_net(l1_in) 92 | h1 = self.stg1_high_band_net(h1_in) 93 | aux1 = torch.cat([l1, h1], dim=2) 94 | 95 | l2_in = torch.cat([l1_in, l1], dim=1) 96 | h2_in = torch.cat([h1_in, h1], dim=1) 97 | l2 = self.stg2_low_band_net(l2_in) 98 | h2 = self.stg2_high_band_net(h2_in) 99 | aux2 = torch.cat([l2, h2], dim=2) 100 | 101 | f3_in = torch.cat([x, aux1, aux2], dim=1) 102 | f3 = self.stg3_full_band_net(f3_in) 103 | 104 | if self.is_complex: 105 | mask = self.out(f3) 106 | mask = torch.complex(mask[:, :2], mask[:, 2:]) 107 | mask = self.bounded_mask(mask) 108 | else: 109 | mask = torch.sigmoid(self.out(f3)) 110 | 111 | mask = F.pad( 112 | input=mask, 113 | pad=(0, 0, 0, self.output_bin - mask.size()[2]), 114 | mode='replicate' 115 | ) 116 | 117 | return mask 118 | 119 | def bounded_mask(self, mask, eps=1e-8): 120 | mask_mag = torch.abs(mask) 121 | mask = torch.tanh(mask_mag) * mask / (mask_mag + eps) 122 | return mask 123 | 124 | def predict_mask(self, x): 125 | mask = self.forward(x) 126 | 127 | if self.offset > 0: 128 | mask = mask[:, :, :, self.offset:-self.offset] 129 | assert mask.size()[3] > 0 130 | 131 | return mask 132 | 133 | def predict(self, x): 134 | mask = self.forward(x) 135 | pred = x * mask 136 | 137 | if self.offset > 0: 138 | pred = pred[:, :, :, self.offset:-self.offset] 139 | assert pred.size()[3] > 0 140 | 141 | return pred 142 | -------------------------------------------------------------------------------- /lib/spec_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import librosa 4 | import numpy as np 5 | import soundfile as sf 6 | 7 | 8 | def crop_center(h1, h2): 9 | h1_shape = h1.size() 10 | h2_shape = h2.size() 11 | 12 | if h1_shape[3] == h2_shape[3]: 13 | return h1 14 | elif h1_shape[3] < h2_shape[3]: 15 | raise ValueError('h1_shape[3] must be greater than h2_shape[3]') 16 | 17 | # s_freq = (h2_shape[2] - h1_shape[2]) // 2 18 | # e_freq = s_freq + h1_shape[2] 19 | s_time = (h1_shape[3] - h2_shape[3]) // 2 20 | e_time = s_time + h2_shape[3] 21 | h1 = h1[:, :, :, s_time:e_time] 22 | 23 | return h1 24 | 25 | 26 | def wave_to_spectrogram(wave, hop_length, n_fft): 27 | spec_left = librosa.stft(wave[0], n_fft=n_fft, hop_length=hop_length) 28 | spec_right = librosa.stft(wave[1], n_fft=n_fft, hop_length=hop_length) 29 | spec = np.asarray([spec_left, spec_right]) 30 | 31 | return spec 32 | 33 | 34 | def spectrogram_to_image(spec, mode='magnitude'): 35 | if mode == 'magnitude': 36 | if np.iscomplexobj(spec): 37 | y = np.abs(spec) 38 | else: 39 | y = spec 40 | y = np.log10(y ** 2 + 1e-8) 41 | elif mode == 'phase': 42 | if np.iscomplexobj(spec): 43 | y = np.angle(spec) 44 | else: 45 | y = spec 46 | 47 | y -= y.min() 48 | y *= 255 / y.max() 49 | img = np.uint8(y) 50 | 51 | if y.ndim == 3: 52 | img = img.transpose(1, 2, 0) 53 | img = np.concatenate([ 54 | np.max(img, axis=2, keepdims=True), img 55 | ], axis=2) 56 | 57 | return img 58 | 59 | 60 | def merge_artifacts(y_mask, thres=0.05, min_range=64, fade_size=32): 61 | if min_range < fade_size * 2: 62 | raise ValueError('min_range must be >= fade_size * 2') 63 | 64 | idx = np.where(y_mask.min(axis=(0, 1)) > thres)[0] 65 | start_idx = np.insert(idx[np.where(np.diff(idx) != 1)[0] + 1], 0, idx[0]) 66 | end_idx = np.append(idx[np.where(np.diff(idx) != 1)[0]], idx[-1]) 67 | artifact_idx = np.where(end_idx - start_idx > min_range)[0] 68 | weight = np.zeros_like(y_mask) 69 | if len(artifact_idx) > 0: 70 | start_idx = start_idx[artifact_idx] 71 | end_idx = end_idx[artifact_idx] 72 | old_e = None 73 | for s, e in zip(start_idx, end_idx): 74 | if old_e is not None and s - old_e < fade_size: 75 | s = old_e - fade_size * 2 76 | 77 | if s != 0: 78 | weight[:, :, s:s + fade_size] = np.linspace(0, 1, fade_size) 79 | else: 80 | s -= fade_size 81 | 82 | if e != y_mask.shape[2]: 83 | weight[:, :, e - fade_size:e] = np.linspace(1, 0, fade_size) 84 | else: 85 | e += fade_size 86 | 87 | weight[:, :, s + fade_size:e - fade_size] = 1 88 | old_e = e 89 | 90 | v_mask = 1 - y_mask 91 | y_mask += weight * v_mask 92 | 93 | return y_mask 94 | 95 | 96 | def align_wave_head_and_tail(a, b, sr): 97 | a, _ = librosa.effects.trim(a) 98 | b, _ = librosa.effects.trim(b) 99 | 100 | a_mono = a[:, :sr * 4].sum(axis=0) 101 | b_mono = b[:, :sr * 4].sum(axis=0) 102 | 103 | a_mono -= a_mono.mean() 104 | b_mono -= b_mono.mean() 105 | 106 | offset = len(a_mono) - 1 107 | delay = np.argmax(np.correlate(a_mono, b_mono, 'full')) - offset 108 | 109 | if delay > 0: 110 | a = a[:, delay:] 111 | else: 112 | b = b[:, np.abs(delay):] 113 | 114 | if a.shape[1] < b.shape[1]: 115 | b = b[:, :a.shape[1]] 116 | else: 117 | a = a[:, :b.shape[1]] 118 | 119 | return a, b 120 | 121 | 122 | def cache_or_load(mix_path, inst_path, sr, hop_length, n_fft): 123 | mix_basename = os.path.splitext(os.path.basename(mix_path))[0] 124 | inst_basename = os.path.splitext(os.path.basename(inst_path))[0] 125 | 126 | cache_dir = 'sr{}_hl{}_nf{}'.format(sr, hop_length, n_fft) 127 | mix_cache_dir = os.path.join(os.path.dirname(mix_path), cache_dir) 128 | inst_cache_dir = os.path.join(os.path.dirname(inst_path), cache_dir) 129 | os.makedirs(mix_cache_dir, exist_ok=True) 130 | os.makedirs(inst_cache_dir, exist_ok=True) 131 | 132 | mix_cache_path = os.path.join(mix_cache_dir, mix_basename + '.npy') 133 | inst_cache_path = os.path.join(inst_cache_dir, inst_basename + '.npy') 134 | 135 | if os.path.exists(mix_cache_path) and os.path.exists(inst_cache_path): 136 | X = np.load(mix_cache_path).transpose(1, 2, 0) 137 | y = np.load(inst_cache_path).transpose(1, 2, 0) 138 | else: 139 | X, _ = librosa.load( 140 | mix_path, sr=sr, mono=False, dtype=np.float32, res_type='kaiser_fast') 141 | y, _ = librosa.load( 142 | inst_path, sr=sr, mono=False, dtype=np.float32, res_type='kaiser_fast') 143 | 144 | X, y = align_wave_head_and_tail(X, y, sr) 145 | 146 | X = wave_to_spectrogram(X, hop_length, n_fft) 147 | y = wave_to_spectrogram(y, hop_length, n_fft) 148 | 149 | np.save(mix_cache_path, X.transpose(2, 0, 1)) 150 | np.save(inst_cache_path, y.transpose(2, 0, 1)) 151 | 152 | assert X.shape == y.shape 153 | 154 | return X, y, mix_cache_path, inst_cache_path 155 | 156 | 157 | def spectrogram_to_wave(spec, hop_length=1024): 158 | if spec.ndim == 2: 159 | wave = librosa.istft(spec, hop_length=hop_length) 160 | elif spec.ndim == 3: 161 | wave_left = librosa.istft(spec[0], hop_length=hop_length) 162 | wave_right = librosa.istft(spec[1], hop_length=hop_length) 163 | wave = np.asarray([wave_left, wave_right]) 164 | 165 | return wave 166 | 167 | 168 | if __name__ == "__main__": 169 | import cv2 170 | import sys 171 | 172 | X, _ = librosa.load( 173 | sys.argv[1], sr=44100, mono=False, dtype=np.float32, res_type='kaiser_fast' 174 | ) 175 | y, _ = librosa.load( 176 | sys.argv[2], sr=44100, mono=False, dtype=np.float32, res_type='kaiser_fast' 177 | ) 178 | 179 | X, y = align_wave_head_and_tail(X, y, 44100) 180 | X_spec = wave_to_spectrogram(X, 1024, 2048) 181 | y_spec = wave_to_spectrogram(y, 1024, 2048) 182 | 183 | # X_spec = np.load(sys.argv[1]).transpose(1, 2, 0) 184 | # y_spec = np.load(sys.argv[2]).transpose(1, 2, 0) 185 | 186 | v_spec = X_spec - y_spec 187 | 188 | X_image = spectrogram_to_image(X_spec) 189 | y_image = spectrogram_to_image(y_spec) 190 | v_image = spectrogram_to_image(v_spec) 191 | 192 | cv2.imwrite('test_X.jpg', X_image) 193 | cv2.imwrite('test_y.jpg', y_image) 194 | cv2.imwrite('test_v.jpg', v_image) 195 | 196 | sf.write('test_X.wav', spectrogram_to_wave(X_spec).T, 44100) 197 | sf.write('test_y.wav', spectrogram_to_wave(y_spec).T, 44100) 198 | sf.write('test_v.wav', spectrogram_to_wave(v_spec).T, 44100) 199 | -------------------------------------------------------------------------------- /lib/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import numpy as np 5 | 6 | 7 | def imread(filename, flags=cv2.IMREAD_COLOR, dtype=np.uint8): 8 | try: 9 | n = np.fromfile(filename, dtype) 10 | img = cv2.imdecode(n, flags) 11 | return img 12 | except Exception as e: 13 | print(e) 14 | return None 15 | 16 | 17 | def imwrite(filename, img, params=None): 18 | try: 19 | ext = os.path.splitext(filename)[1] 20 | result, n = cv2.imencode(ext, img, params) 21 | 22 | if result: 23 | with open(filename, mode='w+b') as f: 24 | n.tofile(f) 25 | return True 26 | else: 27 | return False 28 | except Exception as e: 29 | print(e) 30 | return False 31 | -------------------------------------------------------------------------------- /models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsurumeso/vocal-remover/99f92fe4b6bfe37bf4ff5bf4110ce224007312e5/models/.gitkeep -------------------------------------------------------------------------------- /pseudo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import librosa 5 | import numpy as np 6 | import soundfile as sf 7 | import torch 8 | 9 | from lib import dataset 10 | from lib import nets 11 | from lib import spec_utils 12 | 13 | import inference 14 | 15 | 16 | def main(): 17 | p = argparse.ArgumentParser() 18 | p.add_argument('--gpu', '-g', type=int, default=-1) 19 | p.add_argument('--pretrained_model', '-P', type=str, default='models/baseline.pth') 20 | p.add_argument('--mixtures', '-m', required=True) 21 | p.add_argument('--instruments', '-i', required=True) 22 | p.add_argument('--sr', '-r', type=int, default=44100) 23 | p.add_argument('--n_fft', '-f', type=int, default=2048) 24 | p.add_argument('--hop_length', '-H', type=int, default=1024) 25 | p.add_argument('--batchsize', '-B', type=int, default=4) 26 | p.add_argument('--cropsize', '-c', type=int, default=256) 27 | p.add_argument('--postprocess', '-p', action='store_true') 28 | args = p.parse_args() 29 | 30 | print('loading model...', end=' ') 31 | device = torch.device('cpu') 32 | model = nets.CascadedNet(args.n_fft, args.hop_length) 33 | model.load_state_dict(torch.load(args.pretrained_model, map_location=device)) 34 | if torch.cuda.is_available() and args.gpu >= 0: 35 | device = torch.device('cuda:{}'.format(args.gpu)) 36 | model.to(device) 37 | print('done') 38 | 39 | filelist = dataset.make_pair(args.mixtures, args.instruments) 40 | for mix_path, inst_path in filelist: 41 | # if '_mixture' in mix_path and '_inst' in inst_path: 42 | # continue 43 | # else: 44 | # pass 45 | 46 | basename = os.path.splitext(os.path.basename(mix_path))[0] 47 | print(basename) 48 | 49 | print('loading wave source...', end=' ') 50 | X, sr = librosa.load( 51 | mix_path, sr=args.sr, mono=False, dtype=np.float32, res_type='kaiser_fast') 52 | y, sr = librosa.load( 53 | inst_path, sr=args.sr, mono=False, dtype=np.float32, res_type='kaiser_fast') 54 | print('done') 55 | 56 | if X.ndim == 1: 57 | # mono to stereo 58 | X = np.asarray([X, X]) 59 | 60 | print('stft of wave source...', end=' ') 61 | X, y = spec_utils.align_wave_head_and_tail(X, y, sr) 62 | X = spec_utils.wave_to_spectrogram(X, args.hop_length, args.n_fft) 63 | y = spec_utils.wave_to_spectrogram(y, args.hop_length, args.n_fft) 64 | print('done') 65 | 66 | sp = inference.Separator(model, device, args.batchsize, args.cropsize, args.postprocess) 67 | a_spec, _ = sp.separate_tta(X - y) 68 | 69 | print('inverse stft of pseudo instruments...', end=' ') 70 | pseudo_inst = y + a_spec 71 | print('done') 72 | 73 | sf.write('pseudo/{}_PseudoInstruments.wav'.format(basename), [0], sr) 74 | np.save('pseudo/{}_PseudoInstruments.npy'.format(basename), pseudo_inst) 75 | 76 | 77 | if __name__ == '__main__': 78 | main() 79 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # install from https://pytorch.org/get-started/locally/ 2 | # torch~=2.1.0 3 | # torchvision~=0.16.0 4 | librosa~=0.10.0 5 | matplotlib~=3.8.0 6 | opencv_python~=4.8.0 7 | resampy~=0.4.0 8 | tqdm~=4.66.0 9 | numpy~=1.26.4 10 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from datetime import datetime 3 | import json 4 | import logging 5 | import os 6 | import random 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.utils.data 12 | 13 | from lib import dataset 14 | from lib import nets 15 | from lib import spec_utils 16 | 17 | 18 | def setup_logger(name, logfile='LOGFILENAME.log'): 19 | logger = logging.getLogger(name) 20 | logger.setLevel(logging.DEBUG) 21 | logger.propagate = False 22 | 23 | fh = logging.FileHandler(logfile, encoding='utf8') 24 | fh.setLevel(logging.DEBUG) 25 | fh_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 26 | fh.setFormatter(fh_formatter) 27 | 28 | sh = logging.StreamHandler() 29 | sh.setLevel(logging.INFO) 30 | 31 | logger.addHandler(fh) 32 | logger.addHandler(sh) 33 | 34 | return logger 35 | 36 | 37 | def to_wave(spec, n_fft, hop_length, window): 38 | B, _, N, T = spec.shape 39 | wave = spec.reshape(-1, N, T) 40 | wave = torch.istft(wave, n_fft, hop_length, window=window) 41 | wave = wave.reshape(B, 2, -1) 42 | 43 | return wave 44 | 45 | 46 | def sdr_loss(y, y_pred, eps=1e-8): 47 | sdr = (y * y_pred).sum() 48 | sdr /= torch.linalg.norm(y) * torch.linalg.norm(y_pred) + eps 49 | 50 | return -sdr 51 | 52 | 53 | def weighted_sdr_loss(y, y_pred, n, n_pred, eps=1e-8): 54 | y_sdr = (y * y_pred).sum() 55 | y_sdr /= torch.linalg.norm(y) * torch.linalg.norm(y_pred) + eps 56 | 57 | noise_sdr = (n * n_pred).sum() 58 | noise_sdr /= torch.linalg.norm(n) * torch.linalg.norm(n_pred) + eps 59 | 60 | a = torch.sum(y ** 2) 61 | a /= torch.sum(y ** 2) + torch.sum(n ** 2) + eps 62 | 63 | loss = a * y_sdr + (1 - a) * noise_sdr 64 | 65 | return -loss 66 | 67 | 68 | def train_epoch(dataloader, model, device, optimizer, accumulation_steps): 69 | model.train() 70 | # n_fft = model.n_fft 71 | # hop_length = model.hop_length 72 | # window = torch.hann_window(n_fft).to(device) 73 | 74 | sum_loss = 0 75 | crit_l1 = nn.L1Loss() 76 | 77 | for itr, (X_batch, y_batch) in enumerate(dataloader): 78 | X_batch = X_batch.to(device) 79 | y_batch = y_batch.to(device) 80 | 81 | mask = model(X_batch) 82 | 83 | # y_pred = X_batch * mask 84 | # y_wave_batch = to_wave(y_batch, n_fft, hop_length, window) 85 | # y_wave_pred = to_wave(y_pred, n_fft, hop_length, window) 86 | 87 | # loss = crit_l1(torch.abs(y_batch), torch.abs(y_pred)) 88 | # loss += sdr_loss(y_wave_batch, y_wave_pred) * 0.01 89 | loss = crit_l1(mask * X_batch, y_batch) 90 | 91 | accum_loss = loss / accumulation_steps 92 | accum_loss.backward() 93 | 94 | if (itr + 1) % accumulation_steps == 0: 95 | optimizer.step() 96 | model.zero_grad() 97 | 98 | sum_loss += loss.item() * len(X_batch) 99 | 100 | # the rest batch 101 | if (itr + 1) % accumulation_steps != 0: 102 | optimizer.step() 103 | model.zero_grad() 104 | 105 | return sum_loss / len(dataloader.dataset) 106 | 107 | 108 | def validate_epoch(dataloader, model, device): 109 | model.eval() 110 | # n_fft = model.n_fft 111 | # hop_length = model.hop_length 112 | # window = torch.hann_window(n_fft).to(device) 113 | 114 | sum_loss = 0 115 | crit_l1 = nn.L1Loss() 116 | 117 | with torch.no_grad(): 118 | for X_batch, y_batch in dataloader: 119 | X_batch = X_batch.to(device) 120 | y_batch = y_batch.to(device) 121 | 122 | y_pred = model.predict(X_batch) 123 | 124 | y_batch = spec_utils.crop_center(y_batch, y_pred) 125 | # y_wave_batch = to_wave(y_batch, n_fft, hop_length, window) 126 | # y_wave_pred = to_wave(y_pred, n_fft, hop_length, window) 127 | 128 | # loss = crit_l1(torch.abs(y_batch), torch.abs(y_pred)) 129 | # loss += sdr_loss(y_wave_batch, y_wave_pred) * 0.01 130 | loss = crit_l1(y_pred, y_batch) 131 | 132 | sum_loss += loss.item() * len(X_batch) 133 | 134 | return sum_loss / len(dataloader.dataset) 135 | 136 | 137 | def main(): 138 | p = argparse.ArgumentParser() 139 | p.add_argument('--gpu', '-g', type=int, default=-1) 140 | p.add_argument('--seed', '-s', type=int, default=2019) 141 | p.add_argument('--sr', '-r', type=int, default=44100) 142 | p.add_argument('--hop_length', '-H', type=int, default=1024) 143 | p.add_argument('--n_fft', '-f', type=int, default=2048) 144 | p.add_argument('--dataset', '-d', required=True) 145 | p.add_argument('--split_mode', '-S', type=str, choices=['random', 'subdirs'], default='random') 146 | p.add_argument('--learning_rate', '-l', type=float, default=0.001) 147 | p.add_argument('--lr_min', type=float, default=0.0001) 148 | p.add_argument('--lr_decay_factor', type=float, default=0.9) 149 | p.add_argument('--lr_decay_patience', type=int, default=6) 150 | p.add_argument('--batchsize', '-B', type=int, default=4) 151 | p.add_argument('--accumulation_steps', '-A', type=int, default=1) 152 | p.add_argument('--cropsize', '-C', type=int, default=256) 153 | p.add_argument('--patches', '-p', type=int, default=16) 154 | p.add_argument('--val_rate', '-v', type=float, default=0.2) 155 | p.add_argument('--val_filelist', '-V', type=str, default=None) 156 | p.add_argument('--val_batchsize', '-b', type=int, default=4) 157 | p.add_argument('--val_cropsize', '-c', type=int, default=256) 158 | p.add_argument('--num_workers', '-w', type=int, default=4) 159 | p.add_argument('--epoch', '-E', type=int, default=200) 160 | p.add_argument('--reduction_rate', '-R', type=float, default=0.0) 161 | p.add_argument('--reduction_level', '-L', type=float, default=0.2) 162 | p.add_argument('--mixup_rate', '-M', type=float, default=0.0) 163 | p.add_argument('--mixup_alpha', '-a', type=float, default=1.0) 164 | p.add_argument('--pretrained_model', '-P', type=str, default=None) 165 | p.add_argument('--debug', action='store_true') 166 | args = p.parse_args() 167 | 168 | logger.debug(vars(args)) 169 | 170 | random.seed(args.seed) 171 | np.random.seed(args.seed) 172 | torch.manual_seed(args.seed) 173 | 174 | val_filelist = [] 175 | if args.val_filelist is not None: 176 | with open(args.val_filelist, 'r', encoding='utf8') as f: 177 | val_filelist = json.load(f) 178 | 179 | train_filelist, val_filelist = dataset.train_val_split( 180 | dataset_dir=args.dataset, 181 | split_mode=args.split_mode, 182 | val_rate=args.val_rate, 183 | val_filelist=val_filelist 184 | ) 185 | 186 | if args.debug: 187 | logger.info('### DEBUG MODE') 188 | train_filelist = train_filelist[:1] 189 | val_filelist = val_filelist[:1] 190 | elif args.val_filelist is None and args.split_mode == 'random': 191 | with open('val_{}.json'.format(timestamp), 'w', encoding='utf8') as f: 192 | json.dump(val_filelist, f, ensure_ascii=False) 193 | 194 | for i, (X_fname, y_fname) in enumerate(val_filelist): 195 | logger.info('{} {} {}'.format(i + 1, os.path.basename(X_fname), os.path.basename(y_fname))) 196 | 197 | bins = args.n_fft // 2 + 1 198 | freq_to_bin = 2 * bins / args.sr 199 | unstable_bins = int(200 * freq_to_bin) 200 | stable_bins = int(22050 * freq_to_bin) 201 | reduction_weight = np.concatenate([ 202 | np.linspace(0, 1, unstable_bins, dtype=np.float32)[:, None], 203 | np.linspace(1, 0, stable_bins - unstable_bins, dtype=np.float32)[:, None], 204 | np.zeros((bins - stable_bins, 1), dtype=np.float32), 205 | ], axis=0) * args.reduction_level 206 | 207 | device = torch.device('cpu') 208 | model = nets.CascadedNet(args.n_fft, args.hop_length, 32, 128) 209 | if args.pretrained_model is not None: 210 | model.load_state_dict(torch.load(args.pretrained_model, map_location=device)) 211 | if torch.cuda.is_available() and args.gpu >= 0: 212 | device = torch.device('cuda:{}'.format(args.gpu)) 213 | model.to(device) 214 | 215 | optimizer = torch.optim.Adam( 216 | filter(lambda p: p.requires_grad, model.parameters()), 217 | lr=args.learning_rate 218 | ) 219 | 220 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 221 | optimizer, 222 | factor=args.lr_decay_factor, 223 | patience=args.lr_decay_patience, 224 | threshold=1e-6, 225 | min_lr=args.lr_min, 226 | verbose=True 227 | ) 228 | 229 | training_set = dataset.make_training_set( 230 | filelist=train_filelist, 231 | sr=args.sr, 232 | hop_length=args.hop_length, 233 | n_fft=args.n_fft 234 | ) 235 | 236 | train_dataset = dataset.VocalRemoverTrainingSet( 237 | training_set * args.patches, 238 | cropsize=args.cropsize, 239 | reduction_rate=args.reduction_rate, 240 | reduction_weight=reduction_weight, 241 | mixup_rate=args.mixup_rate, 242 | mixup_alpha=args.mixup_alpha 243 | ) 244 | 245 | train_dataloader = torch.utils.data.DataLoader( 246 | dataset=train_dataset, 247 | batch_size=args.batchsize, 248 | shuffle=True, 249 | num_workers=args.num_workers 250 | ) 251 | 252 | patch_list = dataset.make_validation_set( 253 | filelist=val_filelist, 254 | cropsize=args.val_cropsize, 255 | sr=args.sr, 256 | hop_length=args.hop_length, 257 | n_fft=args.n_fft, 258 | offset=model.offset 259 | ) 260 | 261 | val_dataset = dataset.VocalRemoverValidationSet( 262 | patch_list=patch_list 263 | ) 264 | 265 | val_dataloader = torch.utils.data.DataLoader( 266 | dataset=val_dataset, 267 | batch_size=args.val_batchsize, 268 | shuffle=False, 269 | num_workers=args.num_workers 270 | ) 271 | 272 | log = [] 273 | best_loss = np.inf 274 | for epoch in range(args.epoch): 275 | logger.info('# epoch {}'.format(epoch)) 276 | train_loss = train_epoch(train_dataloader, model, device, optimizer, args.accumulation_steps) 277 | val_loss = validate_epoch(val_dataloader, model, device) 278 | 279 | logger.info( 280 | ' * training loss = {:.6f}, validation loss = {:.6f}' 281 | .format(train_loss, val_loss) 282 | ) 283 | 284 | scheduler.step(val_loss) 285 | 286 | if val_loss < best_loss: 287 | best_loss = val_loss 288 | logger.info(' * best validation loss') 289 | model_path = 'models/model_iter{}.pth'.format(epoch) 290 | torch.save(model.state_dict(), model_path) 291 | 292 | log.append([train_loss, val_loss]) 293 | with open('loss_{}.json'.format(timestamp), 'w', encoding='utf8') as f: 294 | json.dump(log, f, ensure_ascii=False) 295 | 296 | 297 | if __name__ == '__main__': 298 | timestamp = datetime.now().strftime('%Y%m%d%H%M%S') 299 | logger = setup_logger(__name__, 'train_{}.log'.format(timestamp)) 300 | 301 | try: 302 | main() 303 | except Exception as e: 304 | logger.exception(e) 305 | --------------------------------------------------------------------------------