├── LICENSE ├── README.md ├── Singularity ├── audio_examples └── Cristina Vane - So Easy │ ├── accompaniment_true.mp3 │ ├── mix.mp3 │ ├── mix.mp3_bass.wav │ ├── mix.mp3_drums.wav │ ├── mix.mp3_other.wav │ ├── mix.mp3_vocals.wav │ └── vocals_true.mp3 ├── checkpoints └── README.md ├── cog.yaml ├── cog_predict.py ├── data ├── __init__.py ├── dataset.py ├── musdb.py └── utils.py ├── hdf └── README.md ├── logs └── README.md ├── model ├── __init__.py ├── conv.py ├── crop.py ├── resample.py ├── utils.py └── waveunet.py ├── predict.py ├── requirements.txt ├── test.py ├── train.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Daniel Stoller 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Wave-U-Net (Pytorch) 2 | 3 | 4 | Improved version of the [Wave-U-Net](https://arxiv.org/abs/1806.03185) for audio source separation, implemented in Pytorch. 5 | 6 | Click [here](www.github.com/f90/Wave-U-Net) for the original Wave-U-Net implementation in Tensorflow. 7 | You can find more information about the model and results there as well. 8 | 9 | # Improvements 10 | 11 | * Multi-instrument separation by default, using a separate standard Wave-U-Net for each source (can be set to one model as well) 12 | * More scalable to larger data: A depth parameter D can be set that employs D convolutions for each single convolution in the original Wave-U-Net 13 | * More configurable: Layer type, resampling factor at each level etc. can be easily changed (different normalization, residual connections...) 14 | * Fast training: Preprocesses the given dataset by saving the audio into HDF files, which can be read very quickly during training, thereby avoiding slowdown due to resampling and decoding 15 | * Modular thanks to Pytorch: Easily replace components of the model with your own variants/layers/losses 16 | * Better output handling: Separate output convolution for each source estimate with linear activation so amplitudes near 1 and -1 can be easily predicted, at test time thresholding to valid amplitude range [-1,1] 17 | * Fixed or dynamic resampling: Either use fixed lowpass filter to avoid aliasing during resampling, or use a learnable convolution 18 | 19 | # Installation 20 | 21 | GPU strongly recommended to avoid very long training times. 22 | 23 | ### Option 1: Direct install (recommended) 24 | 25 | System requirements: 26 | * Linux-based OS 27 | * Python 3.6 28 | 29 | * [libsndfile](http://mega-nerd.com/libsndfile/) 30 | 31 | * [ffmpeg](https://www.ffmpeg.org/) 32 | * CUDA 10.1 for GPU usage 33 | 34 | Clone the repository: 35 | ``` 36 | git clone https://github.com/f90/Wave-U-Net-Pytorch.git 37 | ``` 38 | 39 | Recommended: Create a new virtual environment to install the required Python packages into, then activate the virtual environment: 40 | 41 | ``` 42 | virtualenv --python /usr/bin/python3.6 waveunet-env 43 | source waveunet-env/bin/activate 44 | ``` 45 | 46 | Install all the required packages listed in the ``requirements.txt``: 47 | 48 | ``` 49 | pip3 install -r requirements.txt 50 | ``` 51 | 52 | ### Option 2: Singularity 53 | 54 | We also provide a Singularity container which allows you to avoid installing the correct Python, CUDA and other system libraries, however we don't provide specific advice on how to run the container and so only do this if you have to or know what you are doing (since you need to mount dataset paths to the container etc.) 55 | 56 | To pull the container, run 57 | ``` 58 | singularity pull shub://f90/Wave-U-Net-Pytorch 59 | ``` 60 | 61 | Then run the container from the directory where you cloned this repository to, using the commands listed further below in this readme. 62 | 63 | # Download datasets 64 | 65 | To directly use the pre-trained models we provide for download to separate your own songs, now skip directly to the [last section](#test), since the datasets are not needed in that case. 66 | 67 | To start training your own models, download the [full MUSDB18HQ dataset](https://sigsep.github.io/datasets/musdb.html) and extract it into a folder of your choice. It should have two subfolders: "test" and "train" as well as a README.md file. 68 | 69 | You can of course use your own datasets for training, but for this you would need to modify the code manually, which will not be discussed here. However, we provide a loading function for the normal MUSDB18 dataset as well. 70 | 71 | # Training the models 72 | 73 | To train a Wave-U-Net, the basic command to use is 74 | 75 | ``` 76 | python3.6 train.py --dataset_dir /PATH/TO/MUSDB18HQ 77 | ``` 78 | where the path to MUSDB18HQ dataset needs to be specified, which contains the ``train`` and ``test`` subfolders. 79 | 80 | Add more command line parameters as needed: 81 | * ``--cuda`` to activate GPU usage 82 | * ``--hdf_dir PATH`` to save the preprocessed data (HDF files) to custom location PATH, instead of the default ``hdf`` subfolder in this repository 83 | * ``--checkpoint_dir`` and ``--log_dir`` to specify where checkpoint files and logs are saved/loaded 84 | * ``--load_model checkpoints/model_name/checkpoint_X`` to start training with weights given by a certain checkpoint 85 | 86 | For more config options, see ``train.py``. 87 | 88 | Training progress can be monitored by using Tensorboard on the respective ``log_dir``. 89 | After training, the model is evaluated on the MUSDB18HQ test set, and SDR/SIR/SAR metrics are reported for all instruments and written into both the Tensorboard, and in more detail also into a ``results.pkl`` file in the ``checkpoint_dir`` 90 | 91 | # Test trained models on songs! 92 | 93 | We provide the default model in a pre-trained form as download so you can separate your own songs right away. 94 | 95 | ## Downloading our pretrained models 96 | 97 | Download our pretrained model [here](https://www.dropbox.com/s/r374hce896g4xlj/models.7z?dl=1). 98 | Extract the archive into the ``checkpoints`` subfolder in this repository, so that you have one subfolder for each model (e.g. ``REPO/checkpoints/waveunet``) 99 | 100 | ## Run pretrained model 101 | 102 | To apply our pretrained model to any of your own songs, simply point to its audio file path using the ``input_path`` parameter: 103 | 104 | ``` 105 | python3.6 predict.py --load_model checkpoints/waveunet/model --input "audio_examples/Cristina Vane - So Easy/mix.mp3" 106 | ``` 107 | 108 | * Add ``--cuda `` when using a GPU, it should be much quicker 109 | * Point ``--input`` to the music file you want to separate 110 | 111 | By default, output is written where the input music file is located, using the original file name plus the instrument name as output file name. Use ``--output`` to customise the output directory. 112 | 113 | To run your own model: 114 | * Point ``--load_model`` to the checkpoint file of the model you are using. If you used non-default hyper-parameters to train your own model, you must specify them here again so the correct model is set up and can receive the weights! 115 | -------------------------------------------------------------------------------- /Singularity: -------------------------------------------------------------------------------- 1 | BootStrap: docker 2 | From: nvidia/cuda:10.1-cudnn7-devel-ubuntu18.04 3 | 4 | %post 5 | # Downloads the latest package lists (important). 6 | apt-get update -y 7 | # Runs apt-get while ensuring that there are no user prompts that would 8 | # cause the build process to hang. 9 | # python3-tk is required by matplotlib. 10 | # python3-dev is needed to require some packages. 11 | DEBIAN_FRONTEND=noninteractive apt-get install -y \ 12 | python3 \ 13 | python3-tk \ 14 | python3-pip \ 15 | python3-dev \ 16 | libsndfile1 \ 17 | libsndfile1-dev \ 18 | ffmpeg \ 19 | git 20 | # Reduce the size of the image by deleting the package lists we downloaded, 21 | # which are useless now. 22 | rm -rf /var/lib/apt/lists/* 23 | 24 | # Install Pipenv. 25 | pip3 install pipenv 26 | 27 | # Install Python modules. 28 | pip3 install future numpy librosa musdb museval h5py tqdm sortedcontainers soundfile 29 | pip3 install torch==1.4.0 torchvision==0.5.0 tensorboard 30 | 31 | %environment 32 | # Pipenv requires a certain terminal encoding. 33 | export LANG=C.UTF-8 34 | export LC_ALL=C.UTF-8 35 | # This configures Pipenv to store the packages in the current working 36 | # directory. 37 | export PIPENV_VENV_IN_PROJECT=1 38 | -------------------------------------------------------------------------------- /audio_examples/Cristina Vane - So Easy/accompaniment_true.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/Wave-U-Net-Pytorch/86c113da51540c94b01a91fbf17227f0e9c3c274/audio_examples/Cristina Vane - So Easy/accompaniment_true.mp3 -------------------------------------------------------------------------------- /audio_examples/Cristina Vane - So Easy/mix.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/Wave-U-Net-Pytorch/86c113da51540c94b01a91fbf17227f0e9c3c274/audio_examples/Cristina Vane - So Easy/mix.mp3 -------------------------------------------------------------------------------- /audio_examples/Cristina Vane - So Easy/mix.mp3_bass.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/Wave-U-Net-Pytorch/86c113da51540c94b01a91fbf17227f0e9c3c274/audio_examples/Cristina Vane - So Easy/mix.mp3_bass.wav -------------------------------------------------------------------------------- /audio_examples/Cristina Vane - So Easy/mix.mp3_drums.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/Wave-U-Net-Pytorch/86c113da51540c94b01a91fbf17227f0e9c3c274/audio_examples/Cristina Vane - So Easy/mix.mp3_drums.wav -------------------------------------------------------------------------------- /audio_examples/Cristina Vane - So Easy/mix.mp3_other.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/Wave-U-Net-Pytorch/86c113da51540c94b01a91fbf17227f0e9c3c274/audio_examples/Cristina Vane - So Easy/mix.mp3_other.wav -------------------------------------------------------------------------------- /audio_examples/Cristina Vane - So Easy/mix.mp3_vocals.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/Wave-U-Net-Pytorch/86c113da51540c94b01a91fbf17227f0e9c3c274/audio_examples/Cristina Vane - So Easy/mix.mp3_vocals.wav -------------------------------------------------------------------------------- /audio_examples/Cristina Vane - So Easy/vocals_true.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/Wave-U-Net-Pytorch/86c113da51540c94b01a91fbf17227f0e9c3c274/audio_examples/Cristina Vane - So Easy/vocals_true.mp3 -------------------------------------------------------------------------------- /checkpoints/README.md: -------------------------------------------------------------------------------- 1 | This is where checkpoint files of models will be saved by default, it's recommended to use a subfolder for each different type of model you have. -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | build: 2 | python_version: "3.6" 3 | gpu: false 4 | python_packages: 5 | - future==0.18.2 6 | - numpy==1.19.5 7 | - librosa==0.8.1 8 | - soundfile==0.10.3.post1 9 | - musdb==0.4.0 10 | - museval==0.4.0 11 | - h5py==3.1.0 12 | - tqdm==4.62.1 13 | - torch==1.4.0 14 | - torchvision==0.5.0 15 | - tensorboard==2.6.0 16 | - sortedcontainers==2.4.0 17 | system_packages: 18 | - libsndfile-dev 19 | - ffmpeg 20 | predict: "cog_predict.py:waveunetPredictor" 21 | -------------------------------------------------------------------------------- /cog_predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cog 3 | import tempfile 4 | import zipfile 5 | from pathlib import Path 6 | import argparse 7 | import data.utils 8 | import model.utils as model_utils 9 | from test import predict_song 10 | from model.waveunet import Waveunet 11 | 12 | 13 | class waveunetPredictor(cog.Predictor): 14 | def setup(self): 15 | """Init wave u net model""" 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument( 18 | "--instruments", 19 | type=str, 20 | nargs="+", 21 | default=["bass", "drums", "other", "vocals"], 22 | help='List of instruments to separate (default: "bass drums other vocals")', 23 | ) 24 | parser.add_argument( 25 | "--cuda", action="store_true", help="Use CUDA (default: False)" 26 | ) 27 | parser.add_argument( 28 | "--features", 29 | type=int, 30 | default=32, 31 | help="Number of feature channels per layer", 32 | ) 33 | parser.add_argument( 34 | "--load_model", 35 | type=str, 36 | default="checkpoints/waveunet/model", 37 | help="Reload a previously trained model", 38 | ) 39 | parser.add_argument("--batch_size", type=int, default=4, help="Batch size") 40 | parser.add_argument( 41 | "--levels", type=int, default=6, help="Number of DS/US blocks" 42 | ) 43 | parser.add_argument( 44 | "--depth", type=int, default=1, help="Number of convs per block" 45 | ) 46 | parser.add_argument("--sr", type=int, default=44100, help="Sampling rate") 47 | parser.add_argument( 48 | "--channels", type=int, default=2, help="Number of input audio channels" 49 | ) 50 | parser.add_argument( 51 | "--kernel_size", 52 | type=int, 53 | default=5, 54 | help="Filter width of kernels. Has to be an odd number", 55 | ) 56 | parser.add_argument( 57 | "--output_size", type=float, default=2.0, help="Output duration" 58 | ) 59 | parser.add_argument( 60 | "--strides", type=int, default=4, help="Strides in Waveunet" 61 | ) 62 | parser.add_argument( 63 | "--conv_type", 64 | type=str, 65 | default="gn", 66 | help="Type of convolution (normal, BN-normalised, GN-normalised): normal/bn/gn", 67 | ) 68 | parser.add_argument( 69 | "--res", 70 | type=str, 71 | default="fixed", 72 | help="Resampling strategy: fixed sinc-based lowpass filtering or learned conv layer: fixed/learned", 73 | ) 74 | parser.add_argument( 75 | "--separate", 76 | type=int, 77 | default=1, 78 | help="Train separate model for each source (1) or only one (0)", 79 | ) 80 | parser.add_argument( 81 | "--feature_growth", 82 | type=str, 83 | default="double", 84 | help="How the features in each layer should grow, either (add) the initial number of features each time, or multiply by 2 (double)", 85 | ) 86 | """ 87 | parser.add_argument('--input', type=str, default=str(input), 88 | help="Path to input mixture to be separated") 89 | parser.add_argument('--output', type=str, default=out_path, help="Output path (same folder as input path if not set)") 90 | """ 91 | args = parser.parse_args([]) 92 | self.args = args 93 | 94 | num_features = ( 95 | [args.features * i for i in range(1, args.levels + 1)] 96 | if args.feature_growth == "add" 97 | else [args.features * 2 ** i for i in range(0, args.levels)] 98 | ) 99 | target_outputs = int(args.output_size * args.sr) 100 | self.model = Waveunet( 101 | args.channels, 102 | num_features, 103 | args.channels, 104 | args.instruments, 105 | kernel_size=args.kernel_size, 106 | target_output_size=target_outputs, 107 | depth=args.depth, 108 | strides=args.strides, 109 | conv_type=args.conv_type, 110 | res=args.res, 111 | separate=args.separate, 112 | ) 113 | 114 | if args.cuda: 115 | self.model = model_utils.DataParallel(model) 116 | print("move model to gpu") 117 | self.model.cuda() 118 | 119 | print("Loading model from checkpoint " + str(args.load_model)) 120 | state = model_utils.load_model(self.model, None, args.load_model, args.cuda) 121 | print("Step", state["step"]) 122 | 123 | @cog.input("input", type=Path, help="audio mixture path") 124 | def predict(self, input): 125 | """Separate tracks from input mixture audio""" 126 | 127 | out_path = Path(tempfile.mkdtemp()) 128 | zip_path = Path(tempfile.mkdtemp()) / "output.zip" 129 | 130 | preds = predict_song(self.args, input, self.model) 131 | 132 | out_names = [] 133 | for inst in preds.keys(): 134 | temp_n = os.path.join( 135 | str(out_path), os.path.basename(str(input)) + "_" + inst + ".wav" 136 | ) 137 | data.utils.write_wav(temp_n, preds[inst], self.args.sr) 138 | out_names.append(temp_n) 139 | 140 | with zipfile.ZipFile(str(zip_path), "w") as zf: 141 | for i in out_names: 142 | zf.write(str(i)) 143 | 144 | return zip_path 145 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/Wave-U-Net-Pytorch/86c113da51540c94b01a91fbf17227f0e9c3c274/data/__init__.py -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import h5py 4 | import numpy as np 5 | from sortedcontainers import SortedList 6 | from torch.utils.data import Dataset 7 | from tqdm import tqdm 8 | 9 | from data.utils import load 10 | 11 | 12 | class SeparationDataset(Dataset): 13 | def __init__(self, dataset, partition, instruments, sr, channels, shapes, random_hops, hdf_dir, audio_transform=None, in_memory=False): 14 | ''' 15 | Initialises a source separation dataset 16 | :param data: HDF audio data object 17 | :param input_size: Number of input samples for each example 18 | :param context_front: Number of extra context samples to prepend to input 19 | :param context_back: NUmber of extra context samples to append to input 20 | :param hop_size: Skip hop_size - 1 sample positions in the audio for each example (subsampling the audio) 21 | :param random_hops: If False, sample examples evenly from whole audio signal according to hop_size parameter. If True, randomly sample a position from the audio 22 | ''' 23 | 24 | super(SeparationDataset, self).__init__() 25 | 26 | self.hdf_dataset = None 27 | os.makedirs(hdf_dir, exist_ok=True) 28 | self.hdf_dir = os.path.join(hdf_dir, partition + ".hdf5") 29 | 30 | self.random_hops = random_hops 31 | self.sr = sr 32 | self.channels = channels 33 | self.shapes = shapes 34 | self.audio_transform = audio_transform 35 | self.in_memory = in_memory 36 | self.instruments = instruments 37 | 38 | # PREPARE HDF FILE 39 | 40 | # Check if HDF file exists already 41 | if not os.path.exists(self.hdf_dir): 42 | # Create folder if it did not exist before 43 | if not os.path.exists(hdf_dir): 44 | os.makedirs(hdf_dir) 45 | 46 | # Create HDF file 47 | with h5py.File(self.hdf_dir, "w") as f: 48 | f.attrs["sr"] = sr 49 | f.attrs["channels"] = channels 50 | f.attrs["instruments"] = instruments 51 | 52 | print("Adding audio files to dataset (preprocessing)...") 53 | for idx, example in enumerate(tqdm(dataset[partition])): 54 | # Load mix 55 | mix_audio, _ = load(example["mix"], sr=self.sr, mono=(self.channels == 1)) 56 | 57 | source_audios = [] 58 | for source in instruments: 59 | # In this case, read in audio and convert to target sampling rate 60 | source_audio, _ = load(example[source], sr=self.sr, mono=(self.channels == 1)) 61 | source_audios.append(source_audio) 62 | source_audios = np.concatenate(source_audios, axis=0) 63 | assert(source_audios.shape[1] == mix_audio.shape[1]) 64 | 65 | # Add to HDF5 file 66 | grp = f.create_group(str(idx)) 67 | grp.create_dataset("inputs", shape=mix_audio.shape, dtype=mix_audio.dtype, data=mix_audio) 68 | grp.create_dataset("targets", shape=source_audios.shape, dtype=source_audios.dtype, data=source_audios) 69 | grp.attrs["length"] = mix_audio.shape[1] 70 | grp.attrs["target_length"] = source_audios.shape[1] 71 | 72 | # In that case, check whether sr and channels are complying with the audio in the HDF file, otherwise raise error 73 | with h5py.File(self.hdf_dir, "r") as f: 74 | if f.attrs["sr"] != sr or \ 75 | f.attrs["channels"] != channels or \ 76 | list(f.attrs["instruments"]) != instruments: 77 | raise ValueError( 78 | "Tried to load existing HDF file, but sampling rate and channel or instruments are not as expected. Did you load an out-dated HDF file?") 79 | 80 | # HDF FILE READY 81 | 82 | # SET SAMPLING POSITIONS 83 | 84 | # Go through HDF and collect lengths of all audio files 85 | with h5py.File(self.hdf_dir, "r") as f: 86 | lengths = [f[str(song_idx)].attrs["target_length"] for song_idx in range(len(f))] 87 | 88 | # Subtract input_size from lengths and divide by hop size to determine number of starting positions 89 | lengths = [(l // self.shapes["output_frames"]) + 1 for l in lengths] 90 | 91 | self.start_pos = SortedList(np.cumsum(lengths)) 92 | self.length = self.start_pos[-1] 93 | 94 | def __getitem__(self, index): 95 | # Open HDF5 96 | if self.hdf_dataset is None: 97 | driver = "core" if self.in_memory else None # Load HDF5 fully into memory if desired 98 | self.hdf_dataset = h5py.File(self.hdf_dir, 'r', driver=driver) 99 | 100 | # Find out which slice of targets we want to read 101 | audio_idx = self.start_pos.bisect_right(index) 102 | if audio_idx > 0: 103 | index = index - self.start_pos[audio_idx - 1] 104 | 105 | # Check length of audio signal 106 | audio_length = self.hdf_dataset[str(audio_idx)].attrs["length"] 107 | target_length = self.hdf_dataset[str(audio_idx)].attrs["target_length"] 108 | 109 | # Determine position where to start targets 110 | if self.random_hops: 111 | start_target_pos = np.random.randint(0, max(target_length - self.shapes["output_frames"] + 1, 1)) 112 | else: 113 | # Map item index to sample position within song 114 | start_target_pos = index * self.shapes["output_frames"] 115 | 116 | # READ INPUTS 117 | # Check front padding 118 | start_pos = start_target_pos - self.shapes["output_start_frame"] 119 | if start_pos < 0: 120 | # Pad manually since audio signal was too short 121 | pad_front = abs(start_pos) 122 | start_pos = 0 123 | else: 124 | pad_front = 0 125 | 126 | # Check back padding 127 | end_pos = start_target_pos - self.shapes["output_start_frame"] + self.shapes["input_frames"] 128 | if end_pos > audio_length: 129 | # Pad manually since audio signal was too short 130 | pad_back = end_pos - audio_length 131 | end_pos = audio_length 132 | else: 133 | pad_back = 0 134 | 135 | # Read and return 136 | audio = self.hdf_dataset[str(audio_idx)]["inputs"][:, start_pos:end_pos].astype(np.float32) 137 | if pad_front > 0 or pad_back > 0: 138 | audio = np.pad(audio, [(0, 0), (pad_front, pad_back)], mode="constant", constant_values=0.0) 139 | 140 | targets = self.hdf_dataset[str(audio_idx)]["targets"][:, start_pos:end_pos].astype(np.float32) 141 | if pad_front > 0 or pad_back > 0: 142 | targets = np.pad(targets, [(0, 0), (pad_front, pad_back)], mode="constant", constant_values=0.0) 143 | 144 | targets = {inst : targets[idx*self.channels:(idx+1)*self.channels] for idx, inst in enumerate(self.instruments)} 145 | 146 | if hasattr(self, "audio_transform") and self.audio_transform is not None: 147 | audio, targets = self.audio_transform(audio, targets) 148 | 149 | return audio, targets 150 | 151 | def __len__(self): 152 | return self.length -------------------------------------------------------------------------------- /data/musdb.py: -------------------------------------------------------------------------------- 1 | import musdb 2 | import os 3 | import numpy as np 4 | import glob 5 | 6 | from data.utils import load, write_wav 7 | 8 | 9 | def get_musdbhq(database_path): 10 | ''' 11 | Retrieve audio file paths for MUSDB HQ dataset 12 | :param database_path: MUSDB HQ root directory 13 | :return: dictionary with train and test keys, each containing list of samples, each sample containing all audio paths 14 | ''' 15 | subsets = list() 16 | 17 | for subset in ["train", "test"]: 18 | print("Loading " + subset + " set...") 19 | tracks = glob.glob(os.path.join(database_path, subset, "*")) 20 | samples = list() 21 | 22 | # Go through tracks 23 | for track_folder in sorted(tracks): 24 | # Skip track if mixture is already written, assuming this track is done already 25 | example = dict() 26 | for stem in ["mix", "bass", "drums", "other", "vocals"]: 27 | filename = stem if stem != "mix" else "mixture" 28 | audio_path = os.path.join(track_folder, filename + ".wav") 29 | example[stem] = audio_path 30 | 31 | # Add other instruments to form accompaniment 32 | acc_path = os.path.join(track_folder, "accompaniment.wav") 33 | 34 | if not os.path.exists(acc_path): 35 | print("Writing accompaniment to " + track_folder) 36 | stem_audio = [] 37 | for stem in ["bass", "drums", "other"]: 38 | audio, sr = load(example[stem], sr=None, mono=False) 39 | stem_audio.append(audio) 40 | acc_audio = np.clip(sum(stem_audio), -1.0, 1.0) 41 | write_wav(acc_path, acc_audio, sr) 42 | 43 | example["accompaniment"] = acc_path 44 | 45 | samples.append(example) 46 | 47 | subsets.append(samples) 48 | 49 | return subsets 50 | 51 | def get_musdb(database_path): 52 | ''' 53 | Retrieve audio file paths for MUSDB dataset 54 | :param database_path: MUSDB root directory 55 | :return: dictionary with train and test keys, each containing list of samples, each sample containing all audio paths 56 | ''' 57 | mus = musdb.DB(root=database_path, is_wav=False) 58 | 59 | subsets = list() 60 | 61 | for subset in ["train", "test"]: 62 | tracks = mus.load_mus_tracks(subset) 63 | samples = list() 64 | 65 | # Go through tracks 66 | for track in sorted(tracks): 67 | # Skip track if mixture is already written, assuming this track is done already 68 | track_path = track.path[:-4] 69 | mix_path = track_path + "_mix.wav" 70 | acc_path = track_path + "_accompaniment.wav" 71 | if os.path.exists(mix_path): 72 | print("WARNING: Skipping track " + mix_path + " since it exists already") 73 | 74 | # Add paths and then skip 75 | paths = {"mix" : mix_path, "accompaniment" : acc_path} 76 | paths.update({key : track_path + "_" + key + ".wav" for key in ["bass", "drums", "other", "vocals"]}) 77 | 78 | samples.append(paths) 79 | 80 | continue 81 | 82 | rate = track.rate 83 | 84 | # Go through each instrument 85 | paths = dict() 86 | stem_audio = dict() 87 | for stem in ["bass", "drums", "other", "vocals"]: 88 | path = track_path + "_" + stem + ".wav" 89 | audio = track.targets[stem].audio 90 | write_wav(path, audio, rate) 91 | stem_audio[stem] = audio 92 | paths[stem] = path 93 | 94 | # Add other instruments to form accompaniment 95 | acc_audio = np.clip(sum([stem_audio[key] for key in list(stem_audio.keys()) if key != "vocals"]), -1.0, 1.0) 96 | write_wav(acc_path, acc_audio, rate) 97 | paths["accompaniment"] = acc_path 98 | 99 | # Create mixture 100 | mix_audio = track.audio 101 | write_wav(mix_path, mix_audio, rate) 102 | paths["mix"] = mix_path 103 | 104 | diff_signal = np.abs(mix_audio - acc_audio - stem_audio["vocals"]) 105 | print("Maximum absolute deviation from source additivity constraint: " + str(np.max(diff_signal)))# Check if acc+vocals=mix 106 | print("Mean absolute deviation from source additivity constraint: " + str(np.mean(diff_signal))) 107 | 108 | samples.append(paths) 109 | 110 | subsets.append(samples) 111 | 112 | print("DONE preparing dataset!") 113 | return subsets 114 | 115 | def get_musdb_folds(root_path, version="HQ"): 116 | if version == "HQ": 117 | dataset = get_musdbhq(root_path) 118 | else: 119 | dataset = get_musdb(root_path) 120 | train_val_list = dataset[0] 121 | test_list = dataset[1] 122 | 123 | np.random.seed(1337) # Ensure that partitioning is always the same on each run 124 | train_list = np.random.choice(train_val_list, 75, replace=False) 125 | val_list = [elem for elem in train_val_list if elem not in train_list] 126 | # print("First training song: " + str(train_list[0])) # To debug whether partitioning is deterministic 127 | return {"train" : train_list, "val" : val_list, "test" : test_list} -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import numpy as np 3 | import soundfile 4 | import torch 5 | 6 | 7 | def random_amplify(mix, targets, shapes, min, max): 8 | ''' 9 | Data augmentation by randomly amplifying sources before adding them to form a new mixture 10 | :param mix: Original mixture 11 | :param targets: Source targets 12 | :param shapes: Shape dict from model 13 | :param min: Minimum possible amplification 14 | :param max: Maximum possible amplification 15 | :return: New data point as tuple (mix, targets) 16 | ''' 17 | residual = mix # start with original mix 18 | for key in targets.keys(): 19 | if key != "mix": 20 | residual -= targets[key] # subtract all instruments (output is zero if all instruments add to mix) 21 | mix = residual * np.random.uniform(min, max) # also apply gain data augmentation to residual 22 | for key in targets.keys(): 23 | if key != "mix": 24 | targets[key] = targets[key] * np.random.uniform(min, max) 25 | mix += targets[key] # add instrument with gain data augmentation to mix 26 | mix = np.clip(mix, -1.0, 1.0) 27 | return crop_targets(mix, targets, shapes) 28 | 29 | 30 | def crop_targets(mix, targets, shapes): 31 | ''' 32 | Crops target audio to the output shape required by the model given in "shapes" 33 | ''' 34 | for key in targets.keys(): 35 | if key != "mix": 36 | targets[key] = targets[key][:, shapes["output_start_frame"]:shapes["output_end_frame"]] 37 | return mix, targets 38 | 39 | 40 | def load(path, sr=22050, mono=True, mode="numpy", offset=0.0, duration=None): 41 | y, curr_sr = librosa.load(path, sr=sr, mono=mono, res_type='kaiser_fast', offset=offset, duration=duration) 42 | 43 | if len(y.shape) == 1: 44 | # Expand channel dimension 45 | y = y[np.newaxis, :] 46 | 47 | if mode == "pytorch": 48 | y = torch.tensor(y) 49 | 50 | return y, curr_sr 51 | 52 | 53 | def write_wav(path, audio, sr): 54 | soundfile.write(path, audio.T, sr, "PCM_16") 55 | 56 | 57 | def resample(audio, orig_sr, new_sr, mode="numpy"): 58 | if orig_sr == new_sr: 59 | return audio 60 | 61 | if isinstance(audio, torch.Tensor): 62 | audio = audio.detach().cpu().numpy() 63 | 64 | out = librosa.resample(audio, orig_sr, new_sr, res_type='kaiser_fast') 65 | 66 | if mode == "pytorch": 67 | out = torch.tensor(out) 68 | return out -------------------------------------------------------------------------------- /hdf/README.md: -------------------------------------------------------------------------------- 1 | This directory is where the preprocessed data is saved by default in the form of HDF file, one for training, validation, and test. These files can get quite large since they are uncompressed by default, so if this is not a good place for them to go because of space issues or disk speed, use a different ``hdf_dir`` parameter for ``train.py``. -------------------------------------------------------------------------------- /logs/README.md: -------------------------------------------------------------------------------- 1 | This is where logs of the training process are going to be stored! -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/Wave-U-Net-Pytorch/86c113da51540c94b01a91fbf17227f0e9c3c274/model/__init__.py -------------------------------------------------------------------------------- /model/conv.py: -------------------------------------------------------------------------------- 1 | from torch import nn as nn 2 | from torch.nn import functional as F 3 | 4 | 5 | class ConvLayer(nn.Module): 6 | def __init__(self, n_inputs, n_outputs, kernel_size, stride, conv_type, transpose=False): 7 | super(ConvLayer, self).__init__() 8 | self.transpose = transpose 9 | self.stride = stride 10 | self.kernel_size = kernel_size 11 | self.conv_type = conv_type 12 | 13 | # How many channels should be normalised as one group if GroupNorm is activated 14 | # WARNING: Number of channels has to be divisible by this number! 15 | NORM_CHANNELS = 8 16 | 17 | if self.transpose: 18 | self.filter = nn.ConvTranspose1d(n_inputs, n_outputs, self.kernel_size, stride, padding=kernel_size-1) 19 | else: 20 | self.filter = nn.Conv1d(n_inputs, n_outputs, self.kernel_size, stride) 21 | 22 | if conv_type == "gn": 23 | assert(n_outputs % NORM_CHANNELS == 0) 24 | self.norm = nn.GroupNorm(n_outputs // NORM_CHANNELS, n_outputs) 25 | elif conv_type == "bn": 26 | self.norm = nn.BatchNorm1d(n_outputs, momentum=0.01) 27 | # Add you own types of variations here! 28 | 29 | def forward(self, x): 30 | # Apply the convolution 31 | if self.conv_type == "gn" or self.conv_type == "bn": 32 | out = F.relu(self.norm((self.filter(x)))) 33 | else: # Add your own variations here with elifs conditioned on "conv_type" parameter! 34 | assert(self.conv_type == "normal") 35 | out = F.leaky_relu(self.filter(x)) 36 | return out 37 | 38 | def get_input_size(self, output_size): 39 | # Strided conv/decimation 40 | if not self.transpose: 41 | curr_size = (output_size - 1)*self.stride + 1 # o = (i-1)//s + 1 => i = (o - 1)*s + 1 42 | else: 43 | curr_size = output_size 44 | 45 | # Conv 46 | curr_size = curr_size + self.kernel_size - 1 # o = i + p - k + 1 47 | 48 | # Transposed 49 | if self.transpose: 50 | assert ((curr_size - 1) % self.stride == 0)# We need to have a value at the beginning and end 51 | curr_size = ((curr_size - 1) // self.stride) + 1 52 | assert(curr_size > 0) 53 | return curr_size 54 | 55 | def get_output_size(self, input_size): 56 | # Transposed 57 | if self.transpose: 58 | assert(input_size > 1) 59 | curr_size = (input_size - 1)*self.stride + 1 # o = (i-1)//s + 1 => i = (o - 1)*s + 1 60 | else: 61 | curr_size = input_size 62 | 63 | # Conv 64 | curr_size = curr_size - self.kernel_size + 1 # o = i + p - k + 1 65 | assert (curr_size > 0) 66 | 67 | # Strided conv/decimation 68 | if not self.transpose: 69 | assert ((curr_size - 1) % self.stride == 0) # We need to have a value at the beginning and end 70 | curr_size = ((curr_size - 1) // self.stride) + 1 71 | 72 | return curr_size -------------------------------------------------------------------------------- /model/crop.py: -------------------------------------------------------------------------------- 1 | def centre_crop(x, target): 2 | ''' 3 | Center-crop 3-dim. input tensor along last axis so it fits the target tensor shape 4 | :param x: Input tensor 5 | :param target: Shape of this tensor will be used as target shape 6 | :return: Cropped input tensor 7 | ''' 8 | if x is None: 9 | return None 10 | if target is None: 11 | return x 12 | 13 | target_shape = target.shape 14 | diff = x.shape[-1] - target_shape[-1] 15 | assert (diff % 2 == 0) 16 | crop = diff // 2 17 | 18 | if crop == 0: 19 | return x 20 | if crop < 0: 21 | raise ArithmeticError 22 | 23 | return x[:, :, crop:-crop].contiguous() -------------------------------------------------------------------------------- /model/resample.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn as nn 4 | from torch.nn import functional as F 5 | 6 | class Resample1d(nn.Module): 7 | def __init__(self, channels, kernel_size, stride, transpose=False, padding="reflect", trainable=False): 8 | ''' 9 | Creates a resampling layer for time series data (using 1D convolution) - (N, C, W) input format 10 | :param channels: Number of features C at each time-step 11 | :param kernel_size: Width of sinc-based lowpass-filter (>= 15 recommended for good filtering performance) 12 | :param stride: Resampling factor (integer) 13 | :param transpose: False for down-, true for upsampling 14 | :param padding: Either "reflect" to pad or "valid" to not pad 15 | :param trainable: Optionally activate this to train the lowpass-filter, starting from the sinc initialisation 16 | ''' 17 | super(Resample1d, self).__init__() 18 | 19 | self.padding = padding 20 | self.kernel_size = kernel_size 21 | self.stride = stride 22 | self.transpose = transpose 23 | self.channels = channels 24 | 25 | cutoff = 0.5 / stride 26 | 27 | assert(kernel_size > 2) 28 | assert ((kernel_size - 1) % 2 == 0) 29 | assert(padding == "reflect" or padding == "valid") 30 | 31 | filter = build_sinc_filter(kernel_size, cutoff) 32 | 33 | self.filter = torch.nn.Parameter(torch.from_numpy(np.repeat(np.reshape(filter, [1, 1, kernel_size]), channels, axis=0)), requires_grad=trainable) 34 | 35 | def forward(self, x): 36 | # Pad here if not using transposed conv 37 | input_size = x.shape[2] 38 | if self.padding != "valid": 39 | num_pad = (self.kernel_size-1)//2 40 | out = F.pad(x, (num_pad, num_pad), mode=self.padding) 41 | else: 42 | out = x 43 | 44 | # Lowpass filter (+ 0 insertion if transposed) 45 | if self.transpose: 46 | expected_steps = ((input_size - 1) * self.stride + 1) 47 | if self.padding == "valid": 48 | expected_steps = expected_steps - self.kernel_size + 1 49 | 50 | out = F.conv_transpose1d(out, self.filter, stride=self.stride, padding=0, groups=self.channels) 51 | diff_steps = out.shape[2] - expected_steps 52 | if diff_steps > 0: 53 | assert(diff_steps % 2 == 0) 54 | out = out[:,:,diff_steps//2:-diff_steps//2] 55 | else: 56 | assert(input_size % self.stride == 1) 57 | out = F.conv1d(out, self.filter, stride=self.stride, padding=0, groups=self.channels) 58 | 59 | return out 60 | 61 | def get_output_size(self, input_size): 62 | ''' 63 | Returns the output dimensionality (number of timesteps) for a given input size 64 | :param input_size: Number of input time steps (Scalar, each feature is one-dimensional) 65 | :return: Output size (scalar) 66 | ''' 67 | assert(input_size > 1) 68 | if self.transpose: 69 | if self.padding == "valid": 70 | return ((input_size - 1) * self.stride + 1) - self.kernel_size + 1 71 | else: 72 | return ((input_size - 1) * self.stride + 1) 73 | else: 74 | assert(input_size % self.stride == 1) # Want to take first and last sample 75 | if self.padding == "valid": 76 | return input_size - self.kernel_size + 1 77 | else: 78 | return input_size 79 | 80 | def get_input_size(self, output_size): 81 | ''' 82 | Returns the input dimensionality (number of timesteps) for a given output size 83 | :param input_size: Number of input time steps (Scalar, each feature is one-dimensional) 84 | :return: Output size (scalar) 85 | ''' 86 | 87 | # Strided conv/decimation 88 | if not self.transpose: 89 | curr_size = (output_size - 1)*self.stride + 1 # o = (i-1)//s + 1 => i = (o - 1)*s + 1 90 | else: 91 | curr_size = output_size 92 | 93 | # Conv 94 | if self.padding == "valid": 95 | curr_size = curr_size + self.kernel_size - 1 # o = i + p - k + 1 96 | 97 | # Transposed 98 | if self.transpose: 99 | assert ((curr_size - 1) % self.stride == 0)# We need to have a value at the beginning and end 100 | curr_size = ((curr_size - 1) // self.stride) + 1 101 | assert(curr_size > 0) 102 | return curr_size 103 | 104 | def build_sinc_filter(kernel_size, cutoff): 105 | # FOLLOWING https://www.analog.com/media/en/technical-documentation/dsp-book/dsp_book_Ch16.pdf 106 | # Sinc lowpass filter 107 | # Build sinc kernel 108 | assert(kernel_size % 2 == 1) 109 | M = kernel_size - 1 110 | filter = np.zeros(kernel_size, dtype=np.float32) 111 | for i in range(kernel_size): 112 | if i == M//2: 113 | filter[i] = 2 * np.pi * cutoff 114 | else: 115 | filter[i] = (np.sin(2 * np.pi * cutoff * (i - M//2)) / (i - M//2)) * \ 116 | (0.42 - 0.5 * np.cos((2 * np.pi * i) / M) + 0.08 * np.cos(4 * np.pi * M)) 117 | 118 | filter = filter / np.sum(filter) 119 | return filter -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | def save_model(model, optimizer, state, path): 5 | if isinstance(model, torch.nn.DataParallel): 6 | model = model.module # save state dict of wrapped module 7 | if len(os.path.dirname(path)) > 0 and not os.path.exists(os.path.dirname(path)): 8 | os.makedirs(os.path.dirname(path)) 9 | torch.save({ 10 | 'model_state_dict': model.state_dict(), 11 | 'optimizer_state_dict': optimizer.state_dict(), 12 | 'state': state, # state of training loop (was 'step') 13 | }, path) 14 | 15 | 16 | def load_model(model, optimizer, path, cuda): 17 | if isinstance(model, torch.nn.DataParallel): 18 | model = model.module # load state dict of wrapped module 19 | if cuda: 20 | checkpoint = torch.load(path) 21 | else: 22 | checkpoint = torch.load(path, map_location='cpu') 23 | try: 24 | model.load_state_dict(checkpoint['model_state_dict']) 25 | except: 26 | # work-around for loading checkpoints where DataParallel was saved instead of inner module 27 | from collections import OrderedDict 28 | model_state_dict_fixed = OrderedDict() 29 | prefix = 'module.' 30 | for k, v in checkpoint['model_state_dict'].items(): 31 | if k.startswith(prefix): 32 | k = k[len(prefix):] 33 | model_state_dict_fixed[k] = v 34 | model.load_state_dict(model_state_dict_fixed) 35 | if optimizer is not None: 36 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 37 | if 'state' in checkpoint: 38 | state = checkpoint['state'] 39 | else: 40 | # older checkpoints only store step, rest of state won't be there 41 | state = {'step': checkpoint['step']} 42 | return state 43 | 44 | 45 | def compute_loss(model, inputs, targets, criterion, compute_grad=False): 46 | ''' 47 | Computes gradients of model with given inputs and targets and loss function. 48 | Optionally backpropagates to compute gradients for weights. 49 | Procedure depends on whether we have one model for each source or not 50 | :param model: Model to train with 51 | :param inputs: Input mixture 52 | :param targets: Target sources 53 | :param criterion: Loss function to use (L1, L2, ..) 54 | :param compute_grad: Whether to compute gradients 55 | :return: Model outputs, Average loss over batch 56 | ''' 57 | all_outputs = {} 58 | 59 | if model.separate: 60 | avg_loss = 0.0 61 | num_sources = 0 62 | for inst in model.instruments: 63 | output = model(inputs, inst) 64 | loss = criterion(output[inst], targets[inst]) 65 | 66 | if compute_grad: 67 | loss.backward() 68 | 69 | avg_loss += loss.item() 70 | num_sources += 1 71 | 72 | all_outputs[inst] = output[inst].detach().clone() 73 | 74 | avg_loss /= float(num_sources) 75 | else: 76 | loss = 0 77 | all_outputs = model(inputs) 78 | for inst in all_outputs.keys(): 79 | loss += criterion(all_outputs[inst], targets[inst]) 80 | 81 | if compute_grad: 82 | loss.backward() 83 | 84 | avg_loss = loss.item() / float(len(all_outputs)) 85 | 86 | return all_outputs, avg_loss 87 | 88 | 89 | class DataParallel(torch.nn.DataParallel): 90 | def __init__(self, module, device_ids=None, output_device=None, dim=0): 91 | super(DataParallel, self).__init__(module, device_ids, output_device, dim) 92 | 93 | def __getattr__(self, name): 94 | try: 95 | return super().__getattr__(name) 96 | except AttributeError: 97 | return getattr(self.module, name) -------------------------------------------------------------------------------- /model/waveunet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from model.crop import centre_crop 5 | from model.resample import Resample1d 6 | from model.conv import ConvLayer 7 | 8 | class UpsamplingBlock(nn.Module): 9 | def __init__(self, n_inputs, n_shortcut, n_outputs, kernel_size, stride, depth, conv_type, res): 10 | super(UpsamplingBlock, self).__init__() 11 | assert(stride > 1) 12 | 13 | # CONV 1 for UPSAMPLING 14 | if res == "fixed": 15 | self.upconv = Resample1d(n_inputs, 15, stride, transpose=True) 16 | else: 17 | self.upconv = ConvLayer(n_inputs, n_inputs, kernel_size, stride, conv_type, transpose=True) 18 | 19 | self.pre_shortcut_convs = nn.ModuleList([ConvLayer(n_inputs, n_outputs, kernel_size, 1, conv_type)] + 20 | [ConvLayer(n_outputs, n_outputs, kernel_size, 1, conv_type) for _ in range(depth - 1)]) 21 | 22 | # CONVS to combine high- with low-level information (from shortcut) 23 | self.post_shortcut_convs = nn.ModuleList([ConvLayer(n_outputs + n_shortcut, n_outputs, kernel_size, 1, conv_type)] + 24 | [ConvLayer(n_outputs, n_outputs, kernel_size, 1, conv_type) for _ in range(depth - 1)]) 25 | 26 | def forward(self, x, shortcut): 27 | # UPSAMPLE HIGH-LEVEL FEATURES 28 | upsampled = self.upconv(x) 29 | 30 | for conv in self.pre_shortcut_convs: 31 | upsampled = conv(upsampled) 32 | 33 | # Prepare shortcut connection 34 | combined = centre_crop(shortcut, upsampled) 35 | 36 | # Combine high- and low-level features 37 | for conv in self.post_shortcut_convs: 38 | combined = conv(torch.cat([combined, centre_crop(upsampled, combined)], dim=1)) 39 | return combined 40 | 41 | def get_output_size(self, input_size): 42 | curr_size = self.upconv.get_output_size(input_size) 43 | 44 | # Upsampling convs 45 | for conv in self.pre_shortcut_convs: 46 | curr_size = conv.get_output_size(curr_size) 47 | 48 | # Combine convolutions 49 | for conv in self.post_shortcut_convs: 50 | curr_size = conv.get_output_size(curr_size) 51 | 52 | return curr_size 53 | 54 | class DownsamplingBlock(nn.Module): 55 | def __init__(self, n_inputs, n_shortcut, n_outputs, kernel_size, stride, depth, conv_type, res): 56 | super(DownsamplingBlock, self).__init__() 57 | assert(stride > 1) 58 | 59 | self.kernel_size = kernel_size 60 | self.stride = stride 61 | 62 | # CONV 1 63 | self.pre_shortcut_convs = nn.ModuleList([ConvLayer(n_inputs, n_shortcut, kernel_size, 1, conv_type)] + 64 | [ConvLayer(n_shortcut, n_shortcut, kernel_size, 1, conv_type) for _ in range(depth - 1)]) 65 | 66 | self.post_shortcut_convs = nn.ModuleList([ConvLayer(n_shortcut, n_outputs, kernel_size, 1, conv_type)] + 67 | [ConvLayer(n_outputs, n_outputs, kernel_size, 1, conv_type) for _ in 68 | range(depth - 1)]) 69 | 70 | # CONV 2 with decimation 71 | if res == "fixed": 72 | self.downconv = Resample1d(n_outputs, 15, stride) # Resampling with fixed-size sinc lowpass filter 73 | else: 74 | self.downconv = ConvLayer(n_outputs, n_outputs, kernel_size, stride, conv_type) 75 | 76 | def forward(self, x): 77 | # PREPARING SHORTCUT FEATURES 78 | shortcut = x 79 | for conv in self.pre_shortcut_convs: 80 | shortcut = conv(shortcut) 81 | 82 | # PREPARING FOR DOWNSAMPLING 83 | out = shortcut 84 | for conv in self.post_shortcut_convs: 85 | out = conv(out) 86 | 87 | # DOWNSAMPLING 88 | out = self.downconv(out) 89 | 90 | return out, shortcut 91 | 92 | def get_input_size(self, output_size): 93 | curr_size = self.downconv.get_input_size(output_size) 94 | 95 | for conv in reversed(self.post_shortcut_convs): 96 | curr_size = conv.get_input_size(curr_size) 97 | 98 | for conv in reversed(self.pre_shortcut_convs): 99 | curr_size = conv.get_input_size(curr_size) 100 | return curr_size 101 | 102 | class Waveunet(nn.Module): 103 | def __init__(self, num_inputs, num_channels, num_outputs, instruments, kernel_size, target_output_size, conv_type, res, separate=False, depth=1, strides=2): 104 | super(Waveunet, self).__init__() 105 | 106 | self.num_levels = len(num_channels) 107 | self.strides = strides 108 | self.kernel_size = kernel_size 109 | self.num_inputs = num_inputs 110 | self.num_outputs = num_outputs 111 | self.depth = depth 112 | self.instruments = instruments 113 | self.separate = separate 114 | 115 | # Only odd filter kernels allowed 116 | assert(kernel_size % 2 == 1) 117 | 118 | self.waveunets = nn.ModuleDict() 119 | 120 | model_list = instruments if separate else ["ALL"] 121 | # Create a model for each source if we separate sources separately, otherwise only one (model_list=["ALL"]) 122 | for instrument in model_list: 123 | module = nn.Module() 124 | 125 | module.downsampling_blocks = nn.ModuleList() 126 | module.upsampling_blocks = nn.ModuleList() 127 | 128 | for i in range(self.num_levels - 1): 129 | in_ch = num_inputs if i == 0 else num_channels[i] 130 | 131 | module.downsampling_blocks.append( 132 | DownsamplingBlock(in_ch, num_channels[i], num_channels[i+1], kernel_size, strides, depth, conv_type, res)) 133 | 134 | for i in range(0, self.num_levels - 1): 135 | module.upsampling_blocks.append( 136 | UpsamplingBlock(num_channels[-1-i], num_channels[-2-i], num_channels[-2-i], kernel_size, strides, depth, conv_type, res)) 137 | 138 | module.bottlenecks = nn.ModuleList( 139 | [ConvLayer(num_channels[-1], num_channels[-1], kernel_size, 1, conv_type) for _ in range(depth)]) 140 | 141 | # Output conv 142 | outputs = num_outputs if separate else num_outputs * len(instruments) 143 | module.output_conv = nn.Conv1d(num_channels[0], outputs, 1) 144 | 145 | self.waveunets[instrument] = module 146 | 147 | self.set_output_size(target_output_size) 148 | 149 | def set_output_size(self, target_output_size): 150 | self.target_output_size = target_output_size 151 | 152 | self.input_size, self.output_size = self.check_padding(target_output_size) 153 | print("Using valid convolutions with " + str(self.input_size) + " inputs and " + str(self.output_size) + " outputs") 154 | 155 | assert((self.input_size - self.output_size) % 2 == 0) 156 | self.shapes = {"output_start_frame" : (self.input_size - self.output_size) // 2, 157 | "output_end_frame" : (self.input_size - self.output_size) // 2 + self.output_size, 158 | "output_frames" : self.output_size, 159 | "input_frames" : self.input_size} 160 | 161 | def check_padding(self, target_output_size): 162 | # Ensure number of outputs covers a whole number of cycles so each output in the cycle is weighted equally during training 163 | bottleneck = 1 164 | 165 | while True: 166 | out = self.check_padding_for_bottleneck(bottleneck, target_output_size) 167 | if out is not False: 168 | return out 169 | bottleneck += 1 170 | 171 | def check_padding_for_bottleneck(self, bottleneck, target_output_size): 172 | module = self.waveunets[[k for k in self.waveunets.keys()][0]] 173 | try: 174 | curr_size = bottleneck 175 | for idx, block in enumerate(module.upsampling_blocks): 176 | curr_size = block.get_output_size(curr_size) 177 | output_size = curr_size 178 | 179 | # Bottleneck-Conv 180 | curr_size = bottleneck 181 | for block in reversed(module.bottlenecks): 182 | curr_size = block.get_input_size(curr_size) 183 | for idx, block in enumerate(reversed(module.downsampling_blocks)): 184 | curr_size = block.get_input_size(curr_size) 185 | 186 | assert(output_size >= target_output_size) 187 | return curr_size, output_size 188 | except AssertionError as e: 189 | return False 190 | 191 | def forward_module(self, x, module): 192 | ''' 193 | A forward pass through a single Wave-U-Net (multiple Wave-U-Nets might be used, one for each source) 194 | :param x: Input mix 195 | :param module: Network module to be used for prediction 196 | :return: Source estimates 197 | ''' 198 | shortcuts = [] 199 | out = x 200 | 201 | # DOWNSAMPLING BLOCKS 202 | for block in module.downsampling_blocks: 203 | out, short = block(out) 204 | shortcuts.append(short) 205 | 206 | # BOTTLENECK CONVOLUTION 207 | for conv in module.bottlenecks: 208 | out = conv(out) 209 | 210 | # UPSAMPLING BLOCKS 211 | for idx, block in enumerate(module.upsampling_blocks): 212 | out = block(out, shortcuts[-1 - idx]) 213 | 214 | # OUTPUT CONV 215 | out = module.output_conv(out) 216 | if not self.training: # At test time clip predictions to valid amplitude range 217 | out = out.clamp(min=-1.0, max=1.0) 218 | return out 219 | 220 | def forward(self, x, inst=None): 221 | curr_input_size = x.shape[-1] 222 | assert(curr_input_size == self.input_size) # User promises to feed the proper input himself, to get the pre-calculated (NOT the originally desired) output size 223 | 224 | if self.separate: 225 | return {inst : self.forward_module(x, self.waveunets[inst])} 226 | else: 227 | assert(len(self.waveunets) == 1) 228 | out = self.forward_module(x, self.waveunets["ALL"]) 229 | 230 | out_dict = {} 231 | for idx, inst in enumerate(self.instruments): 232 | out_dict[inst] = out[:, idx * self.num_outputs:(idx + 1) * self.num_outputs] 233 | return out_dict -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import data.utils 5 | import model.utils as model_utils 6 | 7 | from test import predict_song 8 | from model.waveunet import Waveunet 9 | 10 | def main(args): 11 | # MODEL 12 | num_features = [args.features*i for i in range(1, args.levels+1)] if args.feature_growth == "add" else \ 13 | [args.features*2**i for i in range(0, args.levels)] 14 | target_outputs = int(args.output_size * args.sr) 15 | model = Waveunet(args.channels, num_features, args.channels, args.instruments, kernel_size=args.kernel_size, 16 | target_output_size=target_outputs, depth=args.depth, strides=args.strides, 17 | conv_type=args.conv_type, res=args.res, separate=args.separate) 18 | 19 | if args.cuda: 20 | model = model_utils.DataParallel(model) 21 | print("move model to gpu") 22 | model.cuda() 23 | 24 | print("Loading model from checkpoint " + str(args.load_model)) 25 | state = model_utils.load_model(model, None, args.load_model, args.cuda) 26 | print('Step', state['step']) 27 | 28 | preds = predict_song(args, args.input, model) 29 | 30 | output_folder = os.path.dirname(args.input) if args.output is None else args.output 31 | for inst in preds.keys(): 32 | data.utils.write_wav(os.path.join(output_folder, os.path.basename(args.input) + "_" + inst + ".wav"), preds[inst], args.sr) 33 | 34 | if __name__ == '__main__': 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument('--instruments', type=str, nargs='+', default=["bass", "drums", "other", "vocals"], 37 | help="List of instruments to separate (default: \"bass drums other vocals\")") 38 | parser.add_argument('--cuda', action='store_true', 39 | help='Use CUDA (default: False)') 40 | parser.add_argument('--features', type=int, default=32, 41 | help='Number of feature channels per layer') 42 | parser.add_argument('--load_model', type=str, default='checkpoints/waveunet/model', 43 | help='Reload a previously trained model') 44 | parser.add_argument('--batch_size', type=int, default=4, 45 | help="Batch size") 46 | parser.add_argument('--levels', type=int, default=6, 47 | help="Number of DS/US blocks") 48 | parser.add_argument('--depth', type=int, default=1, 49 | help="Number of convs per block") 50 | parser.add_argument('--sr', type=int, default=44100, 51 | help="Sampling rate") 52 | parser.add_argument('--channels', type=int, default=2, 53 | help="Number of input audio channels") 54 | parser.add_argument('--kernel_size', type=int, default=5, 55 | help="Filter width of kernels. Has to be an odd number") 56 | parser.add_argument('--output_size', type=float, default=2.0, 57 | help="Output duration") 58 | parser.add_argument('--strides', type=int, default=4, 59 | help="Strides in Waveunet") 60 | parser.add_argument('--conv_type', type=str, default="gn", 61 | help="Type of convolution (normal, BN-normalised, GN-normalised): normal/bn/gn") 62 | parser.add_argument('--res', type=str, default="fixed", 63 | help="Resampling strategy: fixed sinc-based lowpass filtering or learned conv layer: fixed/learned") 64 | parser.add_argument('--separate', type=int, default=1, 65 | help="Train separate model for each source (1) or only one (0)") 66 | parser.add_argument('--feature_growth', type=str, default="double", 67 | help="How the features in each layer should grow, either (add) the initial number of features each time, or multiply by 2 (double)") 68 | 69 | parser.add_argument('--input', type=str, default=os.path.join("audio_examples", "Cristina Vane - So Easy", "mix.mp3"), 70 | help="Path to input mixture to be separated") 71 | parser.add_argument('--output', type=str, default=None, help="Output path (same folder as input path if not set)") 72 | 73 | args = parser.parse_args() 74 | 75 | main(args) 76 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | future 2 | numpy 3 | librosa 4 | soundfile 5 | musdb 6 | museval 7 | h5py 8 | tqdm 9 | torch==1.4.0 10 | torchvision==0.5.0 11 | tensorboard 12 | sortedcontainers -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import museval 2 | from tqdm import tqdm 3 | 4 | import numpy as np 5 | import torch 6 | 7 | import data.utils 8 | import model.utils as model_utils 9 | import utils 10 | 11 | def compute_model_output(model, inputs): 12 | ''' 13 | Computes outputs of model with given inputs. Does NOT allow propagating gradients! See compute_loss for training. 14 | Procedure depends on whether we have one model for each source or not 15 | :param model: Model to train with 16 | :param compute_grad: Whether to compute gradients 17 | :return: Model outputs, Average loss over batch 18 | ''' 19 | all_outputs = {} 20 | 21 | if model.separate: 22 | for inst in model.instruments: 23 | output = model(inputs, inst) 24 | all_outputs[inst] = output[inst].detach().clone() 25 | else: 26 | all_outputs = model(inputs) 27 | 28 | return all_outputs 29 | 30 | def predict(audio, model): 31 | ''' 32 | Predict sources for a given audio input signal, with a given model. Audio is split into chunks to make predictions on each chunk before they are concatenated. 33 | :param audio: Audio input tensor, either Pytorch tensor or numpy array 34 | :param model: Pytorch model 35 | :return: Source predictions, dictionary with source names as keys 36 | ''' 37 | if isinstance(audio, torch.Tensor): 38 | is_cuda = audio.is_cuda() 39 | audio = audio.detach().cpu().numpy() 40 | return_mode = "pytorch" 41 | else: 42 | return_mode = "numpy" 43 | 44 | expected_outputs = audio.shape[1] 45 | 46 | # Pad input if it is not divisible in length by the frame shift number 47 | output_shift = model.shapes["output_frames"] 48 | pad_back = audio.shape[1] % output_shift 49 | pad_back = 0 if pad_back == 0 else output_shift - pad_back 50 | if pad_back > 0: 51 | audio = np.pad(audio, [(0,0), (0, pad_back)], mode="constant", constant_values=0.0) 52 | 53 | target_outputs = audio.shape[1] 54 | outputs = {key: np.zeros(audio.shape, np.float32) for key in model.instruments} 55 | 56 | # Pad mixture across time at beginning and end so that neural network can make prediction at the beginning and end of signal 57 | pad_front_context = model.shapes["output_start_frame"] 58 | pad_back_context = model.shapes["input_frames"] - model.shapes["output_end_frame"] 59 | audio = np.pad(audio, [(0,0), (pad_front_context, pad_back_context)], mode="constant", constant_values=0.0) 60 | 61 | # Iterate over mixture magnitudes, fetch network prediction 62 | with torch.no_grad(): 63 | for target_start_pos in range(0, target_outputs, model.shapes["output_frames"]): 64 | # Prepare mixture excerpt by selecting time interval 65 | curr_input = audio[:, target_start_pos:target_start_pos + model.shapes["input_frames"]] # Since audio was front-padded input of [targetpos:targetpos+inputframes] actually predicts [targetpos:targetpos+outputframes] target range 66 | 67 | # Convert to Pytorch tensor for model prediction 68 | curr_input = torch.from_numpy(curr_input).unsqueeze(0) 69 | 70 | # Predict 71 | for key, curr_targets in compute_model_output(model, curr_input).items(): 72 | outputs[key][:,target_start_pos:target_start_pos+model.shapes["output_frames"]] = curr_targets.squeeze(0).cpu().numpy() 73 | 74 | # Crop to expected length (since we padded to handle the frame shift) 75 | outputs = {key : outputs[key][:,:expected_outputs] for key in outputs.keys()} 76 | 77 | if return_mode == "pytorch": 78 | outputs = torch.from_numpy(outputs) 79 | if is_cuda: 80 | outputs = outputs.cuda() 81 | return outputs 82 | 83 | def predict_song(args, audio_path, model): 84 | ''' 85 | Predicts sources for an audio file for which the file path is given, using a given model. 86 | Takes care of resampling the input audio to the models sampling rate and resampling predictions back to input sampling rate. 87 | :param args: Options dictionary 88 | :param audio_path: Path to mixture audio file 89 | :param model: Pytorch model 90 | :return: Source estimates given as dictionary with keys as source names 91 | ''' 92 | model.eval() 93 | 94 | # Load mixture in original sampling rate 95 | mix_audio, mix_sr = data.utils.load(audio_path, sr=None, mono=False) 96 | mix_channels = mix_audio.shape[0] 97 | mix_len = mix_audio.shape[1] 98 | 99 | # Adapt mixture channels to required input channels 100 | if args.channels == 1: 101 | mix_audio = np.mean(mix_audio, axis=0, keepdims=True) 102 | else: 103 | if mix_channels == 1: # Duplicate channels if input is mono but model is stereo 104 | mix_audio = np.tile(mix_audio, [args.channels, 1]) 105 | else: 106 | assert(mix_channels == args.channels) 107 | 108 | # resample to model sampling rate 109 | mix_audio = data.utils.resample(mix_audio, mix_sr, args.sr) 110 | 111 | sources = predict(mix_audio, model) 112 | 113 | # Resample back to mixture sampling rate in case we had model on different sampling rate 114 | sources = {key : data.utils.resample(sources[key], args.sr, mix_sr) for key in sources.keys()} 115 | 116 | # In case we had to pad the mixture at the end, or we have a few samples too many due to inconsistent down- and upsamṕling, remove those samples from source prediction now 117 | for key in sources.keys(): 118 | diff = sources[key].shape[1] - mix_len 119 | if diff > 0: 120 | print("WARNING: Cropping " + str(diff) + " samples") 121 | sources[key] = sources[key][:, :-diff] 122 | elif diff < 0: 123 | print("WARNING: Padding output by " + str(diff) + " samples") 124 | sources[key] = np.pad(sources[key], [(0,0), (0, -diff)], "constant", 0.0) 125 | 126 | # Adapt channels 127 | if mix_channels > args.channels: 128 | assert(args.channels == 1) 129 | # Duplicate mono predictions 130 | sources[key] = np.tile(sources[key], [mix_channels, 1]) 131 | elif mix_channels < args.channels: 132 | assert(mix_channels == 1) 133 | # Reduce model output to mono 134 | sources[key] = np.mean(sources[key], axis=0, keepdims=True) 135 | 136 | sources[key] = np.asfortranarray(sources[key]) # So librosa does not complain if we want to save it 137 | 138 | return sources 139 | 140 | def evaluate(args, dataset, model, instruments): 141 | ''' 142 | Evaluates a given model on a given dataset 143 | :param args: Options dict 144 | :param dataset: Dataset object 145 | :param model: Pytorch model 146 | :param instruments: List of source names 147 | :return: Performance metric dictionary, list with each element describing one dataset sample's results 148 | ''' 149 | perfs = list() 150 | model.eval() 151 | with torch.no_grad(): 152 | for example in dataset: 153 | print("Evaluating " + example["mix"]) 154 | 155 | # Load source references in their original sr and channel number 156 | target_sources = np.stack([data.utils.load(example[instrument], sr=None, mono=False)[0].T for instrument in instruments]) 157 | 158 | # Predict using mixture 159 | pred_sources = predict_song(args, example["mix"], model) 160 | pred_sources = np.stack([pred_sources[key].T for key in instruments]) 161 | 162 | # Evaluate 163 | SDR, ISR, SIR, SAR, _ = museval.metrics.bss_eval(target_sources, pred_sources) 164 | song = {} 165 | for idx, name in enumerate(instruments): 166 | song[name] = {"SDR" : SDR[idx], "ISR" : ISR[idx], "SIR" : SIR[idx], "SAR" : SAR[idx]} 167 | perfs.append(song) 168 | 169 | return perfs 170 | 171 | 172 | def validate(args, model, criterion, test_data): 173 | ''' 174 | Iterate with a given model over a given test dataset and compute the desired loss 175 | :param args: Options dictionary 176 | :param model: Pytorch model 177 | :param criterion: Loss function to use (similar to Pytorch criterions) 178 | :param test_data: Test dataset (Pytorch dataset) 179 | :return: 180 | ''' 181 | # PREPARE DATA 182 | dataloader = torch.utils.data.DataLoader(test_data, 183 | batch_size=args.batch_size, 184 | shuffle=False, 185 | num_workers=args.num_workers) 186 | 187 | # VALIDATE 188 | model.eval() 189 | total_loss = 0. 190 | with tqdm(total=len(test_data) // args.batch_size) as pbar, torch.no_grad(): 191 | for example_num, (x, targets) in enumerate(dataloader): 192 | if args.cuda: 193 | x = x.cuda() 194 | for k in list(targets.keys()): 195 | targets[k] = targets[k].cuda() 196 | 197 | _, avg_loss = model_utils.compute_loss(model, x, targets, criterion) 198 | 199 | total_loss += (1. / float(example_num + 1)) * (avg_loss - total_loss) 200 | 201 | pbar.set_description("Current loss: {:.4f}".format(total_loss)) 202 | pbar.update(1) 203 | 204 | return total_loss -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | from functools import partial 5 | 6 | import torch 7 | import pickle 8 | import numpy as np 9 | 10 | import torch.nn as nn 11 | from torch.utils.tensorboard import SummaryWriter 12 | from torch.optim import Adam 13 | from tqdm import tqdm 14 | 15 | import model.utils as model_utils 16 | import utils 17 | from data.dataset import SeparationDataset 18 | from data.musdb import get_musdb_folds 19 | from data.utils import crop_targets, random_amplify 20 | from test import evaluate, validate 21 | from model.waveunet import Waveunet 22 | 23 | def main(args): 24 | #torch.backends.cudnn.benchmark=True # This makes dilated conv much faster for CuDNN 7.5 25 | 26 | # MODEL 27 | num_features = [args.features*i for i in range(1, args.levels+1)] if args.feature_growth == "add" else \ 28 | [args.features*2**i for i in range(0, args.levels)] 29 | target_outputs = int(args.output_size * args.sr) 30 | model = Waveunet(args.channels, num_features, args.channels, args.instruments, kernel_size=args.kernel_size, 31 | target_output_size=target_outputs, depth=args.depth, strides=args.strides, 32 | conv_type=args.conv_type, res=args.res, separate=args.separate) 33 | 34 | if args.cuda: 35 | model = model_utils.DataParallel(model) 36 | print("move model to gpu") 37 | model.cuda() 38 | 39 | print('model: ', model) 40 | print('parameter count: ', str(sum(p.numel() for p in model.parameters()))) 41 | 42 | writer = SummaryWriter(args.log_dir) 43 | 44 | ### DATASET 45 | musdb = get_musdb_folds(args.dataset_dir) 46 | # If not data augmentation, at least crop targets to fit model output shape 47 | crop_func = partial(crop_targets, shapes=model.shapes) 48 | # Data augmentation function for training 49 | augment_func = partial(random_amplify, shapes=model.shapes, min=0.7, max=1.0) 50 | train_data = SeparationDataset(musdb, "train", args.instruments, args.sr, args.channels, model.shapes, True, args.hdf_dir, audio_transform=augment_func) 51 | val_data = SeparationDataset(musdb, "val", args.instruments, args.sr, args.channels, model.shapes, False, args.hdf_dir, audio_transform=crop_func) 52 | test_data = SeparationDataset(musdb, "test", args.instruments, args.sr, args.channels, model.shapes, False, args.hdf_dir, audio_transform=crop_func) 53 | 54 | dataloader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, worker_init_fn=utils.worker_init_fn) 55 | 56 | ##### TRAINING #### 57 | 58 | # Set up the loss function 59 | if args.loss == "L1": 60 | criterion = nn.L1Loss() 61 | elif args.loss == "L2": 62 | criterion = nn.MSELoss() 63 | else: 64 | raise NotImplementedError("Couldn't find this loss!") 65 | 66 | # Set up optimiser 67 | optimizer = Adam(params=model.parameters(), lr=args.lr) 68 | 69 | # Set up training state dict that will also be saved into checkpoints 70 | state = {"step" : 0, 71 | "worse_epochs" : 0, 72 | "epochs" : 0, 73 | "best_loss" : np.Inf} 74 | 75 | # LOAD MODEL CHECKPOINT IF DESIRED 76 | if args.load_model is not None: 77 | print("Continuing training full model from checkpoint " + str(args.load_model)) 78 | state = model_utils.load_model(model, optimizer, args.load_model, args.cuda) 79 | 80 | print('TRAINING START') 81 | while state["worse_epochs"] < args.patience: 82 | print("Training one epoch from iteration " + str(state["step"])) 83 | avg_time = 0. 84 | model.train() 85 | with tqdm(total=len(train_data) // args.batch_size) as pbar: 86 | np.random.seed() 87 | for example_num, (x, targets) in enumerate(dataloader): 88 | if args.cuda: 89 | x = x.cuda() 90 | for k in list(targets.keys()): 91 | targets[k] = targets[k].cuda() 92 | 93 | t = time.time() 94 | 95 | # Set LR for this iteration 96 | utils.set_cyclic_lr(optimizer, example_num, len(train_data) // args.batch_size, args.cycles, args.min_lr, args.lr) 97 | writer.add_scalar("lr", utils.get_lr(optimizer), state["step"]) 98 | 99 | # Compute loss for each instrument/model 100 | optimizer.zero_grad() 101 | outputs, avg_loss = model_utils.compute_loss(model, x, targets, criterion, compute_grad=True) 102 | 103 | optimizer.step() 104 | 105 | state["step"] += 1 106 | 107 | t = time.time() - t 108 | avg_time += (1. / float(example_num + 1)) * (t - avg_time) 109 | 110 | writer.add_scalar("train_loss", avg_loss, state["step"]) 111 | 112 | if example_num % args.example_freq == 0: 113 | input_centre = torch.mean(x[0, :, model.shapes["output_start_frame"]:model.shapes["output_end_frame"]], 0) # Stereo not supported for logs yet 114 | writer.add_audio("input", input_centre, state["step"], sample_rate=args.sr) 115 | 116 | for inst in outputs.keys(): 117 | writer.add_audio(inst + "_pred", torch.mean(outputs[inst][0], 0), state["step"], sample_rate=args.sr) 118 | writer.add_audio(inst + "_target", torch.mean(targets[inst][0], 0), state["step"], sample_rate=args.sr) 119 | 120 | pbar.update(1) 121 | 122 | # VALIDATE 123 | val_loss = validate(args, model, criterion, val_data) 124 | print("VALIDATION FINISHED: LOSS: " + str(val_loss)) 125 | writer.add_scalar("val_loss", val_loss, state["step"]) 126 | 127 | # EARLY STOPPING CHECK 128 | checkpoint_path = os.path.join(args.checkpoint_dir, "checkpoint_" + str(state["step"])) 129 | if val_loss >= state["best_loss"]: 130 | state["worse_epochs"] += 1 131 | else: 132 | print("MODEL IMPROVED ON VALIDATION SET!") 133 | state["worse_epochs"] = 0 134 | state["best_loss"] = val_loss 135 | state["best_checkpoint"] = checkpoint_path 136 | 137 | state["epochs"] += 1 138 | # CHECKPOINT 139 | print("Saving model...") 140 | model_utils.save_model(model, optimizer, state, checkpoint_path) 141 | 142 | 143 | #### TESTING #### 144 | # Test loss 145 | print("TESTING") 146 | 147 | # Load best model based on validation loss 148 | state = model_utils.load_model(model, None, state["best_checkpoint"], args.cuda) 149 | test_loss = validate(args, model, criterion, test_data) 150 | print("TEST FINISHED: LOSS: " + str(test_loss)) 151 | writer.add_scalar("test_loss", test_loss, state["step"]) 152 | 153 | # Mir_eval metrics 154 | test_metrics = evaluate(args, musdb["test"], model, args.instruments) 155 | 156 | # Dump all metrics results into pickle file for later analysis if needed 157 | with open(os.path.join(args.checkpoint_dir, "results.pkl"), "wb") as f: 158 | pickle.dump(test_metrics, f) 159 | 160 | # Write most important metrics into Tensorboard log 161 | avg_SDRs = {inst : np.mean([np.nanmean(song[inst]["SDR"]) for song in test_metrics]) for inst in args.instruments} 162 | avg_SIRs = {inst : np.mean([np.nanmean(song[inst]["SIR"]) for song in test_metrics]) for inst in args.instruments} 163 | for inst in args.instruments: 164 | writer.add_scalar("test_SDR_" + inst, avg_SDRs[inst], state["step"]) 165 | writer.add_scalar("test_SIR_" + inst, avg_SIRs[inst], state["step"]) 166 | overall_SDR = np.mean([v for v in avg_SDRs.values()]) 167 | writer.add_scalar("test_SDR", overall_SDR) 168 | print("SDR: " + str(overall_SDR)) 169 | 170 | writer.close() 171 | 172 | if __name__ == '__main__': 173 | ## TRAIN PARAMETERS 174 | parser = argparse.ArgumentParser() 175 | parser.add_argument('--instruments', type=str, nargs='+', default=["bass", "drums", "other", "vocals"], 176 | help="List of instruments to separate (default: \"bass drums other vocals\")") 177 | parser.add_argument('--cuda', action='store_true', 178 | help='Use CUDA (default: False)') 179 | parser.add_argument('--num_workers', type=int, default=1, 180 | help='Number of data loader worker threads (default: 1)') 181 | parser.add_argument('--features', type=int, default=32, 182 | help='Number of feature channels per layer') 183 | parser.add_argument('--log_dir', type=str, default='logs/waveunet', 184 | help='Folder to write logs into') 185 | parser.add_argument('--dataset_dir', type=str, default="/mnt/windaten/Datasets/MUSDB18HQ", 186 | help='Dataset path') 187 | parser.add_argument('--hdf_dir', type=str, default="hdf", 188 | help='Dataset path') 189 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoints/waveunet', 190 | help='Folder to write checkpoints into') 191 | parser.add_argument('--load_model', type=str, default=None, 192 | help='Reload a previously trained model (whole task model)') 193 | parser.add_argument('--lr', type=float, default=1e-3, 194 | help='Initial learning rate in LR cycle (default: 1e-3)') 195 | parser.add_argument('--min_lr', type=float, default=5e-5, 196 | help='Minimum learning rate in LR cycle (default: 5e-5)') 197 | parser.add_argument('--cycles', type=int, default=2, 198 | help='Number of LR cycles per epoch') 199 | parser.add_argument('--batch_size', type=int, default=4, 200 | help="Batch size") 201 | parser.add_argument('--levels', type=int, default=6, 202 | help="Number of DS/US blocks") 203 | parser.add_argument('--depth', type=int, default=1, 204 | help="Number of convs per block") 205 | parser.add_argument('--sr', type=int, default=44100, 206 | help="Sampling rate") 207 | parser.add_argument('--channels', type=int, default=2, 208 | help="Number of input audio channels") 209 | parser.add_argument('--kernel_size', type=int, default=5, 210 | help="Filter width of kernels. Has to be an odd number") 211 | parser.add_argument('--output_size', type=float, default=2.0, 212 | help="Output duration") 213 | parser.add_argument('--strides', type=int, default=4, 214 | help="Strides in Waveunet") 215 | parser.add_argument('--patience', type=int, default=20, 216 | help="Patience for early stopping on validation set") 217 | parser.add_argument('--example_freq', type=int, default=200, 218 | help="Write an audio summary into Tensorboard logs every X training iterations") 219 | parser.add_argument('--loss', type=str, default="L1", 220 | help="L1 or L2") 221 | parser.add_argument('--conv_type', type=str, default="gn", 222 | help="Type of convolution (normal, BN-normalised, GN-normalised): normal/bn/gn") 223 | parser.add_argument('--res', type=str, default="fixed", 224 | help="Resampling strategy: fixed sinc-based lowpass filtering or learned conv layer: fixed/learned") 225 | parser.add_argument('--separate', type=int, default=1, 226 | help="Train separate model for each source (1) or only one (0)") 227 | parser.add_argument('--feature_growth', type=str, default="double", 228 | help="How the features in each layer should grow, either (add) the initial number of features each time, or multiply by 2 (double)") 229 | 230 | args = parser.parse_args() 231 | 232 | main(args) 233 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def worker_init_fn(worker_id): # This is apparently needed to ensure workers have different random seeds and draw different examples! 4 | np.random.seed(np.random.get_state()[1][0] + worker_id) 5 | 6 | def get_lr(optim): 7 | return optim.param_groups[0]["lr"] 8 | 9 | def set_lr(optim, lr): 10 | for g in optim.param_groups: 11 | g['lr'] = lr 12 | 13 | def set_cyclic_lr(optimizer, it, epoch_it, cycles, min_lr, max_lr): 14 | cycle_length = epoch_it // cycles 15 | curr_cycle = min(it // cycle_length, cycles-1) 16 | curr_it = it - cycle_length * curr_cycle 17 | 18 | new_lr = min_lr + 0.5*(max_lr - min_lr)*(1 + np.cos((float(curr_it) / float(cycle_length)) * np.pi)) 19 | set_lr(optimizer, new_lr) --------------------------------------------------------------------------------