├── .gitignore ├── LICENSE ├── README.md └── code ├── averaging.py ├── bss_eval_images.m ├── bss_eval_mix.m ├── bss_eval_sources.m ├── build_model_gcn.py ├── build_model_original.py ├── cyclicAnnealing.py ├── data_loader.py ├── evaluate.m ├── evaluate.py ├── post_processing.py ├── pre_processing.py ├── stiching.py ├── test_model.py └── train_model.py /.gitignore: -------------------------------------------------------------------------------- 1 | Project/ 2 | DSD100subset/ 3 | monoaural-audio-source.pdf 4 | Processed/ 5 | Val/ 6 | dsd100/ 7 | code/runs/ 8 | code/__pycache__/ 9 | code/Weights 10 | AudioResults/ 11 | Recovered_Songs/ 12 | Test/ 13 | iml_unshuffled.png 14 | AudioResults/ 15 | Recovered_Songs/ 16 | Recovered_Songs_longer/ 17 | Recovered_Songs_longer2/ 18 | Recovered_Songs_same_genre/ 19 | 20 | 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Sarthak Consul 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Audio Source Separation using Low Latency Neural Network 2 | This reposiory contains the code for our course project for Machine Learning (CS419) at IIT Bombay. We have used the PyTorch library to construct a neural network 3 | to separate instruments from a music file. We have implemented the paper "[Monoaural Audio Source Separation Using Deep 4 | Convolutional Neural Networks](https://pdfs.semanticscholar.org/fede/f8eedef76692d805a6a3380159a95b79b4de.pdf)", along with a few modifications and experiments inspired by other papers. 5 | 6 | 7 | ## Team Members 8 | * [16D100012] Sarthak Consul ([**@SConsul**](https://github.com/SConsul)) 9 | * [160110085] Archiki Prasad ([**@archiki**](https://github.com/archiki)) 10 | * [16D070001] Parthasarathi Khirwadkar ([**@kparth98**](https://github.com/kparth98)) 11 | * [16D100001] Deepak Gopalan ([**@DeepakGopalan**](https://github.com/DeepakGopalan)) 12 | 13 | ## Bibliography 14 | * [[1]](https://pdfs.semanticscholar.org/fede/f8eedef76692d805a6a3380159a95b79b4de.pdf) Pritish Chandna, M. Miron, Jordi Janer, and Emilia G´omez. Monoaural audio source separation using deep convolutional neural networks. In 13th International Conference on Latent Variable Analysis and Signal Separation (LVAICA2017), 02/2017 2017 15 | * [[2]](https://hal.inria.fr/inria-00544230/document) E. Vincent, R. Gribonval, and C. Fevotte. Performance measurement in blind audio source separation. IEEE Transactions on Audio, Speech, and Language Processing, 14(4):1462–1469, July 2006 16 | * [[3]](https://arxiv.org/abs/1703.02719) Chao Peng, Xiangyu Zhang, Gang Yu, Guiming Luo, and Jian Sun. Large kernel matters - improve semantic segmentation by global convolutional network. CoRR, abs/1703.02719, 2017 17 | -------------------------------------------------------------------------------- /code/averaging.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SConsul/audio-source-separation/ccce030502430380edc2599f4d4a9d68e742330b/code/averaging.py -------------------------------------------------------------------------------- /code/bss_eval_images.m: -------------------------------------------------------------------------------- 1 | function [SDR,ISR,SIR,SAR,perm]=bss_eval_images(ie,i) 2 | 3 | % BSS_EVAL_IMAGES Ordering and measurement of the separation quality for 4 | % estimated source spatial image signals in terms of true source, spatial 5 | % (or filtering) distortion, interference and artifacts. 6 | % 7 | % [SDR,ISR,SIR,SAR,perm]=bss_eval_images(ie,i) 8 | % 9 | % Inputs: 10 | % ie: nsrc x nsampl x nchan matrix containing estimated source images 11 | % i: nsrc x nsampl x nchan matrix containing true source images 12 | % 13 | % Outputs: 14 | % SDR: nsrc x 1 vector of Signal to Distortion Ratios 15 | % ISR: nsrc x 1 vector of source Image to Spatial distortion Ratios 16 | % SIR: nsrc x 1 vector of Source to Interference Ratios 17 | % SAR: nsrc x 1 vector of Sources to Artifacts Ratios 18 | % perm: nsrc x 1 vector containing the best ordering of estimated source 19 | % images in the mean SIR sense (estimated source image number perm(j) 20 | % corresponds to true source image number j) 21 | % 22 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 23 | % Copyright 2007-2008 Emmanuel Vincent 24 | % This software is distributed under the terms of the GNU Public License 25 | % version 3 (http://www.gnu.org/licenses/gpl.txt) 26 | % If you find it useful, please cite the following reference: 27 | % Emmanuel Vincent, Hiroshi Sawada, Pau Bofill, Shoji Makino and Justinian 28 | % P. Rosca, "First stereo audio source separation evaluation campaign: 29 | % data, algorithms and results," In Proc. Int. Conf. on Independent 30 | % Component Analysis and Blind Source Separation (ICA), pp. 552-559, 2007. 31 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 32 | 33 | 34 | %%% Errors %%% 35 | if nargin<2, error('Not enough input arguments.'); end 36 | [nsrc,nsampl,nchan]=size(ie); 37 | [nsrc2,nsampl2,nchan2]=size(i); 38 | if nsrc2~=nsrc, error('The number of estimated source images and reference source images must be equal.'); end 39 | if nsampl2~=nsampl, error('The estimated source images and reference source images must have the same duration.'); end 40 | if nchan2~=nchan, error('The estimated source images and reference source images must have the same number of channels.'); end 41 | 42 | %%% Performance criteria %%% 43 | % Computation of the criteria for all possible pair matches 44 | SDR=zeros(nsrc,nsrc); 45 | ISR=zeros(nsrc,nsrc); 46 | SIR=zeros(nsrc,nsrc); 47 | SAR=zeros(nsrc,nsrc); 48 | for jest=1:nsrc, 49 | for jtrue=1:nsrc, 50 | [s_true,e_spat,e_interf,e_artif]=bss_decomp_mtifilt(reshape(ie(jest,:,:),nsampl,nchan).',i,jtrue,512); 51 | [SDR(jest,jtrue),ISR(jest,jtrue),SIR(jest,jtrue),SAR(jest,jtrue)]=bss_image_crit(s_true,e_spat,e_interf,e_artif); 52 | end 53 | end 54 | % Selection of the best ordering 55 | perm=perms(1:nsrc); 56 | nperm=size(perm,1); 57 | meanSIR=zeros(nperm,1); 58 | for p=1:nperm, 59 | meanSIR(p)=mean(SIR((0:nsrc-1)*nsrc+perm(p,:))); 60 | end 61 | [meanSIR,popt]=max(meanSIR); 62 | perm=perm(popt,:).'; 63 | SDR=SDR((0:nsrc-1).'*nsrc+perm); 64 | ISR=ISR((0:nsrc-1).'*nsrc+perm); 65 | SIR=SIR((0:nsrc-1).'*nsrc+perm); 66 | SAR=SAR((0:nsrc-1).'*nsrc+perm); 67 | 68 | return; 69 | 70 | 71 | 72 | function [s_true,e_spat,e_interf,e_artif]=bss_decomp_mtifilt(se,s,j,flen) 73 | 74 | % BSS_DECOMP_MTIFILT Decomposition of an estimated source image into four 75 | % components representing respectively the true source image, spatial (or 76 | % filtering) distortion, interference and artifacts, derived from the true 77 | % source images using multichannel time-invariant filters. 78 | % 79 | % [s_true,e_spat,e_interf,e_artif]=bss_decomp_mtifilt(se,s,j,flen) 80 | % 81 | % Inputs: 82 | % se: nchan x nsampl matrix containing the estimated source image (one row per channel) 83 | % s: nsrc x nsampl x nchan matrix containing the true source images 84 | % j: source index corresponding to the estimated source image in s 85 | % flen: length of the multichannel time-invariant filters in samples 86 | % 87 | % Outputs: 88 | % s_true: nchan x nsampl matrix containing the true source image (one row per channel) 89 | % e_spat: nchan x nsampl matrix containing the spatial (or filtering) distortion component 90 | % e_interf: nchan x nsampl matrix containing the interference component 91 | % e_artif: nchan x nsampl matrix containing the artifacts component 92 | 93 | %%% Errors %%% 94 | if nargin<4, error('Not enough input arguments.'); end 95 | [nchan2,nsampl2]=size(se); 96 | [nsrc,nsampl,nchan]=size(s); 97 | if nchan2~=nchan, error('The number of channels of the true source images and the estimated source image must be equal.'); end 98 | if nsampl2~=nsampl, error('The duration of the true source images and the estimated source image must be equal.'); end 99 | 100 | %%% Decomposition %%% 101 | % True source image 102 | s_true=[reshape(s(j,:,:),nsampl,nchan).',zeros(nchan,flen-1)]; 103 | % Spatial (or filtering) distortion 104 | e_spat=project(se,s(j,:,:),flen)-s_true; 105 | % Interference 106 | e_interf=project(se,s,flen)-s_true-e_spat; 107 | % Artifacts 108 | e_artif=[se,zeros(nchan,flen-1)]-s_true-e_spat-e_interf; 109 | 110 | return; 111 | 112 | 113 | 114 | function sproj=project(se,s,flen) 115 | 116 | % SPROJ Least-squares projection of each channel of se on the subspace 117 | % spanned by delayed versions of the channels of s, with delays between 0 118 | % and flen-1 119 | 120 | [nsrc,nsampl,nchan]=size(s); 121 | s=reshape(permute(s,[3 1 2]),nchan*nsrc,nsampl); 122 | 123 | %%% Computing coefficients of least squares problem via FFT %%% 124 | % Zero padding and FFT of input data 125 | s=[s,zeros(nchan*nsrc,flen-1)]; 126 | se=[se,zeros(nchan,flen-1)]; 127 | fftlen=2^nextpow2(nsampl+flen-1); 128 | sf=fft(s,fftlen,2); 129 | sef=fft(se,fftlen,2); 130 | % Inner products between delayed versions of s 131 | G=zeros(nchan*nsrc*flen); 132 | for k1=0:nchan*nsrc-1, 133 | for k2=0:k1, 134 | ssf=sf(k1+1,:).*conj(sf(k2+1,:)); 135 | ssf=real(ifft(ssf)); 136 | ss=toeplitz(ssf([1 fftlen:-1:fftlen-flen+2]),ssf(1:flen)); 137 | G(k1*flen+1:k1*flen+flen,k2*flen+1:k2*flen+flen)=ss; 138 | G(k2*flen+1:k2*flen+flen,k1*flen+1:k1*flen+flen)=ss.'; 139 | end 140 | end 141 | % Inner products between se and delayed versions of s 142 | D=zeros(nchan*nsrc*flen,nchan); 143 | for k=0:nchan*nsrc-1, 144 | for i=1:nchan, 145 | ssef=sf(k+1,:).*conj(sef(i,:)); 146 | ssef=real(ifft(ssef,[],2)); 147 | D(k*flen+1:k*flen+flen,i)=ssef(:,[1 fftlen:-1:fftlen-flen+2]).'; 148 | end 149 | end 150 | 151 | %%% Computing projection %%% 152 | % Distortion filters 153 | C=G\D; 154 | C=reshape(C,flen,nchan*nsrc,nchan); 155 | % Filtering 156 | sproj=zeros(nchan,nsampl+flen-1); 157 | for k=1:nchan*nsrc, 158 | for i=1:nchan, 159 | sproj(i,:)=sproj(i,:)+fftfilt(C(:,k,i).',s(k,:)); 160 | end 161 | end 162 | 163 | return; 164 | 165 | 166 | 167 | function [SDR,ISR,SIR,SAR]=bss_image_crit(s_true,e_spat,e_interf,e_artif) 168 | 169 | % BSS_IMAGE_CRIT Measurement of the separation quality for a given source 170 | % image in terms of true source, spatial (or filtering) distortion, 171 | % interference and artifacts. 172 | % 173 | % [SDR,ISR,SIR,SAR]=bss_image_crit(s_true,e_spat,e_interf,e_artif) 174 | % 175 | % Inputs: 176 | % s_true: nchan x nsampl matrix containing the true source image (one row per channel) 177 | % e_spat: nchan x nsampl matrix containing the spatial (or filtering) distortion component 178 | % e_interf: nchan x nsampl matrix containing the interference component 179 | % e_artif: nchan x nsampl matrix containing the artifacts component 180 | % 181 | % Outputs: 182 | % SDR: Signal to Distortion Ratio 183 | % ISR: source Image to Spatial distortion Ratio 184 | % SIR: Source to Interference Ratio 185 | % SAR: Sources to Artifacts Ratio 186 | 187 | %%% Errors %%% 188 | if nargin<4, error('Not enough input arguments.'); end 189 | [nchant,nsamplt]=size(s_true); 190 | [nchans,nsampls]=size(e_spat); 191 | [nchani,nsampli]=size(e_interf); 192 | [nchana,nsampla]=size(e_artif); 193 | if ~((nchant==nchans)&&(nchant==nchani)&&(nchant==nchana)), error('All the components must have the same number of channels.'); end 194 | if ~((nsamplt==nsampls)&&(nsamplt==nsampli)&&(nsamplt==nsampla)), error('All the components must have the same duration.'); end 195 | 196 | %%% Energy ratios %%% 197 | % SDR 198 | SDR=10*log10(sum(sum(s_true.^2))/sum(sum((e_spat+e_interf+e_artif).^2))); 199 | % ISR 200 | ISR=10*log10(sum(sum(s_true.^2))/sum(sum(e_spat.^2))); 201 | % SIR 202 | SIR=10*log10(sum(sum((s_true+e_spat).^2))/sum(sum(e_interf.^2))); 203 | % SAR 204 | SAR=10*log10(sum(sum((s_true+e_spat+e_interf).^2))/sum(sum(e_artif.^2))); 205 | 206 | return; -------------------------------------------------------------------------------- /code/bss_eval_mix.m: -------------------------------------------------------------------------------- 1 | function [MER,perm]=bss_eval_mix(Ae,A) 2 | 3 | % BSS_EVAL_MIX Ordering and measurement of the quality of an estimated 4 | % (possibly frequency-dependent) mixing matrix 5 | % 6 | % [MER,perm]=bss_eval_mix(Ae,A) 7 | % 8 | % Inputs: 9 | % Ae: either a nchan x nsrc estimated mixing matrix (for instantaneous 10 | % mixtures) or a nchan x nsrc x nbin estimated frequency-dependent mixing 11 | % matrix (for convolutive mixtures) 12 | % A: the true nchan x nsrc or nchan x nsrc x nbin mixing matrix 13 | % 14 | % Outputs: 15 | % MER: nsrc x 1 vector of Mixing Error Ratios (SNR-like criterion averaged 16 | % over frequency and expressed in decibels, allowing arbitrary scaling for 17 | % each source in each frequency bin) 18 | % perm: nsrc x 1 vector containing the best ordering of estimated sources 19 | % in the maximum MER sense (estimated source number perm(j) corresponds to 20 | % true source number j) 21 | % 22 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 23 | % Copyright 2008 Emmanuel Vincent 24 | % This software is distributed under the terms of the GNU Public License 25 | % version 3 (http://www.gnu.org/licenses/gpl.txt) 26 | % If you find it useful, please cite the following reference: 27 | % Emmanuel Vincent, Shoko Araki and Pau Bofill, "The 2008 Signal Separation 28 | % Evaluation Campaign: A community-based approach to large-scale 29 | % evaluation," In Proc. Int. Conf. on Independent Component Analysis and 30 | % Signal Separation (ICA), pp. 734-741, 2009. 31 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 32 | 33 | 34 | %%% Errors %%% 35 | if nargin<2, error('Not enough input arguments.'); end 36 | [nchan,nsrc,nbin]=size(Ae); 37 | [nchan2,nsrc2,nbin2]=size(A); 38 | if ~((nchan2==nchan)&&(nsrc2==nsrc)&&(nbin2==nbin)), error('The estimated and true mixing matrix must have the same size.'); end 39 | 40 | %%% Performance criterion %%% 41 | % Computation of the criterion for all possible pair matches 42 | MER=zeros(nsrc,nsrc,nbin); 43 | for f=1:nbin, 44 | for jest=1:nsrc, 45 | for jtrue=1:nsrc, 46 | Aproj=A(:,jtrue,f)'*Ae(:,jest,f)/sum(abs(A(:,jtrue,f)).^2)*A(:,jtrue,f); 47 | MER(jest,jtrue,f)=10*log10(sum(abs(Aproj).^2)/sum(abs(Ae(:,jest,f)-Aproj).^2)); 48 | end 49 | end 50 | end 51 | MER=mean(MER,3); 52 | % Selection of the best ordering 53 | perm=perms(1:nsrc); 54 | nperm=size(perm,1); 55 | meanMER=zeros(nperm,1); 56 | for p=1:nperm, 57 | meanMER(p)=mean(MER((0:nsrc-1)*nsrc+perm(p,:))); 58 | end 59 | [meanMER,popt]=max(meanMER); 60 | perm=perm(popt,:).'; 61 | MER=MER((0:nsrc-1).'*nsrc+perm); 62 | 63 | return; 64 | -------------------------------------------------------------------------------- /code/bss_eval_sources.m: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SConsul/audio-source-separation/ccce030502430380edc2599f4d4a9d68e742330b/code/bss_eval_sources.m -------------------------------------------------------------------------------- /code/build_model_gcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class SepConvNet(nn.Module): 7 | def __init__(self, t1, f1, t2, f2, N1, N2, input_shape=[513, 345], NN=128): 8 | super(SepConvNet, self).__init__() 9 | self.vconv_left = nn.Conv2d(1, N1, kernel_size=(f1, t1), padding=0) 10 | self.hconv_left = nn.Conv2d(N1, N2, kernel_size=(f2, t2)) 11 | self.hconv_right = nn.Conv2d(1, N1, kernel_size=(f2, t2)) 12 | self.vconv_right = nn.Conv2d(N1, N2, kernel_size=(f1, t1), padding=0) 13 | 14 | self.fc0 = nn.Linear(N2*(input_shape[0]-f1-f2+2)*(input_shape[1]-t1-t2+2), NN) 15 | self.fc1 = nn.Linear(NN, N2*(input_shape[0]-f1-f2+2)*(input_shape[1]-t1-t2+2)) 16 | self.fc2 = nn.Linear(NN, N2*(input_shape[0]-f1-f2+2)*(input_shape[1]-t1-t2+2)) 17 | self.fc3 = nn.Linear(NN, N2*(input_shape[0]-f1-f2+2)*(input_shape[1]-t1-t2+2)) 18 | self.fc4 = nn.Linear(NN, N2*(input_shape[0]-f1-f2+2)*(input_shape[1]-t1-t2+2)) 19 | self.hdeconv1 = nn.ConvTranspose2d(N2, N1, kernel_size=(f2, t2)) 20 | self.hdeconv2 = nn.ConvTranspose2d(N2, N1, kernel_size=(f2, t2)) 21 | self.hdeconv3 = nn.ConvTranspose2d(N2, N1, kernel_size=(f2, t2)) 22 | self.hdeconv4 = nn.ConvTranspose2d(N2, N1, kernel_size=(f2, t2)) 23 | self.vdeconv1 = nn.ConvTranspose2d(N1, 1, kernel_size=(f1, t1)) 24 | self.vdeconv2 = nn.ConvTranspose2d(N1, 1, kernel_size=(f1, t1)) 25 | self.vdeconv3 = nn.ConvTranspose2d(N1, 1, kernel_size=(f1, t1)) 26 | self.vdeconv4 = nn.ConvTranspose2d(N1, 1, kernel_size=(f1, t1)) 27 | 28 | def forward(self, x): 29 | x_left = self.vconv_left(x) 30 | x_left = self.hconv_left(x_left) 31 | 32 | x_right = self.hconv_right(x) 33 | x_right = self.vconv_right(x_right) 34 | 35 | x = x_left + x_right 36 | 37 | s1 = x.shape 38 | 39 | x = x.view(s1[0], -1) 40 | 41 | x = F.relu(self.fc0(x)) 42 | 43 | x1 = F.relu(self.fc1(x)) 44 | x2 = F.relu(self.fc2(x)) 45 | x3 = F.relu(self.fc3(x)) 46 | x4 = F.relu(self.fc4(x)) 47 | 48 | x1 = x1.view(s1[0], s1[1], s1[2], s1[3]) 49 | x2 = x2.view(s1[0], s1[1], s1[2], s1[3]) 50 | x3 = x3.view(s1[0], s1[1], s1[2], s1[3]) 51 | x4 = x4.view(s1[0], s1[1], s1[2], s1[3]) 52 | 53 | x1 = self.hdeconv1(x1) 54 | x2 = self.hdeconv2(x2) 55 | x3 = self.hdeconv3(x3) 56 | x4 = self.hdeconv4(x4) 57 | 58 | x1 = self.vdeconv1(x1) 59 | x2 = self.vdeconv2(x2) 60 | x3 = self.vdeconv3(x3) 61 | x4 = self.vdeconv4(x4) 62 | 63 | return x1, x2, x3, x4 64 | -------------------------------------------------------------------------------- /code/build_model_original.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | class SepConvNet(nn.Module): 5 | def __init__(self,t1,f1,t2,f2,N1,N2,input_shape=[513,862],NN=128): 6 | super(SepConvNet, self).__init__() 7 | self.vconv = nn.Conv2d(1,N1, kernel_size=(f1,t1),padding=0) 8 | self.hconv = nn.Conv2d(N1,N2, kernel_size=(f2,t2)) 9 | 10 | self.fc0 = nn.Linear(N2*(input_shape[0]-f1-f2+2)*(input_shape[1]-t1-t2+2), NN) 11 | self.fc1 = nn.Linear(NN,N2*(input_shape[0]-f1-f2+2)*(input_shape[1]-t1-t2+2)) 12 | self.fc2 = nn.Linear(NN,N2*(input_shape[0]-f1-f2+2)*(input_shape[1]-t1-t2+2)) 13 | self.fc3 = nn.Linear(NN,N2*(input_shape[0]-f1-f2+2)*(input_shape[1]-t1-t2+2)) 14 | self.fc4 = nn.Linear(NN,N2*(input_shape[0]-f1-f2+2)*(input_shape[1]-t1-t2+2)) 15 | self.hdeconv1 = nn.ConvTranspose2d(N2, N1, kernel_size=(f2,t2)) 16 | self.hdeconv2 = nn.ConvTranspose2d(N2, N1, kernel_size=(f2,t2)) 17 | self.hdeconv3 = nn.ConvTranspose2d(N2, N1, kernel_size=(f2,t2)) 18 | self.hdeconv4 = nn.ConvTranspose2d(N2, N1, kernel_size=(f2,t2)) 19 | self.vdeconv1 = nn.ConvTranspose2d(N1, 1, kernel_size=(f1,t1)) 20 | self.vdeconv2 = nn.ConvTranspose2d(N1, 1, kernel_size=(f1,t1)) 21 | self.vdeconv3 = nn.ConvTranspose2d(N1, 1, kernel_size=(f1,t1)) 22 | self.vdeconv4 = nn.ConvTranspose2d(N1, 1, kernel_size=(f1,t1)) 23 | def forward(self, x): 24 | x = self.vconv(x) 25 | 26 | x = self.hconv(x) 27 | 28 | s1 = x.shape 29 | 30 | x = x.view(s1[0],-1) 31 | 32 | 33 | 34 | x = F.relu(self.fc0(x)) 35 | 36 | x1 = F.relu(self.fc1(x)) 37 | x2 = F.relu(self.fc2(x)) 38 | x3 = F.relu(self.fc3(x)) 39 | x4 = F.relu(self.fc4(x)) 40 | 41 | x1 = x1.view(s1[0], s1[1],s1[2],s1[3]) 42 | x2 = x2.view(s1[0], s1[1],s1[2],s1[3]) 43 | x3 = x3.view(s1[0], s1[1],s1[2],s1[3]) 44 | x4 = x4.view(s1[0], s1[1],s1[2],s1[3]) 45 | 46 | x1 = self.hdeconv1(x1) 47 | x2 = self.hdeconv2(x2) 48 | x3 = self.hdeconv3(x3) 49 | x4 = self.hdeconv4(x4) 50 | 51 | x1 = self.vdeconv1(x1) 52 | x2 = self.vdeconv2(x2) 53 | x3 = self.vdeconv3(x3) 54 | x4 = self.vdeconv4(x4) 55 | 56 | return x1, x2, x3, x4 57 | -------------------------------------------------------------------------------- /code/cyclicAnnealing.py: -------------------------------------------------------------------------------- 1 | import math 2 | from bisect import bisect_right,bisect_left 3 | 4 | import torch 5 | import numpy as np 6 | from torch.optim.lr_scheduler import _LRScheduler 7 | from torch.optim.optimizer import Optimizer 8 | 9 | class CyclicCosAnnealingLR(_LRScheduler): 10 | r""" 11 | Implements reset on milestones inspired from CosineAnnealingLR pytorch 12 | 13 | Set the learning rate of each parameter group using a cosine annealing 14 | schedule, where :math:`\eta_{max}` is set to the initial lr and 15 | :math:`T_{cur}` is the number of epochs since the last restart in SGDR: 16 | 17 | .. math:: 18 | 19 | \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 + 20 | \cos(\frac{T_{cur}}{T_{max}}\pi)) 21 | 22 | When last_epoch > last set milestone, lr is automatically set to \eta_{min} 23 | 24 | It has been proposed in 25 | `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only 26 | implements the cosine annealing part of SGDR, and not the restarts. 27 | 28 | Args: 29 | optimizer (Optimizer): Wrapped optimizer. 30 | milestones (list of ints): List of epoch indices. Must be increasing. 31 | eta_min (float): Minimum learning rate. Default: 0. 32 | last_epoch (int): The index of last epoch. Default: -1. 33 | 34 | .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: 35 | https://arxiv.org/abs/1608.03983 36 | """ 37 | 38 | def __init__(self, optimizer,milestones, eta_min=0, last_epoch=-1): 39 | if not list(milestones) == sorted(milestones): 40 | raise ValueError('Milestones should be a list of' 41 | ' increasing integers. Got {}', milestones) 42 | self.eta_min = eta_min 43 | self.milestones=milestones 44 | super(CyclicCosAnnealingLR, self).__init__(optimizer, last_epoch) 45 | 46 | def get_lr(self): 47 | 48 | if self.last_epoch >= self.milestones[-1]: 49 | return [self.eta_min for base_lr in self.base_lrs] 50 | 51 | idx = bisect_right(self.milestones,self.last_epoch) 52 | 53 | left_barrier = 0 if idx==0 else self.milestones[idx-1] 54 | right_barrier = self.milestones[idx] 55 | 56 | width = right_barrier - left_barrier 57 | curr_pos = self.last_epoch- left_barrier 58 | 59 | return [self.eta_min + (base_lr - self.eta_min) * 60 | (1 + math.cos(math.pi * curr_pos/ width)) / 2 61 | for base_lr in self.base_lrs] 62 | 63 | 64 | class CyclicLinearLR(_LRScheduler): 65 | r""" 66 | Implements reset on milestones inspired from Linear learning rate decay 67 | 68 | Set the learning rate of each parameter group using a linear decay 69 | schedule, where :math:`\eta_{max}` is set to the initial lr and 70 | :math:`T_{cur}` is the number of epochs since the last restart: 71 | 72 | .. math:: 73 | 74 | \eta_t = \eta_{min} + (\eta_{max} - \eta_{min})(1 -\frac{T_{cur}}{T_{max}}) 75 | 76 | When last_epoch > last set milestone, lr is automatically set to \eta_{min} 77 | 78 | Args: 79 | optimizer (Optimizer): Wrapped optimizer. 80 | milestones (list of ints): List of epoch indices. Must be increasing. 81 | eta_min (float): Minimum learning rate. Default: 0. 82 | last_epoch (int): The index of last epoch. Default: -1. 83 | 84 | .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: 85 | https://arxiv.org/abs/1608.03983 86 | """ 87 | 88 | def __init__(self, optimizer,milestones, eta_min=0, last_epoch=-1): 89 | if not list(milestones) == sorted(milestones): 90 | raise ValueError('Milestones should be a list of' 91 | ' increasing integers. Got {}', milestones) 92 | self.eta_min = eta_min 93 | self.milestones=milestones 94 | super(CyclicLinearLR, self).__init__(optimizer, last_epoch) 95 | 96 | def get_lr(self): 97 | 98 | if self.last_epoch >= self.milestones[-1]: 99 | return [self.eta_min for base_lr in self.base_lrs] 100 | 101 | idx = bisect_right(self.milestones,self.last_epoch) 102 | 103 | left_barrier = 0 if idx==0 else self.milestones[idx-1] 104 | right_barrier = self.milestones[idx] 105 | 106 | width = right_barrier - left_barrier 107 | curr_pos = self.last_epoch- left_barrier 108 | 109 | return [self.eta_min + (base_lr - self.eta_min) * 110 | (1. - 1.0*curr_pos/ width) 111 | for base_lr in self.base_lrs] 112 | 113 | ''' 114 | ################################# 115 | # TEST FOR SCHEDULER 116 | ################################# 117 | import matplotlib.pyplot as plt 118 | import torch.nn as nn 119 | import torch.optim as optim 120 | 121 | net = nn.Sequential(nn.Linear(2,2)) 122 | milestones = [(2**x)*300 for x in range(30)] 123 | optimizer = optim.SGD(net.parameters(),lr=1e-3,momentum=0.9,weight_decay=0.0005,nesterov=True) 124 | scheduler = CyclicCosAnnealingLR(optimizer,milestones=milestones,eta_min=1e-6) 125 | 126 | lr_log = [] 127 | 128 | for i in range(20*300): 129 | optimizer.step() 130 | scheduler.step() 131 | for param_group in optimizer.param_groups: 132 | lr_log.append(param_group['lr']) 133 | 134 | plt.plot(lr_log) 135 | plt.show() 136 | ''' 137 | -------------------------------------------------------------------------------- /code/data_loader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.dataset import Dataset 2 | import torch 3 | from torchvision import transforms 4 | #from skimage import io, transform 5 | import os 6 | import numpy as np 7 | import re 8 | 9 | class SourceSepTrain(Dataset): 10 | def __init__(self, path='../Processed/Mixtures', transforms=None): 11 | # assuming this to be the directory containing all the magnitude spectrum 12 | #for all songs and all segments used in training 13 | self.path = path 14 | self.list = os.listdir(self.path) 15 | self.transforms = transforms 16 | 17 | def __getitem__(self, index): 18 | mixture_path = '../Processed/Mixtures/' 19 | bass_path = '../Processed/Bass/' 20 | vocals_path = '../Processed/Vocals/' 21 | drums_path = '../Processed/Drums/' 22 | others_path = '../Processed/Others/' 23 | mixture = torch.load(mixture_path+self.list[index]) 24 | #phase = torch.load(mixture_path+self.list[index]+'_p') 25 | bass = torch.load(bass_path+self.list[index]) 26 | vocals = torch.load(vocals_path+self.list[index]) 27 | drums = torch.load(drums_path+self.list[index]) 28 | others = torch.load(others_path+self.list[index]) 29 | #print(mixture) 30 | if self.transforms is not None: 31 | mixture = self.transforms(mixture) 32 | 33 | bass = self.transforms(bass) 34 | vocals = self.transforms(vocals) 35 | drums = self.transforms(drums) 36 | others = self.transforms(others) 37 | return (mixture,bass, vocals, drums, others) 38 | 39 | def __len__(self): 40 | return len(self.list) # length of how much data you have 41 | 42 | 43 | class SourceSepVal(Dataset): 44 | def __init__(self, path='../Val/Mixtures', transforms=None): 45 | # assuming this to be the directory containing all the magnitude spectrum 46 | #for all songs and all segments used in training 47 | self.path = path 48 | self.list = os.listdir(self.path) 49 | self.transforms = transforms 50 | 51 | def __getitem__(self, index): 52 | # stuff 53 | mixture_path = '../Val/Mixtures/' 54 | bass_path = '../Val/Bass/' 55 | vocals_path = '../Val/Vocals/' 56 | drums_path = '../Val/Drums/' 57 | others_path = '../Val/Others/' 58 | 59 | mixture = torch.load(mixture_path+self.list[index]) 60 | #phase = torch.load(mixture_path+self.list[index]+'_p') 61 | bass = torch.load(bass_path+self.list[index]) 62 | vocals = torch.load(vocals_path+self.list[index]) 63 | drums = torch.load(drums_path+self.list[index]) 64 | others = torch.load(others_path+self.list[index]) 65 | 66 | if self.transforms is not None: 67 | mixture = self.transforms(mixture) 68 | bass = self.transforms(bass) 69 | vocals = self.transforms(vocals) 70 | drums = self.transforms(drums) 71 | others = self.transforms(others) 72 | 73 | return (mixture,bass, vocals, drums, others) 74 | def __len__(self): 75 | return len(self.list) 76 | 77 | class SourceSepTest(Dataset): 78 | def __init__(self, path='../Val/Mixtures',transforms=None): 79 | # assuming this to be the directory containing all the magnitude spectrum 80 | #for all songs and all segments used in training 81 | self.path = path 82 | self.list = os.listdir(self.path) 83 | self.transforms = transforms 84 | 85 | def __getitem__(self, index): 86 | mixture_path = '../Val/Mixtures/' 87 | bass_path = '../Val/Bass/' 88 | vocals_path = '../Val/Vocals/' 89 | drums_path = '../Val/Drums/' 90 | others_path = '../Val/Others/' 91 | phase_path = '../Val/Phases/' 92 | 93 | phase_file=self.list[index].replace('_m','_p') 94 | phase_file=phase_file.replace('.pt','.npy') 95 | mixture = torch.load(mixture_path+self.list[index]) 96 | #phase = np.load(phase_path+phase_file) 97 | bass = torch.load(bass_path+self.list[index]) 98 | vocals = torch.load(vocals_path+self.list[index]) 99 | drums = torch.load(drums_path+self.list[index]) 100 | others = torch.load(others_path+self.list[index]) 101 | 102 | if self.transforms is not None: 103 | mixture = self.transforms(mixture) 104 | bass = self.transforms(bass) 105 | vocals = self.transforms(vocals) 106 | drums = self.transforms(drums) 107 | others = self.transforms(others) 108 | 109 | return (mixture,phase_file,self.list[index]) 110 | 111 | 112 | def __len__(self): 113 | return len(self.list) 114 | -------------------------------------------------------------------------------- /code/evaluate.m: -------------------------------------------------------------------------------- 1 | [data1, samp_freq1] = audioread('vocals.wav'); 2 | [data2, samp_freq2] = audioread('bass.wav'); 3 | [data3, samp_freq3] = audioread('drums.wav'); 4 | [data4, samp_freq4] = audioread('other.wav'); 5 | 6 | % take only num_points to compute power (ratio used) 7 | num_points = 100000; 8 | points_taken = floor(linspace(1, size(data1,1), num_points)); 9 | 10 | %%%%%%% Replace this with your model o/p %%%%%%%%%%%%% 11 | snr = 0.5; 12 | data1_predicted = awgn(data1, snr, 'measured'); 13 | data2_predicted = awgn(data2, snr, 'measured'); 14 | data3_predicted = awgn(data3, snr, 'measured'); 15 | data4_predicted = awgn(data4, snr, 'measured'); 16 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 17 | 18 | 19 | se = [data1_predicted(points_taken,1), data2_predicted(points_taken,1), data3_predicted(points_taken,1), data4_predicted(points_taken,1)]'; 20 | s = [data1(points_taken,1), data2(points_taken,1), data3(points_taken,1), data4(points_taken,1)]'; 21 | 22 | [SDR,SIR,SAR,perm]=bss_eval_sources(se,s); 23 | -------------------------------------------------------------------------------- /code/evaluate.py: -------------------------------------------------------------------------------- 1 | import mir_eval 2 | import numpy as np 3 | from scipy.io import wavfile 4 | import librosa 5 | 6 | 7 | ####################### MODIFY ############################## 8 | #### additional for loop to evaluate multiple songs ######### 9 | # increase step to decrease time 10 | step = 10 11 | bass_gt_path = 'bass.wav' 12 | bass_rec_path = 'bass_rec.wav' 13 | vocal_gt_path = 'vocals.wav' 14 | vocal_rec_path = 'vocals_rec.wav' 15 | drums_gt_path = 'drums.wav' 16 | drums_rec_path = 'drums_rec.wav' 17 | other_gt_path = 'other.wav' 18 | other_rec_path = 'other_rec.wav' 19 | ############################################################ 20 | 21 | 22 | 23 | bass_gt, rate11 = librosa.load(bass_gt_path,sr=44100, offset=30*0.3,duration = 170*0.3) 24 | bass_rec, rate21 = librosa.load(bass_rec_path,sr=44100) 25 | 26 | vocals_gt, rate12 = librosa.load(vocal_gt_path,sr=44100, offset=30*0.3,duration = 170*0.3) 27 | vocals_rec, rate22 = librosa.load(vocal_rec_path,sr=44100) 28 | 29 | drums_gt, rate13 = librosa.load(drums_gt_path,sr=44100, offset=30*0.3,duration = 170*0.3) 30 | drums_rec, rate23 = librosa.load(drums_rec_path,sr=44100) 31 | 32 | other_gt, rate14 = librosa.load(other_gt_path,sr=44100, offset=30*0.3,duration = 170*0.3) 33 | other_rec, rate24 = librosa.load(other_rec_path,sr=44100) 34 | 35 | 36 | bass_gt = bass_gt[0:bass_rec.shape[0]:step] 37 | bass_gt = np.transpose(bass_gt.reshape(len(bass_gt), 1)) 38 | 39 | vocals_gt = vocals_gt[0:vocals_rec.shape[0]:step] 40 | vocals_gt = np.transpose(vocals_gt.reshape(len(vocals_gt), 1)) 41 | 42 | drums_gt = drums_gt[0:drums_rec.shape[0]:step] 43 | drums_gt = np.transpose(drums_gt.reshape(len(drums_gt), 1)) 44 | 45 | other_gt = other_gt[0:other_rec.shape[0]:step] 46 | other_gt = np.transpose(other_gt.reshape(len(other_gt), 1)) 47 | 48 | final_gt = np.concatenate((bass_gt, vocals_gt, drums_gt, other_gt), axis = 0) 49 | print(final_gt.shape) 50 | 51 | 52 | bass_rec = bass_rec[0:bass_rec.shape[0]:step] 53 | bass_rec = np.transpose(bass_rec.reshape(len(bass_rec), 1)) 54 | 55 | vocals_rec = vocals_rec[0:vocals_rec.shape[0]:step] 56 | vocals_rec = np.transpose(vocals_rec.reshape(len(vocals_rec), 1)) 57 | 58 | drums_rec = drums_rec[0:drums_rec.shape[0]:step] 59 | drums_rec = np.transpose(drums_rec.reshape(len(drums_rec), 1)) 60 | 61 | other_rec = other_rec[0:other_rec.shape[0]:step] 62 | other_rec = np.transpose(other_rec.reshape(len(other_rec), 1)) 63 | 64 | final_rec = np.concatenate((bass_rec, vocals_rec, drums_rec, other_rec), axis = 0) 65 | print(final_rec.shape) 66 | 67 | 68 | 69 | SDR, SIR, SAR, perm = mir_eval.separation.bss_eval_sources(final_gt, final_rec) 70 | 71 | print(SDR) 72 | print(SIR) 73 | print(SAR) 74 | print(perm) -------------------------------------------------------------------------------- /code/post_processing.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import numpy as np 3 | #import mathplotlib.pyplot as plt 4 | import pickle 5 | import torch 6 | import os 7 | import re 8 | 9 | def reconstruct(phase, bass_mag, vocals_mag, drums_mag,others_mag,song_num,segment_num,destination_path): 10 | # Retrieve complex STFT 11 | vocals = np.squeeze(vocals_mag.detach().numpy() * phase,axis= (0,1)) 12 | #print(vocals.shape) 13 | bass = np.squeeze(bass_mag.detach().numpy() * phase, axis=(0,1)) 14 | drums = np.squeeze(drums_mag.detach().numpy() * phase, axis=(0,1)) 15 | others = np.squeeze(others_mag.detach().numpy() * phase, axis=(0,1)) 16 | 17 | # Perform ISTFT 18 | vocals_audio = librosa.istft(vocals, win_length=1024,hop_length=256,window='hann',center='True') 19 | bass_audio = librosa.istft(bass, win_length=1024,hop_length=256,window='hann',center='True') 20 | drums_audio = librosa.istft(drums, win_length=1024,hop_length=256,window='hann',center='True') 21 | others_audio = librosa.istft(others, win_length=1024,hop_length=256,window='hann',center='True') 22 | 23 | # Save as wav files 24 | librosa.output.write_wav(os.path.join(destination_path,'vocals',str(song_num)+'_'+str(segment_num)+'.wav'), vocals_audio,sr=44100) 25 | librosa.output.write_wav(os.path.join(destination_path,'bass',str(song_num)+'_'+str(segment_num)+'.wav'), bass_audio, sr=44100) 26 | librosa.output.write_wav(os.path.join(destination_path,'drums',str(song_num)+'_'+str(segment_num)+'.wav'), drums_audio, sr=44100) 27 | librosa.output.write_wav(os.path.join(destination_path,'others',str(song_num)+'_'+str(segment_num)+'.wav'), others_audio, sr=44100) 28 | return 29 | -------------------------------------------------------------------------------- /code/pre_processing.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import numpy as np 3 | #import mathplotlib.pyplot as plt 4 | import pickle 5 | import torch 6 | import os 7 | import re 8 | 9 | path= "../dsd100/subset/" 10 | path_mixtures = path + "Mixtures/Dev/" 11 | path_sources = path + "Sources/Dev/" 12 | mean_var_path= "../Processed/" 13 | destination_path = "../Processed/Mixtures" 14 | phase_path= "../Processed/Phases" 15 | bass_path="../Processed/Bass" 16 | vocals_path="../Processed/Vocals" 17 | drums_path="../Processed/Drums" 18 | others_path="../Processed/Others" 19 | source_dest_paths=[vocals_path,bass_path,drums_path,others_path] 20 | 21 | path_val_mixtures = path + "Mixtures/Test/" 22 | path_val_sources = path + "Sources/Test/" 23 | validation_path = "../Val/Mixtures" 24 | val_phase_path= "../Val/Phases" 25 | val_bass_path="../Val/Bass" 26 | val_vocals_path="../Val/Vocals" 27 | val_drums_path="../Val/Drums" 28 | val_others_path="../Val/Others" 29 | source_val_paths=[val_vocals_path,val_bass_path,val_drums_path,val_others_path] 30 | 31 | ### test paths for now same as the validation path! 32 | path_test_mixtures = path + "Mixtures/Test/" 33 | path_test_sources = path + "Sources/Test/" 34 | testing_path = "../Test/Mixtures" 35 | test_phase_path= "../Test/Phases" 36 | test_bass_path="../Test/Bass" 37 | test_vocals_path="../Test/Vocals" 38 | test_drums_path="../Test/Drums" 39 | test_others_path="../Test/Others" 40 | source_test_paths=[test_vocals_path,test_bass_path,test_drums_path,test_others_path] 41 | 42 | 43 | def process(file_path,direc,destination_path,phase_bool,destination_phase_path): 44 | t1,t2=librosa.load(file_path,sr=None) 45 | duration=librosa.get_duration(t1,t2) 46 | regex = re.compile(r'\d+') 47 | index=regex.findall(direc) 48 | #print(index) 49 | num_segments=0 50 | #mean=np.zeros((513,52)) 51 | #var=np.zeros((513,52)) 52 | for start in range(30,int(200)): 53 | 54 | wave_array, fs = librosa.load(file_path,sr=44100,offset=start*0.3,duration = 0.3) 55 | 56 | mag, phase = librosa.magphase(librosa.stft(wave_array, n_fft=1024,hop_length=256,window='hann',center='True')) 57 | #mean+=mag 58 | #num_segments+=1; 59 | if not os.path.exists(destination_path): 60 | os.makedirs(destination_path) 61 | #print(mag.shape) 62 | #print(torch.from_numpy(np.expand_dims(mag,axis=0)).shape) 63 | 64 | # magnitude stored as tensor, phase as np array 65 | #pickle.dump(torch.from_numpy(np.expand_dims(mag,axis=2)),open(os.path.join(destination_path,(index[0] +"_" + str(start) +'_m.pt')),'wb')) 66 | torch.save(torch.from_numpy(np.expand_dims(mag,axis=0)),os.path.join(destination_path,(index[0] +"_" + str(start) +'_m.pt'))) 67 | if phase_bool: 68 | if not os.path.exists(destination_phase_path): 69 | os.makedirs(destination_phase_path) 70 | np.save(os.path.join(destination_phase_path,(index[0]+"_" +str(start)+'_p.npy')),phase) 71 | return 72 | 73 | #--------- training data------------------------------------- 74 | 75 | for subdirs, dirs, files in os.walk(path_mixtures): 76 | for direc in dirs: 77 | print('working with training '+ direc) 78 | total_mean=0 79 | total_num_segments=0 80 | for s,d,f in os.walk(path_mixtures + direc): 81 | process(os.path.join(path_mixtures,direc,f[0]),direc,destination_path,True,phase_path) 82 | #total_mean+= mean 83 | #total_num_segments+=num_segments 84 | #total_mean/= total_num_segments 85 | 86 | #torch.save(torch.from_numpy(np.expand_dims(total_mean,axis=0)).float(),os.path.join(mean_var_path,'mean.pt')) 87 | # print(total_mean) # print(total_mean) 88 | 89 | # print('##################################################################') 90 | # print(total_var) 91 | # assert False 92 | for subdirs, dirs, files in os.walk(path_sources): 93 | for direc in dirs: 94 | print('source with training '+ direc) 95 | for s,d,file in os.walk(path_sources + direc): 96 | for i in range(0,4): 97 | print(file[i]) 98 | process(os.path.join(path_sources,direc,file[i]),direc,source_dest_paths[i],False,phase_path) 99 | 100 | 101 | 102 | #------------------------ Validation data----------------------------------- 103 | 104 | for subdirs, dirs, files in os.walk(path_val_mixtures): 105 | for direc in dirs: 106 | print('working with validation '+ direc) 107 | for s,d,f in os.walk(path_val_mixtures + direc): 108 | 109 | process(os.path.join(path_val_mixtures,direc,f[0]),direc,validation_path,True,val_phase_path) 110 | 111 | for subdirs, dirs, files in os.walk(path_val_sources): 112 | for direc in dirs: 113 | print('source with validation '+ direc) 114 | for s,d,file in os.walk(path_val_sources + direc): 115 | for i in range(0,4): 116 | print(file[i]) 117 | process(os.path.join(path_val_sources,direc,file[i]),direc,source_val_paths[i],False,val_phase_path) 118 | 119 | #----------------------Testing data------------------------------------------- 120 | 121 | #for subdirs, dirs, files in os.walk(path_test_mixtures): 122 | # for direc in dirs: 123 | # print('working with validation '+ direc) 124 | # for s,d,f in os.walk(path_test_mixtures + direc): 125 | # 126 | # process(os.path.join(path_test_mixtures,direc,f[0]),direc,testing_path,True,test_phase_path) 127 | # 128 | #for subdirs, dirs, files in os.walk(path_test_sources): 129 | # for direc in dirs: 130 | # print('source with testset '+ direc) 131 | # for s,d,file in os.walk(path_test_sources + direc): 132 | # for i in range(0,4): 133 | # print(file[i]) 134 | # process(os.path.join(path_test_sources,direc,file[i]),direc,source_test_paths[i],False,test_phase_path) 135 | -------------------------------------------------------------------------------- /code/stiching.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import subprocess 3 | import numpy as np 4 | import os 5 | import re 6 | import glob 7 | 8 | destination_path='../Recovered_Songs_bigger5/' 9 | vocals_directory='../AudioResults/vocals' 10 | drums_directory='../AudioResults/drums' 11 | bass_directory='../AudioResults/bass' 12 | others_directory='../AudioResults/others' 13 | test_songs_list=[] 14 | test_segment_length=[] 15 | vocals_list=[] 16 | if not os.path.exists(destination_path): 17 | os.makedirs(destination_path) 18 | if not os.path.exists(vocals_directory): 19 | os.makedirs(vocals_directory) 20 | if not os.path.exists(drums_directory): 21 | os.makedirs(drums_directory) 22 | if not os.path.exists(bass_directory): 23 | os.makedirs(bass_directory) 24 | if not os.path.exists(others_directory): 25 | os.makedirs(others_directory) 26 | for subdirs, dirs, files in os.walk(vocals_directory): 27 | print('finding list of songs ') 28 | for file in files : 29 | regex = re.compile(r'\d+') 30 | index = regex.findall(file) 31 | if not (index[0] in test_songs_list) : 32 | test_songs_list.append(index[0]) 33 | 34 | for test_songs in (test_songs_list): 35 | combined_vocals=np.array([]) 36 | sr=None 37 | print('testing,..'+test_songs) 38 | print('Stitching Vocals') 39 | vocals_list = sorted(glob.glob(os.path.join(vocals_directory,test_songs+"*"))) 40 | vocals_path=os.path.join(destination_path,'vocals') 41 | if not os.path.exists(vocals_path): 42 | os.makedirs(vocals_path) 43 | sound_output_path = os.path.join(vocals_path,test_songs+'.wav') 44 | for segment in (vocals_list) : 45 | seg, sr = librosa.load(segment, sr=44100) 46 | print(sr) 47 | assert sr==44100 48 | combined_vocals= np.append(combined_vocals,seg) 49 | librosa.output.write_wav(sound_output_path,combined_vocals,sr) 50 | 51 | 52 | print('Stitching Bass') 53 | combined_bass=np.array([]) 54 | sr=None 55 | bass_list = sorted(glob.glob(os.path.join(bass_directory,test_songs+"*"))) 56 | bass_path=os.path.join(destination_path,'bass') 57 | if not os.path.exists(bass_path): 58 | os.makedirs(bass_path) 59 | sound_output_path = os.path.join(bass_path,test_songs+'.wav') 60 | for segment in (bass_list) : 61 | seg, sr = librosa.load(segment,sr=44100) 62 | assert sr==44100 63 | combined_bass= np.append(combined_bass,seg) 64 | librosa.output.write_wav(sound_output_path,combined_bass,sr) 65 | 66 | 67 | print('Stitching Drums') 68 | combined_drums=np.array([]) 69 | sr=None 70 | drums_list = sorted(glob.glob(os.path.join(drums_directory,test_songs+"*"))) 71 | drums_path=os.path.join(destination_path,'drums') 72 | if not os.path.exists(drums_path): 73 | os.makedirs(drums_path) 74 | sound_output_path = os.path.join(drums_path,test_songs+'.wav') 75 | for segment in (drums_list) : 76 | seg, sr = librosa.load(segment,sr=44100) 77 | combined_drums= np.append(combined_drums,seg) 78 | librosa.output.write_wav(sound_output_path,combined_drums,sr) 79 | 80 | print('Stitching Others') 81 | combined_others=np.array([]) 82 | sr=None 83 | others_list = sorted(glob.glob(os.path.join(others_directory,test_songs+"*"))) 84 | others_path=os.path.join(destination_path,'others') 85 | if not os.path.exists(others_path): 86 | os.makedirs(others_path) 87 | sound_output_path = os.path.join(others_path,test_songs+'.wav') 88 | for segment in (others_list) : 89 | seg, sr = librosa.load(segment,sr=44100) 90 | combined_others= np.append(combined_others,seg) 91 | librosa.output.write_wav(sound_output_path,combined_others,sr) 92 | -------------------------------------------------------------------------------- /code/test_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import glob 4 | import re 5 | import os 6 | from build_model_original import SepConvNet 7 | from torch.utils.data import DataLoader 8 | from data_loader import SourceSepTest 9 | from post_processing import reconstruct 10 | from train_model import TimeFreqMasking 11 | from tqdm import tqdm 12 | 13 | if __name__ == '__main__': 14 | inp_size = [513,52] 15 | t1=1 16 | f1=513#513 17 | t2=15 18 | f2=1 19 | N1=50 20 | N2=30 21 | NN=128 22 | alpha = 0.001 23 | beta = 0.01 24 | beta_vocals = 0.03 25 | batch_size = 1 26 | num_epochs = 50 27 | 28 | destination_path= '../AudioResults/' 29 | phase_path = '../Val/Phases/' 30 | vocals_directory='../AudioResults/vocals' 31 | drums_directory='../AudioResults/drums' 32 | bass_directory='../AudioResults/bass' 33 | others_directory='../AudioResults/others' 34 | 35 | if not os.path.exists(destination_path): 36 | os.makedirs(destination_path) 37 | if not os.path.exists(vocals_directory): 38 | os.makedirs(vocals_directory) 39 | if not os.path.exists(drums_directory): 40 | os.makedirs(drums_directory) 41 | if not os.path.exists(bass_directory): 42 | os.makedirs(bass_directory) 43 | if not os.path.exists(others_directory): 44 | os.makedirs(others_directory) 45 | 46 | 47 | net = SepConvNet(t1,f1,t2,f2,N1,N2,inp_size,NN) 48 | # net.load_state_dict(torch.load('Weights/Weights_200_3722932.6015625.pth')) #least score Weights so far 49 | net.load_state_dict(torch.load('Weights/Weights_norm_orig2.pth')) 50 | net.eval() 51 | test_set = SourceSepTest(transforms = None) 52 | test_loader = DataLoader(test_set, batch_size=batch_size,shuffle=False) 53 | for i,(test_inp,test_phase_file,file_str) in tqdm(enumerate(test_loader)): 54 | print('Testing, i='+str(i)) 55 | test_phase = np.load(phase_path+test_phase_file[0]) 56 | 57 | mean = torch.mean(test_inp) 58 | std = torch.std(test_inp) 59 | test_inp_n = (test_inp-mean)/std 60 | bass_mag, vocals_mag, drums_mag,others_mag = net(test_inp_n) 61 | bass_mag, vocals_mag, drums_mag,others_mag = TimeFreqMasking(bass_mag, vocals_mag, drums_mag,others_mag) 62 | bass_mag = bass_mag*test_inp 63 | vocals_mag = vocals_mag*test_inp 64 | drums_mag = drums_mag*test_inp 65 | others_mag = others_mag*test_inp 66 | 67 | regex = re.compile(r'\d+') 68 | index=regex.findall(file_str[0]) 69 | reconstruct(test_phase, bass_mag, vocals_mag, drums_mag,others_mag,index[0],index[1],destination_path) 70 | 71 | # list = sorted(glob.glob('*.wav')) 72 | -------------------------------------------------------------------------------- /code/train_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torch.utils.data import DataLoader 5 | from torch.autograd import Variable 6 | from build_model_original import * 7 | # from build_model_old import * 8 | # from cyclicAnnealing import CyclicLinearLR 9 | import os 10 | from tqdm import tqdm 11 | from data_loader import * 12 | from tensorboardX import SummaryWriter 13 | from torch.optim.lr_scheduler import MultiStepLR 14 | 15 | mean_var_path= "../Processed/" 16 | if not os.path.exists('Weights'): 17 | os.makedirs('Weights') 18 | #os.environ["CUDA_VISIBLE_DEVICES"]="0" 19 | #-------------------------- 20 | class Average(object): 21 | def __init__(self): 22 | self.reset() 23 | 24 | def reset(self): 25 | self.sum = 0 26 | self.count = 0 27 | 28 | def update(self, val, n=1): 29 | self.sum += val 30 | self.count += n 31 | 32 | #property 33 | def avg(self): 34 | return self.sum / self.count 35 | #------------------------------ 36 | # import csv 37 | writer = SummaryWriter() 38 | #---------------------------------------- 39 | 40 | inp_size = [513,52] 41 | t1=1 42 | f1=513 43 | t2=15 44 | f2=1 45 | N1=50 46 | N2=30 47 | NN=128 48 | alpha = 0.005 49 | beta = 0.05 50 | beta_vocals = 0.08 51 | batch_size = 30 52 | num_epochs = 50 53 | 54 | 55 | class MixedSquaredError(nn.Module): 56 | def __init__(self, weight=None, size_average=True): 57 | super(MixedSquaredError, self).__init__() 58 | 59 | def forward(self, pred_bass,pred_vocals,pred_drums,pred_others, gt_bass,gt_vocals,gt_drums, gt_others): 60 | 61 | 62 | L_sq = torch.sum((pred_bass-gt_bass).pow(2)) + torch.sum((pred_vocals-gt_vocals).pow(2)) + torch.sum((pred_drums-gt_drums).pow(2)) 63 | L_other = torch.sum((pred_bass-gt_others).pow(2)) + torch.sum((pred_drums-gt_others).pow(2)) 64 | #+ torch.sum((pred_vocals-gt_others).pow(2)) 65 | L_othervocals = torch.sum((pred_vocals - gt_others).pow(2)) 66 | L_diff = torch.sum((pred_bass-pred_vocals).pow(2)) + torch.sum((pred_bass-pred_drums).pow(2)) + torch.sum((pred_vocals-pred_drums).pow(2)) 67 | 68 | return (L_sq- alpha*L_diff - beta*L_other - beta_vocals*L_othervocals) 69 | 70 | def TimeFreqMasking(bass,vocals,drums,others,cuda=0): 71 | den = torch.abs(bass) + torch.abs(vocals) + torch.abs(drums) + torch.abs(others) 72 | if(cuda): 73 | den = den + 10e-8*torch.cuda.FloatTensor(bass.size()).normal_() 74 | else: 75 | den = den + 10e-8*torch.FloatTensor(bass.size()).normal_() 76 | 77 | 78 | bass = torch.abs(bass)/den 79 | vocals = torch.abs(vocals)/den 80 | drums = torch.abs(drums)/den 81 | others = torch.abs(others)/den 82 | 83 | return bass,vocals,drums,others 84 | #mu=torch.load(os.path.join(mean_var_path,'mean.pt')) 85 | #std=torch.load(os.path.join(mean_var_path,'std.pt')) 86 | #transformations_train = transforms.Compose([transforms.Normalize(mean = mu, std = std)]) 87 | # 88 | train_set = SourceSepTrain(transforms = None) 89 | 90 | 91 | #transformation_test = transforms.Compose([ transforms.Normalize(mean = 0.0, std =1./var), transforms.Normalize(mean = -1*mu, std = 1.0),]) 92 | 93 | 94 | def train(): 95 | cuda = torch.cuda.is_available() 96 | net = SepConvNet(t1,f1,t2,f2,N1,N2,inp_size,NN) 97 | criterion = MixedSquaredError() 98 | if cuda: 99 | net = net.cuda() 100 | criterion = criterion.cuda() 101 | optimizer = torch.optim.Adam(net.parameters(), lr=1e-3) 102 | # scheduler = CyclicLinearLR(optimizer, milestones=[60,120]) 103 | scheduler = MultiStepLR(optimizer, milestones=[60,120]) 104 | print("preparing training data ...") 105 | train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True) 106 | print("done ...") 107 | val_set = SourceSepVal(transforms = None) 108 | val_loader = DataLoader(val_set, batch_size=batch_size,shuffle=False) 109 | 110 | for epoch in range(num_epochs): 111 | scheduler.step() 112 | train_loss = Average() 113 | 114 | net.train() 115 | for i, (inp, gt_bass,gt_vocals,gt_drums,gt_others) in enumerate(train_loader): 116 | mean = torch.mean(inp) 117 | std = torch.std(inp) 118 | inp_n = (inp-mean)/std 119 | 120 | inp = Variable(inp) 121 | inp_n = Variable(inp_n) 122 | gt_bass = Variable(gt_bass) 123 | gt_vocals = Variable(gt_vocals) 124 | gt_drums = Variable(gt_drums) 125 | gt_others= Variable(gt_others) 126 | if cuda: 127 | inp = inp.cuda() 128 | inp_n = inp_n.cuda() 129 | gt_bass = gt_bass.cuda() 130 | gt_vocals = gt_vocals.cuda() 131 | gt_drums = gt_drums.cuda() 132 | gt_others= gt_others.cuda() 133 | optimizer.zero_grad() 134 | o_bass, o_vocals, o_drums, o_others = net(inp_n) 135 | 136 | 137 | mask_bass,mask_vocals,mask_drums,mask_others = TimeFreqMasking(o_bass, o_vocals, o_drums, o_others,cuda) 138 | pred_drums=inp*mask_drums 139 | pred_vocals=inp*mask_vocals 140 | pred_bass=inp*mask_bass 141 | pred_others=inp*mask_others 142 | 143 | loss = criterion(pred_bass,pred_vocals,pred_drums,pred_others, gt_bass,gt_vocals,gt_drums,gt_others) 144 | writer.add_scalar('Train Loss',loss,epoch) 145 | loss.backward() 146 | optimizer.step() 147 | train_loss.update(loss.item(), inp.size(0)) 148 | for param_group in optimizer.param_groups: 149 | writer.add_scalar('Learning Rate',param_group['lr']) 150 | 151 | val_loss = Average() 152 | net.eval() 153 | for i,(val_inp, gt_bass,gt_vocals,gt_drums,gt_others) in enumerate(val_loader): 154 | val_mean = torch.mean(val_inp) 155 | val_std = torch.std(val_inp) 156 | val_inp_n = (val_inp-val_mean)/val_std 157 | 158 | val_inp = Variable(val_inp) 159 | val_inp_n = Variable(val_inp_n) 160 | gt_bass = Variable(gt_bass) 161 | gt_vocals = Variable(gt_vocals) 162 | gt_drums = Variable(gt_drums) 163 | gt_others = Variable(gt_others) 164 | if cuda: 165 | val_inp = val_inp.cuda() 166 | val_inp_n = val_inp_n.cuda() 167 | gt_bass = gt_bass.cuda() 168 | gt_vocals = gt_vocals.cuda() 169 | gt_drums = gt_drums.cuda() 170 | gt_others = gt_others.cuda() 171 | 172 | o_bass, o_vocals, o_drums, o_others = net(val_inp_n) 173 | mask_bass,mask_vocals,mask_drums,mask_others = TimeFreqMasking(o_bass, o_vocals, o_drums, o_others,cuda) 174 | #print(val_inp.shape) 175 | #print(mask_drums.shape) 176 | #assert False 177 | pred_drums=val_inp*mask_drums 178 | pred_vocals=val_inp*mask_vocals 179 | pred_bass=val_inp*mask_bass 180 | pred_others=val_inp*mask_others 181 | 182 | if (epoch)%10==0: 183 | writer.add_image('Validation Input',val_inp,epoch) 184 | writer.add_image('Validation Bass GT ',gt_bass,epoch) 185 | writer.add_image('Validation Bass Pred ',pred_bass,epoch) 186 | writer.add_image('Validation Vocals GT ',gt_vocals,epoch) 187 | writer.add_image('Validation Vocals Pred ',pred_vocals,epoch) 188 | writer.add_image('Validation Drums GT ',gt_drums,epoch) 189 | writer.add_image('Validation Drums Pred ',pred_drums,epoch) 190 | writer.add_image('Validation Other GT ',gt_others,epoch) 191 | writer.add_image('Validation Others Pred ',pred_others,epoch) 192 | 193 | vloss = criterion(pred_bass,pred_vocals,pred_drums,pred_others, gt_bass,gt_vocals,gt_drums, gt_others) 194 | writer.add_scalar('Validation loss',vloss,epoch) 195 | val_loss.update(vloss.item(), inp.size(0)) 196 | 197 | print("Epoch {}, Training Loss: {}, Validation Loss: {}".format(epoch+1, train_loss.avg(), val_loss.avg())) 198 | torch.save(net.state_dict(), 'Weights/Weights_{}_{}.pth'.format(epoch+1, val_loss.avg())) 199 | return net 200 | 201 | def test(model): 202 | model.eval() 203 | 204 | 205 | if __name__ == "__main__": 206 | train() 207 | --------------------------------------------------------------------------------