├── .gitignore ├── README.md ├── meta ├── README.md ├── __pycache__ │ ├── create_whamr_rirs.cpython-38.pyc │ ├── preprocess_dynamic_mixing.cpython-38.pyc │ └── wham_room.cpython-38.pyc ├── activlev.m ├── create_rirs.sh ├── create_whamr_rirs.py ├── data │ ├── mix_2_spk_filenames_cv.csv │ ├── mix_2_spk_filenames_tr.csv │ ├── mix_2_spk_filenames_tt.csv │ ├── reverb_params_cv.csv │ ├── reverb_params_tr.csv │ └── reverb_params_tt.csv ├── maxfilt.m ├── preprocess_dynamic_mixing.py ├── rir_constants.py └── wham_room.py ├── prepare_data.py ├── requirements.txt └── separation ├── dynamic_mixing.py ├── hparams ├── baselines │ └── tcn │ │ └── tcn-whamr.yaml ├── deformable │ ├── dm │ │ ├── dtcn-whamr.yaml │ │ └── shared_weights │ │ │ └── dtcn-whamr.yaml │ ├── dtcn-whamr.yaml │ └── shared_weights │ │ └── dtcn-whamr.yaml └── wsj0-2mix │ ├── deformable │ ├── dm │ │ ├── dtcn-wsj0-2mix.yaml │ │ └── shared_weights │ │ │ └── dtcn-wsj0-2mix.yaml │ ├── dtcn-wsj0-2mix.yaml │ └── shared_weights │ │ └── dtcn-wsj0-2mix.yaml │ └── dtcn-wsj0-2mix.yaml ├── model_info.py ├── src ├── deformable.py ├── macs.py ├── measures.py ├── tcn.py └── utils.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | /__pycache__/* 2 | /separation/src/__pycache__/* 3 | /separation/results/* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deformable Temporal Convolutional Networks (DTCN) 2 | 3 | Work on this repository is moving to https://github.com/jwr1995/PubSep 4 | 5 | This repository provides training and evalution scripts for the DTCN speech separation model described in the paper "Deformable Temporal Convolutional Networks for Monaural Noisy Reverberant Speech Separation" - https://arxiv.org/pdf/2210.15305.pdf. 6 | 7 | As baseline TCN schema is also provided along with tools for estimating computational efficiency. 8 | 9 | This recipe is a fork of the WHAMandWHAMR recipe in the SpeechBrain library (required, see below). For more help and information on any SpeechBrain related issues: 10 | * https://speechbrain.github.io/ 11 | * https://github.com/speechbrain/speechbrain 12 | 13 | # Data and models 14 | Data: 15 | * WHAMR 16 | * WSJ0-2Mix 17 | 18 | Models: 19 | * Deformable Temporal Convolutional Networks 20 | * Temporal Convolutional Networks (Conv-TasNet without skip connections) 21 | 22 | # Running basic script 23 | First install SRMRpy and remaining required packages 24 | ``` 25 | git clone https://github.com/jfsantos/SRMRpy.git 26 | cd SRMRpy 27 | python setup.py install 28 | 29 | pip install -r requirements.txt 30 | ``` 31 | Then to run basic training of a DTCN model firstly change the ```data_folder``` hyperparameter in the ```separation/hparams/deformable/dtcn-whamr.yaml``` folder. Then run 32 | ``` 33 | cd separation 34 | HPARAMS=hparams/deformable/dtcn-whamr.yaml 35 | python train.py $HPARAMS 36 | ``` 37 | or if you wish to use multi GPU (recommended) run 38 | ``` 39 | python -m torch.distributed.launch --nproc_per_node=$NGPU train.py $HPARAMS --distributed_launch --distributed_backend='nccl' 40 | 41 | ``` 42 | replacing ```NGPU``` with the desired number of GPUs to use. 43 | In order to use dynamic mixing you will also need to change the ```base_folder_dm``` and ```rir_path``` hyperparameters, refer to https://github.com/speechbrain/speechbrain/blob/develop/recipes/WHAMandWHAMR/separation/README.md for more info on setting up dynamic mixing in SpeechBrain recipes. 44 | 45 | # Known issues 46 | * The main issue at present is mixed precision training with ```autocast``` enabled. The reason for this is unknown. At present we do not recommend trying to use this functionality. 47 | 48 | # Paper 49 | Please cite the following paper if you make use of any of this codebase: 50 | ``` 51 | @INPROCEEDINGS{dtcn23, 52 | author={Ravenscroft, William and Goetze, Stefan and Hain, Thomas}, 53 | booktitle={ICASSP 2023 - 2023 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 54 | title={Deformable Temporal Convolutional Networks for Monaural Noisy Reverberant Speech Separation}, 55 | year={2023}, 56 | volume={}, 57 | number={}, 58 | pages={1-5}, 59 | doi={10.1109/ICASSP49357.2023.10095230}} 60 | ``` 61 | -------------------------------------------------------------------------------- /meta/README.md: -------------------------------------------------------------------------------- 1 | 2 | # WSJ0-2mix dataset creation 3 | * The best way to create this dataset is using the original matlab script. This script and the associated meta data can be obtained through the following [link](https://www.dropbox.com/s/gg524noqvfm1t7e/create_mixtures_wsj023mix.zip?dl=1). 4 | * The dataset creation script assumes that the original WSJ0 files in the sphere format are already converted to .wav . 5 | 6 | 7 | # Dynamic Mixing: 8 | 9 | * This recipe supports dynamic mixing where the training data is dynamically created in order to obtain new utterance combinations during training. For this you need to have the WSJ0 dataset (available though LDC at `https://catalog.ldc.upenn.edu/LDC93S6A`). After this you need to run the preprocessing script under `recipes/WSJ0Mix/meta/preprocess_dynamic_mixing.py`. Then you need to specify the path to the output folder of this script through the `wsj0_tr` variable in the variable. This script converts the recordings into 8kHz, and runs the level normalization script. 10 | 11 | This script utilises octave to be able to call the matlab function `activlev.m` for level normalization. Depending on your octave version, you might observe the following warning: 12 | ``` 13 | error: called from graphics_toolkit at line 81 column 5 14 | graphics_toolkit: = toolkit is not available 15 | ``` 16 | This is in essence a warning and does not affect the results of this script. 17 | 18 | To run the script, you need to specify: 19 | * `--input_folder`: This should point to the original WSJ0 with .wav files. 20 | * `--output_folder`: This will be the output of the script. You need specify the path of this folder with the variable `wsj0_tr` variable in your .yaml file if you want to use dynamic mixing during training. 21 | -------------------------------------------------------------------------------- /meta/__pycache__/create_whamr_rirs.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jwr1995/DTCN/00d67bfb50935c087f8be200db67f9504b2f5cf7/meta/__pycache__/create_whamr_rirs.cpython-38.pyc -------------------------------------------------------------------------------- /meta/__pycache__/preprocess_dynamic_mixing.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jwr1995/DTCN/00d67bfb50935c087f8be200db67f9504b2f5cf7/meta/__pycache__/preprocess_dynamic_mixing.cpython-38.pyc -------------------------------------------------------------------------------- /meta/__pycache__/wham_room.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jwr1995/DTCN/00d67bfb50935c087f8be200db67f9504b2f5cf7/meta/__pycache__/wham_room.cpython-38.pyc -------------------------------------------------------------------------------- /meta/activlev.m: -------------------------------------------------------------------------------- 1 | function out = activlev(sp,fs,mode) 2 | %ACTIVLEV Measure active speech level as in ITU-T P.56 [LEV,AF,FSO]=(sp,FS,MODE) 3 | % 4 | %Usage: (1) lev=activlev(s,fs); % speech level in units of power 5 | % (2) db=activlev(s,fs,'d'); % speech level in dB 6 | % (3) s=activlev(s,fs,'n'); % normalize active level to 0 dB 7 | % 8 | %Inputs: sp is the speech signal (with better than 20dB SNR) 9 | % FS is the sample frequency in Hz (see also FSO below) 10 | % MODE is a combination of the following: 11 | % 0 - omit high pass filter completely (i.e. include DC) 12 | % 3 - high pass filter at 30 Hz instead of 200 Hz (but allows mains hum to pass) 13 | % 4 - high pass filter at 40 Hz instead of 200 Hz (but allows mains hum to pass) 14 | % 1 - use cheybyshev 1 filter 15 | % 2 - use chebyshev 2 filter (default) 16 | % e - use elliptic filter 17 | % h - omit low pass filter at 5.5, 12 or 18 kHz 18 | % w - use wideband filter frequencies: 70 Hz to 12 kHz 19 | % W - use ultra wideband filter frequencies: 30 Hz to 18 kHz 20 | % d - give outputs in dB rather than power 21 | % n - output a normalized speech signal as the first argument 22 | % N - output a normalized filtered speech signal as the first argument 23 | % l - give both active and long-term power levels 24 | % a - include A-weighting filter 25 | % i - include ITU-R-BS.468/ITU-T-J.16 weighting filter 26 | % z - do NOT zero-pad the signal by 0.35 s 27 | % 28 | %Outputs: 29 | % If the "n" option is specified, a speech signal normalized to 0dB will be given as 30 | % the first output followed by the other outputs. 31 | % LEV gives the speech level in units of power (or dB if mode='d') 32 | % if mode='l' is specified, LEV is a row vector with the "long term 33 | % level" as its second element (this is just the mean power) 34 | % AF is the activity factor (or duty cycle) in the range 0 to 1 35 | % FSO is a column vector of intermediate information that allows 36 | % you to process a speech signal in chunks. Thus: 37 | % fso=fs; 38 | % for i=1:inc:nsamp 39 | % [lev,af,fso]=activlev(sp(i:min(i+inc-1,nsamp)),fso,['z' mode]); 40 | % end 41 | % lev=activlev([],fso) 42 | % is equivalent to: 43 | % lev=activlev(sp(1:nsamp),fs,mode) 44 | % but is much slower. The two methods will not give identical results 45 | % because they will use slightly different thresholds. Note you need 46 | % the 'z' option for all calls except the last. 47 | % VAD is a boolean vector the same length as sp that acts as an approximate voice activity detector 48 | 49 | %For completeness we list here the contents of the FSO structure: 50 | % 51 | % ffs : sample frequency 52 | % fmd : mode string 53 | % nh : hangover time in samples 54 | % ae : smoothing filter coefs 55 | % abl: HP filter numerator and denominator coefficient 56 | % bh : LP filter numerator coefficient 57 | % ah : LP filter denominator coefficients 58 | % ze : smoothing filter state 59 | % zl : HP filter state 60 | % zh : LP filter state 61 | % zx : hangover max filter state 62 | % emax : maximum envelope exponent + 1 63 | % ssq : signal sum of squares 64 | % ns : number of signal samples 65 | % ss : sum of speech samples (not actually used here) 66 | % kc : cumulative occupancy counts 67 | % aw : weighting filter denominator 68 | % bw : weighting filter numerator 69 | % zw : weighting filter state 70 | % 71 | % This routine implements "Method B" from [1],[2] to calculate the active 72 | % speech level which is defined to be the speech energy divided by the 73 | % duration of speech activity. Speech is designated as "active" based on an 74 | % adaptive threshold applied to the smoothed rectified speech signal. A 75 | % bandpass filter is first applied to the input speech whose -0.25 dB points 76 | % are at 200 Hz & 5.5 kHz by default but this can be changed to 70 Hz & 5.5 kHz 77 | % or to 30 Hz & 18 kHz by specifying the 'w' or 'W' options; these 78 | % correspond respectively to Annexes B and C in [2]. 79 | % 80 | % References: 81 | % [1] ITU-T. Objective measurement of active speech level. Recommendation P.56, Mar. 1993. 82 | % [2] ITU-T. Objective measurement of active speech level. Recommendation P.56, Dec. 2011. 83 | 84 | % Copyright (C) Mike Brookes 2008-2016 85 | % Version: $Id: activlev.m 9407 2017-02-07 13:25:55Z dmb $ 86 | % 87 | % VOICEBOX is a MATLAB toolbox for speech processing. 88 | % Home page: http://www.ee.ic.ac.uk/hp/staff/dmb/voicebox/voicebox.html 89 | % 90 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 91 | % This program is free software; you can redistribute it and/or modify 92 | % it under the terms of the GNU General Public License as published by 93 | % the Free Software Foundation; either version 2 of the License, or 94 | % (at your option) any later version. 95 | % 96 | % This program is distributed in the hope that it will be useful, 97 | % but WITHOUT ANY WARRANTY; without even the implied warranty of 98 | % MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 99 | % GNU General Public License for more details. 100 | % 101 | % You can obtain a copy of the GNU General Public License from 102 | % http://www.gnu.org/copyleft/gpl.html or by writing to 103 | % Free Software Foundation, Inc.,675 Mass Ave, Cambridge, MA 02139, USA. 104 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 105 | 106 | dbstop ("asind", 1); 107 | 108 | persistent nbin thresh c25zp c15zp e5zp 109 | if isempty(nbin) 110 | nbin=20; % 60 dB range at 3dB per bin 111 | thresh=15.9; % threshold in dB 112 | % High pass s-domain zeros and poles of filters with passband ripple<0.25dB, stopband<-50dB, w0=1 113 | % w0=fzero(@ch2,0.5); [c2z,c2p,k]=cheby2(5,50,w0,'high','s'); 114 | % function v=ch2(w); [c2z,c2p,k]=cheby2(5,50,w,'high','s'); v= 20*log10(prod(abs(1i-c2z))/prod(abs(1i-c2p)))+0.25; 115 | c25zp=[0.37843443673309i 0.23388534441447i; -0.20640255179496+0.73942185906851i -0.54036889596392+0.45698784092898i]; 116 | c25zp=[[0; -0.66793268833792] c25zp conj(c25zp)]; 117 | % [c1z,c1p,c1k] = cheby1(5,0.25,1,'high','s'); 118 | c15zp=[-0.659002835294875+1.195798636925079i -0.123261821596263+0.947463030958881i]; 119 | c15zp=[zeros(1,5); -2.288586431066945 c15zp conj(c15zp)]; 120 | % [ez,ep,ek] = ellip(5,0.25,50,1,'high','s') 121 | e5zp=[0.406667680649209i 0.613849362744881i; -0.538736390607201+1.130245082677107i -0.092723126159100+0.958193646330194i]; 122 | e5zp=[[0; -1.964538608244084] e5zp conj(e5zp)]; 123 | % w=linspace(0.2,2,100); 124 | % figure(1); plot(w,20*log10(abs(freqs(real(poly(c15zp(1,:))),real(poly(c15zp(2,:))),w)))); title('Chebyshev 1'); 125 | % figure(2); plot(w,20*log10(abs(freqs(real(poly(c25zp(1,:))),real(poly(c25zp(2,:))),w)))); title('Chebyshev 2'); 126 | % figure(3); plot(w,20*log10(abs(freqs(real(poly(e5zp(1,:))),real(poly(e5zp(2,:))),w)))); title('Elliptic'); 127 | end 128 | 129 | if ~isstruct(fs) % no state vector given 130 | if nargin<3 131 | mode=' '; 132 | end 133 | fso.ffs=fs; % sample frequency 134 | 135 | ti=1/fs; 136 | g=exp(-ti/0.03); % pole position for envelope filter 137 | fso.ae=[1 -2*g g^2]/(1-g)^2; % envelope filter coefficients (DC gain = 1) 138 | fso.ze=zeros(2,1); 139 | fso.nh=ceil(0.2/ti)+1; % hangover time in samples 140 | fso.zx=-Inf; % initial value for maxfilt() 141 | fso.emax=-Inf; % maximum exponent 142 | fso.ns=0; 143 | fso.ssq=0; 144 | fso.ss=0; 145 | fso.kc=zeros(nbin,1); % cumulative occupancy counts 146 | % s-plane zeros and poles of high pass 5'th order filter -0.25dB at w=1 and -50dB stopband 147 | if any(mode=='1') 148 | szp=c15zp; % Chebyshev 1 149 | elseif any(mode=='e') 150 | szp=e5zp; % Elliptic 151 | else 152 | szp=c25zp; % Chebyshev 2 153 | end 154 | flh=[200 5500]; % default frequency range +- 0.25 dB 155 | if any(mode=='w') 156 | flh=[70 12000]; % super-wideband (Annex B of [2]) 157 | elseif any(mode=='W') 158 | flh=[30 18000]; % full band (Annex C of [2]) 159 | end 160 | if any(mode=='3') 161 | flh(1)=30; % force a 30 Hz HPF cutoff 162 | end 163 | if any(mode=='4') 164 | flh(1)=40; % force a 40 Hz HPF cutoff 165 | end 166 | if any(mode=='r') % included for backward compatibility 167 | mode=['0h' mode]; % abolish both filters 168 | elseif fs0 254 | aj=10*log10(fso.ssq*(fso.kc).^(-1)); 255 | % equivalent to cj=20*log10(sqrt(2).^(fso.emax-(1:nbin)-1)); 256 | cj=10*log10(2)*(fso.emax-(1:nbin)-1); % lower limit of bin j in dB 257 | mj=aj'-cj-thresh; 258 | % jj=find(mj*sign(mj(1))<=0); % Find threshold 259 | jj=find(mj(1:end-1)<0 & mj(2:end)>=0,1); % find +ve transition through threshold 260 | if isempty(jj) % if we never cross the threshold 261 | if mj(end)<=0 % if we end up below if 262 | jj=length(mj)-1; % take the threshold to be the bottom of the last (lowest) bin 263 | jf=1; 264 | else % if we are always above it 265 | jj=1; % take the threshold to be the bottom of the first (highest) bin 266 | jf=0; 267 | end 268 | else 269 | jf=1/(1-mj(jj+1)/mj(jj)); % fractional part of j using linear interpolation 270 | end 271 | lev=aj(jj)+jf*(aj(jj+1)-aj(jj)); % active level in decibels 272 | lp=10.^(lev/10); % active level in power 273 | if any(md=='d') % 'd' option -> output in dB 274 | lev=[lev 10*log10(fso.ssq/fso.ns)]; 275 | else % ~'d' option -> output in power 276 | lev=[lp fso.ssq/fso.ns]; 277 | end 278 | af=fso.ssq/(fso.ns*lp); 279 | else % if all samples are equal to zero 280 | af=0; 281 | if any(md=='d') % 'd' option -> output in dB 282 | lev=[-Inf -Inf]; % active level is 0 dB 283 | else % ~'d' option -> output in power 284 | lev=[0 0]; % active level is 0 power 285 | end 286 | end 287 | if all(md~='l') 288 | lev=lev(1); % only output the first element of lev unless 'l' option 289 | end 290 | end 291 | if nargout>3 292 | vad=maxfilt(s(1:nsp),1,fso.nh,1); 293 | vad=vad>(sqrt(lp)/10^(thresh/20)); 294 | end 295 | if ~nargout 296 | vad=maxfilt(s,1,fso.nh,1); 297 | vad=vad>(sqrt(lp)/10^(thresh/20)); 298 | levdb=10*log10(lp); 299 | %clf; 300 | %subplot(2,2,[1 2]); 301 | tax=(1:ns)/fso.ffs; 302 | %plot(tax,sp,'-y',tax,s,'-r',tax,(vad>0)*sqrt(lp),'-b'); 303 | %xlabel('Time (s)'); 304 | %title(sprintf('Active Level = %.2g dB, Activity = %.0f%% (ITU-T P.56)',levdb,100*af)); 305 | %axisenlarge([-1 -1 -1.4 -1.05]); 306 | %if nz>0 307 | % hold on 308 | % ylim=get(gca,'ylim'); 309 | % plot(tax(end-nz)*[1 1],ylim,':k'); 310 | % hold off 311 | %end 312 | %ylabel('Amplitude'); 313 | %legend('Signal','Smoothed envelope','VAD * Active-Level','Location','SouthEast'); 314 | %subplot(2,2,4); 315 | %plot(cj,repmat(levdb,nbin,1),'k:',cj,aj(:),'-b',cj,cj,'-r',levdb-thresh*ones(1,2),[levdb-thresh levdb],'-r'); 316 | %xlabel('Threshold (dB)'); 317 | %ylabel('Active Level (dB)'); 318 | %legend('Active Level','Speech>Thresh','Threshold','Location','NorthWest'); 319 | %texthvc(levdb-thresh,levdb-0.5*thresh,sprintf('%.1f dB ',thresh),'rmr'); 320 | %axisenlarge([-1 -1.05]); 321 | %ylim=get(gca,'ylim'); 322 | %set(gca,'ylim',[levdb-1.2*thresh max(ylim(2),levdb+1.9*thresh)]); 323 | %kch=filter([1 -1],1,kc); 324 | %subplot(2,2,3); 325 | %bar(5*log10(2)+cj(end:-1:1),kch(end:-1:1)*100/kc(end)); 326 | %set(gca,'xlim',[cj(end) cj(1)+10*log10(2)]); 327 | %ylim=get(gca,'ylim'); 328 | %hold on 329 | %plot(lev([1 1]),ylim,'k:',lev([1 1])-thresh,ylim,'r:'); 330 | %hold off 331 | %texthvc(lev(1),ylim(2),sprintf(' Act\n Lev'),'ltk'); 332 | %texthvc(lev(1)-thresh,ylim(2),sprintf('Threshold '),'rtr'); 333 | %xlabel('Frame power (dB)') 334 | %ylabel('% frames'); 335 | elseif any(md=='n') || any(md=='N') % output normalized speech waveform 336 | fsx=fso; % shift along other outputs 337 | fso=af; 338 | af=lev; 339 | if any(md=='n') 340 | sq=sp; % 'n' -> use unfiltered speech 341 | end 342 | if fsx.ns>0 && fsx.ssq>0 % if there has been any non-zero speech 343 | lev=sq(1:nsp)/sqrt(lp); 344 | else 345 | lev=sq(1:nsp); 346 | end 347 | end 348 | out = cat(1, squeeze(lev), af); 349 | 350 | -------------------------------------------------------------------------------- /meta/create_rirs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --nodes=1 3 | #SBATCH --mem=16000 4 | #SBATCH --time=24:00:00 5 | #SBATCH --mail-user=jwravenscroft1@sheffield.ac.uk 6 | 7 | #load the modules 8 | module load Anaconda3/5.3.0 9 | module load fosscuda/2019b # includes GCC 8.3 10 | module load imkl/2019.5.281-iimpi-2019b 11 | module load CMake/3.15.3-GCCcore-8.3.0 12 | #python environment 13 | source activate speechbrain 14 | 15 | srun --export=ALL python3 create_whamr_rirs.py --output-dir ~/fastdata/data/whamr/rirs 16 | -------------------------------------------------------------------------------- /meta/create_whamr_rirs.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from the original WHAMR script to obtain the Room Impulse ResponsesRoom Impulse Responses 3 | 4 | Authors 5 | * Cem Subakan 2021 6 | """ 7 | import os,sys 8 | sys.path.append("../") 9 | import pandas as pd 10 | import argparse 11 | import torchaudio 12 | from meta.wham_room import WhamRoom 13 | from scipy.signal import resample_poly 14 | import torch 15 | from speechbrain.pretrained.fetching import fetch 16 | from tqdm import tqdm 17 | import pyroomacoustics 18 | 19 | 20 | def create_rirs(output_dir, sr=8000): 21 | """ 22 | This function creates the room impulse responses from the WHAMR! dataset 23 | The implementation is based on the scripts from http://wham.whisper.ai/ 24 | 25 | Arguments: 26 | ------ 27 | output_dir (str) : directory for saving the RIRs 28 | sr (int) : sampling rate with which we save 29 | 30 | """ 31 | 32 | assert ( 33 | pyroomacoustics.__version__ == "0.3.1" 34 | ), "The pyroomacoustics version needs to be 0.3.1" 35 | 36 | os.makedirs(output_dir, exist_ok=True) 37 | 38 | metafilesdir = os.path.dirname(os.path.realpath(__file__)) 39 | filelist = [ 40 | "mix_2_spk_filenames_tr.csv", 41 | "mix_2_spk_filenames_cv.csv", 42 | "mix_2_spk_filenames_tt.csv", 43 | "reverb_params_tr.csv", 44 | "reverb_params_cv.csv", 45 | "reverb_params_tt.csv", 46 | ] 47 | 48 | savedir = os.path.join(metafilesdir, "data") 49 | for fl in filelist: 50 | if not os.path.exists(os.path.join(savedir, fl)): 51 | fetch( 52 | "metadata/" + fl, 53 | "speechbrain/sepformer-whamr", 54 | savedir=savedir, 55 | save_filename=fl, 56 | ) 57 | 58 | FILELIST_STUB = os.path.join( 59 | metafilesdir, "data", "mix_2_spk_filenames_{}.csv" 60 | ) 61 | 62 | SPLITS = ["tr","cv","tt"] 63 | [os.makedirs(os.path.join(output_dir,splt), exist_ok=True) for splt in SPLITS] 64 | 65 | reverb_param_stub = os.path.join( 66 | metafilesdir, "data", "reverb_params_{}.csv" 67 | ) 68 | 69 | for splt in SPLITS: 70 | 71 | wsjmix_path = FILELIST_STUB.format(splt) 72 | wsjmix_df = pd.read_csv(wsjmix_path) 73 | 74 | reverb_param_path = reverb_param_stub.format(splt) 75 | reverb_param_df = pd.read_csv(reverb_param_path) 76 | 77 | utt_ids = wsjmix_df.output_filename.values 78 | 79 | for output_name in tqdm(utt_ids): 80 | utt_row = reverb_param_df[ 81 | reverb_param_df["utterance_id"] == output_name 82 | ] 83 | room = WhamRoom( 84 | [ 85 | utt_row["room_x"].iloc[0], 86 | utt_row["room_y"].iloc[0], 87 | utt_row["room_z"].iloc[0], 88 | ], 89 | [ 90 | [ 91 | utt_row["micL_x"].iloc[0], 92 | utt_row["micL_y"].iloc[0], 93 | utt_row["mic_z"].iloc[0], 94 | ], 95 | [ 96 | utt_row["micR_x"].iloc[0], 97 | utt_row["micR_y"].iloc[0], 98 | utt_row["mic_z"].iloc[0], 99 | ], 100 | ], 101 | [ 102 | utt_row["s1_x"].iloc[0], 103 | utt_row["s1_y"].iloc[0], 104 | utt_row["s1_z"].iloc[0], 105 | ], 106 | [ 107 | utt_row["s2_x"].iloc[0], 108 | utt_row["s2_y"].iloc[0], 109 | utt_row["s2_z"].iloc[0], 110 | ], 111 | utt_row["T60"].iloc[0], 112 | ) 113 | room.generate_rirs() 114 | 115 | rir = room.rir_reverberant 116 | 117 | for i, mics in enumerate(rir): 118 | for j, source in enumerate(mics): 119 | h = resample_poly(source, sr, 16000) 120 | h_torch = torch.from_numpy(h).float().unsqueeze(0) 121 | 122 | torchaudio.save( 123 | os.path.join( 124 | output_dir, splt, "{}_{}_".format(i, j) + output_name, 125 | ), 126 | h_torch, 127 | sr, 128 | ) 129 | 130 | 131 | if __name__ == "__main__": 132 | parser = argparse.ArgumentParser() 133 | 134 | parser.add_argument( 135 | "--output-dir", 136 | type=str, 137 | required=True, 138 | help="The output directory for saving the rirs for random augmentation style", 139 | ) 140 | 141 | args = parser.parse_args() 142 | create_rirs(args.output_dir) 143 | -------------------------------------------------------------------------------- /meta/data/mix_2_spk_filenames_cv.csv: -------------------------------------------------------------------------------- 1 | /home/acp19jwr/.cache/huggingface/hub/bf79a01f39717176f2b2131adc41fef7c6a6847aa4f879022177855544a49df1.dac41d06ee279fc105232fec09a6752946e6401a219592cd0f872f1609817786 -------------------------------------------------------------------------------- /meta/data/mix_2_spk_filenames_tr.csv: -------------------------------------------------------------------------------- 1 | /home/acp19jwr/.cache/huggingface/hub/d66712c8564128820c39dac603f0efe6afff2a1dfb11cb0f3f234784a11f3855.12aef7777563247a4c0006433200525efc62ea0237f287b4d5a6038ad0683b08 -------------------------------------------------------------------------------- /meta/data/mix_2_spk_filenames_tt.csv: -------------------------------------------------------------------------------- 1 | /home/acp19jwr/.cache/huggingface/hub/1ff6399dc0beaaa26093602af0e265fdcad656c707378f549c5c2bc7d8c33a6b.a593fa99494f56efa1369d036037ff107781be492d75e678f95742b1399af836 -------------------------------------------------------------------------------- /meta/data/reverb_params_cv.csv: -------------------------------------------------------------------------------- 1 | /home/acp19jwr/.cache/huggingface/hub/1ebe96fd99eee7b98021f8bca1e15f066290c924e635222d9f99fee9ae002a07.9640814893e00a236328b16fe8686fa904a8d3e08b61f0779a9227ca0d34c4ca -------------------------------------------------------------------------------- /meta/data/reverb_params_tr.csv: -------------------------------------------------------------------------------- 1 | /home/acp19jwr/.cache/huggingface/hub/9a7b23c5176141793650bf653abdd40655c657c48aef768ce50d220044c8e917.3e0c675eaebacbdd962578ca99ce8419f6f824458aa1f34f0ae464e10a454321 -------------------------------------------------------------------------------- /meta/data/reverb_params_tt.csv: -------------------------------------------------------------------------------- 1 | /home/acp19jwr/.cache/huggingface/hub/01bf00122c5493045173ba9a6c05cda1012c85702b764f3fbae8d5626e1eb9ee.2de9387bf3df00fe4e02266d84f4a9ce60672364b9edc1ebad7013eeda26ee8c -------------------------------------------------------------------------------- /meta/maxfilt.m: -------------------------------------------------------------------------------- 1 | function [y,k,y0]=maxfilt(x,f,n,d,x0) 2 | %MAXFILT find max of an exponentially weighted sliding window [Y,K,Y0]=(X,F,nn,D,X0) 3 | % 4 | % Usage: (1) y=maxfilt(x) % maximum filter along first non-singleton dimension 5 | % (2) y=maxfilt(x,0.95) % use a forgetting factor of 0.95 (= time const of -1/log(0.95)=19.5 samples) 6 | % (3) Two equivalent methods (i.e. you can process x in chunks): 7 | % y=maxfilt([u v]); [yu,ku,x0)=maxfilt(u); 8 | % yv=maxfilt(v,[],[],[],x0); 9 | % y=[yu yv]; 10 | % 11 | % Inputs: X Vector or matrix of input data 12 | % F exponential forgetting factor in the range 0 (very forgetful) to 1 (no forgetting) 13 | % F=exp(-1/T) gives a time constant of T samples [default = 1] 14 | % n Length of sliding window [default = Inf (equivalent to [])] 15 | % D Dimension for work along [default = first non-singleton dimension] 16 | % X0 Initial values placed in front of the X data 17 | % 18 | % Outputs: Y Output matrix - same size as X 19 | % K Index array: Y=X(K). (Note that these value may be <=0 if input X0 is present) 20 | % Y0 Last nn-1 values (used to initialize a subsequent call to 21 | % maxfilt()) (or last output if n=Inf) 22 | % 23 | % This routine calculates y(p)=max(f^r*x(p-r), r=0:n-1) where x(r)=-inf for r<1 24 | % y=x(k) on output 25 | 26 | % Example: find all peaks in x that are not exceeded within +-w samples 27 | % w=4;m=100;x=rand(m,1);[y,k]=maxfilt(x,1,2*w+1);p=find(((1:m)-k)==w);plot(1:m,x,'-',p-w,x(p-w),'+') 28 | 29 | % Copyright (C) Mike Brookes 2003 30 | % Version: $Id: maxfilt.m 4054 2014-01-12 19:11:46Z dmb $ 31 | % 32 | % VOICEBOX is a MATLAB toolbox for speech processing. 33 | % Home page: http://www.ee.ic.ac.uk/hp/staff/dmb/voicebox/voicebox.html 34 | % 35 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 36 | % This program is free software; you can redistribute it and/or modify 37 | % it under the terms of the GNU General Public License as published by 38 | % the Free Software Foundation; either version 2 of the License, or 39 | % (at your option) any later version. 40 | % 41 | % This program is distributed in the hope that it will be useful, 42 | % but WITHOUT ANY WARRANTY; without even the implied warranty of 43 | % MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 44 | % GNU General Public License for more details. 45 | % 46 | % You can obtain a copy of the GNU General Public License from 47 | % http://www.gnu.org/copyleft/gpl.html or by writing to 48 | % Free Software Foundation, Inc.,675 Mass Ave, Cambridge, MA 02139, USA. 49 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 50 | 51 | s=size(x); 52 | if nargin<4 || isempty(d) 53 | d=find(s>1,1); % find first non-singleton dimension 54 | if isempty(d) 55 | d=1; 56 | end 57 | end 58 | if nargin>4 && numel(x0)>0 % initial values specified 59 | y=shiftdim(cat(d,x0,x),d-1); % concatenate x0 and x along d 60 | nx0=size(x0,d); % number of values added onto front of data 61 | else % dimension specified, d 62 | y=shiftdim(x,d-1); 63 | nx0=0; 64 | end 65 | s=size(y); 66 | s1=s(1); 67 | if nargin<3 || isempty(n) 68 | n0=Inf; 69 | else 70 | n0=max(n,1); 71 | end 72 | if nargin<2 || isempty(f) 73 | f=1; 74 | end 75 | nn=n0; 76 | if nargout>2 % we need to output the tail for next time 77 | if n0s1 92 | k=repmat((1:s1)',[1 s(2:end)]); 93 | if nn>1 94 | j=1; 95 | j2=1; 96 | while j>0 97 | g=f^j; 98 | m=find(y(j+1:s1,:)<=g*y(1:s1-j,:)); 99 | m=m+j*fix((m-1)/(s1-j)); 100 | y(m+j)=g*y(m); 101 | k(m+j)=k(m); 102 | j2=j2+j; 103 | j=min(j2,nn-j2); % j approximately doubles each iteration 104 | end 105 | end 106 | if nargout==0 107 | if nargin<3 108 | x=shiftdim(x); 109 | else 110 | x=shiftdim(x,d-1); 111 | end 112 | ss=min(prod(s(2:end)),5); % maximum of 5 plots 113 | %plot(1:s1,reshape(y(nx0+1:end,1:ss),s1,ss),'-r',1:s1,reshape(x(:,1:ss),s1,ss),'-b'); 114 | else 115 | if nargout>2 && n0==Inf && ny0==1 % if n0==Inf, we need to save the final output 116 | y0=reshape(y(end,:),sy0); 117 | y0=shiftdim(y0,ndims(x)-d+1); 118 | end 119 | if nx0>0 % pre-data specified, x0 120 | s(1)=s(1)-nx0; 121 | y=shiftdim(reshape(y(nx0+1:end,:),s),ndims(x)-d+1); 122 | k=shiftdim(reshape(k(nx0+1:end,:),s),ndims(x)-d+1)-nx0; 123 | else % no pre-data 124 | y=shiftdim(y,ndims(x)-d+1); 125 | k=shiftdim(k,ndims(x)-d+1); 126 | end 127 | end 128 | -------------------------------------------------------------------------------- /meta/preprocess_dynamic_mixing.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script allows to resample a folder which contains audio files. 3 | The files are parsed recursively. An exact copy of the folder is created, 4 | with same structure but contained resampled audio files. 5 | Resampling is performed by using sox through torchaudio. 6 | Author 7 | ------ 8 | Samuele Cornell, 2020 9 | """ 10 | 11 | import os 12 | import argparse 13 | from pathlib import Path 14 | import tqdm 15 | import torchaudio 16 | import glob 17 | 18 | # from oct2py import octave 19 | from scipy import signal 20 | import numpy as np 21 | import torch 22 | 23 | 24 | parser = argparse.ArgumentParser( 25 | "utility for resampling all audio files in a folder recursively" 26 | "It --input_folder to --output_folder and " 27 | "resamples all audio files with specified format to --fs." 28 | ) 29 | parser.add_argument("--input_folder", type=str, required=True) 30 | parser.add_argument("--output_folder", type=str, required=True) 31 | parser.add_argument( 32 | "--fs", type=str, default=8000, help="this is the target sampling frequency" 33 | ) 34 | parser.add_argument("--regex", type=str, default="**/*.wav") 35 | 36 | 37 | def resample_folder(input_folder, output_folder, fs, regex): 38 | """Resamples the wav files within an input folder. 39 | 40 | Arguments 41 | --------- 42 | input_folder : path 43 | Path of the folder to resample. 44 | output_folder : path 45 | Path of the output folder with the resampled data. 46 | fs : int 47 | Target sampling frequency. 48 | reg_exp: str 49 | Regular expression for search. 50 | """ 51 | # filedir = os.path.dirname(os.path.realpath(__file__)) 52 | # octave.addpath(filedir) 53 | # add the matlab functions to octave dir here 54 | 55 | files = glob.glob(os.path.join(input_folder, regex), recursive=True) 56 | for f in tqdm.tqdm(files): 57 | 58 | audio, fs_read = torchaudio.load(f) 59 | audio = audio[0].numpy() 60 | audio = signal.resample_poly(audio, fs, fs_read) 61 | 62 | # tmp = octave.activlev(audio.tolist(), fs, "n") 63 | # audio, _ = tmp[:-1].squeeze(), tmp[-1] 64 | 65 | peak = np.max(np.abs(audio)) 66 | audio = audio / peak 67 | audio = torch.from_numpy(audio).float() 68 | 69 | relative_path = os.path.join( 70 | Path(f).relative_to(Path(input_folder)).parent, 71 | Path(f).relative_to(Path(input_folder)).stem 72 | + "_peak_{}.wav".format(peak), 73 | ) 74 | 75 | os.makedirs( 76 | Path( 77 | os.path.join( 78 | output_folder, Path(f).relative_to(Path(input_folder)) 79 | ) 80 | ).parent, 81 | exist_ok=True, 82 | ) 83 | 84 | torchaudio.save( 85 | os.path.join(output_folder, relative_path), 86 | audio.reshape(1, -1), 87 | fs, 88 | ) 89 | 90 | 91 | if __name__ == "__main__": 92 | 93 | args = parser.parse_args() 94 | resample_folder( 95 | args.input_folder, args.output_folder, int(args.fs), args.regex 96 | ) 97 | -------------------------------------------------------------------------------- /meta/rir_constants.py: -------------------------------------------------------------------------------- 1 | NUM_BANDS = 4 2 | SNR_THRESH = -6.0 3 | PRE_NOISE_SECONDS = 2.0 4 | SAMPLERATE = 16000 5 | MAX_SAMPLE_AMP = 0.95 6 | MIN_SNR_DB = -3.0 7 | MAX_SNR_DB = 6.0 8 | PRE_NOISE_SAMPLES = PRE_NOISE_SECONDS * SAMPLERATE 9 | -------------------------------------------------------------------------------- /meta/wham_room.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pyroomacoustics as pra 3 | from pyroomacoustics.parameters import constants 4 | from scipy.signal import resample_poly 5 | 6 | 7 | class WhamRoom(pra.room.ShoeBox): 8 | """ 9 | This class is taken from the original WHAMR! scripts. 10 | The original script can be found in 11 | http://wham.whisper.ai/ 12 | 13 | This class is used to simulate the room-impulse-responses (RIRs) in the WHAMR dataset. 14 | """ 15 | 16 | def __init__( 17 | self, p, mics, s1, s2, T60, fs=16000, t0=0.0, sigma2_awgn=None 18 | ): 19 | 20 | self.T60 = T60 21 | self.max_rir_len = np.ceil(T60 * fs).astype(int) 22 | 23 | volume = p[0] * p[1] * p[2] 24 | surface_area = 2 * (p[0] * p[1] + p[0] * p[2] + p[1] * p[2]) 25 | absorption = ( 26 | 24 27 | * volume 28 | * np.log(10.0) 29 | / (constants.get("c") * surface_area * T60) 30 | ) 31 | 32 | # minimum max order to guarantee complete filter of length T60 33 | max_order = np.ceil(T60 * constants.get("c") / min(p)).astype(int) 34 | 35 | super().__init__( 36 | p, 37 | fs=fs, 38 | t0=t0, 39 | absorption=absorption, 40 | max_order=max_order, 41 | sigma2_awgn=sigma2_awgn, 42 | sources=None, 43 | mics=None, 44 | ) 45 | 46 | self.add_source(s1) 47 | self.add_source(s2) 48 | 49 | self.add_microphone_array(pra.MicrophoneArray(np.array(mics).T, fs)) 50 | 51 | def add_audio(self, s1, s2): 52 | self.sources[0].add_signal(s1) 53 | self.sources[1].add_signal(s2) 54 | 55 | def compute_rir(self): 56 | 57 | self.rir = [] 58 | self.visibility = None 59 | 60 | self.image_source_model() 61 | 62 | for m, mic in enumerate(self.mic_array.R.T): 63 | h = [] 64 | for s, source in enumerate(self.sources): 65 | h.append( 66 | source.get_rir( 67 | mic, self.visibility[s][m], self.fs, self.t0 68 | )[: self.max_rir_len] 69 | ) 70 | self.rir.append(h) 71 | 72 | def generate_rirs(self): 73 | 74 | original_max_order = self.max_order 75 | self.max_order = 0 76 | 77 | self.compute_rir() 78 | 79 | self.rir_anechoic = self.rir 80 | 81 | self.max_order = original_max_order 82 | 83 | self.compute_rir() 84 | 85 | self.rir_reverberant = self.rir 86 | 87 | def generate_audio(self, anechoic=False, fs=16000): 88 | 89 | if not self.rir: 90 | self.generate_rirs() 91 | if anechoic: 92 | self.rir = self.rir_anechoic 93 | else: 94 | self.rir = self.rir_reverberant 95 | audio_array = self.simulate(return_premix=True, recompute_rir=False) 96 | 97 | if type(fs) is not list: 98 | fs_array = [fs] 99 | else: 100 | fs_array = fs 101 | audio_out = [] 102 | for elem in fs_array: 103 | if type(elem) is str: 104 | elem = int(elem.replace("k", "000")) 105 | if elem != self.fs: 106 | assert self.fs % elem == 0 107 | audio_out.append( 108 | resample_poly(audio_array, elem, self.fs, axis=2) 109 | ) 110 | else: 111 | audio_out.append(audio_array) 112 | if type(fs) is not list: 113 | return audio_out[0] # array of shape (n_sources, n_mics, n_samples) 114 | else: 115 | return audio_out 116 | -------------------------------------------------------------------------------- /prepare_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author 3 | * Cem Subakan 2020 4 | * Will Ravenscroft 2021 5 | 6 | The .csv preperation functions for WHAMR 7 | """ 8 | 9 | import os, glob, csv, json 10 | import soundfile as sf 11 | 12 | set_map = {"tr":"train","cv":"valid","tt":"test"} 13 | 14 | def get_rir_paths(rir_path, sets=set_map.keys()): 15 | rir_directory = { 16 | set : 17 | { 18 | "s1" : 19 | { 20 | os.path.basename(filepath).replace("0_0_","") : filepath 21 | for filepath in glob.glob(os.path.join(rir_path,set,"0_0_*")) 22 | }, 23 | "s2" : 24 | { 25 | os.path.basename(filepath).replace("0_1_","") : filepath 26 | for filepath in glob.glob(os.path.join(rir_path,set,"0_1_*")) 27 | } 28 | } 29 | for set in sets 30 | } 31 | return rir_directory 32 | 33 | def get_meta_data(creation_path, wham_noise_path): 34 | """ 35 | creation_path = Path to data folder in WHAMR! creation scripts 36 | """ 37 | rir_csv_paths = { 38 | os.path.basename(fname).replace("reverb_params_","").replace(".csv","") : fname 39 | for fname in glob.glob(os.path.join(creation_path,"data","reverb_params*.csv")) 40 | } 41 | # mix_csv_paths = { 42 | # fname.replace("reverb_params_","").replace(".csv") : fname 43 | # for fname in glob.glob(os.path.join(creation_path,"data","reverb_params*.csv")) 44 | # } 45 | noise_csv_paths = { 46 | os.path.basename(fname).replace("mix_param_meta_","").replace(".csv","") : fname 47 | for fname in glob.glob(os.path.join(wham_noise_path,"metadata","mix_param_meta*.csv")) 48 | } 49 | 50 | sets = ["cv","tr","tt"] 51 | 52 | directory = {} 53 | 54 | for i, set in enumerate(sets): 55 | try: 56 | with open(rir_csv_paths[set],'r') as f: 57 | reader = csv.DictReader(f) 58 | for i, row in enumerate(reader): 59 | room_size = float(row['room_x'])*float(row['room_y'])*float(row['room_z']) 60 | directory[row["utterance_id"]]={ 61 | "t60":row["T60"], 62 | "room_size":room_size, 63 | "room_x":row["room_x"], 64 | "room_y":row["room_y"], 65 | "room_z":row["room_z"], 66 | "micL_x":row["micL_x"], 67 | "micL_y":row["micL_y"], 68 | "micR_x":row["micR_x"], 69 | "micR_y":row["micR_y"], 70 | "mic_z":row["mic_z"], 71 | "s1_x":row["s1_x"], 72 | "s1_y":row["s1_y"], 73 | "s1_z":row["s1_z"], 74 | "s2_x":row["s2_x"], 75 | "s2_y":row["s2_y"], 76 | "s2_z":row["s2_z"] 77 | } 78 | 79 | with open(noise_csv_paths[set],'r') as f: 80 | reader = csv.DictReader(f) 81 | for i, row in enumerate (reader): 82 | directory[row["utterance_id"]]["snr"]=float(row["target_speaker1_snr_db"]) 83 | except KeyError as e: 84 | rir_key = "Keys found (RIRs):"+str(rir_csv_paths.keys()) 85 | noise_key = "Keys found (noises):"+str(noise_csv_paths.keys()) 86 | raise KeyError(str(e)+'. '+rir_key+'. '+noise_key) 87 | 88 | return directory 89 | 90 | def get_transcriptions( 91 | wsj0_path, 92 | case='u', 93 | filters=["\.PERIOD","\.COMMA","\-HYPHEN","\-\-DASH","\\\"DOUBLE\-QUOTE"] 94 | ): #/path/to/...11.-1.1 etc. 95 | dot_files = (glob.glob(os.path.join(wsj0_path,"*/*/*/*.dot"))+ 96 | glob.glob(os.path.join(wsj0_path,"*/*/*/*/*.dot"))+ 97 | glob.glob(os.path.join(wsj0_path,"*/*/*/*/*/*.dot"))+ 98 | glob.glob(os.path.join(wsj0_path,"*/*/*/*/*/*/*.dot"))+ 99 | glob.glob(os.path.join(wsj0_path,"*/*/*/*/*/*/*/*.dot"))) 100 | 101 | test_file = "/share/mini1/data/audvis/pub/asr/studio/us/wsj/v2/wsj0/11-10.1/wsj0/transcrp/dots/si_tr_s/01v/01vo0300.dot" 102 | test_label = "01vo030q" 103 | assert test_file in dot_files 104 | 105 | dot_directory = {} 106 | for dot in dot_files: 107 | with open(dot,'r') as f: 108 | for line in f: 109 | try: 110 | line = line.replace(")","").replace("(","") 111 | utterance, label = line[:-9], line[-9:] 112 | except ValueError as e: 113 | raise ValueError("Can't split line \""+line+"\"") 114 | label = label.replace(")","").replace("\n","") 115 | dot_directory[label] = utterance.upper() if 'u' else utterance.lower() 116 | 117 | return dot_directory 118 | 119 | def prepare_wham_whamr_csv( 120 | datapath, 121 | savepath, 122 | skip_prep=False, 123 | fs=8000, 124 | mini=False, 125 | mix_folder="mix_both_reverb", 126 | target_condition="anechoic", 127 | set_types=["tr", "cv", "tt"], 128 | num_spks=2, 129 | alternate_path=None, 130 | creation_path=None, 131 | wham_noise_path=None, 132 | version="min", 133 | wsj0_path=None, # for transcriptions 134 | eval_original=False, 135 | meta_dump=False, 136 | use_rirs=False, 137 | rir_path=None, 138 | extended=False, 139 | savename="whamr_" 140 | ): 141 | """ 142 | Prepares the csv files for wham or whamr dataset 143 | 144 | Arguments: 145 | ---------- 146 | datapath (str) : path for the wsj0-mix dataset. 147 | savepath (str) : path where we save the csv file. 148 | skip_prep (bool): If True, skip data preparation 149 | """ 150 | 151 | if skip_prep: 152 | return 153 | 154 | if "whamr" in datapath: 155 | # if we want to train a model on the whamr dataset 156 | create_wham_whamr_csv(datapath, savepath, fs,mix_folder=mix_folder, 157 | alternate_path=alternate_path, num_spks=num_spks, target_condition=target_condition, 158 | version=version, creation_path=creation_path, wham_noise_path=wham_noise_path, 159 | wsj0_path=wsj0_path, eval_original=eval_original, meta_dump=meta_dump, rir_path=rir_path, 160 | use_rirs=use_rirs,set_types=set_types,extended=extended,savename=savename 161 | ) 162 | elif "wham" in datapath: 163 | # if we want to train a model on the original wham dataset 164 | create_wham_whamr_csv( 165 | datapath, savepath, fs, savename="whamorg_", add_reverb=False, mini=mini, 166 | mix_folder=mix_folder.replace("reverb","").replace("anechoic","") 167 | ) 168 | else: 169 | raise ValueError("Unsupported Dataset at: "+datapath) 170 | 171 | def create_wham_whamr_csv( 172 | datapath, 173 | savepath, 174 | fs, 175 | version="min", 176 | mix_folder="mix_both_reverb", 177 | target_condition="anechoic", 178 | savename="whamr_", 179 | set_types=["tr", "cv", "tt"], 180 | add_reverb=True, 181 | mini=False, 182 | num_spks=2, 183 | alternate_path=None, 184 | creation_path=None, 185 | wham_noise_path=None, 186 | wsj0_path=None, # for transcriptions 187 | eval_original=False, 188 | meta_dump = False, # grabs every kind of meta data for room simulation 189 | use_rirs = False, 190 | rir_path=None, 191 | extended=False 192 | ): 193 | """ 194 | This function creates the csv files to get the speechbrain data loaders for the whamr dataset. 195 | 196 | Arguments: 197 | datapath (str) : path for the wsj0-mix dataset. 198 | savepath (str) : path where we save the csv file 199 | fs (int) : the sampling rate 200 | version (str) : min or max 201 | savename (str) : the prefix to use for the .csv files 202 | set_types (list) : the sets to create 203 | """ 204 | 205 | if alternate_path==None: 206 | target_path = datapath 207 | else: 208 | target_path = alternate_path 209 | 210 | if creation_path != None and wham_noise_path != None: 211 | extra = True 212 | directory = get_meta_data(creation_path, wham_noise_path) 213 | else: 214 | extra = False 215 | 216 | if not wsj0_path == None: 217 | transcription_directory = get_transcriptions(wsj0_path) 218 | transcribe = True 219 | else: 220 | transcribe = False 221 | 222 | # if not rir_path == None: 223 | # rir_directory= get_rir_paths(rir_path) 224 | # use_rirs = True 225 | # else: 226 | # use_rirs = False 227 | 228 | if fs == 8000: 229 | sample_rate = "8k" 230 | elif fs == 16000: 231 | sample_rate = "16k" 232 | else: 233 | raise ValueError("Unsupported sampling rate") 234 | 235 | if add_reverb: 236 | mix = mix_folder+"/" 237 | s1 = "s1_"+target_condition+"/" 238 | s2 = "s2_"+target_condition+"/" 239 | else: 240 | mix = mix_folder+"/" 241 | s1 = "s1/" 242 | s2 = "s2/" 243 | 244 | if eval_original: 245 | s1_eval = "s1_anechoic/" 246 | s2_eval = "s2_anechoic/" 247 | 248 | 249 | for set_type in set_types: 250 | mix_path = os.path.join( 251 | datapath, "wav{}".format(sample_rate), version, set_type, mix, 252 | ) 253 | if eval_original and (set_type == "tt" or set_type == "cv"): 254 | s1_path = os.path.join( 255 | datapath, "wav{}".format(sample_rate), version, set_type, s1_eval, 256 | ) 257 | if num_spks==2: 258 | s2_path = os.path.join( 259 | datapath, "wav{}".format(sample_rate), version, set_type, s2_eval, 260 | ) 261 | else: 262 | s1_path = os.path.join( 263 | target_path, "wav{}".format(sample_rate), version, set_type, s1, 264 | ) 265 | if num_spks==2: 266 | s2_path = os.path.join( 267 | target_path, "wav{}".format(sample_rate), version, set_type, s2, 268 | ) 269 | 270 | noise_path = os.path.join( 271 | datapath, "wav{}".format(sample_rate), version, set_type, "noise/" 272 | ) 273 | # rir_path = os.path.join( 274 | # datapath, "wav{}".format(sample_rate), version, set_type, "rirs/" 275 | # ) 276 | 277 | files = os.listdir(mix_path) 278 | 279 | mix_fl_paths = [mix_path + fl for fl in files] 280 | s1_fl_paths = [s1_path + fl for fl in files] 281 | if num_spks==2: 282 | s2_fl_paths = [s2_path + fl for fl in files] 283 | noise_fl_paths = [noise_path + fl for fl in files] 284 | # rir_fl_paths = [rir_path + fl + ".t" for fl in files] 285 | 286 | # if not rir_path ==None: 287 | # rir_paths = rir_directory[set_type] 288 | 289 | if num_spks==1: 290 | csv_columns = [ 291 | "ID", 292 | "duration", 293 | "mix_wav", 294 | "mix_wav_format", 295 | "mix_wav_opts", 296 | "s1_wav", 297 | "s1_wav_format", 298 | "s1_wav_opts", 299 | "noise_wav", 300 | "noise_wav_format", 301 | "noise_wav_opts", 302 | # "rir_t", 303 | # "rir_format", 304 | # "rir_opts", 305 | "t60", 306 | "room_size", 307 | "snr", 308 | "s1_dot" 309 | ] 310 | else: 311 | csv_columns = [ 312 | "ID", 313 | "duration", 314 | "mix_wav", 315 | "mix_wav_format", 316 | "mix_wav_opts", 317 | "s1_wav", 318 | "s1_wav_format", 319 | "s1_wav_opts", 320 | "s2_wav", 321 | "s2_wav_format", 322 | "s2_wav_opts", 323 | "noise_wav", 324 | "noise_wav_format", 325 | "noise_wav_opts", 326 | # "rir_t", 327 | # "rir_format", 328 | # "rir_opts", 329 | "t60", 330 | "room_size", 331 | "snr", 332 | "s1_dot", 333 | "s2_dot" 334 | ] 335 | 336 | if meta_dump: 337 | extra_cols = [ 338 | "room_x", 339 | "room_y", 340 | "room_z", 341 | "micL_x", 342 | "micL_y", 343 | "micR_x", 344 | "micR_y", 345 | "mic_z", 346 | "s1_x", 347 | "s1_y", 348 | "s1_z", 349 | "s2_x", 350 | "s2_y", 351 | "s2_z" 352 | ] 353 | csv_columns = csv_columns + extra_cols 354 | 355 | if use_rirs: 356 | extra_cols = [ 357 | "s1_rir", 358 | "s2_rir" 359 | ] 360 | csv_columns = csv_columns + extra_cols 361 | 362 | with open( 363 | ( 364 | os.path.join(savepath, savename + set_type + ".csv") if not extended else os.path.join(savepath, savename + set_type + "_ext.csv") 365 | ), "w" 366 | ) as csvfile: 367 | writer = csv.DictWriter(csvfile, fieldnames=csv_columns) 368 | writer.writeheader() 369 | if mini and (num_spks==1): 370 | zipped = list(zip( 371 | mix_fl_paths, 372 | s1_fl_paths, 373 | noise_fl_paths, 374 | # rir_fl_paths, 375 | ))[:len(mix_fl_paths)//4] 376 | elif (not mini) and (num_spks==1): 377 | zipped = zip( 378 | mix_fl_paths, 379 | s1_fl_paths, 380 | noise_fl_paths, 381 | # rir_fl_paths, 382 | ) 383 | elif mini and (num_spks==2): 384 | zipped = list(zip( 385 | mix_fl_paths, 386 | s1_fl_paths, 387 | s2_fl_paths, 388 | noise_fl_paths, 389 | # rir_fl_paths, 390 | ))[:len(mix_fl_paths)//4] 391 | else: 392 | zipped = zip( 393 | mix_fl_paths, 394 | s1_fl_paths, 395 | s2_fl_paths, 396 | noise_fl_paths, 397 | # rir_fl_paths, 398 | ) 399 | for (i, packed,) in enumerate(zipped): 400 | if num_spks==1: 401 | mix_path, s1_path, noise_path = packed 402 | basename = os.path.basename(mix_path) 403 | if transcribe: 404 | s1_basename = os.path.basename(s1_path).replace(".wav","") 405 | s1_dot = transcription_directory[s1_basename] 406 | else: 407 | s1_dot=None 408 | row = { 409 | "ID": i, 410 | "duration": 1.0, 411 | "mix_wav": mix_path, 412 | "mix_wav_format": "wav", 413 | "mix_wav_opts": None, 414 | "s1_wav": s1_path, 415 | "s1_wav_format": "wav", 416 | "s1_wav_opts": None, 417 | "noise_wav": noise_path, 418 | "noise_wav_format": "wav", 419 | "noise_wav_opts": None, 420 | # "rir_t": rir_path, 421 | # "rir_format": ".t", 422 | # "rir_opts": None, 423 | "t60": directory[basename]["t60"] if extra else None, 424 | "room_size": directory[basename]["room_size"] if extra else None, 425 | "snr": directory[basename]["snr"] if extra else None, 426 | "s1_dot": s1_dot 427 | } 428 | else: 429 | mix_path, s1_path, s2_path, noise_path = packed 430 | basename = os.path.basename(mix_path) 431 | if transcribe: 432 | s1_basename = os.path.basename(s1_path).replace(".wav","") 433 | s1_dot = transcription_directory[s1_basename] 434 | s2_basename = os.path.basename(s2_path).replace(".wav","") 435 | s2_dot = transcription_directory[s2_basename] 436 | else: 437 | s1_dot=None 438 | s2_dot=None 439 | row = { 440 | "ID": i, 441 | "duration": 1.0, 442 | "mix_wav": mix_path, 443 | "mix_wav_format": "wav", 444 | "mix_wav_opts": None, 445 | "s1_wav": s1_path, 446 | "s1_wav_format": "wav", 447 | "s1_wav_opts": None, 448 | "s2_wav": s2_path, 449 | "s2_wav_format": "wav", 450 | "s2_wav_opts": None, 451 | "noise_wav": noise_path, 452 | "noise_wav_format": "wav", 453 | "noise_wav_opts": None, 454 | # "rir_t": rir_path, 455 | # "rir_format": ".t", 456 | # "rir_opts": None, 457 | "t60": directory[basename]["t60"] if extra else None, 458 | "room_size": directory[basename]["room_size"] if extra else None, 459 | "snr": directory[basename]["snr"] if extra else None, 460 | "s1_dot": s1_dot, 461 | "s2_dot": s2_dot, 462 | } 463 | 464 | if meta_dump: 465 | row["room_x"] = directory[basename]["room_x"] 466 | row["room_y"] = directory[basename]["room_y"] 467 | row["room_z"] = directory[basename]["room_z"] 468 | row["micL_x"] = directory[basename]["micL_x"] 469 | row["micL_y"] = directory[basename]["micL_y"] 470 | row["micR_x"] = directory[basename]["micR_x"] 471 | row["micR_y"] = directory[basename]["micR_y"] 472 | row["mic_z"] = directory[basename]["mic_z"] 473 | row["s1_x"] = directory[basename]["mic_z"] 474 | row["s1_y"] = directory[basename]["s1_y"] 475 | row["s1_z"] = directory[basename]["s1_z"] 476 | row["s2_x"] = directory[basename]["s2_x"] 477 | row["s2_y"] = directory[basename]["s2_y"] 478 | row["s2_z"] = directory[basename]["s2_z"] 479 | 480 | if use_rirs: 481 | # row["s1_rir"] = rir_paths["s1"][basename] 482 | # row["s2_rir"] = rir_paths["s2"][basename] 483 | row["s1_rir"] = s1_path.replace("s1_anechoic","s1_rir") 484 | if num_spks == 2: 485 | row["s2_rir"] = s2_path.replace("s2_anechoic","s2_rir") 486 | 487 | 488 | writer.writerow(row) 489 | 490 | def create_whamr_rir_csv(datapath, savepath): 491 | """ 492 | This function creates the csv files to get the data loaders for the whamr dataset. 493 | 494 | Arguments: 495 | datapath (str) : path for the whamr rirs. 496 | savepath (str) : path where we save the csv file 497 | """ 498 | 499 | csv_columns = ["ID", "duration", "wav", "wav_format", "wav_opts"] 500 | 501 | files = os.listdir(datapath) 502 | all_paths = [os.path.join(datapath, fl) for fl in files] 503 | 504 | with open(savepath + "/whamr_rirs.csv", "w") as csvfile: 505 | writer = csv.DictWriter(csvfile, fieldnames=csv_columns) 506 | writer.writeheader() 507 | for i, wav_path in enumerate(all_paths): 508 | 509 | row = { 510 | "ID": i, 511 | "duration": 2.0, 512 | "wav": wav_path, 513 | "wav_format": "wav", 514 | "wav_opts": None, 515 | } 516 | writer.writerow(row) 517 | 518 | def create_wham_whamr_json( 519 | datapath, 520 | savepath, 521 | fs, 522 | version="min", 523 | mix_folder="mix_both_reverb", 524 | target_condition="anechoic", 525 | savename="whamr_", 526 | set_types=["tr", "cv", "tt"], 527 | add_reverb=True, 528 | mini=False, 529 | num_spks=2, 530 | alternate_path=None, 531 | creation_path=None, 532 | wham_noise_path=None, 533 | wsj0_path=None, # for transcriptions 534 | rir_path=None, 535 | eval_original=False, 536 | ): 537 | """ 538 | This function creates the csv files to get the speechbrain data loaders for the whamr dataset. 539 | 540 | Arguments: 541 | datapath (str) : path for the wsj0-mix dataset. 542 | savepath (str) : path where we save the csv file 543 | fs (int) : the sampling rate 544 | version (str) : min or max 545 | savename (str) : the prefix to use for the .csv files 546 | set_types (list) : the sets to create 547 | """ 548 | 549 | if alternate_path==None: 550 | target_path = datapath 551 | else: 552 | target_path = alternate_path 553 | 554 | if creation_path != None and wham_noise_path != None: 555 | extra = True 556 | directory = get_meta_data(creation_path, wham_noise_path) 557 | else: 558 | extra = False 559 | 560 | if not wsj0_path == None: 561 | transcription_directory = get_transcriptions(wsj0_path) 562 | transcribe = True 563 | else: 564 | transcribe = False 565 | 566 | if fs == 8000: 567 | sample_rate = "8k" 568 | elif fs == 16000: 569 | sample_rate = "16k" 570 | else: 571 | raise ValueError("Unsupported sampling rate") 572 | 573 | if add_reverb: 574 | mix = mix_folder+"/" 575 | s1 = "s1_"+target_condition+"/" 576 | s2 = "s2_"+target_condition+"/" 577 | else: 578 | mix = mix_folder+"/" 579 | s1 = "s1/" 580 | s2 = "s2/" 581 | 582 | if eval_original: 583 | s1_eval = "s1_anechoic/" 584 | s2_eval = "s2_anechoic/" 585 | 586 | 587 | for set_type in set_types: 588 | mix_path = os.path.join( 589 | datapath, "wav{}".format(sample_rate), version, set_type, mix, 590 | ) 591 | if eval_original and (set_type == "tt" or set_type == "cv"): 592 | s1_path = os.path.join( 593 | datapath, "wav{}".format(sample_rate), version, set_type, s1_eval, 594 | ) 595 | if num_spks==2: 596 | s2_path = os.path.join( 597 | datapath, "wav{}".format(sample_rate), version, set_type, s2_eval, 598 | ) 599 | else: 600 | s1_path = os.path.join( 601 | target_path, "wav{}".format(sample_rate), version, set_type, s1, 602 | ) 603 | if num_spks==2: 604 | s2_path = os.path.join( 605 | target_path, "wav{}".format(sample_rate), version, set_type, s2, 606 | ) 607 | 608 | noise_path = os.path.join( 609 | datapath, "wav{}".format(sample_rate), version, set_type, "noise/" 610 | ) 611 | # rir_path = os.path.join( 612 | # datapath, "wav{}".format(sample_rate), version, set_type, "rirs/" 613 | # ) 614 | 615 | files = os.listdir(mix_path) 616 | 617 | mix_fl_paths = [mix_path + fl for fl in files] 618 | s1_fl_paths = [s1_path + fl for fl in files] 619 | if num_spks==2: 620 | s2_fl_paths = [s2_path + fl for fl in files] 621 | noise_fl_paths = [noise_path + fl for fl in files] 622 | # rir_fl_paths = [rir_path + fl + ".t" for fl in files] 623 | 624 | if num_spks==1: 625 | csv_columns = [ 626 | "ID", 627 | "duration", 628 | "mix_wav", 629 | "mix_wav_format", 630 | "mix_wav_opts", 631 | "s1_wav", 632 | "s1_wav_format", 633 | "s1_wav_opts", 634 | "noise_wav", 635 | "noise_wav_format", 636 | "noise_wav_opts", 637 | # "rir_t", 638 | # "rir_format", 639 | # "rir_opts", 640 | "t60", 641 | "room_size", 642 | "snr", 643 | "s1_dot" 644 | ] 645 | else: 646 | csv_columns = [ 647 | "ID", 648 | "duration", 649 | "mix_wav", 650 | "mix_wav_format", 651 | "mix_wav_opts", 652 | "s1_wav", 653 | "s1_wav_format", 654 | "s1_wav_opts", 655 | "s2_wav", 656 | "s2_wav_format", 657 | "s2_wav_opts", 658 | "noise_wav", 659 | "noise_wav_format", 660 | "noise_wav_opts", 661 | # "rir_t", 662 | # "rir_format", 663 | # "rir_opts", 664 | "t60", 665 | "room_size", 666 | "snr", 667 | "s1_dot", 668 | "s2_dot" 669 | ] 670 | 671 | with open( 672 | os.path.join(savepath, savename + set_type + ".csv"), "w" 673 | ) as csvfile: 674 | json_dict = {} 675 | if mini and (num_spks==1): 676 | zipped = list(zip( 677 | mix_fl_paths, 678 | s1_fl_paths, 679 | noise_fl_paths, 680 | # rir_fl_paths, 681 | ))[:len(mix_fl_paths)//4] 682 | elif (not mini) and (num_spks==1): 683 | zipped = zip( 684 | mix_fl_paths, 685 | s1_fl_paths, 686 | noise_fl_paths, 687 | # rir_fl_paths, 688 | ) 689 | elif mini and (num_spks==2): 690 | zipped = list(zip( 691 | mix_fl_paths, 692 | s1_fl_paths, 693 | s2_fl_paths, 694 | noise_fl_paths, 695 | # rir_fl_paths, 696 | ))[:len(mix_fl_paths)//4] 697 | else: 698 | zipped = zip( 699 | mix_fl_paths, 700 | s1_fl_paths, 701 | s2_fl_paths, 702 | noise_fl_paths, 703 | # rir_fl_paths, 704 | ) 705 | for (i, packed,) in enumerate(zipped): 706 | if num_spks==1: 707 | mix_path, s1_path, noise_path = packed 708 | basename = os.path.basename(mix_path) 709 | data, fs = sf.read(mix_path) 710 | duration = len(data)/fs 711 | if transcribe: 712 | s1_basename = os.path.basename(s1_path).replace(".wav","") 713 | s1_dot = transcription_directory[s1_basename] 714 | else: 715 | s1_dot=None 716 | row = { 717 | "ID": i, 718 | "duration": duration, 719 | "mix_wav": mix_path, 720 | "mix_wav_format": "wav", 721 | "mix_wav_opts": None, 722 | "s1_wav": s1_path, 723 | "s1_wav_format": "wav", 724 | "s1_wav_opts": None, 725 | "noise_wav": noise_path, 726 | "noise_wav_format": "wav", 727 | "noise_wav_opts": None, 728 | # "rir_t": rir_path, 729 | # "rir_format": ".t", 730 | # "rir_opts": None, 731 | "t60": directory[basename]["t60"] if extra else None, 732 | "room_size": directory[basename]["room_size"] if extra else None, 733 | "snr": directory[basename]["snr"] if extra else None, 734 | "s1_dot": s1_dot 735 | } 736 | else: 737 | mix_path, s1_path, s2_path, noise_path = packed 738 | basename = os.path.basename(mix_path) 739 | data, fs = sf.read(mix_path) 740 | duration = len(data)/fs 741 | if transcribe: 742 | s1_basename = os.path.basename(s1_path).replace(".wav","") 743 | s1_dot = transcription_directory[s1_basename] 744 | s2_basename = os.path.basename(s2_path).replace(".wav","") 745 | s2_dot = transcription_directory[s2_basename] 746 | else: 747 | s1_dot=None 748 | s2_dot=None 749 | row = { 750 | "ID": i, 751 | "duration": duration, 752 | "mix_wav": mix_path, 753 | "mix_wav_format": "wav", 754 | "mix_wav_opts": None, 755 | "s1_wav": s1_path, 756 | "s1_wav_format": "wav", 757 | "s1_wav_opts": None, 758 | "s2_wav": s2_path, 759 | "s2_wav_format": "wav", 760 | "s2_wav_opts": None, 761 | "noise_wav": noise_path, 762 | "noise_wav_format": "wav", 763 | "noise_wav_opts": None, 764 | # "rir_t": rir_path, 765 | # "rir_format": ".t", 766 | # "rir_opts": None, 767 | "t60": directory[basename]["t60"] if extra else None, 768 | "room_size": directory[basename]["room_size"] if extra else None, 769 | "snr": directory[basename]["snr"] if extra else None, 770 | "s1_dot": s1_dot, 771 | "s2_dot": s2_dot, 772 | } 773 | 774 | json_dict[basename.replace(".wav","")+"_#_"+target_condition] = row 775 | json_fname = os.path.join(savepath,set_type+".json") 776 | with open(json_fname, 'w') as f: 777 | json_string= json.dumps(json_dict, sort_keys=False, indent=4) 778 | f.write(json_string) 779 | 780 | 781 | def tokenizer_data_prep( 782 | datapath, 783 | savepath, 784 | fs, 785 | version="min", 786 | mix_folder="mix_both_reverb", 787 | target_condition="anechoic", 788 | savename="whamr_", 789 | set_types=["tr", "cv", "tt"], 790 | add_reverb=True, 791 | mini=False, 792 | num_spks=2, 793 | wsj0_path=None, # for transcriptions 794 | ): 795 | """ 796 | This function creates the csv files to get the speechbrain data loaders for the whamr dataset. 797 | 798 | Arguments: 799 | datapath (str) : path for the wsj0-mix dataset. 800 | savepath (str) : path where we save the csv file 801 | fs (int) : the sampling rate 802 | version (str) : min or max 803 | savename (str) : the prefix to use for the .csv files 804 | set_types (list) : the sets to create 805 | """ 806 | 807 | os.makedirs(savepath, exist_ok=True) 808 | 809 | if not wsj0_path == None: 810 | transcription_directory = get_transcriptions(wsj0_path) 811 | transcribe = True 812 | else: 813 | transcribe = False 814 | 815 | if fs == 8000: 816 | sample_rate = "8k" 817 | elif fs == 16000: 818 | sample_rate = "16k" 819 | else: 820 | raise ValueError("Unsupported sampling rate") 821 | 822 | if add_reverb: 823 | mix = mix_folder+"/" 824 | s1 = "s1_"+target_condition+"/" 825 | s2 = "s2_"+target_condition+"/" 826 | else: 827 | mix = mix_folder+"/" 828 | s1 = "s1/" 829 | s2 = "s2/" 830 | 831 | 832 | for set_type in set_types: 833 | mix_path = os.path.join( 834 | datapath, "wav{}".format(sample_rate), version, set_type, mix, 835 | ) 836 | 837 | s1_path = os.path.join( 838 | datapath, "wav{}".format(sample_rate), version, set_type, s1, 839 | ) 840 | if num_spks==2: 841 | s2_path = os.path.join( 842 | datapath, "wav{}".format(sample_rate), version, set_type, s2, 843 | ) 844 | 845 | noise_path = os.path.join( 846 | datapath, "wav{}".format(sample_rate), version, set_type, "noise/" 847 | ) 848 | # rir_path = os.path.join( 849 | # datapath, "wav{}".format(sample_rate), version, set_type, "rirs/" 850 | # ) 851 | 852 | files = os.listdir(mix_path) 853 | 854 | mix_fl_paths = [mix_path + fl for fl in files] 855 | s1_fl_paths = [s1_path + fl for fl in files] 856 | if num_spks==2: 857 | s2_fl_paths = [s2_path + fl for fl in files] 858 | noise_fl_paths = [noise_path + fl for fl in files] 859 | # rir_fl_paths = [rir_path + fl + ".t" for fl in files] 860 | 861 | 862 | with open( 863 | os.path.join(savepath, set_map[set_type] + ".json"), "w" 864 | ) as jsonfile: 865 | json_dict = {} 866 | if mini and (num_spks==1): 867 | zipped = list(zip( 868 | mix_fl_paths, 869 | s1_fl_paths, 870 | noise_fl_paths, 871 | # rir_fl_paths, 872 | ))[:len(mix_fl_paths)//4] 873 | elif (not mini) and (num_spks==1): 874 | zipped = zip( 875 | mix_fl_paths, 876 | s1_fl_paths, 877 | noise_fl_paths, 878 | # rir_fl_paths, 879 | ) 880 | elif mini and (num_spks==2): 881 | zipped = list(zip( 882 | mix_fl_paths, 883 | s1_fl_paths, 884 | s2_fl_paths, 885 | noise_fl_paths, 886 | # rir_fl_paths, 887 | ))[:len(mix_fl_paths)//4] 888 | else: 889 | zipped = zip( 890 | mix_fl_paths, 891 | s1_fl_paths, 892 | s2_fl_paths, 893 | noise_fl_paths, 894 | # rir_fl_paths, 895 | ) 896 | for (i, packed,) in enumerate(zipped): 897 | if num_spks==1: 898 | mix_path, s1_path, noise_path = packed 899 | basename = os.path.basename(mix_path) 900 | data, fs = sf.read(mix_path) 901 | duration = len(data)/fs 902 | if transcribe: 903 | s1_basename = os.path.basename(s1_path).replace(".wav","") 904 | s1_dot = transcription_directory[s1_basename] 905 | else: 906 | s1_dot=None 907 | json_dict[basename.replace(".wav","")+"#s1"] = { 908 | "length": duration, 909 | "wav": s1_path, 910 | "words": s1_dot 911 | } 912 | else: 913 | mix_path, s1_path, s2_path, noise_path = packed 914 | basename = os.path.basename(mix_path) 915 | data, fs = sf.read(mix_path) 916 | duration = len(data)/fs 917 | try: 918 | if transcribe: 919 | s1_key, _, s2_key, _ = basename.replace(".wav","").split("_") 920 | s1_dot = transcription_directory[s1_key] 921 | s2_dot = transcription_directory[s2_key] 922 | else: 923 | s1_dot=None 924 | s2_dot=None 925 | except KeyError as e: 926 | err_str = "Error with key "+str(e)+". Available keys: "+str(list(transcription_directory.keys())[:20])+"..." 927 | raise KeyError(err_str) 928 | 929 | json_dict[basename.replace(".wav","")+"#s1"] = { 930 | "length": duration, 931 | "wav": s1_path, 932 | "words": s1_dot 933 | } 934 | json_dict[basename.replace(".wav","")+"#s2"] = { 935 | "length": duration, 936 | "wav": s2_path, 937 | "words": s2_dot 938 | } 939 | json_string= json.dumps(json_dict, sort_keys=False, indent=4) 940 | try: 941 | jsonfile.write(json_string) 942 | except ValueError as e: 943 | raise ValueError("Error writing json dump "+json_string) 944 | 945 | def prepare_multi_whamr_csv( 946 | datapath, 947 | savepath, 948 | skip_prep=False, 949 | fs=8000, 950 | mini=False, 951 | mix_folder="mix_both_reverb", 952 | target_folders=["s1_anechoic","s2_anechoic"], 953 | set_types=["tr", "cv", "tt"], 954 | num_spks=2, 955 | alternate_path=None, 956 | creation_path=None, 957 | wham_noise_path=None, 958 | version="min", 959 | wsj0_path=None, # for transcriptions 960 | eval_original=False, 961 | meta_dump=False, 962 | use_rirs=False, 963 | rir_path=None, 964 | extended=False, 965 | savename="whamr_" 966 | ): 967 | """ 968 | Prepares the csv files for wham or whamr dataset 969 | 970 | Arguments: 971 | ---------- 972 | datapath (str) : path for the wsj0-mix dataset. 973 | savepath (str) : path where we save the csv file. 974 | skip_prep (bool): If True, skip data preparation 975 | """ 976 | 977 | if skip_prep: 978 | return 979 | 980 | if "whamr" in datapath: 981 | # if we want to train a model on the whamr dataset 982 | create_multi_whamr_csv(datapath, savepath, fs,mix_folder=mix_folder, 983 | alternate_path=alternate_path, num_spks=num_spks, target_folders=target_folders, 984 | version=version, creation_path=creation_path, wham_noise_path=wham_noise_path, 985 | wsj0_path=wsj0_path, eval_original=eval_original, meta_dump=meta_dump, rir_path=rir_path, 986 | use_rirs=use_rirs,set_types=set_types,extended=extended,savename=savename 987 | ) 988 | elif "wham" in datapath: 989 | # if we want to train a model on the original wham dataset 990 | create_wham_whamr_csv( 991 | datapath, savepath, fs, savename="whamorg_", add_reverb=False, mini=mini, 992 | mix_folder=mix_folder.replace("reverb","").replace("anechoic","") 993 | ) 994 | else: 995 | raise ValueError("Unsupported Dataset at: "+datapath) 996 | 997 | 998 | def create_multi_whamr_csv( 999 | datapath, 1000 | savepath, 1001 | fs, 1002 | version="min", 1003 | mix_folder="mix_both_reverb", 1004 | target_folders=["s1_anechoic","s2_anechoic"], 1005 | savename="whamr_", 1006 | set_types=["tr", "cv", "tt"], 1007 | add_reverb=True, 1008 | mini=False, 1009 | num_spks=2, 1010 | alternate_path=None, 1011 | creation_path=None, 1012 | wham_noise_path=None, 1013 | wsj0_path=None, # for transcriptions 1014 | eval_original=False, 1015 | meta_dump = False, # grabs every kind of meta data for room simulation 1016 | use_rirs = False, 1017 | rir_path=None, 1018 | extended=False 1019 | ): 1020 | """ 1021 | This function creates the csv files to get the speechbrain data loaders for the whamr dataset. 1022 | 1023 | Arguments: 1024 | datapath (str) : path for the wsj0-mix dataset. 1025 | savepath (str) : path where we save the csv file 1026 | fs (int) : the sampling rate 1027 | version (str) : min or max 1028 | savename (str) : the prefix to use for the .csv files 1029 | set_types (list) : the sets to create 1030 | """ 1031 | 1032 | if alternate_path==None: 1033 | target_path = datapath 1034 | else: 1035 | target_path = alternate_path 1036 | 1037 | if creation_path != None and wham_noise_path != None: 1038 | extra = True 1039 | directory = get_meta_data(creation_path, wham_noise_path) 1040 | else: 1041 | extra = False 1042 | 1043 | if not wsj0_path == None: 1044 | transcription_directory = get_transcriptions(wsj0_path) 1045 | transcribe = True 1046 | else: 1047 | transcribe = False 1048 | 1049 | # if not rir_path == None: 1050 | # rir_directory= get_rir_paths(rir_path) 1051 | # use_rirs = True 1052 | # else: 1053 | # use_rirs = False 1054 | 1055 | if fs == 8000: 1056 | sample_rate = "8k" 1057 | elif fs == 16000: 1058 | sample_rate = "16k" 1059 | else: 1060 | raise ValueError("Unsupported sampling rate") 1061 | 1062 | mix = mix_folder+"/" 1063 | speakers = [folder+"/" for folder in target_folders] 1064 | 1065 | 1066 | if eval_original: 1067 | s1_eval = "s1_anechoic/" 1068 | s2_eval = "s2_anechoic/" 1069 | 1070 | 1071 | for set_type in set_types: 1072 | mix_path = os.path.join( 1073 | datapath, "wav{}".format(sample_rate), version, set_type, mix, 1074 | ) 1075 | speaker_paths = [os.path.join( 1076 | target_path, "wav{}".format(sample_rate), version, set_type, speaker, 1077 | ) for speaker in speakers] 1078 | 1079 | noise_path = os.path.join( 1080 | datapath, "wav{}".format(sample_rate), version, set_type, "noise/" 1081 | ) 1082 | 1083 | files = os.listdir(mix_path) 1084 | 1085 | mix_fl_paths = [mix_path + fl for fl in files] 1086 | speaker_fl_paths = [ 1087 | [speaker_path + fl for fl in files] for speaker_path in speaker_paths 1088 | ] 1089 | noise_fl_paths = [noise_path + fl for fl in files] 1090 | 1091 | csv_columns = [ 1092 | "ID", 1093 | "duration", 1094 | "mix_wav", 1095 | "mix_wav_format", 1096 | "mix_wav_opts"] 1097 | for i in range(len(speaker_paths)): 1098 | pre_spk = "s"+str(i+1) 1099 | spk_cols = [pre_spk+"_wav",pre_spk+"_wav_format",pre_spk+"_wav_opts"] 1100 | csv_columns = csv_columns+spk_cols 1101 | csv_columns = csv_columns + ["noise_wav", 1102 | "noise_wav_format", 1103 | "noise_wav_opts", 1104 | "t60", 1105 | "room_size", 1106 | "snr", 1107 | ] 1108 | 1109 | if meta_dump: 1110 | extra_cols = [ 1111 | "room_x", 1112 | "room_y", 1113 | "room_z", 1114 | "micL_x", 1115 | "micL_y", 1116 | "micR_x", 1117 | "micR_y", 1118 | "mic_z", 1119 | ] 1120 | for i in range(len(speaker_paths)): 1121 | pre_spk = "s"+str(i+1) 1122 | spk_cols = [pre_spk+"_x",pre_spk+"_y",pre_spk+"_z"] 1123 | extra_cols = extra_cols+spk_cols 1124 | csv_columns = csv_columns + extra_cols 1125 | 1126 | if use_rirs: 1127 | extra_cols = ["s"+str(i+1)+"_rir" for i in range(len(speaker_paths))] 1128 | csv_columns = csv_columns + extra_cols 1129 | 1130 | with open( 1131 | ( 1132 | os.path.join(savepath, savename + set_type + ".csv") if not extended else os.path.join(savepath, savename + set_type + "_ext.csv") 1133 | ), "w" 1134 | ) as csvfile: 1135 | writer = csv.DictWriter(csvfile, fieldnames=csv_columns) 1136 | writer.writeheader() 1137 | if mini: 1138 | zipped = list(zip( 1139 | mix_fl_paths, 1140 | *speaker_fl_paths, 1141 | noise_fl_paths, 1142 | # rir_fl_paths, 1143 | ))[:len(mix_fl_paths)//4] 1144 | elif (not mini): 1145 | zipped = zip( 1146 | mix_fl_paths, 1147 | *speaker_fl_paths, 1148 | noise_fl_paths, 1149 | # rir_fl_paths, 1150 | ) 1151 | 1152 | for (i, packed,) in enumerate(zipped): 1153 | mix_path = packed[0] 1154 | speaker_paths_pack = packed[1:-1] 1155 | noise_path = packed[-1] 1156 | basename = os.path.basename(mix_path) 1157 | 1158 | row = { 1159 | "ID": i, 1160 | "duration": 1.0, 1161 | "mix_wav": mix_path, 1162 | "mix_wav_format": "wav", 1163 | "mix_wav_opts": None} 1164 | for i in range(len(speaker_paths)): 1165 | pre_spk="s"+str(i+1) 1166 | spk_dict = {pre_spk+"_wav": speaker_paths_pack[i], 1167 | pre_spk+"_wav_format": "wav", 1168 | pre_spk+"_wav_opts": None} 1169 | row = {**row, **spk_dict} 1170 | 1171 | row = {**row,**{"noise_wav": noise_path, 1172 | "noise_wav_format": "wav", 1173 | "noise_wav_opts": None, 1174 | "t60": directory[basename]["t60"] if extra else None, 1175 | "room_size": directory[basename]["room_size"] if extra else None, 1176 | "snr": directory[basename]["snr"] if extra else None, 1177 | }} 1178 | 1179 | if meta_dump: 1180 | row["room_x"] = directory[basename]["room_x"] 1181 | row["room_y"] = directory[basename]["room_y"] 1182 | row["room_z"] = directory[basename]["room_z"] 1183 | row["micL_x"] = directory[basename]["micL_x"] 1184 | row["micL_y"] = directory[basename]["micL_y"] 1185 | row["micR_x"] = directory[basename]["micR_x"] 1186 | row["micR_y"] = directory[basename]["micR_y"] 1187 | row["mic_z"] = directory[basename]["mic_z"] 1188 | for i in range(len(speaker_paths)): 1189 | row["s"+(i+1)+"_x"] = directory[basename]["s"+(i+1)+"_x"] 1190 | row["s"+(i+1)+"_y"] = directory[basename]["s"+(i+1)+"_y"] 1191 | row["s"+(i+1)+"_z"] = directory[basename]["s"+(i+1)+"_z"] 1192 | 1193 | # if use_rirs: 1194 | # row["s1_rir"] = s1_path.replace("s1_anechoic","s1_rir") 1195 | # if num_spks == 2: 1196 | # row["s2_rir"] = s2_path.replace("s2_anechoic","s2_rir") 1197 | 1198 | # print(row) 1199 | writer.writerow(row) 1200 | 1201 | 1202 | ###### WSJ0MIX ######## 1203 | def prepare_wsjmix_csv( 1204 | datapath, 1205 | savepath, 1206 | n_spks=2, 1207 | skip_prep=False, 1208 | librimix_addnoise=False, 1209 | fs=8000, 1210 | ): 1211 | """ 1212 | Prepared wsj2mix if n_spks=2 and wsj3mix if n_spks=3. 1213 | 1214 | Arguments: 1215 | ---------- 1216 | datapath (str) : path for the wsj0-mix dataset. 1217 | savepath (str) : path where we save the csv file. 1218 | n_spks (int): number of speakers 1219 | skip_prep (bool): If True, skip data preparation 1220 | librimix_addnoise: If True, add whamnoise to librimix datasets 1221 | """ 1222 | 1223 | if skip_prep: 1224 | return 1225 | 1226 | 1227 | if n_spks == 2: 1228 | assert ( 1229 | "2mix" in datapath 1230 | ), "Inconsistent number of speakers and datapath" 1231 | create_wsj_csv(datapath, savepath) 1232 | elif n_spks == 3: 1233 | assert ( 1234 | "3mix" in datapath 1235 | ), "Inconsistent number of speakers and datapath" 1236 | create_wsj_csv_3spks(datapath, savepath) 1237 | else: 1238 | raise ValueError("Unsupported Number of Speakers") 1239 | 1240 | def create_wsj_csv(datapath, savepath): 1241 | """ 1242 | This function creates the csv files to get the speechbrain data loaders for the wsj0-2mix dataset. 1243 | 1244 | Arguments: 1245 | datapath (str) : path for the wsj0-mix dataset. 1246 | savepath (str) : path where we save the csv file 1247 | """ 1248 | for set_type in ["tr", "cv", "tt"]: 1249 | mix_path = os.path.join(datapath, "wav8k/min/" + set_type + "/mix/") 1250 | s1_path = os.path.join(datapath, "wav8k/min/" + set_type + "/s1/") 1251 | s2_path = os.path.join(datapath, "wav8k/min/" + set_type + "/s2/") 1252 | 1253 | files = os.listdir(mix_path) 1254 | 1255 | mix_fl_paths = [mix_path + fl for fl in files] 1256 | s1_fl_paths = [s1_path + fl for fl in files] 1257 | s2_fl_paths = [s2_path + fl for fl in files] 1258 | 1259 | csv_columns = [ 1260 | "ID", 1261 | "duration", 1262 | "mix_wav", 1263 | "mix_wav_format", 1264 | "mix_wav_opts", 1265 | "s1_wav", 1266 | "s1_wav_format", 1267 | "s1_wav_opts", 1268 | "s2_wav", 1269 | "s2_wav_format", 1270 | "s2_wav_opts", 1271 | ] 1272 | 1273 | with open(savepath + "/wsj_" + set_type + ".csv", "w") as csvfile: 1274 | writer = csv.DictWriter(csvfile, fieldnames=csv_columns) 1275 | writer.writeheader() 1276 | for i, (mix_path, s1_path, s2_path) in enumerate( 1277 | zip(mix_fl_paths, s1_fl_paths, s2_fl_paths) 1278 | ): 1279 | 1280 | row = { 1281 | "ID": i, 1282 | "duration": 1.0, 1283 | "mix_wav": mix_path, 1284 | "mix_wav_format": "wav", 1285 | "mix_wav_opts": None, 1286 | "s1_wav": s1_path, 1287 | "s1_wav_format": "wav", 1288 | "s1_wav_opts": None, 1289 | "s2_wav": s2_path, 1290 | "s2_wav_format": "wav", 1291 | "s2_wav_opts": None, 1292 | } 1293 | writer.writerow(row) 1294 | 1295 | 1296 | def create_wsj_csv_3spks(datapath, savepath): 1297 | """ 1298 | This function creates the csv files to get the speechbrain data loaders for the wsj0-3mix dataset. 1299 | 1300 | Arguments: 1301 | datapath (str) : path for the wsj0-mix dataset. 1302 | savepath (str) : path where we save the csv file 1303 | """ 1304 | for set_type in ["tr", "cv", "tt"]: 1305 | mix_path = os.path.join(datapath, "wav8k/min/" + set_type + "/mix/") 1306 | s1_path = os.path.join(datapath, "wav8k/min/" + set_type + "/s1/") 1307 | s2_path = os.path.join(datapath, "wav8k/min/" + set_type + "/s2/") 1308 | s3_path = os.path.join(datapath, "wav8k/min/" + set_type + "/s3/") 1309 | 1310 | files = os.listdir(mix_path) 1311 | 1312 | mix_fl_paths = [mix_path + fl for fl in files] 1313 | s1_fl_paths = [s1_path + fl for fl in files] 1314 | s2_fl_paths = [s2_path + fl for fl in files] 1315 | s3_fl_paths = [s3_path + fl for fl in files] 1316 | 1317 | csv_columns = [ 1318 | "ID", 1319 | "duration", 1320 | "mix_wav", 1321 | "mix_wav_format", 1322 | "mix_wav_opts", 1323 | "s1_wav", 1324 | "s1_wav_format", 1325 | "s1_wav_opts", 1326 | "s2_wav", 1327 | "s2_wav_format", 1328 | "s2_wav_opts", 1329 | "s3_wav", 1330 | "s3_wav_format", 1331 | "s3_wav_opts", 1332 | ] 1333 | 1334 | with open(savepath + "/wsj_" + set_type + ".csv", "w") as csvfile: 1335 | writer = csv.DictWriter(csvfile, fieldnames=csv_columns) 1336 | writer.writeheader() 1337 | for i, (mix_path, s1_path, s2_path, s3_path) in enumerate( 1338 | zip(mix_fl_paths, s1_fl_paths, s2_fl_paths, s3_fl_paths) 1339 | ): 1340 | 1341 | row = { 1342 | "ID": i, 1343 | "duration": 1.0, 1344 | "mix_wav": mix_path, 1345 | "mix_wav_format": "wav", 1346 | "mix_wav_opts": None, 1347 | "s1_wav": s1_path, 1348 | "s1_wav_format": "wav", 1349 | "s1_wav_opts": None, 1350 | "s2_wav": s2_path, 1351 | "s2_wav_format": "wav", 1352 | "s2_wav_opts": None, 1353 | "s3_wav": s3_path, 1354 | "s3_wav_format": "wav", 1355 | "s3_wav_opts": None, 1356 | } 1357 | writer.writerow(row) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | speechbrain 2 | dc1d==0.0.2 3 | pysepm 4 | pesq 5 | thop 6 | mir-eval==0.6 7 | pyroomacoustics==0.3.1 8 | -------------------------------------------------------------------------------- /separation/dynamic_mixing.py: -------------------------------------------------------------------------------- 1 | import speechbrain as sb 2 | import numpy as np 3 | import torch 4 | import torchaudio 5 | import glob 6 | import os 7 | from pathlib import Path 8 | import random 9 | from speechbrain.processing.signal_processing import rescale 10 | from speechbrain.dataio.batch import PaddedBatch 11 | 12 | """ 13 | The functions to implement Dynamic Mixing For SpeechSeparation 14 | 15 | Authors 16 | * Samuele Cornell 2021 17 | * Cem Subakan 2021 18 | """ 19 | 20 | 21 | def build_spk_hashtable(hparams): 22 | """ 23 | This function builds a dictionary of speaker-utterance pairs to be used in dynamic mixing 24 | """ 25 | 26 | wsj0_utterances = glob.glob( 27 | os.path.join(hparams["base_folder_dm"], "**/*.wav"), recursive=True 28 | ) 29 | 30 | spk_hashtable = {} 31 | for utt in wsj0_utterances: 32 | 33 | spk_id = Path(utt).stem[:3] 34 | assert torchaudio.info(utt).sample_rate == hparams["sample_rate"] 35 | 36 | # e.g. 2speakers/wav8k/min/tr/mix/019o031a_0.27588_01vo030q_-0.27588.wav 37 | # id of speaker 1 is 019 utterance id is o031a 38 | # id of speaker 2 is 01v utterance id is 01vo030q 39 | 40 | if spk_id not in spk_hashtable.keys(): 41 | spk_hashtable[spk_id] = [utt] 42 | else: 43 | spk_hashtable[spk_id].append(utt) 44 | 45 | # calculate weights for each speaker ( len of list of utterances) 46 | spk_weights = [len(spk_hashtable[x]) for x in spk_hashtable.keys()] 47 | 48 | return spk_hashtable, spk_weights 49 | 50 | 51 | def get_wham_noise_filenames(hparams): 52 | "This function lists the WHAM! noise files to be used in dynamic mixing" 53 | 54 | if "Libri" in hparams["data_folder"]: 55 | # Data folder should point to Libri2Mix folder 56 | if hparams["sample_rate"] == 8000: 57 | noise_path = "wav8k/min/train-360/noise/" 58 | elif hparams["sample_rate"] == 16000: 59 | noise_path = "wav16k/min/train-360/noise/" 60 | else: 61 | raise ValueError("Unsupported Sampling Rate") 62 | else: 63 | if hparams["sample_rate"] == 8000: 64 | noise_path = "wav8k/min/tr/noise/" 65 | elif hparams["sample_rate"] == 16000: 66 | noise_path = "wav16k/min/tr/noise/" 67 | else: 68 | raise ValueError("Unsupported Sampling Rate") 69 | 70 | noise_files = glob.glob( 71 | os.path.join(hparams["data_folder"], noise_path, "*.wav") 72 | ) 73 | return noise_files 74 | 75 | 76 | def dynamic_mix_data_prep(hparams): 77 | """ 78 | Dynamic mixing for WSJ0-2/3Mix and WHAM!/WHAMR! 79 | """ 80 | 81 | # 1. Define datasets 82 | train_data = sb.dataio.dataset.DynamicItemDataset.from_csv( 83 | csv_path=hparams["train_data"], 84 | replacements={"data_root": hparams["data_folder"]}, 85 | ) 86 | 87 | # we build an dictionary where keys are speakers id and entries are list 88 | # of utterances files of that speaker 89 | spk_hashtable, spk_weights = build_spk_hashtable(hparams) 90 | 91 | spk_list = [x for x in spk_hashtable.keys()] 92 | spk_weights = [x / sum(spk_weights) for x in spk_weights] 93 | 94 | if "wham" in Path(hparams["data_folder"]).stem: 95 | noise_files = get_wham_noise_filenames(hparams) 96 | 97 | @sb.utils.data_pipeline.takes("mix_wav") 98 | @sb.utils.data_pipeline.provides( 99 | "mix_sig", "s1_sig", "s2_sig", "s3_sig", "noise_sig" 100 | ) 101 | def audio_pipeline( 102 | mix_wav, 103 | ): # this is dummy --> it means one epoch will be same as without dynamic mixing 104 | """ 105 | This audio pipeline defines the compute graph for dynamic mixing 106 | """ 107 | 108 | speakers = np.random.choice( 109 | spk_list, hparams["num_spks"], replace=False, p=spk_weights 110 | ) 111 | 112 | if "wham" in Path(hparams["data_folder"]).stem: 113 | noise_file = np.random.choice(noise_files, 1, replace=False) 114 | 115 | noise, fs_read = torchaudio.load(noise_file[0]) 116 | noise = noise.squeeze() 117 | # gain = np.clip(random.normalvariate(1, 10), -4, 15) 118 | # noise = rescale(noise, torch.tensor(len(noise)), gain, scale="dB").squeeze() 119 | 120 | # select two speakers randomly 121 | sources = [] 122 | first_lvl = None 123 | 124 | spk_files = [ 125 | np.random.choice(spk_hashtable[spk], 1, False)[0] 126 | for spk in speakers 127 | ] 128 | 129 | minlen = min( 130 | *[torchaudio.info(x).num_frames for x in spk_files], 131 | hparams["training_signal_len"], 132 | ) 133 | 134 | for i, spk_file in enumerate(spk_files): 135 | 136 | # select random offset 137 | length = torchaudio.info(spk_file).num_frames 138 | start = 0 139 | stop = length 140 | if length > minlen: # take a random window 141 | start = np.random.randint(0, length - minlen) 142 | stop = start + minlen 143 | 144 | tmp, fs_read = torchaudio.load( 145 | spk_file, frame_offset=start, num_frames=stop - start, 146 | ) 147 | 148 | # peak = float(Path(spk_file).stem.split("_peak_")[-1]) 149 | tmp = tmp[0] # * peak # remove channel dim and normalize 150 | 151 | if i == 0: 152 | gain = np.clip(random.normalvariate(-27.43, 2.57), -45, 0) 153 | tmp = rescale(tmp, torch.tensor(len(tmp)), gain, scale="dB") 154 | # assert not torch.all(torch.isnan(tmp)) 155 | first_lvl = gain 156 | else: 157 | gain = np.clip( 158 | first_lvl + random.normalvariate(-2.51, 2.66), -45, 0 159 | ) 160 | tmp = rescale(tmp, torch.tensor(len(tmp)), gain, scale="dB") 161 | # assert not torch.all(torch.isnan(tmp)) 162 | sources.append(tmp) 163 | 164 | # we mix the sources together 165 | # here we can also use augmentations ! -> runs on cpu and for each 166 | # mixture parameters will be different rather than for whole batch. 167 | # no difference however for bsz=1 :) 168 | 169 | # padding left 170 | # sources, _ = batch_pad_right(sources) 171 | 172 | sources = torch.stack(sources) 173 | mixture = torch.sum(sources, 0) 174 | if "wham" in Path(hparams["data_folder"]).stem: 175 | len_noise = len(noise) 176 | len_mix = len(mixture) 177 | min_len = min(len_noise, len_mix) 178 | mixture = mixture[:min_len] + noise[:min_len] 179 | 180 | max_amp = max( 181 | torch.abs(mixture).max().item(), 182 | *[x.item() for x in torch.abs(sources).max(dim=-1)[0]], 183 | ) 184 | mix_scaling = 1 / max_amp * 0.9 185 | sources = mix_scaling * sources 186 | mixture = mix_scaling * mixture 187 | 188 | yield mixture 189 | for i in range(hparams["num_spks"]): 190 | yield sources[i] 191 | 192 | # If the number of speakers is 2, yield None for the 3rd speaker 193 | if hparams["num_spks"] == 2: 194 | yield None 195 | 196 | if "wham" in Path(hparams["data_folder"]).stem: 197 | mean_source_lvl = sources.abs().mean() 198 | mean_noise_lvl = noise.abs().mean() 199 | noise = (mean_source_lvl / mean_noise_lvl) * noise 200 | yield noise 201 | else: 202 | yield None 203 | 204 | sb.dataio.dataset.add_dynamic_item([train_data], audio_pipeline) 205 | sb.dataio.dataset.set_output_keys( 206 | [train_data], 207 | ["id", "mix_sig", "s1_sig", "s2_sig", "s3_sig", "noise_sig"], 208 | ) 209 | 210 | train_data = torch.utils.data.DataLoader( 211 | train_data, 212 | batch_size=hparams["dataloader_opts"]["batch_size"], 213 | num_workers=hparams["dataloader_opts"]["num_workers"], 214 | collate_fn=PaddedBatch, 215 | worker_init_fn=lambda x: np.random.seed( 216 | int.from_bytes(os.urandom(4), "little") + x 217 | ), 218 | ) 219 | return train_data 220 | -------------------------------------------------------------------------------- /separation/hparams/baselines/tcn/tcn-whamr.yaml: -------------------------------------------------------------------------------- 1 | # ################################ 2 | # Model: tcn for source separation 3 | # https://arxiv.org/abs/2010.13154 4 | # 5 | # Dataset : WHAMR! 6 | # ################################ 7 | # Basic parameters 8 | # Seed needs to be set at top of yaml, before objects with parameters are made 9 | #l 10 | seed: 1234 11 | __set_seed: !apply:torch.manual_seed [!ref ] 12 | 13 | # Data params 14 | 15 | # the data folder for the wham dataset 16 | # data_folder needs to follow the format: /yourpath/whamr. 17 | # make sure to use the name whamr at your top folder for the dataset! 18 | data_folder: /fastdata/acp19jwr/data/whamr 19 | alternate_path: !ref 20 | mix_folder: mix_both_reverb 21 | mini: False # if true only uses a quarter of the wham/whamr data 22 | 23 | 24 | # the path for wsj0/si_tr_s/ folder -- only needed if dynamic mixing is used 25 | # e.g. /yourpath/wsj0-processed/si_tr_s/ 26 | # you need to convert the original wsj0 to 8k 27 | # you can do this conversion with the script ../meta/preprocess_dynamic_mixing.py 28 | base_folder_dm: /yourpath/wsj0-processed/si_tr_s/ 29 | 30 | experiment_name: tcn-whamr 31 | experiment_folder: !ref tcn/baselines/R=/X= 32 | output_folder: !ref results/// 33 | train_log: !ref /train_log.txt 34 | save_folder: !ref /save 35 | 36 | # the file names should start with whamr instead of whamorg 37 | train_data: !ref /whamr_tr.csv 38 | valid_data: !ref /whamr_cv.csv 39 | test_data: !ref /whamr_tt.csv 40 | skip_prep: False 41 | 42 | # Experiment params 43 | auto_mix_prec: False # Set it to True for mixed precision 44 | test_only: False 45 | num_spks: 2 # set to 3 for wsj0-3mix 46 | progressbar: True 47 | save_audio: False # Save estimated sources on disk 48 | sample_rate: 8000 49 | 50 | # Training parameters 51 | N_epochs: 100 52 | batch_size: 8 53 | lr: 0.001 54 | clip_grad_norm: 5 55 | loss_upper_lim: 9999 # this is the upper limit for an acceptable loss 56 | # if True, the training sequences are cut to a specified length 57 | limit_training_signal_len: True 58 | # this is the length of sequences if we choose to limit 59 | # the signal length of training sequences 60 | training_signal_len: 32000 61 | 62 | # Set it to True to dynamically create mixtures at training time 63 | dynamic_mixing: False 64 | 65 | # Parameters for data augmentation 66 | 67 | # rir_path variable points to the directory of the room impulse responses 68 | # e.g. /miniscratch/subakany/rir_wavs 69 | # If the path does not exist, it is created automatically. 70 | # rir_path: /share/mini1/usr/will/scratch/whamr/rir_wavs 71 | 72 | use_wavedrop: False 73 | use_speedperturb: False 74 | use_speedperturb_sameforeachsource: False 75 | use_rand_shift: False 76 | min_shift: -8000 77 | max_shift: 8000 78 | 79 | speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment 80 | perturb_prob: 1.0 81 | drop_freq_prob: 0.0 82 | drop_chunk_prob: 0.0 83 | sample_rate: !ref 84 | speeds: [95, 100, 105] 85 | 86 | wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment 87 | perturb_prob: 0.0 88 | drop_freq_prob: 1.0 89 | drop_chunk_prob: 1.0 90 | sample_rate: !ref 91 | 92 | # loss thresholding -- this thresholds the training loss 93 | threshold_byloss: True 94 | threshold: -30 95 | 96 | # Encoder parameters 97 | N_encoder_out: 512 98 | out_channels: 512 99 | kernel_size: 16 100 | kernel_stride: 8 101 | 102 | # Dataloader options 103 | dataloader_opts: 104 | batch_size: !ref 105 | num_workers: 4 106 | 107 | # Specifying the network 108 | Encoder: !new:speechbrain.lobes.models.dual_path.Encoder 109 | kernel_size: !ref 110 | out_channels: !ref 111 | 112 | X: 6 113 | R: 4 114 | 115 | MaskNet: !new:speechbrain.lobes.models.conv_tasnet.MaskNet 116 | N: !ref 117 | B: 128 118 | H: 512 119 | P: 3 120 | X: !ref 121 | R: !ref 122 | C: !ref 123 | norm_type: 'gLN' 124 | causal: False 125 | mask_nonlinear: 'relu' 126 | 127 | Decoder: !new:speechbrain.lobes.models.dual_path.Decoder 128 | in_channels: !ref 129 | out_channels: 1 130 | kernel_size: !ref 131 | stride: !ref 132 | bias: False 133 | 134 | optimizer: !name:torch.optim.Adam 135 | lr: !ref 136 | weight_decay: 0 137 | 138 | loss: !name:speechbrain.nnet.losses.get_si_snr_with_pitwrapper 139 | 140 | lr_scheduler: !new:speechbrain.nnet.schedulers.ReduceLROnPlateau 141 | factor: 0.5 142 | patience: 3 143 | dont_halve_until_epoch: 3 144 | 145 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 146 | limit: !ref 147 | 148 | modules: 149 | encoder: !ref 150 | decoder: !ref 151 | masknet: !ref 152 | 153 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 154 | checkpoints_dir: !ref 155 | recoverables: 156 | encoder: !ref 157 | decoder: !ref 158 | masknet: !ref 159 | counter: !ref 160 | lr_scheduler: !ref 161 | 162 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 163 | save_file: !ref 164 | -------------------------------------------------------------------------------- /separation/hparams/deformable/dm/dtcn-whamr.yaml: -------------------------------------------------------------------------------- 1 | # ################################ 2 | # Model: dtcn for source separation 3 | # https://arxiv.org/abs/2010.13154 4 | # 5 | # Dataset : WHAMR! 6 | # ################################ 7 | # Basic parameters 8 | # Seed needs to be set at top of yaml, before objects with parameters are made 9 | #l 10 | seed: 1234 11 | __set_seed: !apply:torch.manual_seed [!ref ] 12 | 13 | # Data params 14 | 15 | # the data folder for the wham dataset 16 | # data_folder needs to follow the format: /yourpath/whamr. 17 | # make sure to use the name whamr at your top folder for the dataset! 18 | data_folder: /fastdata/acp19jwr/data/whamr 19 | alternate_path: !ref 20 | mix_folder: mix_both_reverb 21 | mini: False # if true only uses a quarter of the wham/whamr data 22 | 23 | 24 | # the path for wsj0/si_tr_s/ folder -- only needed if dynamic mixing is used 25 | # e.g. /yourpath/wsj0-processed/si_tr_s/ 26 | # you need to convert the original wsj0 to 8k 27 | # you can do this conversion with the script ../meta/preprocess_dynamic_mixing.py 28 | base_folder_dm: /home/acp19jwr/fastdata/data/wsj0/si_tr_s_processed 29 | 30 | experiment_name: dtcn-whamr 31 | experiment_folder: !ref deformable/dm/R=/X= 32 | output_folder: !ref results/// 33 | train_log: !ref /train_log.txt 34 | save_folder: !ref /save 35 | 36 | # the file names should start with whamr instead of whamorg 37 | train_data: !ref /whamr_tr.csv 38 | valid_data: !ref /whamr_cv.csv 39 | test_data: !ref /whamr_tt.csv 40 | skip_prep: False 41 | 42 | # Experiment params 43 | auto_mix_prec: False # Set it to True for mixed precision 44 | test_only: False 45 | num_spks: 2 # set to 3 for wsj0-3mix 46 | progressbar: True 47 | save_audio: True # Save estimated sources on disk 48 | sample_rate: 8000 49 | 50 | # Training parameters 51 | N_epochs: 100 52 | batch_size: 2 53 | lr: 0.001 54 | clip_grad_norm: 5 55 | loss_upper_lim: 9999 # this is the upper limit for an acceptable loss 56 | # if True, the training sequences are cut to a specified length 57 | limit_training_signal_len: True 58 | # this is the length of sequences if we choose to limit 59 | # the signal length of training sequences 60 | training_signal_len: 32000 61 | 62 | # Set it to True to dynamically create mixtures at training time 63 | dynamic_mixing: True 64 | 65 | # Parameters for data augmentation 66 | 67 | # rir_path variable points to the directory of the room impulse responses 68 | # e.g. /miniscratch/subakany/rir_wavs 69 | # If the path does not exist, it is created automatically. 70 | rir_path: /home/acp19jwr/fastdata/data/whamr/rirs/tr 71 | 72 | use_wavedrop: False 73 | use_speedperturb: True 74 | use_speedperturb_sameforeachsource: False 75 | use_rand_shift: False 76 | min_shift: -8000 77 | max_shift: 8000 78 | 79 | speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment 80 | perturb_prob: 1.0 81 | drop_freq_prob: 0.0 82 | drop_chunk_prob: 0.0 83 | sample_rate: !ref 84 | speeds: [95, 100, 105] 85 | 86 | wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment 87 | perturb_prob: 0.0 88 | drop_freq_prob: 1.0 89 | drop_chunk_prob: 1.0 90 | sample_rate: !ref 91 | 92 | # loss thresholding -- this thresholds the training loss 93 | threshold_byloss: True 94 | threshold: -30 95 | 96 | # Encoder parameters 97 | N_encoder_out: 512 98 | out_channels: 512 99 | kernel_size: 16 100 | kernel_stride: 8 101 | 102 | # Dataloader options 103 | dataloader_opts: 104 | batch_size: !ref 105 | num_workers: 4 106 | 107 | # Specifying the network 108 | Encoder: !new:speechbrain.lobes.models.dual_path.Encoder 109 | kernel_size: !ref 110 | out_channels: !ref 111 | 112 | X: 8 113 | R: 3 114 | 115 | MaskNet: !new:src.deformable.MaskNet 116 | N: !ref 117 | B: 128 118 | H: 512 119 | P: 3 120 | X: !ref 121 | R: !ref 122 | C: !ref 123 | norm_type: 'gLN' 124 | causal: False 125 | mask_nonlinear: 'relu' 126 | 127 | Decoder: !new:speechbrain.lobes.models.dual_path.Decoder 128 | in_channels: !ref 129 | out_channels: 1 130 | kernel_size: !ref 131 | stride: !ref 132 | bias: False 133 | 134 | optimizer: !name:torch.optim.Adam 135 | lr: !ref 136 | weight_decay: 0 137 | 138 | loss: !name:speechbrain.nnet.losses.get_si_snr_with_pitwrapper 139 | 140 | lr_scheduler: !new:speechbrain.nnet.schedulers.ReduceLROnPlateau 141 | factor: 0.5 142 | patience: 3 143 | dont_halve_until_epoch: 3 144 | 145 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 146 | limit: !ref 147 | 148 | modules: 149 | encoder: !ref 150 | decoder: !ref 151 | masknet: !ref 152 | 153 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 154 | checkpoints_dir: !ref 155 | recoverables: 156 | encoder: !ref 157 | decoder: !ref 158 | masknet: !ref 159 | counter: !ref 160 | lr_scheduler: !ref 161 | 162 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 163 | save_file: !ref 164 | -------------------------------------------------------------------------------- /separation/hparams/deformable/dm/shared_weights/dtcn-whamr.yaml: -------------------------------------------------------------------------------- 1 | # ################################ 2 | # Model: dtcn for source separation 3 | # https://arxiv.org/abs/2010.13154 4 | # 5 | # Dataset : WHAMR! 6 | # ################################ 7 | # Basic parameters 8 | # Seed needs to be set at top of yaml, before objects with parameters are made 9 | #l 10 | seed: 1234 11 | __set_seed: !apply:torch.manual_seed [!ref ] 12 | 13 | # Data params 14 | 15 | # the data folder for the wham dataset 16 | # data_folder needs to follow the format: /yourpath/whamr. 17 | # make sure to use the name whamr at your top folder for the dataset! 18 | data_folder: /fastdata/acp19jwr/data/whamr 19 | alternate_path: !ref 20 | mix_folder: mix_both_reverb 21 | mini: False # if true only uses a quarter of the wham/whamr data 22 | 23 | 24 | # the path for wsj0/si_tr_s/ folder -- only needed if dynamic mixing is used 25 | # e.g. /yourpath/wsj0-processed/si_tr_s/ 26 | # you need to convert the original wsj0 to 8k 27 | # you can do this conversion with the script ../meta/preprocess_dynamic_mixing.py 28 | base_folder_dm: /home/acp19jwr/fastdata/data/wsj0/si_tr_s_processed 29 | 30 | experiment_name: dtcn-whamr 31 | experiment_folder: !ref deformable/dm/shared_weights/R=/X= 32 | output_folder: !ref results/// 33 | train_log: !ref /train_log.txt 34 | save_folder: !ref /save 35 | 36 | # the file names should start with whamr instead of whamorg 37 | train_data: !ref /whamr_tr.csv 38 | valid_data: !ref /whamr_cv.csv 39 | test_data: !ref /whamr_tt.csv 40 | skip_prep: False 41 | 42 | # Experiment params 43 | auto_mix_prec: False # Set it to True for mixed precision 44 | test_only: False 45 | num_spks: 2 # set to 3 for wsj0-3mix 46 | progressbar: True 47 | save_audio: True # Save estimated sources on disk 48 | sample_rate: 8000 49 | 50 | # Training parameters 51 | N_epochs: 100 52 | batch_size: 2 53 | lr: 0.001 54 | clip_grad_norm: 5 55 | loss_upper_lim: 9999 # this is the upper limit for an acceptable loss 56 | # if True, the training sequences are cut to a specified length 57 | limit_training_signal_len: True 58 | # this is the length of sequences if we choose to limit 59 | # the signal length of training sequences 60 | training_signal_len: 32000 61 | 62 | # Set it to True to dynamically create mixtures at training time 63 | dynamic_mixing: True 64 | 65 | # Parameters for data augmentation 66 | 67 | # rir_path variable points to the directory of the room impulse responses 68 | # e.g. /miniscratch/subakany/rir_wavs 69 | # If the path does not exist, it is created automatically. 70 | rir_path: /home/acp19jwr/fastdata/data/whamr/rirs/tr 71 | 72 | use_wavedrop: False 73 | use_speedperturb: True 74 | use_speedperturb_sameforeachsource: False 75 | use_rand_shift: False 76 | min_shift: -8000 77 | max_shift: 8000 78 | 79 | speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment 80 | perturb_prob: 1.0 81 | drop_freq_prob: 0.0 82 | drop_chunk_prob: 0.0 83 | sample_rate: !ref 84 | speeds: [95, 100, 105] 85 | 86 | wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment 87 | perturb_prob: 0.0 88 | drop_freq_prob: 1.0 89 | drop_chunk_prob: 1.0 90 | sample_rate: !ref 91 | 92 | # loss thresholding -- this thresholds the training loss 93 | threshold_byloss: True 94 | threshold: -30 95 | 96 | # Encoder parameters 97 | N_encoder_out: 512 98 | out_channels: 512 99 | kernel_size: 16 100 | kernel_stride: 8 101 | 102 | # Dataloader options 103 | dataloader_opts: 104 | batch_size: !ref 105 | num_workers: 4 106 | 107 | # Specifying the network 108 | Encoder: !new:speechbrain.lobes.models.dual_path.Encoder 109 | kernel_size: !ref 110 | out_channels: !ref 111 | 112 | X: 8 113 | R: 3 114 | 115 | MaskNet: !new:src.deformable.MaskNet 116 | N: !ref 117 | B: 128 118 | H: 512 119 | P: 3 120 | X: !ref 121 | R: !ref 122 | C: !ref 123 | norm_type: 'gLN' 124 | causal: False 125 | mask_nonlinear: 'relu' 126 | shared_weights: True 127 | 128 | Decoder: !new:speechbrain.lobes.models.dual_path.Decoder 129 | in_channels: !ref 130 | out_channels: 1 131 | kernel_size: !ref 132 | stride: !ref 133 | bias: False 134 | 135 | optimizer: !name:torch.optim.Adam 136 | lr: !ref 137 | weight_decay: 0 138 | 139 | loss: !name:speechbrain.nnet.losses.get_si_snr_with_pitwrapper 140 | 141 | lr_scheduler: !new:speechbrain.nnet.schedulers.ReduceLROnPlateau 142 | factor: 0.5 143 | patience: 3 144 | dont_halve_until_epoch: 3 145 | 146 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 147 | limit: !ref 148 | 149 | modules: 150 | encoder: !ref 151 | decoder: !ref 152 | masknet: !ref 153 | 154 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 155 | checkpoints_dir: !ref 156 | recoverables: 157 | encoder: !ref 158 | decoder: !ref 159 | masknet: !ref 160 | counter: !ref 161 | lr_scheduler: !ref 162 | 163 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 164 | save_file: !ref 165 | -------------------------------------------------------------------------------- /separation/hparams/deformable/dtcn-whamr.yaml: -------------------------------------------------------------------------------- 1 | # ################################ 2 | # Model: dtcn for source separation 3 | # https://arxiv.org/abs/2010.13154 4 | # 5 | # Dataset : WHAMR! 6 | # ################################ 7 | # Basic parameters 8 | # Seed needs to be set at top of yaml, before objects with parameters are made 9 | #l 10 | seed: 1234 11 | __set_seed: !apply:torch.manual_seed [!ref ] 12 | 13 | # Data params 14 | 15 | # the data folder for the wham dataset 16 | # data_folder needs to follow the format: /yourpath/whamr. 17 | # make sure to use the name whamr at your top folder for the dataset! 18 | data_folder: /fastdata/acp19jwr/data/whamr 19 | alternate_path: !ref 20 | mix_folder: mix_both_reverb 21 | mini: False # if true only uses a quarter of the wham/whamr data 22 | 23 | 24 | # the path for wsj0/si_tr_s/ folder -- only needed if dynamic mixing is used 25 | # e.g. /yourpath/wsj0-processed/si_tr_s/ 26 | # you need to convert the original wsj0 to 8k 27 | # you can do this conversion with the script ../meta/preprocess_dynamic_mixing.py 28 | base_folder_dm: /yourpath/wsj0-processed/si_tr_s/ 29 | 30 | experiment_name: dtcn-whamr 31 | experiment_folder: !ref deformable/baselines/R=/X= 32 | output_folder: !ref results/// 33 | train_log: !ref /train_log.txt 34 | save_folder: !ref /save 35 | 36 | # the file names should start with whamr instead of whamorg 37 | train_data: !ref /whamr_tr.csv 38 | valid_data: !ref /whamr_cv.csv 39 | test_data: !ref /whamr_tt.csv 40 | skip_prep: False 41 | 42 | # Experiment params 43 | auto_mix_prec: False # Set it to True for mixed precision 44 | test_only: False 45 | num_spks: 2 # set to 3 for wsj0-3mix 46 | progressbar: True 47 | save_audio: True # Save estimated sources on disk 48 | sample_rate: 8000 49 | 50 | # Training parameters 51 | N_epochs: 100 52 | batch_size: 2 53 | lr: 0.001 54 | clip_grad_norm: 5 55 | loss_upper_lim: 9999 # this is the upper limit for an acceptable loss 56 | # if True, the training sequences are cut to a specified length 57 | limit_training_signal_len: True 58 | # this is the length of sequences if we choose to limit 59 | # the signal length of training sequences 60 | training_signal_len: 32000 61 | 62 | # Set it to True to dynamically create mixtures at training time 63 | dynamic_mixing: False 64 | 65 | # Parameters for data augmentation 66 | 67 | # rir_path variable points to the directory of the room impulse responses 68 | # e.g. /miniscratch/subakany/rir_wavs 69 | # If the path does not exist, it is created automatically. 70 | # rir_path: /share/mini1/usr/will/scratch/whamr/rir_wavs 71 | 72 | use_wavedrop: False 73 | use_speedperturb: False 74 | use_speedperturb_sameforeachsource: False 75 | use_rand_shift: False 76 | min_shift: -8000 77 | max_shift: 8000 78 | 79 | speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment 80 | perturb_prob: 1.0 81 | drop_freq_prob: 0.0 82 | drop_chunk_prob: 0.0 83 | sample_rate: !ref 84 | speeds: [95, 100, 105] 85 | 86 | wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment 87 | perturb_prob: 0.0 88 | drop_freq_prob: 1.0 89 | drop_chunk_prob: 1.0 90 | sample_rate: !ref 91 | 92 | # loss thresholding -- this thresholds the training loss 93 | threshold_byloss: True 94 | threshold: -30 95 | 96 | # Encoder parameters 97 | N_encoder_out: 512 98 | out_channels: 512 99 | kernel_size: 16 100 | kernel_stride: 8 101 | 102 | # Dataloader options 103 | dataloader_opts: 104 | batch_size: !ref 105 | num_workers: 4 106 | 107 | # Specifying the network 108 | Encoder: !new:speechbrain.lobes.models.dual_path.Encoder 109 | kernel_size: !ref 110 | out_channels: !ref 111 | 112 | X: 8 113 | R: 3 114 | 115 | MaskNet: !new:src.deformable.MaskNet 116 | N: !ref 117 | B: 128 118 | H: 512 119 | P: 3 120 | X: !ref 121 | R: !ref 122 | C: !ref 123 | norm_type: 'gLN' 124 | causal: False 125 | mask_nonlinear: 'relu' 126 | 127 | Decoder: !new:speechbrain.lobes.models.dual_path.Decoder 128 | in_channels: !ref 129 | out_channels: 1 130 | kernel_size: !ref 131 | stride: !ref 132 | bias: False 133 | 134 | optimizer: !name:torch.optim.Adam 135 | lr: !ref 136 | weight_decay: 0 137 | 138 | loss: !name:speechbrain.nnet.losses.get_si_snr_with_pitwrapper 139 | 140 | lr_scheduler: !new:speechbrain.nnet.schedulers.ReduceLROnPlateau 141 | factor: 0.5 142 | patience: 3 143 | dont_halve_until_epoch: 3 144 | 145 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 146 | limit: !ref 147 | 148 | modules: 149 | encoder: !ref 150 | decoder: !ref 151 | masknet: !ref 152 | 153 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 154 | checkpoints_dir: !ref 155 | recoverables: 156 | encoder: !ref 157 | decoder: !ref 158 | masknet: !ref 159 | counter: !ref 160 | lr_scheduler: !ref 161 | 162 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 163 | save_file: !ref 164 | -------------------------------------------------------------------------------- /separation/hparams/deformable/shared_weights/dtcn-whamr.yaml: -------------------------------------------------------------------------------- 1 | # ################################ 2 | # Model: dtcn for source separation 3 | # https://arxiv.org/abs/2010.13154 4 | # 5 | # Dataset : WHAMR! 6 | # ################################ 7 | # Basic parameters 8 | # Seed needs to be set at top of yaml, before objects with parameters are made 9 | #l 10 | seed: 1234 11 | __set_seed: !apply:torch.manual_seed [!ref ] 12 | 13 | # Data params 14 | 15 | # the data folder for the wham dataset 16 | # data_folder needs to follow the format: /yourpath/whamr. 17 | # make sure to use the name whamr at your top folder for the dataset! 18 | data_folder: /fastdata/acp19jwr/data/whamr 19 | alternate_path: !ref 20 | mix_folder: mix_both_reverb 21 | mini: False # if true only uses a quarter of the wham/whamr data 22 | 23 | 24 | # the path for wsj0/si_tr_s/ folder -- only needed if dynamic mixing is used 25 | # e.g. /yourpath/wsj0-processed/si_tr_s/ 26 | # you need to convert the original wsj0 to 8k 27 | # you can do this conversion with the script ../meta/preprocess_dynamic_mixing.py 28 | base_folder_dm: /yourpath/wsj0-processed/si_tr_s/ 29 | 30 | experiment_name: dtcn-whamr 31 | experiment_folder: !ref deformable/shared_weights/R=/X= 32 | output_folder: !ref results/// 33 | train_log: !ref /train_log.txt 34 | save_folder: !ref /save 35 | 36 | # the file names should start with whamr instead of whamorg 37 | train_data: !ref /whamr_tr.csv 38 | valid_data: !ref /whamr_cv.csv 39 | test_data: !ref /whamr_tt.csv 40 | skip_prep: False 41 | 42 | # Experiment params 43 | auto_mix_prec: False # Set it to True for mixed precision 44 | test_only: False 45 | num_spks: 2 # set to 3 for wsj0-3mix 46 | progressbar: True 47 | save_audio: True # Save estimated sources on disk 48 | sample_rate: 8000 49 | 50 | # Training parameters 51 | N_epochs: 100 52 | batch_size: 2 53 | lr: 0.001 54 | clip_grad_norm: 5 55 | loss_upper_lim: 9999 # this is the upper limit for an acceptable loss 56 | # if True, the training sequences are cut to a specified length 57 | limit_training_signal_len: True 58 | # this is the length of sequences if we choose to limit 59 | # the signal length of training sequences 60 | training_signal_len: 32000 61 | 62 | # Set it to True to dynamically create mixtures at training time 63 | dynamic_mixing: False 64 | 65 | # Parameters for data augmentation 66 | 67 | # rir_path variable points to the directory of the room impulse responses 68 | # e.g. /miniscratch/subakany/rir_wavs 69 | # If the path does not exist, it is created automatically. 70 | # rir_path: /share/mini1/usr/will/scratch/whamr/rir_wavs 71 | 72 | use_wavedrop: False 73 | use_speedperturb: False 74 | use_speedperturb_sameforeachsource: False 75 | use_rand_shift: False 76 | min_shift: -8000 77 | max_shift: 8000 78 | 79 | speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment 80 | perturb_prob: 1.0 81 | drop_freq_prob: 0.0 82 | drop_chunk_prob: 0.0 83 | sample_rate: !ref 84 | speeds: [95, 100, 105] 85 | 86 | wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment 87 | perturb_prob: 0.0 88 | drop_freq_prob: 1.0 89 | drop_chunk_prob: 1.0 90 | sample_rate: !ref 91 | 92 | # loss thresholding -- this thresholds the training loss 93 | threshold_byloss: True 94 | threshold: -30 95 | 96 | # Encoder parameters 97 | N_encoder_out: 512 98 | out_channels: 512 99 | kernel_size: 16 100 | kernel_stride: 8 101 | 102 | # Dataloader options 103 | dataloader_opts: 104 | batch_size: !ref 105 | num_workers: 4 106 | 107 | # Specifying the network 108 | Encoder: !new:speechbrain.lobes.models.dual_path.Encoder 109 | kernel_size: !ref 110 | out_channels: !ref 111 | 112 | X: 8 113 | R: 3 114 | 115 | MaskNet: !new:src.deformable.MaskNet 116 | N: !ref 117 | B: 128 118 | H: 512 119 | P: 3 120 | X: !ref 121 | R: !ref 122 | C: !ref 123 | norm_type: 'gLN' 124 | causal: False 125 | mask_nonlinear: 'relu' 126 | shared_weights: True 127 | 128 | Decoder: !new:speechbrain.lobes.models.dual_path.Decoder 129 | in_channels: !ref 130 | out_channels: 1 131 | kernel_size: !ref 132 | stride: !ref 133 | bias: False 134 | 135 | optimizer: !name:torch.optim.Adam 136 | lr: !ref 137 | weight_decay: 0 138 | 139 | loss: !name:speechbrain.nnet.losses.get_si_snr_with_pitwrapper 140 | 141 | lr_scheduler: !new:speechbrain.nnet.schedulers.ReduceLROnPlateau 142 | factor: 0.5 143 | patience: 3 144 | dont_halve_until_epoch: 3 145 | 146 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 147 | limit: !ref 148 | 149 | modules: 150 | encoder: !ref 151 | decoder: !ref 152 | masknet: !ref 153 | 154 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 155 | checkpoints_dir: !ref 156 | recoverables: 157 | encoder: !ref 158 | decoder: !ref 159 | masknet: !ref 160 | counter: !ref 161 | lr_scheduler: !ref 162 | 163 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 164 | save_file: !ref 165 | -------------------------------------------------------------------------------- /separation/hparams/wsj0-2mix/deformable/dm/dtcn-wsj0-2mix.yaml: -------------------------------------------------------------------------------- 1 | # ################################ 2 | # Model: dtcn for source separation 3 | # https://arxiv.org/abs/2010.13154 4 | # 5 | # Dataset : WSJ0-2Mix 6 | # ################################ 7 | # Basic parameters 8 | # Seed needs to be set at top of yaml, before objects with parameters are made 9 | #l 10 | seed: 1234 11 | __set_seed: !apply:torch.manual_seed [!ref ] 12 | 13 | # Data params 14 | 15 | # the data folder for the wham dataset 16 | # data_folder needs to follow the format: /yourpath/wsj0-2mix. 17 | # make sure to use the name wsj0-2mix at your top folder for the dataset! 18 | data_folder: /fastdata/acp19jwr/data/wsj0-2mix 19 | # mini: False # if true only uses a quarter of the wham/wsj0-2mix data 20 | 21 | 22 | # the path for wsj0/si_tr_s/ folder -- only needed if dynamic mixing is used 23 | # e.g. /yourpath/wsj0-processed/si_tr_s/ 24 | # you need to convert the original wsj0 to 8k 25 | # you can do this conversion with the script ../meta/preprocess_dynamic_mixing.py 26 | base_folder_dm: /fastdata/acp19jwr/data/wsj0/si_tr_s_processed 27 | 28 | experiment_name: deformable/dm 29 | experiment_folder: !ref wsj0-2mix//R=/X= 30 | output_folder: !ref results// 31 | train_log: !ref /train_log.txt 32 | save_folder: !ref /save 33 | 34 | # the file names should start with wsj0-2mix instead of whamorg 35 | train_data: !ref /wsj_tr.csv 36 | valid_data: !ref /wsj_cv.csv 37 | test_data: !ref /wsj_tt.csv 38 | skip_prep: False 39 | 40 | # Experiment params 41 | auto_mix_prec: False # Set it to True for mixed precision 42 | test_only: False 43 | num_spks: 2 # set to 3 for wsj0-3mix 44 | progressbar: True 45 | save_audio: False # Save estimated sources on disk 46 | sample_rate: 8000 47 | 48 | # Training parameters 49 | N_epochs: 100 50 | batch_size: 1 51 | lr: 0.001 52 | clip_grad_norm: 5 53 | loss_upper_lim: 9999 # this is the upper limit for an acceptable loss 54 | # if True, the training sequences are cut to a specified length 55 | limit_training_signal_len: True 56 | # this is the length of sequences if we choose to limit 57 | # the signal length of training sequences 58 | training_signal_len: 32000 59 | 60 | # Set it to True to dynamically create mixtures at training time 61 | dynamic_mixing: True 62 | 63 | # Parameters for data augmentation 64 | 65 | # rir_path variable points to the directory of the room impulse responses 66 | # e.g. /miniscratch/subakany/rir_wavs 67 | # If the path does not exist, it is created automatically. 68 | # rir_path: /share/mini1/usr/will/scratch/wsj0-2mix/rir_wavs 69 | 70 | use_wavedrop: False 71 | use_speedperturb: True 72 | use_speedperturb_sameforeachsource: False 73 | use_rand_shift: False 74 | min_shift: -8000 75 | max_shift: 8000 76 | 77 | speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment 78 | perturb_prob: 1.0 79 | drop_freq_prob: 0.0 80 | drop_chunk_prob: 0.0 81 | sample_rate: !ref 82 | speeds: [95, 100, 105] 83 | 84 | wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment 85 | perturb_prob: 0.0 86 | drop_freq_prob: 1.0 87 | drop_chunk_prob: 1.0 88 | sample_rate: !ref 89 | 90 | # loss thresholding -- this thresholds the training loss 91 | threshold_byloss: True 92 | threshold: -30 93 | 94 | # Encoder parameters 95 | N_encoder_out: 512 96 | out_channels: 512 97 | kernel_size: 16 98 | kernel_stride: 8 99 | 100 | # Dataloader options 101 | dataloader_opts: 102 | batch_size: !ref 103 | num_workers: 4 104 | 105 | # Specifying the network 106 | Encoder: !new:speechbrain.lobes.models.dual_path.Encoder 107 | kernel_size: !ref 108 | out_channels: !ref 109 | 110 | X: 8 111 | R: 3 112 | 113 | MaskNet: !new:src.deformable.MaskNet 114 | N: !ref 115 | B: 128 116 | H: 512 117 | P: 3 118 | X: !ref 119 | R: !ref 120 | C: !ref 121 | norm_type: 'gLN' 122 | causal: False 123 | mask_nonlinear: 'relu' 124 | 125 | Decoder: !new:speechbrain.lobes.models.dual_path.Decoder 126 | in_channels: !ref 127 | out_channels: 1 128 | kernel_size: !ref 129 | stride: !ref 130 | bias: False 131 | 132 | optimizer: !name:torch.optim.Adam 133 | lr: !ref 134 | weight_decay: 0 135 | 136 | loss: !name:speechbrain.nnet.losses.get_si_snr_with_pitwrapper 137 | 138 | lr_scheduler: !new:speechbrain.nnet.schedulers.ReduceLROnPlateau 139 | factor: 0.5 140 | patience: 3 141 | dont_halve_until_epoch: 3 142 | 143 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 144 | limit: !ref 145 | 146 | modules: 147 | encoder: !ref 148 | decoder: !ref 149 | masknet: !ref 150 | 151 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 152 | checkpoints_dir: !ref 153 | recoverables: 154 | encoder: !ref 155 | decoder: !ref 156 | masknet: !ref 157 | counter: !ref 158 | lr_scheduler: !ref 159 | 160 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 161 | save_file: !ref 162 | -------------------------------------------------------------------------------- /separation/hparams/wsj0-2mix/deformable/dm/shared_weights/dtcn-wsj0-2mix.yaml: -------------------------------------------------------------------------------- 1 | # ################################ 2 | # Model: dtcn for source separation 3 | # https://arxiv.org/abs/2010.13154 4 | # 5 | # Dataset : WSJ0-2Mix 6 | # ################################ 7 | # Basic parameters 8 | # Seed needs to be set at top of yaml, before objects with parameters are made 9 | #l 10 | seed: 1234 11 | __set_seed: !apply:torch.manual_seed [!ref ] 12 | 13 | # Data params 14 | 15 | # the data folder for the wham dataset 16 | # data_folder needs to follow the format: /yourpath/wsj0-2mix. 17 | # make sure to use the name wsj0-2mix at your top folder for the dataset! 18 | data_folder: /fastdata/acp19jwr/data/wsj0-2mix 19 | # mini: False # if true only uses a quarter of the wham/wsj0-2mix data 20 | 21 | 22 | # the path for wsj0/si_tr_s/ folder -- only needed if dynamic mixing is used 23 | # e.g. /yourpath/wsj0-processed/si_tr_s/ 24 | # you need to convert the original wsj0 to 8k 25 | # you can do this conversion with the script ../meta/preprocess_dynamic_mixing.py 26 | base_folder_dm: /fastdata/acp19jwr/data/wsj0/si_tr_s_processed 27 | 28 | experiment_name: deformable/dm/shared_weights 29 | experiment_folder: !ref wsj0-2mix//R=/X= 30 | output_folder: !ref results// 31 | train_log: !ref /train_log.txt 32 | save_folder: !ref /save 33 | 34 | # the file names should start with wsj0-2mix instead of whamorg 35 | train_data: !ref /wsj_tr.csv 36 | valid_data: !ref /wsj_cv.csv 37 | test_data: !ref /wsj_tt.csv 38 | skip_prep: False 39 | 40 | # Experiment params 41 | auto_mix_prec: False # Set it to True for mixed precision 42 | test_only: False 43 | num_spks: 2 # set to 3 for wsj0-3mix 44 | progressbar: True 45 | save_audio: False # Save estimated sources on disk 46 | sample_rate: 8000 47 | 48 | # Training parameters 49 | N_epochs: 100 50 | batch_size: 1 51 | lr: 0.001 52 | clip_grad_norm: 5 53 | loss_upper_lim: 9999 # this is the upper limit for an acceptable loss 54 | # if True, the training sequences are cut to a specified length 55 | limit_training_signal_len: True 56 | # this is the length of sequences if we choose to limit 57 | # the signal length of training sequences 58 | training_signal_len: 32000 59 | 60 | # Set it to True to dynamically create mixtures at training time 61 | dynamic_mixing: True 62 | 63 | # Parameters for data augmentation 64 | 65 | # rir_path variable points to the directory of the room impulse responses 66 | # e.g. /miniscratch/subakany/rir_wavs 67 | # If the path does not exist, it is created automatically. 68 | # rir_path: /share/mini1/usr/will/scratch/wsj0-2mix/rir_wavs 69 | 70 | use_wavedrop: False 71 | use_speedperturb: True 72 | use_speedperturb_sameforeachsource: False 73 | use_rand_shift: False 74 | min_shift: -8000 75 | max_shift: 8000 76 | 77 | speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment 78 | perturb_prob: 1.0 79 | drop_freq_prob: 0.0 80 | drop_chunk_prob: 0.0 81 | sample_rate: !ref 82 | speeds: [95, 100, 105] 83 | 84 | wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment 85 | perturb_prob: 0.0 86 | drop_freq_prob: 1.0 87 | drop_chunk_prob: 1.0 88 | sample_rate: !ref 89 | 90 | # loss thresholding -- this thresholds the training loss 91 | threshold_byloss: True 92 | threshold: -30 93 | 94 | # Encoder parameters 95 | N_encoder_out: 512 96 | out_channels: 512 97 | kernel_size: 16 98 | kernel_stride: 8 99 | 100 | # Dataloader options 101 | dataloader_opts: 102 | batch_size: !ref 103 | num_workers: 4 104 | 105 | # Specifying the network 106 | Encoder: !new:speechbrain.lobes.models.dual_path.Encoder 107 | kernel_size: !ref 108 | out_channels: !ref 109 | 110 | X: 8 111 | R: 3 112 | 113 | MaskNet: !new:src.deformable.MaskNet 114 | N: !ref 115 | B: 128 116 | H: 512 117 | P: 3 118 | X: !ref 119 | R: !ref 120 | C: !ref 121 | norm_type: 'gLN' 122 | causal: False 123 | mask_nonlinear: 'relu' 124 | shared_weights: True 125 | 126 | Decoder: !new:speechbrain.lobes.models.dual_path.Decoder 127 | in_channels: !ref 128 | out_channels: 1 129 | kernel_size: !ref 130 | stride: !ref 131 | bias: False 132 | 133 | optimizer: !name:torch.optim.Adam 134 | lr: !ref 135 | weight_decay: 0 136 | 137 | loss: !name:speechbrain.nnet.losses.get_si_snr_with_pitwrapper 138 | 139 | lr_scheduler: !new:speechbrain.nnet.schedulers.ReduceLROnPlateau 140 | factor: 0.5 141 | patience: 3 142 | dont_halve_until_epoch: 3 143 | 144 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 145 | limit: !ref 146 | 147 | modules: 148 | encoder: !ref 149 | decoder: !ref 150 | masknet: !ref 151 | 152 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 153 | checkpoints_dir: !ref 154 | recoverables: 155 | encoder: !ref 156 | decoder: !ref 157 | masknet: !ref 158 | counter: !ref 159 | lr_scheduler: !ref 160 | 161 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 162 | save_file: !ref 163 | -------------------------------------------------------------------------------- /separation/hparams/wsj0-2mix/deformable/dtcn-wsj0-2mix.yaml: -------------------------------------------------------------------------------- 1 | # ################################ 2 | # Model: dtcn for source separation 3 | # https://arxiv.org/abs/2010.13154 4 | # 5 | # Dataset : WSJ0-2Mix 6 | # ################################ 7 | # Basic parameters 8 | # Seed needs to be set at top of yaml, before objects with parameters are made 9 | #l 10 | seed: 1234 11 | __set_seed: !apply:torch.manual_seed [!ref ] 12 | 13 | # Data params 14 | 15 | # the data folder for the wham dataset 16 | # data_folder needs to follow the format: /yourpath/wsj0-2mix. 17 | # make sure to use the name wsj0-2mix at your top folder for the dataset! 18 | data_folder: /fastdata/acp19jwr/data/wsj0-2mix 19 | # mini: False # if true only uses a quarter of the wham/wsj0-2mix data 20 | 21 | 22 | # the path for wsj0/si_tr_s/ folder -- only needed if dynamic mixing is used 23 | # e.g. /yourpath/wsj0-processed/si_tr_s/ 24 | # you need to convert the original wsj0 to 8k 25 | # you can do this conversion with the script ../meta/preprocess_dynamic_mixing.py 26 | base_folder_dm: /yourpath/wsj0-processed/si_tr_s/ 27 | 28 | experiment_name: deformable 29 | experiment_folder: !ref wsj0-2mix//R=/X= 30 | output_folder: !ref results// 31 | train_log: !ref /train_log.txt 32 | save_folder: !ref /save 33 | 34 | # the file names should start with wsj0-2mix instead of whamorg 35 | train_data: !ref /wsj_tr.csv 36 | valid_data: !ref /wsj_cv.csv 37 | test_data: !ref /wsj_tt.csv 38 | skip_prep: False 39 | 40 | # Experiment params 41 | auto_mix_prec: False # Set it to True for mixed precision 42 | test_only: False 43 | num_spks: 2 # set to 3 for wsj0-3mix 44 | progressbar: True 45 | save_audio: False # Save estimated sources on disk 46 | sample_rate: 8000 47 | 48 | # Training parameters 49 | N_epochs: 100 50 | batch_size: 1 51 | lr: 0.001 52 | clip_grad_norm: 5 53 | loss_upper_lim: 9999 # this is the upper limit for an acceptable loss 54 | # if True, the training sequences are cut to a specified length 55 | limit_training_signal_len: True 56 | # this is the length of sequences if we choose to limit 57 | # the signal length of training sequences 58 | training_signal_len: 32000 59 | 60 | # Set it to True to dynamically create mixtures at training time 61 | dynamic_mixing: False 62 | 63 | # Parameters for data augmentation 64 | 65 | # rir_path variable points to the directory of the room impulse responses 66 | # e.g. /miniscratch/subakany/rir_wavs 67 | # If the path does not exist, it is created automatically. 68 | # rir_path: /share/mini1/usr/will/scratch/wsj0-2mix/rir_wavs 69 | 70 | use_wavedrop: False 71 | use_speedperturb: False 72 | use_speedperturb_sameforeachsource: False 73 | use_rand_shift: False 74 | min_shift: -8000 75 | max_shift: 8000 76 | 77 | speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment 78 | perturb_prob: 1.0 79 | drop_freq_prob: 0.0 80 | drop_chunk_prob: 0.0 81 | sample_rate: !ref 82 | speeds: [95, 100, 105] 83 | 84 | wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment 85 | perturb_prob: 0.0 86 | drop_freq_prob: 1.0 87 | drop_chunk_prob: 1.0 88 | sample_rate: !ref 89 | 90 | # loss thresholding -- this thresholds the training loss 91 | threshold_byloss: True 92 | threshold: -30 93 | 94 | # Encoder parameters 95 | N_encoder_out: 512 96 | out_channels: 512 97 | kernel_size: 16 98 | kernel_stride: 8 99 | 100 | # Dataloader options 101 | dataloader_opts: 102 | batch_size: !ref 103 | num_workers: 4 104 | 105 | # Specifying the network 106 | Encoder: !new:speechbrain.lobes.models.dual_path.Encoder 107 | kernel_size: !ref 108 | out_channels: !ref 109 | 110 | X: 8 111 | R: 3 112 | 113 | MaskNet: !new:src.deformable.MaskNet 114 | N: !ref 115 | B: 128 116 | H: 512 117 | P: 3 118 | X: !ref 119 | R: !ref 120 | C: !ref 121 | norm_type: 'gLN' 122 | causal: False 123 | mask_nonlinear: 'relu' 124 | 125 | Decoder: !new:speechbrain.lobes.models.dual_path.Decoder 126 | in_channels: !ref 127 | out_channels: 1 128 | kernel_size: !ref 129 | stride: !ref 130 | bias: False 131 | 132 | optimizer: !name:torch.optim.Adam 133 | lr: !ref 134 | weight_decay: 0 135 | 136 | loss: !name:speechbrain.nnet.losses.get_si_snr_with_pitwrapper 137 | 138 | lr_scheduler: !new:speechbrain.nnet.schedulers.ReduceLROnPlateau 139 | factor: 0.5 140 | patience: 3 141 | dont_halve_until_epoch: 3 142 | 143 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 144 | limit: !ref 145 | 146 | modules: 147 | encoder: !ref 148 | decoder: !ref 149 | masknet: !ref 150 | 151 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 152 | checkpoints_dir: !ref 153 | recoverables: 154 | encoder: !ref 155 | decoder: !ref 156 | masknet: !ref 157 | counter: !ref 158 | lr_scheduler: !ref 159 | 160 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 161 | save_file: !ref 162 | -------------------------------------------------------------------------------- /separation/hparams/wsj0-2mix/deformable/shared_weights/dtcn-wsj0-2mix.yaml: -------------------------------------------------------------------------------- 1 | # ################################ 2 | # Model: dtcn for source separation 3 | # https://arxiv.org/abs/2010.13154 4 | # 5 | # Dataset : WSJ0-2Mix 6 | # ################################ 7 | # Basic parameters 8 | # Seed needs to be set at top of yaml, before objects with parameters are made 9 | #l 10 | seed: 1234 11 | __set_seed: !apply:torch.manual_seed [!ref ] 12 | 13 | # Data params 14 | 15 | # the data folder for the wham dataset 16 | # data_folder needs to follow the format: /yourpath/wsj0-2mix. 17 | # make sure to use the name wsj0-2mix at your top folder for the dataset! 18 | data_folder: /fastdata/acp19jwr/data/wsj0-2mix 19 | # mini: False # if true only uses a quarter of the wham/wsj0-2mix data 20 | 21 | 22 | # the path for wsj0/si_tr_s/ folder -- only needed if dynamic mixing is used 23 | # e.g. /yourpath/wsj0-processed/si_tr_s/ 24 | # you need to convert the original wsj0 to 8k 25 | # you can do this conversion with the script ../meta/preprocess_dynamic_mixing.py 26 | base_folder_dm: /fastdata/acp19jwr/data/wsj0/si_tr_s_processed 27 | 28 | experiment_name: deformable/shared_weights 29 | experiment_folder: !ref wsj0-2mix//R=/X= 30 | output_folder: !ref results// 31 | train_log: !ref /train_log.txt 32 | save_folder: !ref /save 33 | 34 | # the file names should start with wsj0-2mix instead of whamorg 35 | train_data: !ref /wsj_tr.csv 36 | valid_data: !ref /wsj_cv.csv 37 | test_data: !ref /wsj_tt.csv 38 | skip_prep: False 39 | 40 | # Experiment params 41 | auto_mix_prec: False # Set it to True for mixed precision 42 | test_only: False 43 | num_spks: 2 # set to 3 for wsj0-3mix 44 | progressbar: True 45 | save_audio: False # Save estimated sources on disk 46 | sample_rate: 8000 47 | 48 | # Training parameters 49 | N_epochs: 100 50 | batch_size: 2 51 | lr: 0.001 52 | clip_grad_norm: 5 53 | loss_upper_lim: 9999 # this is the upper limit for an acceptable loss 54 | # if True, the training sequences are cut to a specified length 55 | limit_training_signal_len: True 56 | # this is the length of sequences if we choose to limit 57 | # the signal length of training sequences 58 | training_signal_len: 32000 59 | 60 | # Set it to True to dynamically create mixtures at training time 61 | dynamic_mixing: False 62 | 63 | # Parameters for data augmentation 64 | 65 | # rir_path variable points to the directory of the room impulse responses 66 | # e.g. /miniscratch/subakany/rir_wavs 67 | # If the path does not exist, it is created automatically. 68 | # rir_path: /share/mini1/usr/will/scratch/wsj0-2mix/rir_wavs 69 | 70 | use_wavedrop: False 71 | use_speedperturb: False 72 | use_speedperturb_sameforeachsource: False 73 | use_rand_shift: False 74 | min_shift: -8000 75 | max_shift: 8000 76 | 77 | speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment 78 | perturb_prob: 1.0 79 | drop_freq_prob: 0.0 80 | drop_chunk_prob: 0.0 81 | sample_rate: !ref 82 | speeds: [95, 100, 105] 83 | 84 | wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment 85 | perturb_prob: 0.0 86 | drop_freq_prob: 1.0 87 | drop_chunk_prob: 1.0 88 | sample_rate: !ref 89 | 90 | # loss thresholding -- this thresholds the training loss 91 | threshold_byloss: True 92 | threshold: -30 93 | 94 | # Encoder parameters 95 | N_encoder_out: 512 96 | out_channels: 512 97 | kernel_size: 16 98 | kernel_stride: 8 99 | 100 | # Dataloader options 101 | dataloader_opts: 102 | batch_size: !ref 103 | num_workers: 4 104 | 105 | # Specifying the network 106 | Encoder: !new:speechbrain.lobes.models.dual_path.Encoder 107 | kernel_size: !ref 108 | out_channels: !ref 109 | 110 | X: 8 111 | R: 3 112 | 113 | MaskNet: !new:src.deformable.MaskNet 114 | N: !ref 115 | B: 128 116 | H: 512 117 | P: 3 118 | X: !ref 119 | R: !ref 120 | C: !ref 121 | norm_type: 'gLN' 122 | causal: False 123 | mask_nonlinear: 'relu' 124 | shared_weights: True 125 | 126 | Decoder: !new:speechbrain.lobes.models.dual_path.Decoder 127 | in_channels: !ref 128 | out_channels: 1 129 | kernel_size: !ref 130 | stride: !ref 131 | bias: False 132 | 133 | optimizer: !name:torch.optim.Adam 134 | lr: !ref 135 | weight_decay: 0 136 | 137 | loss: !name:speechbrain.nnet.losses.get_si_snr_with_pitwrapper 138 | 139 | lr_scheduler: !new:speechbrain.nnet.schedulers.ReduceLROnPlateau 140 | factor: 0.5 141 | patience: 3 142 | dont_halve_until_epoch: 3 143 | 144 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 145 | limit: !ref 146 | 147 | modules: 148 | encoder: !ref 149 | decoder: !ref 150 | masknet: !ref 151 | 152 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 153 | checkpoints_dir: !ref 154 | recoverables: 155 | encoder: !ref 156 | decoder: !ref 157 | masknet: !ref 158 | counter: !ref 159 | lr_scheduler: !ref 160 | 161 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 162 | save_file: !ref 163 | -------------------------------------------------------------------------------- /separation/hparams/wsj0-2mix/dtcn-wsj0-2mix.yaml: -------------------------------------------------------------------------------- 1 | # ################################ 2 | # Model: dtcn for source separation 3 | # https://arxiv.org/abs/2010.13154 4 | # 5 | # Dataset : WSJ0-2Mix 6 | # ################################ 7 | # Basic parameters 8 | # Seed needs to be set at top of yaml, before objects with parameters are made 9 | #l 10 | seed: 1234 11 | __set_seed: !apply:torch.manual_seed [!ref ] 12 | 13 | # Data params 14 | 15 | # the data folder for the wham dataset 16 | # data_folder needs to follow the format: /yourpath/wsj0-2mix. 17 | # make sure to use the name wsj0-2mix at your top folder for the dataset! 18 | data_folder: /fastdata/acp19jwr/data/wsj0-2mix 19 | # mini: False # if true only uses a quarter of the wham/wsj0-2mix data 20 | 21 | 22 | # the path for wsj0/si_tr_s/ folder -- only needed if dynamic mixing is used 23 | # e.g. /yourpath/wsj0-processed/si_tr_s/ 24 | # you need to convert the original wsj0 to 8k 25 | # you can do this conversion with the script ../meta/preprocess_dynamic_mixing.py 26 | base_folder_dm: /yourpath/wsj0-processed/si_tr_s/ 27 | 28 | experiment_name: dtcn 29 | experiment_folder: !ref wsj0-2mix//R=/X= 30 | output_folder: !ref results// 31 | train_log: !ref /train_log.txt 32 | save_folder: !ref /save 33 | 34 | # the file names should start with wsj0-2mix instead of whamorg 35 | train_data: !ref /wsj_tr.csv 36 | valid_data: !ref /wsj_cv.csv 37 | test_data: !ref /wsj_tt.csv 38 | skip_prep: False 39 | 40 | # Experiment params 41 | auto_mix_prec: False # Set it to True for mixed precision 42 | test_only: False 43 | num_spks: 2 # set to 3 for wsj0-3mix 44 | progressbar: True 45 | save_audio: False # Save estimated sources on disk 46 | sample_rate: 8000 47 | 48 | # Training parameters 49 | N_epochs: 100 50 | batch_size: 8 51 | lr: 0.001 52 | clip_grad_norm: 5 53 | loss_upper_lim: 9999 # this is the upper limit for an acceptable loss 54 | # if True, the training sequences are cut to a specified length 55 | limit_training_signal_len: True 56 | # this is the length of sequences if we choose to limit 57 | # the signal length of training sequences 58 | training_signal_len: 32000 59 | 60 | # Set it to True to dynamically create mixtures at training time 61 | dynamic_mixing: False 62 | 63 | # Parameters for data augmentation 64 | 65 | # rir_path variable points to the directory of the room impulse responses 66 | # e.g. /miniscratch/subakany/rir_wavs 67 | # If the path does not exist, it is created automatically. 68 | # rir_path: /share/mini1/usr/will/scratch/wsj0-2mix/rir_wavs 69 | 70 | use_wavedrop: False 71 | use_speedperturb: False 72 | use_speedperturb_sameforeachsource: False 73 | use_rand_shift: False 74 | min_shift: -8000 75 | max_shift: 8000 76 | 77 | speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment 78 | perturb_prob: 1.0 79 | drop_freq_prob: 0.0 80 | drop_chunk_prob: 0.0 81 | sample_rate: !ref 82 | speeds: [95, 100, 105] 83 | 84 | wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment 85 | perturb_prob: 0.0 86 | drop_freq_prob: 1.0 87 | drop_chunk_prob: 1.0 88 | sample_rate: !ref 89 | 90 | # loss thresholding -- this thresholds the training loss 91 | threshold_byloss: True 92 | threshold: -30 93 | 94 | # Encoder parameters 95 | N_encoder_out: 512 96 | out_channels: 512 97 | kernel_size: 16 98 | kernel_stride: 8 99 | 100 | # Dataloader options 101 | dataloader_opts: 102 | batch_size: !ref 103 | num_workers: 4 104 | 105 | # Specifying the network 106 | Encoder: !new:speechbrain.lobes.models.dual_path.Encoder 107 | kernel_size: !ref 108 | out_channels: !ref 109 | 110 | X: 8 111 | R: 3 112 | 113 | MaskNet: !new:speechbrain.lobes.models.conv_tasnet.MaskNet 114 | N: !ref 115 | B: 128 116 | H: 512 117 | P: 3 118 | X: !ref 119 | R: !ref 120 | C: !ref 121 | norm_type: 'gLN' 122 | causal: False 123 | mask_nonlinear: 'relu' 124 | 125 | Decoder: !new:speechbrain.lobes.models.dual_path.Decoder 126 | in_channels: !ref 127 | out_channels: 1 128 | kernel_size: !ref 129 | stride: !ref 130 | bias: False 131 | 132 | optimizer: !name:torch.optim.Adam 133 | lr: !ref 134 | weight_decay: 0 135 | 136 | loss: !name:speechbrain.nnet.losses.get_si_snr_with_pitwrapper 137 | 138 | lr_scheduler: !new:speechbrain.nnet.schedulers.ReduceLROnPlateau 139 | factor: 0.5 140 | patience: 3 141 | dont_halve_until_epoch: 3 142 | 143 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 144 | limit: !ref 145 | 146 | modules: 147 | encoder: !ref 148 | decoder: !ref 149 | masknet: !ref 150 | 151 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 152 | checkpoints_dir: !ref 153 | recoverables: 154 | encoder: !ref 155 | decoder: !ref 156 | masknet: !ref 157 | counter: !ref 158 | lr_scheduler: !ref 159 | 160 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 161 | save_file: !ref 162 | -------------------------------------------------------------------------------- /separation/model_info.py: -------------------------------------------------------------------------------- 1 | 2 | import os, sys 3 | import torch 4 | import speechbrain as sb 5 | from hyperpyyaml import load_hyperpyyaml 6 | from thop import profile, clever_format 7 | 8 | from train import Separation 9 | 10 | hparam_file_list = [ 11 | "baselines/tcn/convtasnet-whamr.yaml", 12 | "deformable/convtasnet-whamr.yaml", 13 | "deformable/shared_weights/convtasnet-whamr.yaml", 14 | ] 15 | 16 | model_names = ["_".join(f.split("/")[:-1]) for f in hparam_file_list] 17 | # print(model_names);exit() 18 | 19 | run_opts_list = ["--data_folder /fastdata/acp19jwr/data/mono-whamr"]*len(hparam_file_list) 20 | 21 | overrides_list = [""]*len(hparam_file_list) 22 | 23 | sig_len=6 24 | 25 | for h,r,o,m in zip(hparam_file_list,run_opts_list,overrides_list,model_names): 26 | hparams_file, run_opts, overrides = sb.parse_arguments([h]+r.split(" ")+o.split(" ")) 27 | # print(sb.parse_arguments([h]+r.split(" ")+o.split(" ")));exit() 28 | 29 | with open(os.path.join("hparams",hparams_file)) as fin: 30 | hparams = load_hyperpyyaml(fin, overrides) 31 | 32 | run_opts["auto_mix_prec"] = hparams["auto_mix_prec"] 33 | 34 | 35 | input = torch.randn(1, hparams["sample_rate"]*sig_len).cuda() 36 | 37 | model = Separation( 38 | modules=hparams["modules"], 39 | opt_class=hparams["optimizer"], 40 | hparams=hparams, 41 | run_opts=run_opts, 42 | checkpointer=hparams["checkpointer"], 43 | ) 44 | 45 | model.hparams.encode_rirs = False 46 | 47 | prof_macs, prof_params = model.compute_forward(input, None, sb.Stage.TEST, None, profiler=True) 48 | prof_macs["decoder"] = prof_macs["decoder"]*model.hparams.num_spks 49 | total_macs = sum(prof_macs.values())/sig_len 50 | total_params = sum(prof_params.values()) 51 | macs, params = clever_format([total_macs, total_params], "%.3f") 52 | print(prof_macs["decoder"]) 53 | print(m,macs,params) 54 | 55 | -------------------------------------------------------------------------------- /separation/src/deformable.py: -------------------------------------------------------------------------------- 1 | 2 | """ Implementation of a popular speech separation model. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import speechbrain as sb 7 | import torch.nn.functional as F 8 | 9 | from speechbrain.processing.signal_processing import overlap_and_add 10 | from speechbrain.lobes.models.conv_tasnet import GlobalLayerNorm, ChannelwiseLayerNorm, Chomp1d, choose_norm 11 | from speechbrain.nnet.CNN import Conv1d 12 | 13 | # from fast_transformers.attention import linear_attention, attention_layer 14 | # from fast_transformers.masking import FullMask, LengthMask 15 | 16 | from dc1d.nn import PackedDeformConv1d 17 | 18 | EPS = 1e-8 19 | 20 | class DeformableTemporalBlocksSequential(sb.nnet.containers.Sequential): 21 | """ 22 | A wrapper for the temporal-block layer to replicate it 23 | 24 | Arguments 25 | --------- 26 | input_shape : tuple 27 | Expected shape of the input. 28 | H : int 29 | The number of intermediate channels. 30 | P : int 31 | The kernel size in the convolutions. 32 | R : int 33 | The number of times to replicate the multilayer Temporal Blocks. 34 | X : int 35 | The number of layers of Temporal Blocks with different dilations. 36 | norm type : str 37 | The type of normalization, in ['gLN', 'cLN']. 38 | causal : bool 39 | To use causal or non-causal convolutions, in [True, False]. 40 | 41 | Example 42 | ------- 43 | >>> x = torch.randn(14, 100, 10) 44 | >>> H, P, R, X = 10, 5, 2, 3 45 | >>> DeformableTemporalBlocks = DeformableTemporalBlocksSequential( 46 | ... x.shape, H, P, R, X, 'gLN', False 47 | ... ) 48 | >>> y = DeformableTemporalBlocks(x) 49 | >>> y.shape 50 | torch.Size([14, 100, 10]) 51 | """ 52 | 53 | def __init__( 54 | self, 55 | input_shape, 56 | H, 57 | P, 58 | R, 59 | X, 60 | norm_type, 61 | causal, 62 | bias=True, 63 | store_intermediates=False, 64 | shared_weights=False 65 | ): 66 | super().__init__(input_shape=input_shape) 67 | 68 | for r in range(R): 69 | for x in range(X): 70 | dilation = 2 ** x 71 | self.append( 72 | DeformableTemporalBlock, 73 | out_channels=H, 74 | kernel_size=P, 75 | stride=1, 76 | padding="same", 77 | dilation=dilation, 78 | norm_type=norm_type, 79 | causal=causal, 80 | layer_name=f"temporalblock_{r}_{x}", 81 | bias=bias, 82 | store_intermediates=store_intermediates 83 | ) 84 | if shared_weights==True: 85 | self.shared_weights=shared_weights 86 | self.R=R 87 | break 88 | 89 | 90 | 91 | # def get_output_shape(self): 92 | # """Returns expected shape of the output. 93 | 94 | # Computed by passing dummy input constructed with the 95 | # ``self.input_shape`` attribute. 96 | # """ 97 | # self.store_intermediates = False 98 | # with torch.no_grad(): 99 | # dummy_input = torch.zeros(self.input_shape) 100 | # dummy_output = self(dummy_input) 101 | # if isinstance(dummy_output,tuple): 102 | # return dummy_output[0].shape 103 | # else: 104 | # return dummy_output.shape 105 | 106 | def set_store_intermediates(self, store_intermediates=True): 107 | self.store_intermediates = store_intermediates 108 | for layer in self.values(): 109 | layer.set_store_intermediates(store_intermediates) 110 | 111 | def forward(self, x): 112 | i_dict = {} 113 | # if not "shared_weights" in self.__dict__: 114 | # for name, layer in self.items(): 115 | # if "store_intermediates" in self.__dict__.keys(): 116 | # if self.store_intermediates: 117 | # layer.set_store_intermediates(self.store_intermediates) 118 | # x, intermediate = layer(x) 119 | # i_dict[name] = intermediate 120 | # else: 121 | # x = layer(x) 122 | # if isinstance(x, tuple): 123 | # x = x[0] 124 | if "R" in self.__dict__: 125 | repeat=self.R 126 | else: 127 | repeat=1 128 | for _ in range(repeat): 129 | for name, layer in self.items(): 130 | if "store_intermediates" in self.__dict__.keys(): 131 | if self.store_intermediates: 132 | layer.set_store_intermediates(self.store_intermediates) 133 | x, intermediate = layer(x) 134 | i_dict[name] = intermediate 135 | else: 136 | x = layer(x) 137 | if isinstance(x, tuple): 138 | x = x[0] 139 | 140 | if "store_intermediates" in self.__dict__.keys(): 141 | if self.store_intermediates: 142 | return x, i_dict 143 | else: 144 | return x 145 | 146 | 147 | class MaskNet(nn.Module): 148 | """ 149 | Arguments 150 | --------- 151 | N : int 152 | Number of filters in autoencoder. 153 | B : int 154 | Number of channels in bottleneck 1 × 1-conv block. 155 | H : int 156 | Number of channels in convolutional blocks. 157 | P : int 158 | Kernel size in convolutional blocks. 159 | X : int 160 | Number of convolutional blocks in each repeat. 161 | R : int 162 | Number of repeats. 163 | C : int 164 | Number of speakers. 165 | norm_type : str 166 | One of BN, gLN, cLN. 167 | causal : bool 168 | Causal or non-causal. 169 | mask_nonlinear : str 170 | Use which non-linear function to generate mask, in ['softmax', 'relu']. 171 | 172 | Example: 173 | --------- 174 | >>> N, B, H, P, X, R, C = 11, 12, 2, 5, 3, 1, 2 175 | >>> MaskNet = MaskNet(N, B, H, P, X, R, C) 176 | >>> mixture_w = torch.randn(10, 11, 100) 177 | >>> est_mask = MaskNet(mixture_w) 178 | >>> est_mask.shape 179 | torch.Size([2, 10, 11, 100]) 180 | """ 181 | 182 | def __init__( 183 | self, 184 | N, 185 | B, 186 | H, 187 | P, 188 | X, 189 | R, 190 | C, 191 | norm_type="gLN", 192 | causal=False, 193 | mask_nonlinear="relu", 194 | store_intermediates=False, 195 | shared_weights=False, 196 | ): 197 | super(MaskNet, self).__init__() 198 | 199 | # Hyper-parameter 200 | self.C = C 201 | self.mask_nonlinear = mask_nonlinear 202 | self.shared_weights = shared_weights 203 | 204 | # Components 205 | # [M, K, N] -> [M, K, N] 206 | self.layer_norm = ChannelwiseLayerNorm(N) 207 | 208 | # [M, K, N] -> [M, K, B] 209 | self.bottleneck_conv1x1 = sb.nnet.CNN.Conv1d( 210 | in_channels=N, out_channels=B, kernel_size=1, bias=False, 211 | ) 212 | 213 | # [M, K, B] -> [M, K, B] 214 | in_shape = (None, None, B) 215 | self.temporal_conv_net = DeformableTemporalBlocksSequential( 216 | in_shape, 217 | H, 218 | P, 219 | R, 220 | X, 221 | norm_type, 222 | causal, 223 | bias=True, 224 | store_intermediates=store_intermediates, 225 | shared_weights=shared_weights, 226 | ) 227 | 228 | # [M, K, B] -> [M, K, C*N] 229 | self.mask_conv1x1 = sb.nnet.CNN.Conv1d( 230 | in_channels=B, out_channels=C * N, kernel_size=1, bias=False 231 | ) 232 | def set_store_intermediates(self, store_intermediates=True): 233 | self.store_intermediates = store_intermediates 234 | self.temporal_conv_net.set_store_intermediates(store_intermediates) 235 | 236 | def forward(self, mixture_w): 237 | """Keep this API same with TasNet. 238 | 239 | Arguments 240 | --------- 241 | mixture_w : Tensor 242 | Tensor shape is [M, K, N], M is batch size. 243 | 244 | Returns 245 | ------- 246 | est_mask : Tensor 247 | Tensor shape is [M, K, C, N]. 248 | """ 249 | 250 | mixture_w = mixture_w.permute(0, 2, 1) 251 | M, K, N = mixture_w.size() 252 | y = self.layer_norm(mixture_w) 253 | y = self.bottleneck_conv1x1(y) 254 | if "store_intermediates" in self.__dict__.keys(): 255 | if self.store_intermediates: 256 | self.temporal_conv_net.set_store_intermediates(self.store_intermediates) 257 | y = self.temporal_conv_net(y) 258 | if isinstance(y, tuple): 259 | i_dict = y[1] 260 | y = y[0] 261 | 262 | score = self.mask_conv1x1(y) 263 | 264 | # score = self.network(mixture_w) # [M, K, N] -> [M, K, C*N] 265 | score = score.contiguous().reshape( 266 | M, K, self.C, N 267 | ) # [M, K, C*N] -> [M, K, C, N] 268 | 269 | # [M, K, C, N] -> [C, M, N, K] 270 | score = score.permute(2, 0, 3, 1) 271 | 272 | if self.mask_nonlinear == "softmax": 273 | est_mask = F.softmax(score, dim=2) 274 | elif self.mask_nonlinear == "relu": 275 | est_mask = F.relu(score) 276 | else: 277 | raise ValueError("Unsupported mask non-linear function") 278 | 279 | if "store_intermediates" in self.__dict__.keys(): 280 | if self.store_intermediates: 281 | return est_mask, i_dict 282 | else: 283 | return est_mask 284 | else: 285 | return est_mask 286 | 287 | 288 | class DeformableTemporalBlock(torch.nn.Module): 289 | """The conv1d compound layers used in Masknet. 290 | 291 | Arguments 292 | --------- 293 | input_shape : tuple 294 | The expected shape of the input. 295 | out_channels : int 296 | The number of intermediate channels. 297 | kernel_size : int 298 | The kernel size in the convolutions. 299 | stride : int 300 | Convolution stride in convolutional layers. 301 | padding : str 302 | The type of padding in the convolutional layers, 303 | (same, valid, causal). If "valid", no padding is performed. 304 | dilation : int 305 | Amount of dilation in convolutional layers. 306 | norm type : str 307 | The type of normalization, in ['gLN', 'cLN']. 308 | causal : bool 309 | To use causal or non-causal convolutions, in [True, False]. 310 | 311 | Example: 312 | --------- 313 | >>> x = torch.randn(14, 100, 10) 314 | >>> DeformableTemporalBlock = DeformableTemporalBlock(x.shape, 10, 11, 1, 'same', 1) 315 | >>> y = DeformableTemporalBlock(x) 316 | >>> y.shape 317 | torch.Size([14, 100, 10]) 318 | """ 319 | 320 | def __init__( 321 | self, 322 | input_shape, 323 | out_channels, 324 | kernel_size, 325 | stride, 326 | dilation, 327 | padding="same", 328 | norm_type="gLN", 329 | causal=False, 330 | bias=True, 331 | store_intermediates=False, 332 | ): 333 | super().__init__() 334 | M, K, B = input_shape # batch x time x features 335 | 336 | self.layers = sb.nnet.containers.Sequential(input_shape=input_shape) 337 | # print(input_shape,out_channels) 338 | # [M, K, B] -> [M, K, H] 339 | self.layers.append( 340 | sb.nnet.CNN.Conv1d, 341 | out_channels=out_channels, 342 | kernel_size=1, 343 | bias=False, 344 | layer_name="conv", 345 | ) 346 | self.layers.append(nn.PReLU(), layer_name="act") 347 | self.layers.append( 348 | choose_norm(norm_type, out_channels), layer_name="norm" 349 | ) 350 | 351 | # [M, K, H] -> [M, K, B] 352 | self.layers.append( 353 | DeformableDepthwiseSeparableConv, 354 | out_channels=B, 355 | kernel_size=kernel_size, 356 | stride=stride, 357 | padding=padding, 358 | dilation=dilation, 359 | norm_type=norm_type, 360 | causal=causal, 361 | bias=bias, 362 | layer_name="DSconv", 363 | ) 364 | self.store_intermediates=store_intermediates 365 | 366 | def set_store_intermediates(self, store_intermediates=True): 367 | self.store_intermediates = store_intermediates 368 | self.layers["DSconv"].set_store_intermediates(store_intermediates) 369 | 370 | def forward(self, x): 371 | """ 372 | Arguments 373 | --------- 374 | x : Tensor 375 | Tensor shape is [batch size, sequence length, input channels]. 376 | 377 | Returns 378 | ------- 379 | x : Tensor 380 | Tensor shape is [M, K, B]. 381 | """ 382 | residual = x 383 | i_dict = {} 384 | for name, layer in self.layers.items(): 385 | if type(layer)== DeformableDepthwiseSeparableConv and "store_intermediates" in self.__dict__.keys(): 386 | if self.store_intermediates: 387 | layer.set_store_intermediates(self.store_intermediates) 388 | x, intermediate = layer(x) 389 | i_dict[name] = intermediate 390 | else: 391 | x = layer(x) 392 | 393 | else: 394 | x = layer(x) 395 | if type(x) == type(None): 396 | message = f"Output of layer {name} should not be None but it is" 397 | raise Exception(message) 398 | return x + residual, i_dict 399 | 400 | class DeformableDepthwiseSeparableConv(nn.Module): 401 | """Building block for the Temporal Blocks of Masknet in ConvTasNet. 402 | 403 | Arguments 404 | --------- 405 | input_shape : tuple 406 | Expected shape of the input. 407 | out_channels : int 408 | Number of output channels. 409 | kernel_size : int 410 | The kernel size in the convolutions. 411 | stride : int 412 | Convolution stride in convolutional layers. 413 | padding : str 414 | The type of padding in the convolutional layers, 415 | (same, valid, causal). If "valid", no padding is performed. 416 | dilation : int 417 | Amount of dilation in convolutional layers. 418 | norm type : str 419 | The type of normalization, in ['gLN', 'cLN']. 420 | causal : bool 421 | To use causal or non-causal convolutions, in [True, False]. 422 | 423 | Example 424 | ------- 425 | >>> x = torch.randn(14, 100, 10) 426 | >>> DSconv =DeformableDepthwiseSeparableConv(x.shape, 10, 11, 1, 'same', 1) 427 | >>> y = DSconv(x) 428 | >>> y.shape 429 | torch.Size([14, 100, 10]) 430 | 431 | """ 432 | 433 | def __init__( 434 | self, 435 | input_shape, 436 | out_channels, 437 | kernel_size, 438 | stride=1, 439 | dilation=1, 440 | norm_type="gLN", 441 | causal=False, 442 | padding="same", 443 | bias=True, 444 | store_intermediates=False, 445 | layer_name=None, 446 | *args, 447 | **kwargs 448 | ): 449 | super(DeformableDepthwiseSeparableConv, self).__init__(*args, **kwargs) 450 | 451 | batchsize, time, in_channels = input_shape 452 | 453 | # Depthwise [M, K, H] -> [M, K, H] 454 | self.depthwise_conv = PackedDeformConv1d( 455 | in_channels=in_channels, 456 | out_channels=in_channels, 457 | kernel_size=kernel_size, 458 | stride=stride, 459 | dilation=dilation, 460 | padding=padding, 461 | groups=in_channels, 462 | bias=bias, 463 | ) 464 | 465 | if causal: 466 | self.chomp = Chomp1d((padding if (type(padding) == int) else 0)) 467 | 468 | self.prelu = nn.PReLU() 469 | self.norm = choose_norm(norm_type, in_channels) 470 | 471 | # Pointwise [M, K, H] -> [M, K, B] 472 | self.pointwise_conv = sb.nnet.CNN.Conv1d( 473 | in_channels=in_channels, 474 | out_channels=out_channels, 475 | kernel_size=1, 476 | groups=1, 477 | bias=False, 478 | ) 479 | self.store_intermediates = store_intermediates 480 | self.layer_name=layer_name 481 | 482 | def set_store_intermediates(self, store_intermediates=True): 483 | self.store_intermediates = store_intermediates 484 | 485 | def forward(self, x): 486 | """Applies layers in sequence, passing only the first element of tuples. 487 | 488 | Arguments 489 | --------- 490 | x : torch.Tensor 491 | The input tensor to run through the network. 492 | """ 493 | i_dict = {} 494 | 495 | if "store_intermediates" in self.__dict__.keys(): 496 | if self.store_intermediates: 497 | x, offsets = self.depthwise_conv(x.moveaxis(1,2),self.store_intermediates) 498 | i_dict["offsets"] = offsets 499 | else: 500 | x = self.depthwise_conv(x.moveaxis(1,2)) 501 | else: 502 | x = self.depthwise_conv(x.moveaxis(1,2)) 503 | x = self.prelu(x.moveaxis(2,1)) 504 | x = self.norm(x) 505 | x = self.pointwise_conv(x) 506 | 507 | if "store_intermediates" in self.__dict__.keys(): 508 | if self.store_intermediates: 509 | return x, i_dict 510 | else: 511 | return x 512 | else: 513 | return x 514 | 515 | if __name__ == '__main__': 516 | batch_size, N, L = 4, 25, 3321 517 | P=3 518 | 519 | x = torch.rand((batch_size, N, L),device="cuda") 520 | 521 | N=N 522 | B=N//4 523 | H=N 524 | P=3 525 | X=8 526 | R=3 527 | C=2 528 | 529 | ddc = MaskNet( 530 | N=N, 531 | B=B, 532 | H=H, 533 | P=P, 534 | X=X, 535 | R=R, 536 | C=C, 537 | norm_type="gLN", 538 | causal=False, 539 | mask_nonlinear="relu", 540 | shared_weights=True 541 | ).to("cuda:0") 542 | 543 | print(x.shape,x[0,0,:3]) 544 | ddc.set_store_intermediates(True) 545 | x,i = ddc(x) 546 | print(i["temporalblock_0_0"]["DSconv"]["offsets"].shape) 547 | print(x.shape,x[0,0,0,:3]) -------------------------------------------------------------------------------- /separation/src/macs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import speechbrain as sb 3 | from thop.vision.basic_hooks import count_convNd, count_normalization 4 | from thop.vision.calc_func import l_prod 5 | from dc1d.nn import gLN, DeformConv1d 6 | 7 | def sb_calculate_conv2d_flops(input_size: list, output_size: list, kernel_size: list, groups: int, bias: bool = False): 8 | # n, out_c, oh, ow = output_size 9 | # n, in_c, ih, iw = input_size 10 | # out_c, in_c, kh, kw = kernel_size 11 | in_c = input_size[1] 12 | g = groups 13 | return l_prod(output_size) * (in_c // g) * l_prod(kernel_size[2:]) 14 | 15 | def calculate_conv2d_flops(input_size: list, output_size: list, kernel_size: list, groups: int, bias: bool = False): 16 | # n, out_c, oh, ow = output_size 17 | # n, in_c, ih, iw = input_size 18 | # out_c, in_c, kh, kw = kernel_size 19 | in_c = input_size[1] 20 | g = groups 21 | return l_prod(output_size) * (in_c // g) * l_prod(kernel_size[2:]) 22 | 23 | def sb_count_convNd(m: sb.nnet.CNN.Conv1d, x, y: torch.Tensor): 24 | x = x[0] 25 | 26 | kernel_ops = torch.zeros(m.conv.weight.size()[2:]).numel() # Kw x Kh 27 | bias_ops = 1 if m.conv.bias is not None else 0 28 | 29 | m.total_ops += calculate_conv2d_flops( 30 | input_size = list(x.shape), 31 | output_size = list(y.shape), 32 | kernel_size = list(m.conv.weight.shape), 33 | groups = m.conv.groups, 34 | bias = m.conv.bias 35 | ) 36 | 37 | def count_deformconvNd(m: DeformConv1d, x, y: torch.Tensor): 38 | x = x[0] 39 | 40 | kernel_ops = torch.zeros(m.conv.weight.size()[2:]).numel() # Kw x Kh 41 | bias_ops = 1 if m.conv.bias is not None else 0 42 | 43 | m.total_ops += calculate_conv2d_flops( 44 | input_size = list(x.shape), 45 | output_size = list(y.shape), 46 | kernel_size = list(m.conv.weight.shape), 47 | groups = m.conv.groups, 48 | bias = m.conv.bias 49 | ) 50 | 51 | sb_ops_dict={ 52 | gLN: count_normalization, 53 | DeformConv1d: count_convNd, 54 | sb.lobes.models.conv_tasnet.ChannelwiseLayerNorm: count_normalization, 55 | sb.lobes.models.conv_tasnet.GlobalLayerNorm: count_normalization, 56 | sb.lobes.models.dual_path.Decoder: count_convNd, 57 | # sb.nnet.CNN.Conv1d: sb_count_convNd 58 | } -------------------------------------------------------------------------------- /separation/src/measures.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Pool 2 | from functools import partial 3 | import numpy as np 4 | import torch 5 | from speechbrain.nnet.losses import PitWrapper 6 | from pysepm import stoi 7 | from pesq import pesq 8 | from srmrpy import srmr 9 | 10 | def _pesq_call(ref_deg,fs): 11 | zipped = zip(ref_deg[0],ref_deg[1]) 12 | results = [torch.Tensor([_pesq(r_d,fs)]) for r_d in zipped] 13 | return torch.Tensor(results) 14 | 15 | def _pesq(ref_deg,fs): # reorder arguments 16 | """ 17 | Input args = ((T), (T)), int 18 | Output args = int 19 | """ 20 | ref, deg = ref_deg 21 | 22 | if fs == 8000: 23 | mode='nb' 24 | else: 25 | mode='wb' 26 | try : 27 | return pesq(fs,ref,deg,mode=mode) 28 | except: 29 | return 1.0 30 | 31 | class PESQ(): 32 | def __init__(self, fs): 33 | self.fs = fs 34 | self.pesq_with_pitwrapper = PitWrapper(self.cal_pesq_loss) 35 | 36 | def cal_pesq_loss(self, source, estimate): 37 | """ 38 | Input size = T, B, C 39 | """ 40 | 41 | assert source.shape == estimate.shape 42 | 43 | source_np = source.detach().cpu().numpy() 44 | estimate_np = estimate.detach().cpu().numpy() 45 | source_np = np.moveaxis(source_np,0,-1) 46 | estimate_np = np.moveaxis(estimate_np,0,-1) 47 | 48 | zipped = zip(estimate_np, source_np) 49 | 50 | results = [_pesq_call(ref_deg, fs =self.fs) for ref_deg in zipped] 51 | 52 | results = -torch.stack(results) 53 | return results.unsqueeze(0) 54 | 55 | def get_pesq_loss_with_pit(self, source, estimate): 56 | """ 57 | Input shapes = (B, T, C), (B, T, C) 58 | Output shape = (1), (B, C) 59 | """ 60 | results, perms = self.pesq_with_pitwrapper(source,estimate) 61 | return results 62 | 63 | def pesq_measure_with_pit(self, source, estimate): 64 | """ 65 | Input shapes = (B, T, C), (B, T, C) 66 | Output shape = (1), (B, C) 67 | """ 68 | return -self.get_pesq_loss_with_pit(source,estimate) 69 | 70 | def _stoi_call(ref_deg,fs,extended=False): 71 | zipped = zip(ref_deg[0],ref_deg[1]) 72 | results = [torch.Tensor([_stoi(r_d,fs,extended)]) for r_d in zipped] 73 | return torch.Tensor(results) 74 | 75 | def _stoi(ref_deg,fs,extended): # reorder arguments 76 | """ 77 | Input args = ((T), (T)), int 78 | Output args = int 79 | """ 80 | ref, deg = ref_deg 81 | try: 82 | return stoi(ref,deg,fs,extended=extended) 83 | except: 84 | return float("nan") 85 | 86 | 87 | 88 | class STOI(): 89 | def __init__(self, fs, extended=False): 90 | self.fs = fs 91 | self.extended = extended 92 | self.stoi_with_pitwrapper = PitWrapper(self.cal_stoi_loss) 93 | 94 | def cal_stoi_loss(self, source, estimate): 95 | """ 96 | Input size = T, B, C 97 | """ 98 | assert source.shape == estimate.shape 99 | 100 | source_np = source.detach().cpu().numpy() 101 | estimate_np = estimate.detach().cpu().numpy() 102 | source_np = np.moveaxis(source_np,0,-1) 103 | estimate_np = np.moveaxis(estimate_np,0,-1) 104 | 105 | zipped = zip(estimate_np, source_np) 106 | results = [_stoi_call(ref_deg, fs=self.fs,extended=self.extended) for ref_deg in zipped] 107 | 108 | results = -torch.stack(results) 109 | return results.unsqueeze(0) 110 | 111 | def get_stoi_loss_with_pit(self, source, estimate): 112 | """ 113 | Input shapes = (B, T, C), (B, T, C) 114 | Output shape = (1), (B, C) 115 | """ 116 | results, perms = self.stoi_with_pitwrapper(source,estimate) 117 | median = torch.nanmedian(results) 118 | return torch.nan_to_num(results,nan=median) # account for nan values 119 | 120 | def stoi_measure_with_pit(self, source, estimate): 121 | """ 122 | Input shapes = (B, T, C), (B, T, C) 123 | Output shape = (1), (B, C) 124 | """ 125 | return -self.get_stoi_loss_with_pit(source,estimate) 126 | 127 | def _srmr_call(deg,fs): 128 | results = [torch.Tensor([srmr(d,fs)[0]]) for d in deg] 129 | return torch.Tensor(results) 130 | 131 | class SRMR(): 132 | def __init__(self, fs=8000): 133 | self.fs =fs 134 | self.srmr_with_pitwrapper = PitWrapper(self.cal_srmr_loss) 135 | 136 | def cal_srmr_loss(self, _UNUSED, estimate): 137 | estimate_np = estimate.detach().cpu().numpy() 138 | estimate_np = np.moveaxis(estimate_np,0,-1) 139 | 140 | results = [_srmr_call(d, fs=self.fs) for d in estimate_np] 141 | 142 | results = -torch.stack(results) 143 | 144 | return results.unsqueeze(0) 145 | 146 | def get_srmr_loss_with_pit(self, estimate): 147 | """ 148 | Input shapes = (B, T, C), (B, T, C) 149 | Output shape = (1), (B, C) 150 | """ 151 | results, perms = self.srmr_with_pitwrapper(estimate,estimate) 152 | median = torch.nanmedian(results) 153 | return torch.nan_to_num(results,nan=median) # account for nan values 154 | 155 | def srmr_measure_with_pit(self, estimate): 156 | """ 157 | Input shapes = (B, T, C), (B, T, C) 158 | Output shape = (1), (B, C) 159 | """ 160 | return -self.get_srmr_loss_with_pit(estimate) 161 | 162 | if __name__ == '__main__': 163 | import time 164 | 165 | 166 | class Timer(): 167 | def __init__(self): 168 | self.start = time.time() 169 | self.end = self.start 170 | 171 | def tic(self): 172 | self.start = time.time() 173 | 174 | def toc(self): 175 | self.stop = time.time() 176 | self.time_taken = self.stop-self.start 177 | print("Ellapsed time: {}s\n".format(self.time_taken)) 178 | 179 | timer = Timer() 180 | 181 | B = 2 182 | C = 2 183 | T = 32000 184 | fs = 8000 185 | 186 | x = torch.randn(B, T, C) 187 | xhat = (x + torch.randn(B, T, C))/2 188 | 189 | # PESQ 190 | print("PESQ") 191 | pesq_measure = PESQ(fs) 192 | timer.tic() 193 | print(pesq_measure.pesq_measure_with_pit(x,xhat)) 194 | timer.toc() 195 | 196 | # STOI 197 | print("STOI") 198 | stoi_measure = STOI(fs) 199 | timer.tic() 200 | print(stoi_measure.stoi_measure_with_pit(x,xhat)) 201 | timer.toc() 202 | 203 | # ESTOI 204 | print("ESTOI") 205 | estoi_measure = STOI(fs,extended=True) 206 | timer.tic() 207 | print(estoi_measure.stoi_measure_with_pit(x,xhat)) 208 | timer.toc() 209 | 210 | # SRMR 211 | print("SRMR") 212 | srmr_measure = SRMR(fs) 213 | timer.tic() 214 | print(srmr_measure.srmr_measure_with_pit(xhat)) 215 | timer.toc() -------------------------------------------------------------------------------- /separation/src/tcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from speechbrain.lobes.models.conv_tasnet import MaskNet 3 | from speechbrain.lobes.models.dual_path import Encoder, Decoder 4 | 5 | class TCNEncoder(Encoder): 6 | def __init__( 7 | self, 8 | in_channels=1, 9 | out_channels=512, 10 | kernel_size=16, 11 | B=128, 12 | H=512, 13 | P=3, 14 | X=4, 15 | R=6, 16 | device='cuda' if torch.cuda.is_available() else 'cpu' 17 | ): 18 | super(TCNEncoder, self).__init__( 19 | kernel_size=kernel_size, 20 | out_channels=out_channels, 21 | in_channels=in_channels, 22 | ) 23 | self.tcn = MaskNet( 24 | N=out_channels, 25 | B=B, 26 | H=H, 27 | P=P, 28 | X=X, 29 | R=R, 30 | C=1, 31 | norm_type="gLN", 32 | causal=False, 33 | mask_nonlinear="relu" 34 | ) 35 | self._device = device 36 | self.to(self._device) 37 | 38 | def forward(self,x): 39 | 40 | x = super().forward(x) 41 | x = self.tcn(x) 42 | 43 | return x[0] 44 | 45 | class TCNDecoder(Decoder): 46 | def __init__( 47 | self, 48 | in_channels=512, 49 | out_channels=1, 50 | kernel_size=16, 51 | B=128, 52 | H=512, 53 | P=3, 54 | X=4, 55 | R=6, 56 | device='cuda' if torch.cuda.is_available() else 'cpu' 57 | ): 58 | super(TCNDecoder, self).__init__( 59 | in_channels=in_channels, 60 | out_channels=out_channels, 61 | kernel_size=kernel_size, 62 | stride=kernel_size//2 63 | ) 64 | self.tcn = MaskNet( 65 | N=in_channels, 66 | B=B, 67 | H=H, 68 | P=P, 69 | X=X, 70 | R=R, 71 | C=1, 72 | norm_type="gLN", 73 | causal=False, 74 | mask_nonlinear="relu" 75 | ) 76 | self._device = device 77 | self.to(self._device) 78 | 79 | def forward(self,x): 80 | mask = self.tcn(x) 81 | x = mask[0]*x 82 | x = super().forward(x) 83 | 84 | return x 85 | 86 | 87 | if __name__ == '__main__': 88 | kernel_size = 16 89 | stride = kernel_size//2 90 | M, L, N = 4, 32000, 256 91 | B, H, P, X, R = 128, 256, 3, 10, 1 92 | 93 | encoder = TCNEncoder( 94 | in_channels=1, 95 | out_channels=N, 96 | kernel_size=kernel_size, 97 | B=B, 98 | H=H, 99 | P=P, 100 | X=X, 101 | R=R, 102 | ) 103 | 104 | decoder = TCNDecoder( 105 | in_channels=N, 106 | out_channels=1, 107 | kernel_size=kernel_size, 108 | B=B, 109 | H=H, 110 | P=P, 111 | X=X, 112 | R=R, 113 | ) 114 | 115 | x = torch.rand((M,L)).cuda() 116 | y = encoder(x) 117 | print("encoder shape:",y.shape) 118 | z = decoder(y) 119 | print("decoder shape:",z.shape) -------------------------------------------------------------------------------- /separation/src/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn import cluster 3 | from sklearn.metrics import euclidean_distances 4 | from sklearn.metrics.pairwise import euclidean_distances as dist 5 | 6 | def normalize(mat, neg=False): 7 | min_value = np.min(mat) 8 | data = mat-min_value 9 | max_value = np.max(data) 10 | data = data/max_value 11 | if not neg: 12 | return data 13 | else: 14 | data = data*2-1 15 | return data 16 | 17 | euclidean_matrix = lambda X : dist(X,X) 18 | 19 | distance_triangle = lambda X : [list(row[:i]) for i, row in enumerate(euclidean_matrix(X))] 20 | 21 | 22 | def channel_sort(X): 23 | # N x T 24 | dist_mat = euclidean_matrix(X) 25 | max_val = np.max(dist_mat)+1 26 | dist_mat[dist_mat==0.0] = max_val 27 | channel_order = np.where(dist_mat == dist_mat.min())[0] # starting channels 28 | subset_1 = dist_mat[channel_order[0]] 29 | 30 | subset_1[channel_order[1]] = max_val 31 | 32 | 33 | while len(channel_order) < X.shape[0]: 34 | min_arg_1 = np.argmin(subset_1) 35 | channel_order = np.insert(channel_order,0, min_arg_1) 36 | 37 | subset_1[min_arg_1] = max_val 38 | 39 | 40 | 41 | return channel_order 42 | 43 | from UPGMApy.UPGMA import UPGMA 44 | 45 | 46 | def count_clusters(no_clusters, clusters): 47 | print(type(tuple()),type(clusters[0]),print(clusters[1])) 48 | if type(tuple()) == type(clusters[0]): 49 | no_clusters+=1 50 | no_clusters=count_clusters(no_clusters,clusters[0]) 51 | return no_clusters 52 | 53 | def unpack_clusters(clusters,label_list=None,): 54 | if label_list == None: 55 | label_list = [] 56 | 57 | for i in range(len(clusters)): 58 | if type(clusters[-i]) == type(int(0)): 59 | label_list.append(clusters[-i]) 60 | else: 61 | label_list = unpack_clusters(clusters[-i],label_list) 62 | 63 | return label_list 64 | 65 | def tuple_string_to_list(tuple_string): 66 | return [entry for entry in tuple_string.replace(")","").replace("(","").split(",")] 67 | 68 | def tuple_string_to_enumerated_list(tuple_string): 69 | return [int(entry) for entry in tuple_string.replace(")","").replace("(","").split(",")] 70 | 71 | 72 | def upgma_channel_sort(W): 73 | N = len(W) #.shape # out channels, in channels, kernel size 74 | W = W.squeeze() 75 | W = distance_triangle(W) 76 | channel_labels = [str(i) for i in range(N)] 77 | assert len(W) == len(channel_labels) 78 | upgma_clusters = UPGMA(W,channel_labels) 79 | channel_order = tuple_string_to_enumerated_list(upgma_clusters) 80 | return channel_order 81 | 82 | 83 | if __name__ == '__main__': 84 | from UPGMApy.UPGMA import alpha_labels 85 | import sys 86 | # sys.setrecursionlimit(10000) 87 | 88 | # M_labels = alpha_labels("A", "G") #A through G 89 | # M = [ 90 | # [], #A 91 | # [19], #B 92 | # [27, 31], #C 93 | # [8, 18, 26], #D 94 | # [33, 36, 41, 31], #E 95 | # [18, 1, 32, 17, 35], #F 96 | # [13, 13, 29, 14, 28, 12] #G 97 | # ] 98 | 99 | # clusters = UPGMA(M, M_labels) # should output: '((((A,D),((B,F),G)),C),E)' 100 | # print(clusters) 101 | # l_list = tuple_string_to_list(clusters) 102 | # print(l_list) 103 | 104 | W = np.random.rand(512,1,16) 105 | channel_order=upgma_channel_sort(W) 106 | print(channel_order) 107 | print(channel_sort(W.squeeze())) 108 | 109 | 110 | # ##################################### 111 | # K=10 112 | # N=10 113 | # X = np.random.rand(K,N) 114 | # X = channel_sort(X) 115 | # print(X) 116 | --------------------------------------------------------------------------------