├── LICENSE
├── README.md
├── Singularity
├── audio_examples
└── Cristina Vane - So Easy
│ ├── accompaniment_true.mp3
│ ├── mix.mp3
│ ├── mix.mp3_bass.wav
│ ├── mix.mp3_drums.wav
│ ├── mix.mp3_other.wav
│ ├── mix.mp3_vocals.wav
│ └── vocals_true.mp3
├── checkpoints
└── README.md
├── cog.yaml
├── cog_predict.py
├── data
├── __init__.py
├── dataset.py
├── musdb.py
└── utils.py
├── hdf
└── README.md
├── logs
└── README.md
├── model
├── __init__.py
├── conv.py
├── crop.py
├── resample.py
├── utils.py
└── waveunet.py
├── predict.py
├── requirements.txt
├── test.py
├── train.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Daniel Stoller
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Wave-U-Net (Pytorch)
2 |
3 |
4 | Improved version of the [Wave-U-Net](https://arxiv.org/abs/1806.03185) for audio source separation, implemented in Pytorch.
5 |
6 | Click [here](www.github.com/f90/Wave-U-Net) for the original Wave-U-Net implementation in Tensorflow.
7 | You can find more information about the model and results there as well.
8 |
9 | # Improvements
10 |
11 | * Multi-instrument separation by default, using a separate standard Wave-U-Net for each source (can be set to one model as well)
12 | * More scalable to larger data: A depth parameter D can be set that employs D convolutions for each single convolution in the original Wave-U-Net
13 | * More configurable: Layer type, resampling factor at each level etc. can be easily changed (different normalization, residual connections...)
14 | * Fast training: Preprocesses the given dataset by saving the audio into HDF files, which can be read very quickly during training, thereby avoiding slowdown due to resampling and decoding
15 | * Modular thanks to Pytorch: Easily replace components of the model with your own variants/layers/losses
16 | * Better output handling: Separate output convolution for each source estimate with linear activation so amplitudes near 1 and -1 can be easily predicted, at test time thresholding to valid amplitude range [-1,1]
17 | * Fixed or dynamic resampling: Either use fixed lowpass filter to avoid aliasing during resampling, or use a learnable convolution
18 |
19 | # Installation
20 |
21 | GPU strongly recommended to avoid very long training times.
22 |
23 | ### Option 1: Direct install (recommended)
24 |
25 | System requirements:
26 | * Linux-based OS
27 | * Python 3.6
28 |
29 | * [libsndfile](http://mega-nerd.com/libsndfile/)
30 |
31 | * [ffmpeg](https://www.ffmpeg.org/)
32 | * CUDA 10.1 for GPU usage
33 |
34 | Clone the repository:
35 | ```
36 | git clone https://github.com/f90/Wave-U-Net-Pytorch.git
37 | ```
38 |
39 | Recommended: Create a new virtual environment to install the required Python packages into, then activate the virtual environment:
40 |
41 | ```
42 | virtualenv --python /usr/bin/python3.6 waveunet-env
43 | source waveunet-env/bin/activate
44 | ```
45 |
46 | Install all the required packages listed in the ``requirements.txt``:
47 |
48 | ```
49 | pip3 install -r requirements.txt
50 | ```
51 |
52 | ### Option 2: Singularity
53 |
54 | We also provide a Singularity container which allows you to avoid installing the correct Python, CUDA and other system libraries, however we don't provide specific advice on how to run the container and so only do this if you have to or know what you are doing (since you need to mount dataset paths to the container etc.)
55 |
56 | To pull the container, run
57 | ```
58 | singularity pull shub://f90/Wave-U-Net-Pytorch
59 | ```
60 |
61 | Then run the container from the directory where you cloned this repository to, using the commands listed further below in this readme.
62 |
63 | # Download datasets
64 |
65 | To directly use the pre-trained models we provide for download to separate your own songs, now skip directly to the [last section](#test), since the datasets are not needed in that case.
66 |
67 | To start training your own models, download the [full MUSDB18HQ dataset](https://sigsep.github.io/datasets/musdb.html) and extract it into a folder of your choice. It should have two subfolders: "test" and "train" as well as a README.md file.
68 |
69 | You can of course use your own datasets for training, but for this you would need to modify the code manually, which will not be discussed here. However, we provide a loading function for the normal MUSDB18 dataset as well.
70 |
71 | # Training the models
72 |
73 | To train a Wave-U-Net, the basic command to use is
74 |
75 | ```
76 | python3.6 train.py --dataset_dir /PATH/TO/MUSDB18HQ
77 | ```
78 | where the path to MUSDB18HQ dataset needs to be specified, which contains the ``train`` and ``test`` subfolders.
79 |
80 | Add more command line parameters as needed:
81 | * ``--cuda`` to activate GPU usage
82 | * ``--hdf_dir PATH`` to save the preprocessed data (HDF files) to custom location PATH, instead of the default ``hdf`` subfolder in this repository
83 | * ``--checkpoint_dir`` and ``--log_dir`` to specify where checkpoint files and logs are saved/loaded
84 | * ``--load_model checkpoints/model_name/checkpoint_X`` to start training with weights given by a certain checkpoint
85 |
86 | For more config options, see ``train.py``.
87 |
88 | Training progress can be monitored by using Tensorboard on the respective ``log_dir``.
89 | After training, the model is evaluated on the MUSDB18HQ test set, and SDR/SIR/SAR metrics are reported for all instruments and written into both the Tensorboard, and in more detail also into a ``results.pkl`` file in the ``checkpoint_dir``
90 |
91 | # Test trained models on songs!
92 |
93 | We provide the default model in a pre-trained form as download so you can separate your own songs right away.
94 |
95 | ## Downloading our pretrained models
96 |
97 | Download our pretrained model [here](https://www.dropbox.com/s/r374hce896g4xlj/models.7z?dl=1).
98 | Extract the archive into the ``checkpoints`` subfolder in this repository, so that you have one subfolder for each model (e.g. ``REPO/checkpoints/waveunet``)
99 |
100 | ## Run pretrained model
101 |
102 | To apply our pretrained model to any of your own songs, simply point to its audio file path using the ``input_path`` parameter:
103 |
104 | ```
105 | python3.6 predict.py --load_model checkpoints/waveunet/model --input "audio_examples/Cristina Vane - So Easy/mix.mp3"
106 | ```
107 |
108 | * Add ``--cuda `` when using a GPU, it should be much quicker
109 | * Point ``--input`` to the music file you want to separate
110 |
111 | By default, output is written where the input music file is located, using the original file name plus the instrument name as output file name. Use ``--output`` to customise the output directory.
112 |
113 | To run your own model:
114 | * Point ``--load_model`` to the checkpoint file of the model you are using. If you used non-default hyper-parameters to train your own model, you must specify them here again so the correct model is set up and can receive the weights!
115 |
--------------------------------------------------------------------------------
/Singularity:
--------------------------------------------------------------------------------
1 | BootStrap: docker
2 | From: nvidia/cuda:10.1-cudnn7-devel-ubuntu18.04
3 |
4 | %post
5 | # Downloads the latest package lists (important).
6 | apt-get update -y
7 | # Runs apt-get while ensuring that there are no user prompts that would
8 | # cause the build process to hang.
9 | # python3-tk is required by matplotlib.
10 | # python3-dev is needed to require some packages.
11 | DEBIAN_FRONTEND=noninteractive apt-get install -y \
12 | python3 \
13 | python3-tk \
14 | python3-pip \
15 | python3-dev \
16 | libsndfile1 \
17 | libsndfile1-dev \
18 | ffmpeg \
19 | git
20 | # Reduce the size of the image by deleting the package lists we downloaded,
21 | # which are useless now.
22 | rm -rf /var/lib/apt/lists/*
23 |
24 | # Install Pipenv.
25 | pip3 install pipenv
26 |
27 | # Install Python modules.
28 | pip3 install future numpy librosa musdb museval h5py tqdm sortedcontainers soundfile
29 | pip3 install torch==1.4.0 torchvision==0.5.0 tensorboard
30 |
31 | %environment
32 | # Pipenv requires a certain terminal encoding.
33 | export LANG=C.UTF-8
34 | export LC_ALL=C.UTF-8
35 | # This configures Pipenv to store the packages in the current working
36 | # directory.
37 | export PIPENV_VENV_IN_PROJECT=1
38 |
--------------------------------------------------------------------------------
/audio_examples/Cristina Vane - So Easy/accompaniment_true.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/Wave-U-Net-Pytorch/86c113da51540c94b01a91fbf17227f0e9c3c274/audio_examples/Cristina Vane - So Easy/accompaniment_true.mp3
--------------------------------------------------------------------------------
/audio_examples/Cristina Vane - So Easy/mix.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/Wave-U-Net-Pytorch/86c113da51540c94b01a91fbf17227f0e9c3c274/audio_examples/Cristina Vane - So Easy/mix.mp3
--------------------------------------------------------------------------------
/audio_examples/Cristina Vane - So Easy/mix.mp3_bass.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/Wave-U-Net-Pytorch/86c113da51540c94b01a91fbf17227f0e9c3c274/audio_examples/Cristina Vane - So Easy/mix.mp3_bass.wav
--------------------------------------------------------------------------------
/audio_examples/Cristina Vane - So Easy/mix.mp3_drums.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/Wave-U-Net-Pytorch/86c113da51540c94b01a91fbf17227f0e9c3c274/audio_examples/Cristina Vane - So Easy/mix.mp3_drums.wav
--------------------------------------------------------------------------------
/audio_examples/Cristina Vane - So Easy/mix.mp3_other.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/Wave-U-Net-Pytorch/86c113da51540c94b01a91fbf17227f0e9c3c274/audio_examples/Cristina Vane - So Easy/mix.mp3_other.wav
--------------------------------------------------------------------------------
/audio_examples/Cristina Vane - So Easy/mix.mp3_vocals.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/Wave-U-Net-Pytorch/86c113da51540c94b01a91fbf17227f0e9c3c274/audio_examples/Cristina Vane - So Easy/mix.mp3_vocals.wav
--------------------------------------------------------------------------------
/audio_examples/Cristina Vane - So Easy/vocals_true.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/Wave-U-Net-Pytorch/86c113da51540c94b01a91fbf17227f0e9c3c274/audio_examples/Cristina Vane - So Easy/vocals_true.mp3
--------------------------------------------------------------------------------
/checkpoints/README.md:
--------------------------------------------------------------------------------
1 | This is where checkpoint files of models will be saved by default, it's recommended to use a subfolder for each different type of model you have.
--------------------------------------------------------------------------------
/cog.yaml:
--------------------------------------------------------------------------------
1 | build:
2 | python_version: "3.6"
3 | gpu: false
4 | python_packages:
5 | - future==0.18.2
6 | - numpy==1.19.5
7 | - librosa==0.8.1
8 | - soundfile==0.10.3.post1
9 | - musdb==0.4.0
10 | - museval==0.4.0
11 | - h5py==3.1.0
12 | - tqdm==4.62.1
13 | - torch==1.4.0
14 | - torchvision==0.5.0
15 | - tensorboard==2.6.0
16 | - sortedcontainers==2.4.0
17 | system_packages:
18 | - libsndfile-dev
19 | - ffmpeg
20 | predict: "cog_predict.py:waveunetPredictor"
21 |
--------------------------------------------------------------------------------
/cog_predict.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cog
3 | import tempfile
4 | import zipfile
5 | from pathlib import Path
6 | import argparse
7 | import data.utils
8 | import model.utils as model_utils
9 | from test import predict_song
10 | from model.waveunet import Waveunet
11 |
12 |
13 | class waveunetPredictor(cog.Predictor):
14 | def setup(self):
15 | """Init wave u net model"""
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument(
18 | "--instruments",
19 | type=str,
20 | nargs="+",
21 | default=["bass", "drums", "other", "vocals"],
22 | help='List of instruments to separate (default: "bass drums other vocals")',
23 | )
24 | parser.add_argument(
25 | "--cuda", action="store_true", help="Use CUDA (default: False)"
26 | )
27 | parser.add_argument(
28 | "--features",
29 | type=int,
30 | default=32,
31 | help="Number of feature channels per layer",
32 | )
33 | parser.add_argument(
34 | "--load_model",
35 | type=str,
36 | default="checkpoints/waveunet/model",
37 | help="Reload a previously trained model",
38 | )
39 | parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
40 | parser.add_argument(
41 | "--levels", type=int, default=6, help="Number of DS/US blocks"
42 | )
43 | parser.add_argument(
44 | "--depth", type=int, default=1, help="Number of convs per block"
45 | )
46 | parser.add_argument("--sr", type=int, default=44100, help="Sampling rate")
47 | parser.add_argument(
48 | "--channels", type=int, default=2, help="Number of input audio channels"
49 | )
50 | parser.add_argument(
51 | "--kernel_size",
52 | type=int,
53 | default=5,
54 | help="Filter width of kernels. Has to be an odd number",
55 | )
56 | parser.add_argument(
57 | "--output_size", type=float, default=2.0, help="Output duration"
58 | )
59 | parser.add_argument(
60 | "--strides", type=int, default=4, help="Strides in Waveunet"
61 | )
62 | parser.add_argument(
63 | "--conv_type",
64 | type=str,
65 | default="gn",
66 | help="Type of convolution (normal, BN-normalised, GN-normalised): normal/bn/gn",
67 | )
68 | parser.add_argument(
69 | "--res",
70 | type=str,
71 | default="fixed",
72 | help="Resampling strategy: fixed sinc-based lowpass filtering or learned conv layer: fixed/learned",
73 | )
74 | parser.add_argument(
75 | "--separate",
76 | type=int,
77 | default=1,
78 | help="Train separate model for each source (1) or only one (0)",
79 | )
80 | parser.add_argument(
81 | "--feature_growth",
82 | type=str,
83 | default="double",
84 | help="How the features in each layer should grow, either (add) the initial number of features each time, or multiply by 2 (double)",
85 | )
86 | """
87 | parser.add_argument('--input', type=str, default=str(input),
88 | help="Path to input mixture to be separated")
89 | parser.add_argument('--output', type=str, default=out_path, help="Output path (same folder as input path if not set)")
90 | """
91 | args = parser.parse_args([])
92 | self.args = args
93 |
94 | num_features = (
95 | [args.features * i for i in range(1, args.levels + 1)]
96 | if args.feature_growth == "add"
97 | else [args.features * 2 ** i for i in range(0, args.levels)]
98 | )
99 | target_outputs = int(args.output_size * args.sr)
100 | self.model = Waveunet(
101 | args.channels,
102 | num_features,
103 | args.channels,
104 | args.instruments,
105 | kernel_size=args.kernel_size,
106 | target_output_size=target_outputs,
107 | depth=args.depth,
108 | strides=args.strides,
109 | conv_type=args.conv_type,
110 | res=args.res,
111 | separate=args.separate,
112 | )
113 |
114 | if args.cuda:
115 | self.model = model_utils.DataParallel(model)
116 | print("move model to gpu")
117 | self.model.cuda()
118 |
119 | print("Loading model from checkpoint " + str(args.load_model))
120 | state = model_utils.load_model(self.model, None, args.load_model, args.cuda)
121 | print("Step", state["step"])
122 |
123 | @cog.input("input", type=Path, help="audio mixture path")
124 | def predict(self, input):
125 | """Separate tracks from input mixture audio"""
126 |
127 | out_path = Path(tempfile.mkdtemp())
128 | zip_path = Path(tempfile.mkdtemp()) / "output.zip"
129 |
130 | preds = predict_song(self.args, input, self.model)
131 |
132 | out_names = []
133 | for inst in preds.keys():
134 | temp_n = os.path.join(
135 | str(out_path), os.path.basename(str(input)) + "_" + inst + ".wav"
136 | )
137 | data.utils.write_wav(temp_n, preds[inst], self.args.sr)
138 | out_names.append(temp_n)
139 |
140 | with zipfile.ZipFile(str(zip_path), "w") as zf:
141 | for i in out_names:
142 | zf.write(str(i))
143 |
144 | return zip_path
145 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/Wave-U-Net-Pytorch/86c113da51540c94b01a91fbf17227f0e9c3c274/data/__init__.py
--------------------------------------------------------------------------------
/data/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import h5py
4 | import numpy as np
5 | from sortedcontainers import SortedList
6 | from torch.utils.data import Dataset
7 | from tqdm import tqdm
8 |
9 | from data.utils import load
10 |
11 |
12 | class SeparationDataset(Dataset):
13 | def __init__(self, dataset, partition, instruments, sr, channels, shapes, random_hops, hdf_dir, audio_transform=None, in_memory=False):
14 | '''
15 | Initialises a source separation dataset
16 | :param data: HDF audio data object
17 | :param input_size: Number of input samples for each example
18 | :param context_front: Number of extra context samples to prepend to input
19 | :param context_back: NUmber of extra context samples to append to input
20 | :param hop_size: Skip hop_size - 1 sample positions in the audio for each example (subsampling the audio)
21 | :param random_hops: If False, sample examples evenly from whole audio signal according to hop_size parameter. If True, randomly sample a position from the audio
22 | '''
23 |
24 | super(SeparationDataset, self).__init__()
25 |
26 | self.hdf_dataset = None
27 | os.makedirs(hdf_dir, exist_ok=True)
28 | self.hdf_dir = os.path.join(hdf_dir, partition + ".hdf5")
29 |
30 | self.random_hops = random_hops
31 | self.sr = sr
32 | self.channels = channels
33 | self.shapes = shapes
34 | self.audio_transform = audio_transform
35 | self.in_memory = in_memory
36 | self.instruments = instruments
37 |
38 | # PREPARE HDF FILE
39 |
40 | # Check if HDF file exists already
41 | if not os.path.exists(self.hdf_dir):
42 | # Create folder if it did not exist before
43 | if not os.path.exists(hdf_dir):
44 | os.makedirs(hdf_dir)
45 |
46 | # Create HDF file
47 | with h5py.File(self.hdf_dir, "w") as f:
48 | f.attrs["sr"] = sr
49 | f.attrs["channels"] = channels
50 | f.attrs["instruments"] = instruments
51 |
52 | print("Adding audio files to dataset (preprocessing)...")
53 | for idx, example in enumerate(tqdm(dataset[partition])):
54 | # Load mix
55 | mix_audio, _ = load(example["mix"], sr=self.sr, mono=(self.channels == 1))
56 |
57 | source_audios = []
58 | for source in instruments:
59 | # In this case, read in audio and convert to target sampling rate
60 | source_audio, _ = load(example[source], sr=self.sr, mono=(self.channels == 1))
61 | source_audios.append(source_audio)
62 | source_audios = np.concatenate(source_audios, axis=0)
63 | assert(source_audios.shape[1] == mix_audio.shape[1])
64 |
65 | # Add to HDF5 file
66 | grp = f.create_group(str(idx))
67 | grp.create_dataset("inputs", shape=mix_audio.shape, dtype=mix_audio.dtype, data=mix_audio)
68 | grp.create_dataset("targets", shape=source_audios.shape, dtype=source_audios.dtype, data=source_audios)
69 | grp.attrs["length"] = mix_audio.shape[1]
70 | grp.attrs["target_length"] = source_audios.shape[1]
71 |
72 | # In that case, check whether sr and channels are complying with the audio in the HDF file, otherwise raise error
73 | with h5py.File(self.hdf_dir, "r") as f:
74 | if f.attrs["sr"] != sr or \
75 | f.attrs["channels"] != channels or \
76 | list(f.attrs["instruments"]) != instruments:
77 | raise ValueError(
78 | "Tried to load existing HDF file, but sampling rate and channel or instruments are not as expected. Did you load an out-dated HDF file?")
79 |
80 | # HDF FILE READY
81 |
82 | # SET SAMPLING POSITIONS
83 |
84 | # Go through HDF and collect lengths of all audio files
85 | with h5py.File(self.hdf_dir, "r") as f:
86 | lengths = [f[str(song_idx)].attrs["target_length"] for song_idx in range(len(f))]
87 |
88 | # Subtract input_size from lengths and divide by hop size to determine number of starting positions
89 | lengths = [(l // self.shapes["output_frames"]) + 1 for l in lengths]
90 |
91 | self.start_pos = SortedList(np.cumsum(lengths))
92 | self.length = self.start_pos[-1]
93 |
94 | def __getitem__(self, index):
95 | # Open HDF5
96 | if self.hdf_dataset is None:
97 | driver = "core" if self.in_memory else None # Load HDF5 fully into memory if desired
98 | self.hdf_dataset = h5py.File(self.hdf_dir, 'r', driver=driver)
99 |
100 | # Find out which slice of targets we want to read
101 | audio_idx = self.start_pos.bisect_right(index)
102 | if audio_idx > 0:
103 | index = index - self.start_pos[audio_idx - 1]
104 |
105 | # Check length of audio signal
106 | audio_length = self.hdf_dataset[str(audio_idx)].attrs["length"]
107 | target_length = self.hdf_dataset[str(audio_idx)].attrs["target_length"]
108 |
109 | # Determine position where to start targets
110 | if self.random_hops:
111 | start_target_pos = np.random.randint(0, max(target_length - self.shapes["output_frames"] + 1, 1))
112 | else:
113 | # Map item index to sample position within song
114 | start_target_pos = index * self.shapes["output_frames"]
115 |
116 | # READ INPUTS
117 | # Check front padding
118 | start_pos = start_target_pos - self.shapes["output_start_frame"]
119 | if start_pos < 0:
120 | # Pad manually since audio signal was too short
121 | pad_front = abs(start_pos)
122 | start_pos = 0
123 | else:
124 | pad_front = 0
125 |
126 | # Check back padding
127 | end_pos = start_target_pos - self.shapes["output_start_frame"] + self.shapes["input_frames"]
128 | if end_pos > audio_length:
129 | # Pad manually since audio signal was too short
130 | pad_back = end_pos - audio_length
131 | end_pos = audio_length
132 | else:
133 | pad_back = 0
134 |
135 | # Read and return
136 | audio = self.hdf_dataset[str(audio_idx)]["inputs"][:, start_pos:end_pos].astype(np.float32)
137 | if pad_front > 0 or pad_back > 0:
138 | audio = np.pad(audio, [(0, 0), (pad_front, pad_back)], mode="constant", constant_values=0.0)
139 |
140 | targets = self.hdf_dataset[str(audio_idx)]["targets"][:, start_pos:end_pos].astype(np.float32)
141 | if pad_front > 0 or pad_back > 0:
142 | targets = np.pad(targets, [(0, 0), (pad_front, pad_back)], mode="constant", constant_values=0.0)
143 |
144 | targets = {inst : targets[idx*self.channels:(idx+1)*self.channels] for idx, inst in enumerate(self.instruments)}
145 |
146 | if hasattr(self, "audio_transform") and self.audio_transform is not None:
147 | audio, targets = self.audio_transform(audio, targets)
148 |
149 | return audio, targets
150 |
151 | def __len__(self):
152 | return self.length
--------------------------------------------------------------------------------
/data/musdb.py:
--------------------------------------------------------------------------------
1 | import musdb
2 | import os
3 | import numpy as np
4 | import glob
5 |
6 | from data.utils import load, write_wav
7 |
8 |
9 | def get_musdbhq(database_path):
10 | '''
11 | Retrieve audio file paths for MUSDB HQ dataset
12 | :param database_path: MUSDB HQ root directory
13 | :return: dictionary with train and test keys, each containing list of samples, each sample containing all audio paths
14 | '''
15 | subsets = list()
16 |
17 | for subset in ["train", "test"]:
18 | print("Loading " + subset + " set...")
19 | tracks = glob.glob(os.path.join(database_path, subset, "*"))
20 | samples = list()
21 |
22 | # Go through tracks
23 | for track_folder in sorted(tracks):
24 | # Skip track if mixture is already written, assuming this track is done already
25 | example = dict()
26 | for stem in ["mix", "bass", "drums", "other", "vocals"]:
27 | filename = stem if stem != "mix" else "mixture"
28 | audio_path = os.path.join(track_folder, filename + ".wav")
29 | example[stem] = audio_path
30 |
31 | # Add other instruments to form accompaniment
32 | acc_path = os.path.join(track_folder, "accompaniment.wav")
33 |
34 | if not os.path.exists(acc_path):
35 | print("Writing accompaniment to " + track_folder)
36 | stem_audio = []
37 | for stem in ["bass", "drums", "other"]:
38 | audio, sr = load(example[stem], sr=None, mono=False)
39 | stem_audio.append(audio)
40 | acc_audio = np.clip(sum(stem_audio), -1.0, 1.0)
41 | write_wav(acc_path, acc_audio, sr)
42 |
43 | example["accompaniment"] = acc_path
44 |
45 | samples.append(example)
46 |
47 | subsets.append(samples)
48 |
49 | return subsets
50 |
51 | def get_musdb(database_path):
52 | '''
53 | Retrieve audio file paths for MUSDB dataset
54 | :param database_path: MUSDB root directory
55 | :return: dictionary with train and test keys, each containing list of samples, each sample containing all audio paths
56 | '''
57 | mus = musdb.DB(root=database_path, is_wav=False)
58 |
59 | subsets = list()
60 |
61 | for subset in ["train", "test"]:
62 | tracks = mus.load_mus_tracks(subset)
63 | samples = list()
64 |
65 | # Go through tracks
66 | for track in sorted(tracks):
67 | # Skip track if mixture is already written, assuming this track is done already
68 | track_path = track.path[:-4]
69 | mix_path = track_path + "_mix.wav"
70 | acc_path = track_path + "_accompaniment.wav"
71 | if os.path.exists(mix_path):
72 | print("WARNING: Skipping track " + mix_path + " since it exists already")
73 |
74 | # Add paths and then skip
75 | paths = {"mix" : mix_path, "accompaniment" : acc_path}
76 | paths.update({key : track_path + "_" + key + ".wav" for key in ["bass", "drums", "other", "vocals"]})
77 |
78 | samples.append(paths)
79 |
80 | continue
81 |
82 | rate = track.rate
83 |
84 | # Go through each instrument
85 | paths = dict()
86 | stem_audio = dict()
87 | for stem in ["bass", "drums", "other", "vocals"]:
88 | path = track_path + "_" + stem + ".wav"
89 | audio = track.targets[stem].audio
90 | write_wav(path, audio, rate)
91 | stem_audio[stem] = audio
92 | paths[stem] = path
93 |
94 | # Add other instruments to form accompaniment
95 | acc_audio = np.clip(sum([stem_audio[key] for key in list(stem_audio.keys()) if key != "vocals"]), -1.0, 1.0)
96 | write_wav(acc_path, acc_audio, rate)
97 | paths["accompaniment"] = acc_path
98 |
99 | # Create mixture
100 | mix_audio = track.audio
101 | write_wav(mix_path, mix_audio, rate)
102 | paths["mix"] = mix_path
103 |
104 | diff_signal = np.abs(mix_audio - acc_audio - stem_audio["vocals"])
105 | print("Maximum absolute deviation from source additivity constraint: " + str(np.max(diff_signal)))# Check if acc+vocals=mix
106 | print("Mean absolute deviation from source additivity constraint: " + str(np.mean(diff_signal)))
107 |
108 | samples.append(paths)
109 |
110 | subsets.append(samples)
111 |
112 | print("DONE preparing dataset!")
113 | return subsets
114 |
115 | def get_musdb_folds(root_path, version="HQ"):
116 | if version == "HQ":
117 | dataset = get_musdbhq(root_path)
118 | else:
119 | dataset = get_musdb(root_path)
120 | train_val_list = dataset[0]
121 | test_list = dataset[1]
122 |
123 | np.random.seed(1337) # Ensure that partitioning is always the same on each run
124 | train_list = np.random.choice(train_val_list, 75, replace=False)
125 | val_list = [elem for elem in train_val_list if elem not in train_list]
126 | # print("First training song: " + str(train_list[0])) # To debug whether partitioning is deterministic
127 | return {"train" : train_list, "val" : val_list, "test" : test_list}
--------------------------------------------------------------------------------
/data/utils.py:
--------------------------------------------------------------------------------
1 | import librosa
2 | import numpy as np
3 | import soundfile
4 | import torch
5 |
6 |
7 | def random_amplify(mix, targets, shapes, min, max):
8 | '''
9 | Data augmentation by randomly amplifying sources before adding them to form a new mixture
10 | :param mix: Original mixture
11 | :param targets: Source targets
12 | :param shapes: Shape dict from model
13 | :param min: Minimum possible amplification
14 | :param max: Maximum possible amplification
15 | :return: New data point as tuple (mix, targets)
16 | '''
17 | residual = mix # start with original mix
18 | for key in targets.keys():
19 | if key != "mix":
20 | residual -= targets[key] # subtract all instruments (output is zero if all instruments add to mix)
21 | mix = residual * np.random.uniform(min, max) # also apply gain data augmentation to residual
22 | for key in targets.keys():
23 | if key != "mix":
24 | targets[key] = targets[key] * np.random.uniform(min, max)
25 | mix += targets[key] # add instrument with gain data augmentation to mix
26 | mix = np.clip(mix, -1.0, 1.0)
27 | return crop_targets(mix, targets, shapes)
28 |
29 |
30 | def crop_targets(mix, targets, shapes):
31 | '''
32 | Crops target audio to the output shape required by the model given in "shapes"
33 | '''
34 | for key in targets.keys():
35 | if key != "mix":
36 | targets[key] = targets[key][:, shapes["output_start_frame"]:shapes["output_end_frame"]]
37 | return mix, targets
38 |
39 |
40 | def load(path, sr=22050, mono=True, mode="numpy", offset=0.0, duration=None):
41 | y, curr_sr = librosa.load(path, sr=sr, mono=mono, res_type='kaiser_fast', offset=offset, duration=duration)
42 |
43 | if len(y.shape) == 1:
44 | # Expand channel dimension
45 | y = y[np.newaxis, :]
46 |
47 | if mode == "pytorch":
48 | y = torch.tensor(y)
49 |
50 | return y, curr_sr
51 |
52 |
53 | def write_wav(path, audio, sr):
54 | soundfile.write(path, audio.T, sr, "PCM_16")
55 |
56 |
57 | def resample(audio, orig_sr, new_sr, mode="numpy"):
58 | if orig_sr == new_sr:
59 | return audio
60 |
61 | if isinstance(audio, torch.Tensor):
62 | audio = audio.detach().cpu().numpy()
63 |
64 | out = librosa.resample(audio, orig_sr, new_sr, res_type='kaiser_fast')
65 |
66 | if mode == "pytorch":
67 | out = torch.tensor(out)
68 | return out
--------------------------------------------------------------------------------
/hdf/README.md:
--------------------------------------------------------------------------------
1 | This directory is where the preprocessed data is saved by default in the form of HDF file, one for training, validation, and test. These files can get quite large since they are uncompressed by default, so if this is not a good place for them to go because of space issues or disk speed, use a different ``hdf_dir`` parameter for ``train.py``.
--------------------------------------------------------------------------------
/logs/README.md:
--------------------------------------------------------------------------------
1 | This is where logs of the training process are going to be stored!
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/Wave-U-Net-Pytorch/86c113da51540c94b01a91fbf17227f0e9c3c274/model/__init__.py
--------------------------------------------------------------------------------
/model/conv.py:
--------------------------------------------------------------------------------
1 | from torch import nn as nn
2 | from torch.nn import functional as F
3 |
4 |
5 | class ConvLayer(nn.Module):
6 | def __init__(self, n_inputs, n_outputs, kernel_size, stride, conv_type, transpose=False):
7 | super(ConvLayer, self).__init__()
8 | self.transpose = transpose
9 | self.stride = stride
10 | self.kernel_size = kernel_size
11 | self.conv_type = conv_type
12 |
13 | # How many channels should be normalised as one group if GroupNorm is activated
14 | # WARNING: Number of channels has to be divisible by this number!
15 | NORM_CHANNELS = 8
16 |
17 | if self.transpose:
18 | self.filter = nn.ConvTranspose1d(n_inputs, n_outputs, self.kernel_size, stride, padding=kernel_size-1)
19 | else:
20 | self.filter = nn.Conv1d(n_inputs, n_outputs, self.kernel_size, stride)
21 |
22 | if conv_type == "gn":
23 | assert(n_outputs % NORM_CHANNELS == 0)
24 | self.norm = nn.GroupNorm(n_outputs // NORM_CHANNELS, n_outputs)
25 | elif conv_type == "bn":
26 | self.norm = nn.BatchNorm1d(n_outputs, momentum=0.01)
27 | # Add you own types of variations here!
28 |
29 | def forward(self, x):
30 | # Apply the convolution
31 | if self.conv_type == "gn" or self.conv_type == "bn":
32 | out = F.relu(self.norm((self.filter(x))))
33 | else: # Add your own variations here with elifs conditioned on "conv_type" parameter!
34 | assert(self.conv_type == "normal")
35 | out = F.leaky_relu(self.filter(x))
36 | return out
37 |
38 | def get_input_size(self, output_size):
39 | # Strided conv/decimation
40 | if not self.transpose:
41 | curr_size = (output_size - 1)*self.stride + 1 # o = (i-1)//s + 1 => i = (o - 1)*s + 1
42 | else:
43 | curr_size = output_size
44 |
45 | # Conv
46 | curr_size = curr_size + self.kernel_size - 1 # o = i + p - k + 1
47 |
48 | # Transposed
49 | if self.transpose:
50 | assert ((curr_size - 1) % self.stride == 0)# We need to have a value at the beginning and end
51 | curr_size = ((curr_size - 1) // self.stride) + 1
52 | assert(curr_size > 0)
53 | return curr_size
54 |
55 | def get_output_size(self, input_size):
56 | # Transposed
57 | if self.transpose:
58 | assert(input_size > 1)
59 | curr_size = (input_size - 1)*self.stride + 1 # o = (i-1)//s + 1 => i = (o - 1)*s + 1
60 | else:
61 | curr_size = input_size
62 |
63 | # Conv
64 | curr_size = curr_size - self.kernel_size + 1 # o = i + p - k + 1
65 | assert (curr_size > 0)
66 |
67 | # Strided conv/decimation
68 | if not self.transpose:
69 | assert ((curr_size - 1) % self.stride == 0) # We need to have a value at the beginning and end
70 | curr_size = ((curr_size - 1) // self.stride) + 1
71 |
72 | return curr_size
--------------------------------------------------------------------------------
/model/crop.py:
--------------------------------------------------------------------------------
1 | def centre_crop(x, target):
2 | '''
3 | Center-crop 3-dim. input tensor along last axis so it fits the target tensor shape
4 | :param x: Input tensor
5 | :param target: Shape of this tensor will be used as target shape
6 | :return: Cropped input tensor
7 | '''
8 | if x is None:
9 | return None
10 | if target is None:
11 | return x
12 |
13 | target_shape = target.shape
14 | diff = x.shape[-1] - target_shape[-1]
15 | assert (diff % 2 == 0)
16 | crop = diff // 2
17 |
18 | if crop == 0:
19 | return x
20 | if crop < 0:
21 | raise ArithmeticError
22 |
23 | return x[:, :, crop:-crop].contiguous()
--------------------------------------------------------------------------------
/model/resample.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn as nn
4 | from torch.nn import functional as F
5 |
6 | class Resample1d(nn.Module):
7 | def __init__(self, channels, kernel_size, stride, transpose=False, padding="reflect", trainable=False):
8 | '''
9 | Creates a resampling layer for time series data (using 1D convolution) - (N, C, W) input format
10 | :param channels: Number of features C at each time-step
11 | :param kernel_size: Width of sinc-based lowpass-filter (>= 15 recommended for good filtering performance)
12 | :param stride: Resampling factor (integer)
13 | :param transpose: False for down-, true for upsampling
14 | :param padding: Either "reflect" to pad or "valid" to not pad
15 | :param trainable: Optionally activate this to train the lowpass-filter, starting from the sinc initialisation
16 | '''
17 | super(Resample1d, self).__init__()
18 |
19 | self.padding = padding
20 | self.kernel_size = kernel_size
21 | self.stride = stride
22 | self.transpose = transpose
23 | self.channels = channels
24 |
25 | cutoff = 0.5 / stride
26 |
27 | assert(kernel_size > 2)
28 | assert ((kernel_size - 1) % 2 == 0)
29 | assert(padding == "reflect" or padding == "valid")
30 |
31 | filter = build_sinc_filter(kernel_size, cutoff)
32 |
33 | self.filter = torch.nn.Parameter(torch.from_numpy(np.repeat(np.reshape(filter, [1, 1, kernel_size]), channels, axis=0)), requires_grad=trainable)
34 |
35 | def forward(self, x):
36 | # Pad here if not using transposed conv
37 | input_size = x.shape[2]
38 | if self.padding != "valid":
39 | num_pad = (self.kernel_size-1)//2
40 | out = F.pad(x, (num_pad, num_pad), mode=self.padding)
41 | else:
42 | out = x
43 |
44 | # Lowpass filter (+ 0 insertion if transposed)
45 | if self.transpose:
46 | expected_steps = ((input_size - 1) * self.stride + 1)
47 | if self.padding == "valid":
48 | expected_steps = expected_steps - self.kernel_size + 1
49 |
50 | out = F.conv_transpose1d(out, self.filter, stride=self.stride, padding=0, groups=self.channels)
51 | diff_steps = out.shape[2] - expected_steps
52 | if diff_steps > 0:
53 | assert(diff_steps % 2 == 0)
54 | out = out[:,:,diff_steps//2:-diff_steps//2]
55 | else:
56 | assert(input_size % self.stride == 1)
57 | out = F.conv1d(out, self.filter, stride=self.stride, padding=0, groups=self.channels)
58 |
59 | return out
60 |
61 | def get_output_size(self, input_size):
62 | '''
63 | Returns the output dimensionality (number of timesteps) for a given input size
64 | :param input_size: Number of input time steps (Scalar, each feature is one-dimensional)
65 | :return: Output size (scalar)
66 | '''
67 | assert(input_size > 1)
68 | if self.transpose:
69 | if self.padding == "valid":
70 | return ((input_size - 1) * self.stride + 1) - self.kernel_size + 1
71 | else:
72 | return ((input_size - 1) * self.stride + 1)
73 | else:
74 | assert(input_size % self.stride == 1) # Want to take first and last sample
75 | if self.padding == "valid":
76 | return input_size - self.kernel_size + 1
77 | else:
78 | return input_size
79 |
80 | def get_input_size(self, output_size):
81 | '''
82 | Returns the input dimensionality (number of timesteps) for a given output size
83 | :param input_size: Number of input time steps (Scalar, each feature is one-dimensional)
84 | :return: Output size (scalar)
85 | '''
86 |
87 | # Strided conv/decimation
88 | if not self.transpose:
89 | curr_size = (output_size - 1)*self.stride + 1 # o = (i-1)//s + 1 => i = (o - 1)*s + 1
90 | else:
91 | curr_size = output_size
92 |
93 | # Conv
94 | if self.padding == "valid":
95 | curr_size = curr_size + self.kernel_size - 1 # o = i + p - k + 1
96 |
97 | # Transposed
98 | if self.transpose:
99 | assert ((curr_size - 1) % self.stride == 0)# We need to have a value at the beginning and end
100 | curr_size = ((curr_size - 1) // self.stride) + 1
101 | assert(curr_size > 0)
102 | return curr_size
103 |
104 | def build_sinc_filter(kernel_size, cutoff):
105 | # FOLLOWING https://www.analog.com/media/en/technical-documentation/dsp-book/dsp_book_Ch16.pdf
106 | # Sinc lowpass filter
107 | # Build sinc kernel
108 | assert(kernel_size % 2 == 1)
109 | M = kernel_size - 1
110 | filter = np.zeros(kernel_size, dtype=np.float32)
111 | for i in range(kernel_size):
112 | if i == M//2:
113 | filter[i] = 2 * np.pi * cutoff
114 | else:
115 | filter[i] = (np.sin(2 * np.pi * cutoff * (i - M//2)) / (i - M//2)) * \
116 | (0.42 - 0.5 * np.cos((2 * np.pi * i) / M) + 0.08 * np.cos(4 * np.pi * M))
117 |
118 | filter = filter / np.sum(filter)
119 | return filter
--------------------------------------------------------------------------------
/model/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | def save_model(model, optimizer, state, path):
5 | if isinstance(model, torch.nn.DataParallel):
6 | model = model.module # save state dict of wrapped module
7 | if len(os.path.dirname(path)) > 0 and not os.path.exists(os.path.dirname(path)):
8 | os.makedirs(os.path.dirname(path))
9 | torch.save({
10 | 'model_state_dict': model.state_dict(),
11 | 'optimizer_state_dict': optimizer.state_dict(),
12 | 'state': state, # state of training loop (was 'step')
13 | }, path)
14 |
15 |
16 | def load_model(model, optimizer, path, cuda):
17 | if isinstance(model, torch.nn.DataParallel):
18 | model = model.module # load state dict of wrapped module
19 | if cuda:
20 | checkpoint = torch.load(path)
21 | else:
22 | checkpoint = torch.load(path, map_location='cpu')
23 | try:
24 | model.load_state_dict(checkpoint['model_state_dict'])
25 | except:
26 | # work-around for loading checkpoints where DataParallel was saved instead of inner module
27 | from collections import OrderedDict
28 | model_state_dict_fixed = OrderedDict()
29 | prefix = 'module.'
30 | for k, v in checkpoint['model_state_dict'].items():
31 | if k.startswith(prefix):
32 | k = k[len(prefix):]
33 | model_state_dict_fixed[k] = v
34 | model.load_state_dict(model_state_dict_fixed)
35 | if optimizer is not None:
36 | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
37 | if 'state' in checkpoint:
38 | state = checkpoint['state']
39 | else:
40 | # older checkpoints only store step, rest of state won't be there
41 | state = {'step': checkpoint['step']}
42 | return state
43 |
44 |
45 | def compute_loss(model, inputs, targets, criterion, compute_grad=False):
46 | '''
47 | Computes gradients of model with given inputs and targets and loss function.
48 | Optionally backpropagates to compute gradients for weights.
49 | Procedure depends on whether we have one model for each source or not
50 | :param model: Model to train with
51 | :param inputs: Input mixture
52 | :param targets: Target sources
53 | :param criterion: Loss function to use (L1, L2, ..)
54 | :param compute_grad: Whether to compute gradients
55 | :return: Model outputs, Average loss over batch
56 | '''
57 | all_outputs = {}
58 |
59 | if model.separate:
60 | avg_loss = 0.0
61 | num_sources = 0
62 | for inst in model.instruments:
63 | output = model(inputs, inst)
64 | loss = criterion(output[inst], targets[inst])
65 |
66 | if compute_grad:
67 | loss.backward()
68 |
69 | avg_loss += loss.item()
70 | num_sources += 1
71 |
72 | all_outputs[inst] = output[inst].detach().clone()
73 |
74 | avg_loss /= float(num_sources)
75 | else:
76 | loss = 0
77 | all_outputs = model(inputs)
78 | for inst in all_outputs.keys():
79 | loss += criterion(all_outputs[inst], targets[inst])
80 |
81 | if compute_grad:
82 | loss.backward()
83 |
84 | avg_loss = loss.item() / float(len(all_outputs))
85 |
86 | return all_outputs, avg_loss
87 |
88 |
89 | class DataParallel(torch.nn.DataParallel):
90 | def __init__(self, module, device_ids=None, output_device=None, dim=0):
91 | super(DataParallel, self).__init__(module, device_ids, output_device, dim)
92 |
93 | def __getattr__(self, name):
94 | try:
95 | return super().__getattr__(name)
96 | except AttributeError:
97 | return getattr(self.module, name)
--------------------------------------------------------------------------------
/model/waveunet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from model.crop import centre_crop
5 | from model.resample import Resample1d
6 | from model.conv import ConvLayer
7 |
8 | class UpsamplingBlock(nn.Module):
9 | def __init__(self, n_inputs, n_shortcut, n_outputs, kernel_size, stride, depth, conv_type, res):
10 | super(UpsamplingBlock, self).__init__()
11 | assert(stride > 1)
12 |
13 | # CONV 1 for UPSAMPLING
14 | if res == "fixed":
15 | self.upconv = Resample1d(n_inputs, 15, stride, transpose=True)
16 | else:
17 | self.upconv = ConvLayer(n_inputs, n_inputs, kernel_size, stride, conv_type, transpose=True)
18 |
19 | self.pre_shortcut_convs = nn.ModuleList([ConvLayer(n_inputs, n_outputs, kernel_size, 1, conv_type)] +
20 | [ConvLayer(n_outputs, n_outputs, kernel_size, 1, conv_type) for _ in range(depth - 1)])
21 |
22 | # CONVS to combine high- with low-level information (from shortcut)
23 | self.post_shortcut_convs = nn.ModuleList([ConvLayer(n_outputs + n_shortcut, n_outputs, kernel_size, 1, conv_type)] +
24 | [ConvLayer(n_outputs, n_outputs, kernel_size, 1, conv_type) for _ in range(depth - 1)])
25 |
26 | def forward(self, x, shortcut):
27 | # UPSAMPLE HIGH-LEVEL FEATURES
28 | upsampled = self.upconv(x)
29 |
30 | for conv in self.pre_shortcut_convs:
31 | upsampled = conv(upsampled)
32 |
33 | # Prepare shortcut connection
34 | combined = centre_crop(shortcut, upsampled)
35 |
36 | # Combine high- and low-level features
37 | for conv in self.post_shortcut_convs:
38 | combined = conv(torch.cat([combined, centre_crop(upsampled, combined)], dim=1))
39 | return combined
40 |
41 | def get_output_size(self, input_size):
42 | curr_size = self.upconv.get_output_size(input_size)
43 |
44 | # Upsampling convs
45 | for conv in self.pre_shortcut_convs:
46 | curr_size = conv.get_output_size(curr_size)
47 |
48 | # Combine convolutions
49 | for conv in self.post_shortcut_convs:
50 | curr_size = conv.get_output_size(curr_size)
51 |
52 | return curr_size
53 |
54 | class DownsamplingBlock(nn.Module):
55 | def __init__(self, n_inputs, n_shortcut, n_outputs, kernel_size, stride, depth, conv_type, res):
56 | super(DownsamplingBlock, self).__init__()
57 | assert(stride > 1)
58 |
59 | self.kernel_size = kernel_size
60 | self.stride = stride
61 |
62 | # CONV 1
63 | self.pre_shortcut_convs = nn.ModuleList([ConvLayer(n_inputs, n_shortcut, kernel_size, 1, conv_type)] +
64 | [ConvLayer(n_shortcut, n_shortcut, kernel_size, 1, conv_type) for _ in range(depth - 1)])
65 |
66 | self.post_shortcut_convs = nn.ModuleList([ConvLayer(n_shortcut, n_outputs, kernel_size, 1, conv_type)] +
67 | [ConvLayer(n_outputs, n_outputs, kernel_size, 1, conv_type) for _ in
68 | range(depth - 1)])
69 |
70 | # CONV 2 with decimation
71 | if res == "fixed":
72 | self.downconv = Resample1d(n_outputs, 15, stride) # Resampling with fixed-size sinc lowpass filter
73 | else:
74 | self.downconv = ConvLayer(n_outputs, n_outputs, kernel_size, stride, conv_type)
75 |
76 | def forward(self, x):
77 | # PREPARING SHORTCUT FEATURES
78 | shortcut = x
79 | for conv in self.pre_shortcut_convs:
80 | shortcut = conv(shortcut)
81 |
82 | # PREPARING FOR DOWNSAMPLING
83 | out = shortcut
84 | for conv in self.post_shortcut_convs:
85 | out = conv(out)
86 |
87 | # DOWNSAMPLING
88 | out = self.downconv(out)
89 |
90 | return out, shortcut
91 |
92 | def get_input_size(self, output_size):
93 | curr_size = self.downconv.get_input_size(output_size)
94 |
95 | for conv in reversed(self.post_shortcut_convs):
96 | curr_size = conv.get_input_size(curr_size)
97 |
98 | for conv in reversed(self.pre_shortcut_convs):
99 | curr_size = conv.get_input_size(curr_size)
100 | return curr_size
101 |
102 | class Waveunet(nn.Module):
103 | def __init__(self, num_inputs, num_channels, num_outputs, instruments, kernel_size, target_output_size, conv_type, res, separate=False, depth=1, strides=2):
104 | super(Waveunet, self).__init__()
105 |
106 | self.num_levels = len(num_channels)
107 | self.strides = strides
108 | self.kernel_size = kernel_size
109 | self.num_inputs = num_inputs
110 | self.num_outputs = num_outputs
111 | self.depth = depth
112 | self.instruments = instruments
113 | self.separate = separate
114 |
115 | # Only odd filter kernels allowed
116 | assert(kernel_size % 2 == 1)
117 |
118 | self.waveunets = nn.ModuleDict()
119 |
120 | model_list = instruments if separate else ["ALL"]
121 | # Create a model for each source if we separate sources separately, otherwise only one (model_list=["ALL"])
122 | for instrument in model_list:
123 | module = nn.Module()
124 |
125 | module.downsampling_blocks = nn.ModuleList()
126 | module.upsampling_blocks = nn.ModuleList()
127 |
128 | for i in range(self.num_levels - 1):
129 | in_ch = num_inputs if i == 0 else num_channels[i]
130 |
131 | module.downsampling_blocks.append(
132 | DownsamplingBlock(in_ch, num_channels[i], num_channels[i+1], kernel_size, strides, depth, conv_type, res))
133 |
134 | for i in range(0, self.num_levels - 1):
135 | module.upsampling_blocks.append(
136 | UpsamplingBlock(num_channels[-1-i], num_channels[-2-i], num_channels[-2-i], kernel_size, strides, depth, conv_type, res))
137 |
138 | module.bottlenecks = nn.ModuleList(
139 | [ConvLayer(num_channels[-1], num_channels[-1], kernel_size, 1, conv_type) for _ in range(depth)])
140 |
141 | # Output conv
142 | outputs = num_outputs if separate else num_outputs * len(instruments)
143 | module.output_conv = nn.Conv1d(num_channels[0], outputs, 1)
144 |
145 | self.waveunets[instrument] = module
146 |
147 | self.set_output_size(target_output_size)
148 |
149 | def set_output_size(self, target_output_size):
150 | self.target_output_size = target_output_size
151 |
152 | self.input_size, self.output_size = self.check_padding(target_output_size)
153 | print("Using valid convolutions with " + str(self.input_size) + " inputs and " + str(self.output_size) + " outputs")
154 |
155 | assert((self.input_size - self.output_size) % 2 == 0)
156 | self.shapes = {"output_start_frame" : (self.input_size - self.output_size) // 2,
157 | "output_end_frame" : (self.input_size - self.output_size) // 2 + self.output_size,
158 | "output_frames" : self.output_size,
159 | "input_frames" : self.input_size}
160 |
161 | def check_padding(self, target_output_size):
162 | # Ensure number of outputs covers a whole number of cycles so each output in the cycle is weighted equally during training
163 | bottleneck = 1
164 |
165 | while True:
166 | out = self.check_padding_for_bottleneck(bottleneck, target_output_size)
167 | if out is not False:
168 | return out
169 | bottleneck += 1
170 |
171 | def check_padding_for_bottleneck(self, bottleneck, target_output_size):
172 | module = self.waveunets[[k for k in self.waveunets.keys()][0]]
173 | try:
174 | curr_size = bottleneck
175 | for idx, block in enumerate(module.upsampling_blocks):
176 | curr_size = block.get_output_size(curr_size)
177 | output_size = curr_size
178 |
179 | # Bottleneck-Conv
180 | curr_size = bottleneck
181 | for block in reversed(module.bottlenecks):
182 | curr_size = block.get_input_size(curr_size)
183 | for idx, block in enumerate(reversed(module.downsampling_blocks)):
184 | curr_size = block.get_input_size(curr_size)
185 |
186 | assert(output_size >= target_output_size)
187 | return curr_size, output_size
188 | except AssertionError as e:
189 | return False
190 |
191 | def forward_module(self, x, module):
192 | '''
193 | A forward pass through a single Wave-U-Net (multiple Wave-U-Nets might be used, one for each source)
194 | :param x: Input mix
195 | :param module: Network module to be used for prediction
196 | :return: Source estimates
197 | '''
198 | shortcuts = []
199 | out = x
200 |
201 | # DOWNSAMPLING BLOCKS
202 | for block in module.downsampling_blocks:
203 | out, short = block(out)
204 | shortcuts.append(short)
205 |
206 | # BOTTLENECK CONVOLUTION
207 | for conv in module.bottlenecks:
208 | out = conv(out)
209 |
210 | # UPSAMPLING BLOCKS
211 | for idx, block in enumerate(module.upsampling_blocks):
212 | out = block(out, shortcuts[-1 - idx])
213 |
214 | # OUTPUT CONV
215 | out = module.output_conv(out)
216 | if not self.training: # At test time clip predictions to valid amplitude range
217 | out = out.clamp(min=-1.0, max=1.0)
218 | return out
219 |
220 | def forward(self, x, inst=None):
221 | curr_input_size = x.shape[-1]
222 | assert(curr_input_size == self.input_size) # User promises to feed the proper input himself, to get the pre-calculated (NOT the originally desired) output size
223 |
224 | if self.separate:
225 | return {inst : self.forward_module(x, self.waveunets[inst])}
226 | else:
227 | assert(len(self.waveunets) == 1)
228 | out = self.forward_module(x, self.waveunets["ALL"])
229 |
230 | out_dict = {}
231 | for idx, inst in enumerate(self.instruments):
232 | out_dict[inst] = out[:, idx * self.num_outputs:(idx + 1) * self.num_outputs]
233 | return out_dict
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import data.utils
5 | import model.utils as model_utils
6 |
7 | from test import predict_song
8 | from model.waveunet import Waveunet
9 |
10 | def main(args):
11 | # MODEL
12 | num_features = [args.features*i for i in range(1, args.levels+1)] if args.feature_growth == "add" else \
13 | [args.features*2**i for i in range(0, args.levels)]
14 | target_outputs = int(args.output_size * args.sr)
15 | model = Waveunet(args.channels, num_features, args.channels, args.instruments, kernel_size=args.kernel_size,
16 | target_output_size=target_outputs, depth=args.depth, strides=args.strides,
17 | conv_type=args.conv_type, res=args.res, separate=args.separate)
18 |
19 | if args.cuda:
20 | model = model_utils.DataParallel(model)
21 | print("move model to gpu")
22 | model.cuda()
23 |
24 | print("Loading model from checkpoint " + str(args.load_model))
25 | state = model_utils.load_model(model, None, args.load_model, args.cuda)
26 | print('Step', state['step'])
27 |
28 | preds = predict_song(args, args.input, model)
29 |
30 | output_folder = os.path.dirname(args.input) if args.output is None else args.output
31 | for inst in preds.keys():
32 | data.utils.write_wav(os.path.join(output_folder, os.path.basename(args.input) + "_" + inst + ".wav"), preds[inst], args.sr)
33 |
34 | if __name__ == '__main__':
35 | parser = argparse.ArgumentParser()
36 | parser.add_argument('--instruments', type=str, nargs='+', default=["bass", "drums", "other", "vocals"],
37 | help="List of instruments to separate (default: \"bass drums other vocals\")")
38 | parser.add_argument('--cuda', action='store_true',
39 | help='Use CUDA (default: False)')
40 | parser.add_argument('--features', type=int, default=32,
41 | help='Number of feature channels per layer')
42 | parser.add_argument('--load_model', type=str, default='checkpoints/waveunet/model',
43 | help='Reload a previously trained model')
44 | parser.add_argument('--batch_size', type=int, default=4,
45 | help="Batch size")
46 | parser.add_argument('--levels', type=int, default=6,
47 | help="Number of DS/US blocks")
48 | parser.add_argument('--depth', type=int, default=1,
49 | help="Number of convs per block")
50 | parser.add_argument('--sr', type=int, default=44100,
51 | help="Sampling rate")
52 | parser.add_argument('--channels', type=int, default=2,
53 | help="Number of input audio channels")
54 | parser.add_argument('--kernel_size', type=int, default=5,
55 | help="Filter width of kernels. Has to be an odd number")
56 | parser.add_argument('--output_size', type=float, default=2.0,
57 | help="Output duration")
58 | parser.add_argument('--strides', type=int, default=4,
59 | help="Strides in Waveunet")
60 | parser.add_argument('--conv_type', type=str, default="gn",
61 | help="Type of convolution (normal, BN-normalised, GN-normalised): normal/bn/gn")
62 | parser.add_argument('--res', type=str, default="fixed",
63 | help="Resampling strategy: fixed sinc-based lowpass filtering or learned conv layer: fixed/learned")
64 | parser.add_argument('--separate', type=int, default=1,
65 | help="Train separate model for each source (1) or only one (0)")
66 | parser.add_argument('--feature_growth', type=str, default="double",
67 | help="How the features in each layer should grow, either (add) the initial number of features each time, or multiply by 2 (double)")
68 |
69 | parser.add_argument('--input', type=str, default=os.path.join("audio_examples", "Cristina Vane - So Easy", "mix.mp3"),
70 | help="Path to input mixture to be separated")
71 | parser.add_argument('--output', type=str, default=None, help="Output path (same folder as input path if not set)")
72 |
73 | args = parser.parse_args()
74 |
75 | main(args)
76 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | future
2 | numpy
3 | librosa
4 | soundfile
5 | musdb
6 | museval
7 | h5py
8 | tqdm
9 | torch==1.4.0
10 | torchvision==0.5.0
11 | tensorboard
12 | sortedcontainers
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import museval
2 | from tqdm import tqdm
3 |
4 | import numpy as np
5 | import torch
6 |
7 | import data.utils
8 | import model.utils as model_utils
9 | import utils
10 |
11 | def compute_model_output(model, inputs):
12 | '''
13 | Computes outputs of model with given inputs. Does NOT allow propagating gradients! See compute_loss for training.
14 | Procedure depends on whether we have one model for each source or not
15 | :param model: Model to train with
16 | :param compute_grad: Whether to compute gradients
17 | :return: Model outputs, Average loss over batch
18 | '''
19 | all_outputs = {}
20 |
21 | if model.separate:
22 | for inst in model.instruments:
23 | output = model(inputs, inst)
24 | all_outputs[inst] = output[inst].detach().clone()
25 | else:
26 | all_outputs = model(inputs)
27 |
28 | return all_outputs
29 |
30 | def predict(audio, model):
31 | '''
32 | Predict sources for a given audio input signal, with a given model. Audio is split into chunks to make predictions on each chunk before they are concatenated.
33 | :param audio: Audio input tensor, either Pytorch tensor or numpy array
34 | :param model: Pytorch model
35 | :return: Source predictions, dictionary with source names as keys
36 | '''
37 | if isinstance(audio, torch.Tensor):
38 | is_cuda = audio.is_cuda()
39 | audio = audio.detach().cpu().numpy()
40 | return_mode = "pytorch"
41 | else:
42 | return_mode = "numpy"
43 |
44 | expected_outputs = audio.shape[1]
45 |
46 | # Pad input if it is not divisible in length by the frame shift number
47 | output_shift = model.shapes["output_frames"]
48 | pad_back = audio.shape[1] % output_shift
49 | pad_back = 0 if pad_back == 0 else output_shift - pad_back
50 | if pad_back > 0:
51 | audio = np.pad(audio, [(0,0), (0, pad_back)], mode="constant", constant_values=0.0)
52 |
53 | target_outputs = audio.shape[1]
54 | outputs = {key: np.zeros(audio.shape, np.float32) for key in model.instruments}
55 |
56 | # Pad mixture across time at beginning and end so that neural network can make prediction at the beginning and end of signal
57 | pad_front_context = model.shapes["output_start_frame"]
58 | pad_back_context = model.shapes["input_frames"] - model.shapes["output_end_frame"]
59 | audio = np.pad(audio, [(0,0), (pad_front_context, pad_back_context)], mode="constant", constant_values=0.0)
60 |
61 | # Iterate over mixture magnitudes, fetch network prediction
62 | with torch.no_grad():
63 | for target_start_pos in range(0, target_outputs, model.shapes["output_frames"]):
64 | # Prepare mixture excerpt by selecting time interval
65 | curr_input = audio[:, target_start_pos:target_start_pos + model.shapes["input_frames"]] # Since audio was front-padded input of [targetpos:targetpos+inputframes] actually predicts [targetpos:targetpos+outputframes] target range
66 |
67 | # Convert to Pytorch tensor for model prediction
68 | curr_input = torch.from_numpy(curr_input).unsqueeze(0)
69 |
70 | # Predict
71 | for key, curr_targets in compute_model_output(model, curr_input).items():
72 | outputs[key][:,target_start_pos:target_start_pos+model.shapes["output_frames"]] = curr_targets.squeeze(0).cpu().numpy()
73 |
74 | # Crop to expected length (since we padded to handle the frame shift)
75 | outputs = {key : outputs[key][:,:expected_outputs] for key in outputs.keys()}
76 |
77 | if return_mode == "pytorch":
78 | outputs = torch.from_numpy(outputs)
79 | if is_cuda:
80 | outputs = outputs.cuda()
81 | return outputs
82 |
83 | def predict_song(args, audio_path, model):
84 | '''
85 | Predicts sources for an audio file for which the file path is given, using a given model.
86 | Takes care of resampling the input audio to the models sampling rate and resampling predictions back to input sampling rate.
87 | :param args: Options dictionary
88 | :param audio_path: Path to mixture audio file
89 | :param model: Pytorch model
90 | :return: Source estimates given as dictionary with keys as source names
91 | '''
92 | model.eval()
93 |
94 | # Load mixture in original sampling rate
95 | mix_audio, mix_sr = data.utils.load(audio_path, sr=None, mono=False)
96 | mix_channels = mix_audio.shape[0]
97 | mix_len = mix_audio.shape[1]
98 |
99 | # Adapt mixture channels to required input channels
100 | if args.channels == 1:
101 | mix_audio = np.mean(mix_audio, axis=0, keepdims=True)
102 | else:
103 | if mix_channels == 1: # Duplicate channels if input is mono but model is stereo
104 | mix_audio = np.tile(mix_audio, [args.channels, 1])
105 | else:
106 | assert(mix_channels == args.channels)
107 |
108 | # resample to model sampling rate
109 | mix_audio = data.utils.resample(mix_audio, mix_sr, args.sr)
110 |
111 | sources = predict(mix_audio, model)
112 |
113 | # Resample back to mixture sampling rate in case we had model on different sampling rate
114 | sources = {key : data.utils.resample(sources[key], args.sr, mix_sr) for key in sources.keys()}
115 |
116 | # In case we had to pad the mixture at the end, or we have a few samples too many due to inconsistent down- and upsamṕling, remove those samples from source prediction now
117 | for key in sources.keys():
118 | diff = sources[key].shape[1] - mix_len
119 | if diff > 0:
120 | print("WARNING: Cropping " + str(diff) + " samples")
121 | sources[key] = sources[key][:, :-diff]
122 | elif diff < 0:
123 | print("WARNING: Padding output by " + str(diff) + " samples")
124 | sources[key] = np.pad(sources[key], [(0,0), (0, -diff)], "constant", 0.0)
125 |
126 | # Adapt channels
127 | if mix_channels > args.channels:
128 | assert(args.channels == 1)
129 | # Duplicate mono predictions
130 | sources[key] = np.tile(sources[key], [mix_channels, 1])
131 | elif mix_channels < args.channels:
132 | assert(mix_channels == 1)
133 | # Reduce model output to mono
134 | sources[key] = np.mean(sources[key], axis=0, keepdims=True)
135 |
136 | sources[key] = np.asfortranarray(sources[key]) # So librosa does not complain if we want to save it
137 |
138 | return sources
139 |
140 | def evaluate(args, dataset, model, instruments):
141 | '''
142 | Evaluates a given model on a given dataset
143 | :param args: Options dict
144 | :param dataset: Dataset object
145 | :param model: Pytorch model
146 | :param instruments: List of source names
147 | :return: Performance metric dictionary, list with each element describing one dataset sample's results
148 | '''
149 | perfs = list()
150 | model.eval()
151 | with torch.no_grad():
152 | for example in dataset:
153 | print("Evaluating " + example["mix"])
154 |
155 | # Load source references in their original sr and channel number
156 | target_sources = np.stack([data.utils.load(example[instrument], sr=None, mono=False)[0].T for instrument in instruments])
157 |
158 | # Predict using mixture
159 | pred_sources = predict_song(args, example["mix"], model)
160 | pred_sources = np.stack([pred_sources[key].T for key in instruments])
161 |
162 | # Evaluate
163 | SDR, ISR, SIR, SAR, _ = museval.metrics.bss_eval(target_sources, pred_sources)
164 | song = {}
165 | for idx, name in enumerate(instruments):
166 | song[name] = {"SDR" : SDR[idx], "ISR" : ISR[idx], "SIR" : SIR[idx], "SAR" : SAR[idx]}
167 | perfs.append(song)
168 |
169 | return perfs
170 |
171 |
172 | def validate(args, model, criterion, test_data):
173 | '''
174 | Iterate with a given model over a given test dataset and compute the desired loss
175 | :param args: Options dictionary
176 | :param model: Pytorch model
177 | :param criterion: Loss function to use (similar to Pytorch criterions)
178 | :param test_data: Test dataset (Pytorch dataset)
179 | :return:
180 | '''
181 | # PREPARE DATA
182 | dataloader = torch.utils.data.DataLoader(test_data,
183 | batch_size=args.batch_size,
184 | shuffle=False,
185 | num_workers=args.num_workers)
186 |
187 | # VALIDATE
188 | model.eval()
189 | total_loss = 0.
190 | with tqdm(total=len(test_data) // args.batch_size) as pbar, torch.no_grad():
191 | for example_num, (x, targets) in enumerate(dataloader):
192 | if args.cuda:
193 | x = x.cuda()
194 | for k in list(targets.keys()):
195 | targets[k] = targets[k].cuda()
196 |
197 | _, avg_loss = model_utils.compute_loss(model, x, targets, criterion)
198 |
199 | total_loss += (1. / float(example_num + 1)) * (avg_loss - total_loss)
200 |
201 | pbar.set_description("Current loss: {:.4f}".format(total_loss))
202 | pbar.update(1)
203 |
204 | return total_loss
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import time
4 | from functools import partial
5 |
6 | import torch
7 | import pickle
8 | import numpy as np
9 |
10 | import torch.nn as nn
11 | from torch.utils.tensorboard import SummaryWriter
12 | from torch.optim import Adam
13 | from tqdm import tqdm
14 |
15 | import model.utils as model_utils
16 | import utils
17 | from data.dataset import SeparationDataset
18 | from data.musdb import get_musdb_folds
19 | from data.utils import crop_targets, random_amplify
20 | from test import evaluate, validate
21 | from model.waveunet import Waveunet
22 |
23 | def main(args):
24 | #torch.backends.cudnn.benchmark=True # This makes dilated conv much faster for CuDNN 7.5
25 |
26 | # MODEL
27 | num_features = [args.features*i for i in range(1, args.levels+1)] if args.feature_growth == "add" else \
28 | [args.features*2**i for i in range(0, args.levels)]
29 | target_outputs = int(args.output_size * args.sr)
30 | model = Waveunet(args.channels, num_features, args.channels, args.instruments, kernel_size=args.kernel_size,
31 | target_output_size=target_outputs, depth=args.depth, strides=args.strides,
32 | conv_type=args.conv_type, res=args.res, separate=args.separate)
33 |
34 | if args.cuda:
35 | model = model_utils.DataParallel(model)
36 | print("move model to gpu")
37 | model.cuda()
38 |
39 | print('model: ', model)
40 | print('parameter count: ', str(sum(p.numel() for p in model.parameters())))
41 |
42 | writer = SummaryWriter(args.log_dir)
43 |
44 | ### DATASET
45 | musdb = get_musdb_folds(args.dataset_dir)
46 | # If not data augmentation, at least crop targets to fit model output shape
47 | crop_func = partial(crop_targets, shapes=model.shapes)
48 | # Data augmentation function for training
49 | augment_func = partial(random_amplify, shapes=model.shapes, min=0.7, max=1.0)
50 | train_data = SeparationDataset(musdb, "train", args.instruments, args.sr, args.channels, model.shapes, True, args.hdf_dir, audio_transform=augment_func)
51 | val_data = SeparationDataset(musdb, "val", args.instruments, args.sr, args.channels, model.shapes, False, args.hdf_dir, audio_transform=crop_func)
52 | test_data = SeparationDataset(musdb, "test", args.instruments, args.sr, args.channels, model.shapes, False, args.hdf_dir, audio_transform=crop_func)
53 |
54 | dataloader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, worker_init_fn=utils.worker_init_fn)
55 |
56 | ##### TRAINING ####
57 |
58 | # Set up the loss function
59 | if args.loss == "L1":
60 | criterion = nn.L1Loss()
61 | elif args.loss == "L2":
62 | criterion = nn.MSELoss()
63 | else:
64 | raise NotImplementedError("Couldn't find this loss!")
65 |
66 | # Set up optimiser
67 | optimizer = Adam(params=model.parameters(), lr=args.lr)
68 |
69 | # Set up training state dict that will also be saved into checkpoints
70 | state = {"step" : 0,
71 | "worse_epochs" : 0,
72 | "epochs" : 0,
73 | "best_loss" : np.Inf}
74 |
75 | # LOAD MODEL CHECKPOINT IF DESIRED
76 | if args.load_model is not None:
77 | print("Continuing training full model from checkpoint " + str(args.load_model))
78 | state = model_utils.load_model(model, optimizer, args.load_model, args.cuda)
79 |
80 | print('TRAINING START')
81 | while state["worse_epochs"] < args.patience:
82 | print("Training one epoch from iteration " + str(state["step"]))
83 | avg_time = 0.
84 | model.train()
85 | with tqdm(total=len(train_data) // args.batch_size) as pbar:
86 | np.random.seed()
87 | for example_num, (x, targets) in enumerate(dataloader):
88 | if args.cuda:
89 | x = x.cuda()
90 | for k in list(targets.keys()):
91 | targets[k] = targets[k].cuda()
92 |
93 | t = time.time()
94 |
95 | # Set LR for this iteration
96 | utils.set_cyclic_lr(optimizer, example_num, len(train_data) // args.batch_size, args.cycles, args.min_lr, args.lr)
97 | writer.add_scalar("lr", utils.get_lr(optimizer), state["step"])
98 |
99 | # Compute loss for each instrument/model
100 | optimizer.zero_grad()
101 | outputs, avg_loss = model_utils.compute_loss(model, x, targets, criterion, compute_grad=True)
102 |
103 | optimizer.step()
104 |
105 | state["step"] += 1
106 |
107 | t = time.time() - t
108 | avg_time += (1. / float(example_num + 1)) * (t - avg_time)
109 |
110 | writer.add_scalar("train_loss", avg_loss, state["step"])
111 |
112 | if example_num % args.example_freq == 0:
113 | input_centre = torch.mean(x[0, :, model.shapes["output_start_frame"]:model.shapes["output_end_frame"]], 0) # Stereo not supported for logs yet
114 | writer.add_audio("input", input_centre, state["step"], sample_rate=args.sr)
115 |
116 | for inst in outputs.keys():
117 | writer.add_audio(inst + "_pred", torch.mean(outputs[inst][0], 0), state["step"], sample_rate=args.sr)
118 | writer.add_audio(inst + "_target", torch.mean(targets[inst][0], 0), state["step"], sample_rate=args.sr)
119 |
120 | pbar.update(1)
121 |
122 | # VALIDATE
123 | val_loss = validate(args, model, criterion, val_data)
124 | print("VALIDATION FINISHED: LOSS: " + str(val_loss))
125 | writer.add_scalar("val_loss", val_loss, state["step"])
126 |
127 | # EARLY STOPPING CHECK
128 | checkpoint_path = os.path.join(args.checkpoint_dir, "checkpoint_" + str(state["step"]))
129 | if val_loss >= state["best_loss"]:
130 | state["worse_epochs"] += 1
131 | else:
132 | print("MODEL IMPROVED ON VALIDATION SET!")
133 | state["worse_epochs"] = 0
134 | state["best_loss"] = val_loss
135 | state["best_checkpoint"] = checkpoint_path
136 |
137 | state["epochs"] += 1
138 | # CHECKPOINT
139 | print("Saving model...")
140 | model_utils.save_model(model, optimizer, state, checkpoint_path)
141 |
142 |
143 | #### TESTING ####
144 | # Test loss
145 | print("TESTING")
146 |
147 | # Load best model based on validation loss
148 | state = model_utils.load_model(model, None, state["best_checkpoint"], args.cuda)
149 | test_loss = validate(args, model, criterion, test_data)
150 | print("TEST FINISHED: LOSS: " + str(test_loss))
151 | writer.add_scalar("test_loss", test_loss, state["step"])
152 |
153 | # Mir_eval metrics
154 | test_metrics = evaluate(args, musdb["test"], model, args.instruments)
155 |
156 | # Dump all metrics results into pickle file for later analysis if needed
157 | with open(os.path.join(args.checkpoint_dir, "results.pkl"), "wb") as f:
158 | pickle.dump(test_metrics, f)
159 |
160 | # Write most important metrics into Tensorboard log
161 | avg_SDRs = {inst : np.mean([np.nanmean(song[inst]["SDR"]) for song in test_metrics]) for inst in args.instruments}
162 | avg_SIRs = {inst : np.mean([np.nanmean(song[inst]["SIR"]) for song in test_metrics]) for inst in args.instruments}
163 | for inst in args.instruments:
164 | writer.add_scalar("test_SDR_" + inst, avg_SDRs[inst], state["step"])
165 | writer.add_scalar("test_SIR_" + inst, avg_SIRs[inst], state["step"])
166 | overall_SDR = np.mean([v for v in avg_SDRs.values()])
167 | writer.add_scalar("test_SDR", overall_SDR)
168 | print("SDR: " + str(overall_SDR))
169 |
170 | writer.close()
171 |
172 | if __name__ == '__main__':
173 | ## TRAIN PARAMETERS
174 | parser = argparse.ArgumentParser()
175 | parser.add_argument('--instruments', type=str, nargs='+', default=["bass", "drums", "other", "vocals"],
176 | help="List of instruments to separate (default: \"bass drums other vocals\")")
177 | parser.add_argument('--cuda', action='store_true',
178 | help='Use CUDA (default: False)')
179 | parser.add_argument('--num_workers', type=int, default=1,
180 | help='Number of data loader worker threads (default: 1)')
181 | parser.add_argument('--features', type=int, default=32,
182 | help='Number of feature channels per layer')
183 | parser.add_argument('--log_dir', type=str, default='logs/waveunet',
184 | help='Folder to write logs into')
185 | parser.add_argument('--dataset_dir', type=str, default="/mnt/windaten/Datasets/MUSDB18HQ",
186 | help='Dataset path')
187 | parser.add_argument('--hdf_dir', type=str, default="hdf",
188 | help='Dataset path')
189 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoints/waveunet',
190 | help='Folder to write checkpoints into')
191 | parser.add_argument('--load_model', type=str, default=None,
192 | help='Reload a previously trained model (whole task model)')
193 | parser.add_argument('--lr', type=float, default=1e-3,
194 | help='Initial learning rate in LR cycle (default: 1e-3)')
195 | parser.add_argument('--min_lr', type=float, default=5e-5,
196 | help='Minimum learning rate in LR cycle (default: 5e-5)')
197 | parser.add_argument('--cycles', type=int, default=2,
198 | help='Number of LR cycles per epoch')
199 | parser.add_argument('--batch_size', type=int, default=4,
200 | help="Batch size")
201 | parser.add_argument('--levels', type=int, default=6,
202 | help="Number of DS/US blocks")
203 | parser.add_argument('--depth', type=int, default=1,
204 | help="Number of convs per block")
205 | parser.add_argument('--sr', type=int, default=44100,
206 | help="Sampling rate")
207 | parser.add_argument('--channels', type=int, default=2,
208 | help="Number of input audio channels")
209 | parser.add_argument('--kernel_size', type=int, default=5,
210 | help="Filter width of kernels. Has to be an odd number")
211 | parser.add_argument('--output_size', type=float, default=2.0,
212 | help="Output duration")
213 | parser.add_argument('--strides', type=int, default=4,
214 | help="Strides in Waveunet")
215 | parser.add_argument('--patience', type=int, default=20,
216 | help="Patience for early stopping on validation set")
217 | parser.add_argument('--example_freq', type=int, default=200,
218 | help="Write an audio summary into Tensorboard logs every X training iterations")
219 | parser.add_argument('--loss', type=str, default="L1",
220 | help="L1 or L2")
221 | parser.add_argument('--conv_type', type=str, default="gn",
222 | help="Type of convolution (normal, BN-normalised, GN-normalised): normal/bn/gn")
223 | parser.add_argument('--res', type=str, default="fixed",
224 | help="Resampling strategy: fixed sinc-based lowpass filtering or learned conv layer: fixed/learned")
225 | parser.add_argument('--separate', type=int, default=1,
226 | help="Train separate model for each source (1) or only one (0)")
227 | parser.add_argument('--feature_growth', type=str, default="double",
228 | help="How the features in each layer should grow, either (add) the initial number of features each time, or multiply by 2 (double)")
229 |
230 | args = parser.parse_args()
231 |
232 | main(args)
233 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def worker_init_fn(worker_id): # This is apparently needed to ensure workers have different random seeds and draw different examples!
4 | np.random.seed(np.random.get_state()[1][0] + worker_id)
5 |
6 | def get_lr(optim):
7 | return optim.param_groups[0]["lr"]
8 |
9 | def set_lr(optim, lr):
10 | for g in optim.param_groups:
11 | g['lr'] = lr
12 |
13 | def set_cyclic_lr(optimizer, it, epoch_it, cycles, min_lr, max_lr):
14 | cycle_length = epoch_it // cycles
15 | curr_cycle = min(it // cycle_length, cycles-1)
16 | curr_it = it - cycle_length * curr_cycle
17 |
18 | new_lr = min_lr + 0.5*(max_lr - min_lr)*(1 + np.cos((float(curr_it) / float(cycle_length)) * np.pi))
19 | set_lr(optimizer, new_lr)
--------------------------------------------------------------------------------