├── README.md ├── gen_data.py ├── gen_songlist.py ├── get_data_stats.py ├── models ├── saved_model_0.pt ├── saved_model_1.pt ├── saved_model_2.pt ├── saved_model_3.pt ├── saved_model_4.pt ├── saved_model_5.pt ├── saved_model_6.pt └── saved_model_7.pt ├── songlist.txt ├── test.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Musical onset detection using a CNN 2 | 3 | Pytorch implementation of the method described in:
4 | Schlüter, Jan, and Sebastian Böck. "Improved musical onset detection with convolutional neural networks." 2014 ieee international conference on acoustics, speech and signal processing (icassp). IEEE, 2014. 5 | 6 | 7 | ## Requirements 8 | * Pytorch 9 | * Librosa 10 | * Numpy 11 | * Matplotlib(optional) 12 | 13 | ## Dataset (used in the paper) 14 | * Can be obtained from here (until the Google drive links are alive) - "https://github.com/CPJKU/onset_db/issues/1#issuecomment-472300295" 15 | 16 | ## Usage 17 | 18 | ### Train the network 19 | 1. Run gen_songlist.py to get the list of all songs for which there is onset annotation data available(there are some extra audios in the dataset) 20 | 2. Run get_data_stats.py to compute the mean and standard deviation across 80 mel bands over the entire dataset 21 | 3. Run gen_data.py to generate the 15-frame mel spectrogram chunks and frame-wise labels for all the audios 22 | 4. Run train.py to train the network. Specify a fold number in the command line when running this script. This is used to partition the data into train and val splits using the splits data provided by the authors. The training almost exactly follows the procedure described in the paper. The weights at the end of 100 epochs get saved in the models folder. 23 | 24 | ### Evaluate the network 25 | 1. Run test.py to evaluate on the dataset. Again, specify a fold number to get the results for that fold. Results get saved to a text file in the form of #true-postives, #false-alarms, and #ground-truth-onsets, summed over all the validation songs, for different evaluation thresholds. 26 | 27 | ### Load saved model 28 | If you wish to use the trained model on different data, utils.py contains the model class definition (and some other helper functions). Import the model class from here and load one of the saved model state dicts from the models folder. 29 | -------------------------------------------------------------------------------- /gen_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import librosa 4 | import torch 5 | 6 | #function to zero pad ends of spectrogram 7 | def zeropad2d(x,n_frames): 8 | y=np.hstack((np.zeros([x.shape[0],n_frames]), x)) 9 | y=np.hstack((y,np.zeros([x.shape[0],n_frames]))) 10 | return y 11 | 12 | #function to create N-frame overlapping chunks of the full audio spectrogram 13 | def makechunks(x,duration): 14 | y=np.zeros([x.shape[1],x.shape[0],duration]) 15 | for i_frame in range(x.shape[1]-duration): 16 | y[i_frame]=x[:,i_frame:i_frame+duration] 17 | return y 18 | 19 | #data dirs 20 | audio_dir='/media/Sharedata/rohit/SS_onset_detection/audio' 21 | onset_dir='/media/Sharedata/rohit/SS_onset_detection/onsets' 22 | save_dir='/media/Sharedata/rohit/SS_onset_detection/data_pt_test' 23 | 24 | #data stats for normalization 25 | stats=np.load('means_stds.npy') 26 | means=stats[0] 27 | stds=stats[1] 28 | 29 | #context parameters 30 | contextlen=7 #+- frames 31 | duration=2*contextlen+1 32 | 33 | #main 34 | songlist=np.loadtxt('songlist.txt',dtype=str) 35 | audio_format='.flac' 36 | labels_master={} 37 | weights_master={} 38 | filelist=[] 39 | for item in songlist: 40 | print(item) 41 | #load audio and onsets 42 | x,fs=librosa.load(os.path.join(audio_dir,item+audio_format), sr=44100) 43 | if not os.path.exists(os.path.join(onset_dir,item+'.onsets')): continue 44 | onsets=np.loadtxt(os.path.join(onset_dir,item+'.onsets')) 45 | 46 | #get mel spectrogram 47 | melgram1=librosa.feature.melspectrogram(x,sr=fs,n_fft=1024, hop_length=441,n_mels=80, fmin=27.5, fmax=16000) 48 | melgram2=librosa.feature.melspectrogram(x,sr=fs,n_fft=2048, hop_length=441,n_mels=80, fmin=27.5, fmax=16000) 49 | melgram3=librosa.feature.melspectrogram(x,sr=fs,n_fft=4096, hop_length=441,n_mels=80, fmin=27.5, fmax=16000) 50 | 51 | #log scaling 52 | melgram1=10*np.log10(1e-10+melgram1) 53 | melgram2=10*np.log10(1e-10+melgram2) 54 | melgram3=10*np.log10(1e-10+melgram3) 55 | 56 | #normalize 57 | melgram1=(melgram1-np.atleast_2d(means[0]).T)/np.atleast_2d(stds[0]).T 58 | melgram2=(melgram2-np.atleast_2d(means[1]).T)/np.atleast_2d(stds[1]).T 59 | melgram3=(melgram3-np.atleast_2d(means[2]).T)/np.atleast_2d(stds[2]).T 60 | 61 | #zero pad ends 62 | melgram1=zeropad2d(melgram1,contextlen) 63 | melgram2=zeropad2d(melgram2,contextlen) 64 | melgram3=zeropad2d(melgram3,contextlen) 65 | 66 | #make chunks 67 | melgram1_chunks=makechunks(melgram1,duration) 68 | melgram2_chunks=makechunks(melgram2,duration) 69 | melgram3_chunks=makechunks(melgram3,duration) 70 | 71 | #generate song labels 72 | hop_dur=10e-3 73 | labels=np.zeros(melgram1_chunks.shape[0]) 74 | weights=np.ones(melgram1_chunks.shape[0]) 75 | idxs=np.array(np.round(onsets/hop_dur),dtype=int) 76 | labels[idxs]=1 77 | 78 | #target smearing 79 | labels[idxs-1]=1 80 | labels[idxs+1]=1 81 | weights[idxs-1]=0.25 82 | weights[idxs+1]=0.25 83 | 84 | labels_dict={} 85 | weights_dict={} 86 | 87 | #save 88 | savedir=os.path.join(save_dir,item) 89 | if not os.path.exists(savedir): os.makedirs(savedir) 90 | 91 | for i_chunk in range(melgram1_chunks.shape[0]): 92 | savepath=os.path.join(savedir,str(i_chunk)+'.pt') 93 | #np.save(savepath,np.array([melgram1_chunks[i_chunk],melgram2_chunks[i_chunk],melgram3_chunks[i_chunk]])) 94 | torch.save(torch.tensor(np.array([melgram1_chunks[i_chunk], melgram2_chunks[i_chunk], melgram3_chunks[i_chunk]])), savepath) 95 | filelist.append(savepath) 96 | labels_dict[savepath]=labels[i_chunk] 97 | weights_dict[savepath]=weights[i_chunk] 98 | 99 | #append labels to master 100 | labels_master.update(labels_dict) 101 | weights_master.update(weights_dict) 102 | 103 | np.savetxt(os.path.join(savedir,'labels.txt'),labels) 104 | np.savetxt(os.path.join(savedir,'weights.txt'),weights) 105 | 106 | np.save('labels_master',labels_master) 107 | np.save('weights_master',weights_master) 108 | #np.savetxt('filelist.txt',filelist,fmt='%s') 109 | -------------------------------------------------------------------------------- /gen_songlist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | songlist=np.array([]) 5 | for foldfile in os.listdir('./splits/'): 6 | foldsongs=np.loadtxt('./splits/'+foldfile,dtype=str) 7 | songlist=np.append(songlist,foldsongs) 8 | 9 | np.savetxt('songlist.txt',songlist,fmt='%s') 10 | -------------------------------------------------------------------------------- /get_data_stats.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import librosa 4 | 5 | #data dir 6 | audio_dir='/media/Sharedata/rohit/SS_onset_detection/audio' 7 | 8 | #main 9 | songlist=np.loadtxt('songlist.txt',dtype=str) 10 | i_song=0 11 | 12 | means_song=[np.array([]),np.array([]),np.array([])] 13 | stds_song=[np.array([]),np.array([]),np.array([])] 14 | 15 | for i_song in range(len(songlist)): 16 | #load audio 17 | x,fs=librosa.load(os.path.join(audio_dir,songlist[i_song]+'.flac'), sr=44100) 18 | 19 | #get mel spectrogram 20 | melgram1=librosa.feature.melspectrogram(x,sr=fs,n_fft=1024, hop_length=441,n_mels=80, fmin=27.5, fmax=16000) 21 | melgram2=librosa.feature.melspectrogram(x,sr=fs,n_fft=2048, hop_length=441,n_mels=80, fmin=27.5, fmax=16000) 22 | melgram3=librosa.feature.melspectrogram(x,sr=fs,n_fft=4096, hop_length=441,n_mels=80, fmin=27.5, fmax=16000) 23 | 24 | #log scaling 25 | melgram1=10*np.log10(1e-10+melgram1) 26 | melgram2=10*np.log10(1e-10+melgram2) 27 | melgram3=10*np.log10(1e-10+melgram3) 28 | 29 | #compute mean and std of dataset 30 | if i_song==0: 31 | means_song[0]=np.mean(melgram1,1) 32 | means_song[1]=np.mean(melgram2,1) 33 | means_song[2]=np.mean(melgram3,1) 34 | 35 | stds_song[0]=np.std(melgram1,1) 36 | stds_song[1]=np.std(melgram2,1) 37 | stds_song[2]=np.std(melgram3,1) 38 | 39 | else: 40 | means_song[0]+=np.mean(melgram1,1) 41 | means_song[1]+=np.mean(melgram2,1) 42 | means_song[2]+=np.mean(melgram3,1) 43 | 44 | stds_song[0]+=np.std(melgram1,1) 45 | stds_song[1]+=np.std(melgram2,1) 46 | stds_song[2]+=np.std(melgram3,1) 47 | 48 | means_song[0]/=i_song 49 | means_song[1]/=i_song 50 | means_song[2]/=i_song 51 | 52 | stds_song[0]/=i_song 53 | stds_song[1]/=i_song 54 | stds_song[2]/=i_song 55 | 56 | np.save('means_stds', np.array([means_song,stds_song])) 57 | -------------------------------------------------------------------------------- /models/saved_model_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rohitma38/cnn-onset-detection/e79a4105a644f316f34f932fb4430c991dd6504b/models/saved_model_0.pt -------------------------------------------------------------------------------- /models/saved_model_1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rohitma38/cnn-onset-detection/e79a4105a644f316f34f932fb4430c991dd6504b/models/saved_model_1.pt -------------------------------------------------------------------------------- /models/saved_model_2.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rohitma38/cnn-onset-detection/e79a4105a644f316f34f932fb4430c991dd6504b/models/saved_model_2.pt -------------------------------------------------------------------------------- /models/saved_model_3.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rohitma38/cnn-onset-detection/e79a4105a644f316f34f932fb4430c991dd6504b/models/saved_model_3.pt -------------------------------------------------------------------------------- /models/saved_model_4.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rohitma38/cnn-onset-detection/e79a4105a644f316f34f932fb4430c991dd6504b/models/saved_model_4.pt -------------------------------------------------------------------------------- /models/saved_model_5.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rohitma38/cnn-onset-detection/e79a4105a644f316f34f932fb4430c991dd6504b/models/saved_model_5.pt -------------------------------------------------------------------------------- /models/saved_model_6.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rohitma38/cnn-onset-detection/e79a4105a644f316f34f932fb4430c991dd6504b/models/saved_model_6.pt -------------------------------------------------------------------------------- /models/saved_model_7.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rohitma38/cnn-onset-detection/e79a4105a644f316f34f932fb4430c991dd6504b/models/saved_model_7.pt -------------------------------------------------------------------------------- /songlist.txt: -------------------------------------------------------------------------------- 1 | ah_development_guitar_2684_TexasMusicForge_Dandelion_pt1 2 | ah_development_oud_Diverse_-_01_-_Taksim_pt1 3 | ah_test_cello_03-Cello_Sonata_3__I_Allegro_ma_non_tanto_pt1 4 | ah_test_cello_14_VioloncelloTaksim_pt1 5 | ah_test_guitar_guitar2 6 | ah_test_guitar_Guitar_Licks_51-10 7 | ah_test_kemence_08_-_HicazTaksim_cut 8 | ah_test_kemence_10_huseyni_taksim_ve_cecenkizi_cut 9 | ah_test_kemence_11_RastTaksim_Kemence 10 | ah_test_ney_ne_se01 11 | ah_test_oud_ud_taksimleri_-_17_-_ussak_taksim 12 | ah_test_sax_Tubby_Hayes_-_The_Eighth_Wonder_-_11_-_Unidentified_12_Bar_Theme_pt1 13 | ah_test_trumpet_waldhorn33_-_Paloseco_pt1 14 | al_Albums-AnaBelen_Veneo-13(1.8-11.8) 15 | sb_Albums-Chrisanne3-07(3.0-13.0) 16 | sb_Albums-I_Like_It2-01(13.1-23.1) 17 | al_Albums-Latin_Jam2-03(6.1-16.1) 18 | al_Albums-Latino_Latino-09(8.8-18.8) 19 | sb_Albums-Step_By_Step-09(2.0-12.0) 20 | api_3-you_think_too_muchb 21 | api_RM-C003 22 | ff123_ItCouldBeSweet 23 | ff123_kraftwerk 24 | jpb_Jaillet70 25 | jpb_Jaillet75 26 | jpb_metheny 27 | jpb_wilco 28 | lame_t1 29 | lame_vbrtest 30 | al_Media-103917(13.5-23.5) 31 | al_Media-104218(9.3-19.3) 32 | al_Media-105404(8.0-18.0) 33 | al_Media-106115(15.9-25.9) 34 | mit_karaoke_tempo 35 | mit_track3 36 | SoundCheck2_61_Vocal_Soprano_opera 37 | mck_train10 38 | mck_train14 39 | mck_train20 40 | ah_development_oud_rast_taksim1 41 | ah_test_cello_08_-_Bach_(JS)-_Cello_Suite_-4_In_E_Flat_BWV_1010i_-_1._Preludium 42 | ah_test_guitar_Guitar_Licks_15-06 43 | ah_test_kemence_22_NevaTaksim_Kemence 44 | ah_test_mixtures_pop1 45 | ah_test_piano_Chick_Corea_-_08_-_So_in_Love 46 | ah_test_sax_20928_stephenchai_Gypsy_1_Alto_Sax_B_103bpm 47 | ah_test_tanburpluck_ta_1eb07 48 | sb_Albums-Cafe_Paradiso-05(4.3-14.3) 49 | sb_Albums-Cafe_Paradiso-07(13.1-23.1) 50 | sb_Albums-Chrisanne1-01(3.0-13.0) 51 | api_25-rujero 52 | api_2-uncle_mean 53 | api_realorgan3 54 | api_tiersen11 55 | ff123_charlies 56 | ff123_deerhunter 57 | ff123_freeup 58 | ff123_grace1 59 | ff123_Hongroise 60 | ff123_Waiting 61 | jpb_Jaillet17 62 | jpb_Jaillet21 63 | lame_fatboy 64 | al_Media-103515(19.7-29.7) 65 | al_Media-103611(0.2-10.2) 66 | al_Media-103814(13.8-23.8) 67 | al_Media-104105(15.6-25.6) 68 | al_Media-104317(12.3-22.3) 69 | al_Media-105009(20.3-30.3) 70 | al_Media-105402(9.7-19.7) 71 | sb_Media-106003(0.2-10.2) 72 | al_Media-106011(14.8-24.8) 73 | mit_track5 74 | SoundCheck2_72_Instrumental_Bongos 75 | SoundCheck2_74_Instrumental_Kick_drum 76 | SoundCheck2_84_Robin_S_-_Luv_4_Luv 77 | mck_train18 78 | mck_train7 79 | vorbis_bassdrum 80 | ah_development_guitar_Guitar_Licks_06-12 81 | ah_development_guitar_Guitar_Licks_15-05 82 | ah_development_percussion_castagnet1 83 | ah_development_percussion_conga1 84 | ah_test_cello_cello1 85 | ah_test_clarinet_44784_alikirodgers_AB_Clarinet_01_pt1 86 | ah_test_kemence_01_sultaniyghtaksim_Kemence 87 | ah_test_piano_Oscar_Peterson_-_02_-_Love_Ballade 88 | ah_test_piano_Oscar_Peterson_-_08_-_I_Love_You_Porgy 89 | ah_test_trumpet_trumpet1 90 | ah_test_violin_03__kurdilihicazkar__keman 91 | ah_test_violin_dark_eyes_viol 92 | al_Albums-Ballroom_Classics4-07(9.8-19.8) 93 | al_Albums-Latin_Jam-02(6.8-16.8) 94 | al_Albums-Latin_Jam3-11(11.5-21.5) 95 | al_Albums-Latino_Latino-05(7.8-17.8) 96 | al_Albums-Step_By_Step-07(1.1-11.1) 97 | api_15-tamerlano_act_i_track_15b 98 | api_its_alright_for_you_o 99 | ff123_BeautySlept 100 | ff123_beo1test 101 | ff123_eb_andul_short 102 | ff123_FloorEssence 103 | jpb_dido 104 | jpb_Jaillet27 105 | lame_ftb_samp 106 | lame_main_theme 107 | al_Media-103515(9.1-19.1) 108 | sb_Media-105215(12.0-22.0) 109 | sb_Media-105907(0.0-10.0) 110 | al_Media-106001(9.7-19.7) 111 | sb_Media-106103(4.0-14.0) 112 | sb_Media-106117(7.0-17.0) 113 | mit_track4 114 | SoundCheck2_60_Vocal_Tenor_opera 115 | SoundCheck2_64_Instrumental_Acoustic_steel_strung_guitar_finger_style 116 | SoundCheck2_82_Yello_-_The_race 117 | mck_train1 118 | mck_train3 119 | mck_train8 120 | ah_development_guitar_my_guitar1 121 | ah_test_clarinet_44361_debudding_Clarinet_ORTF_Stereo_Pair_NT_5_s_01 122 | ah_test_clarinet_clarinet1 123 | ah_test_mixtures_classic2 124 | ah_test_ney_ne_icm05 125 | ah_test_ney_ne_se04 126 | ah_test_piano_Chick_Corea_-_03_-_Folk_Song 127 | ah_test_sax_george_russel2_-_dimensions_pt1 128 | ah_test_tanburpluck_ta_1eb06 129 | ah_test_trumpet_1_-_Lost_In_Madrid_Part_1 130 | sb_Albums-Ballroom_Classics4-11(7.0-17.0) 131 | al_Albums-Cafe_Paradiso-09(14.9-24.9) 132 | al_Albums-Cafe_Paradiso-14(15.9-25.9) 133 | sb_Albums-Chrisanne3-02(12.0-22.0) 134 | sb_Albums-Latin_Jam-13(6.0-16.0) 135 | api_RM-C026 136 | api_RM-C036 137 | ff123_2nd_vent_clip 138 | ff123_DaFunk 139 | ff123_duel 140 | jpb_fiona 141 | jpb_Jaillet29 142 | jpb_Jaillet67 143 | jpb_PianoDebussy 144 | jpb_tabla 145 | lame_castanets 146 | lame_iron 147 | sb_Media-100608(3.0-13.0) 148 | al_Media-104717(18.2-28.2) 149 | al_Media-105016(4.1-14.1) 150 | al_Media-105614(20.9-30.9) 151 | al_Media-105913(10.6-20.6) 152 | al_Media-106109(12.4-22.4) 153 | mit_track1 154 | SoundCheck2_62_Vocal_Male_Rock_Vocal 155 | SoundCheck2_80_Instrumental_Cellos_and_violas 156 | mck_train11 157 | mck_train19 158 | unknown_violin 159 | vorbis_lalaw 160 | ah_development_guitar_Guitar_Licks_07-06 161 | ah_development_oud_8 162 | ah_development_piano_autumn 163 | ah_development_violin_01__hicaz__keman 164 | ah_development_violin_violin2 165 | ah_test_cello_03_CelloTaksimi_pt1 166 | ah_test_clarinet_my_clarinet1 167 | ah_test_clarinet_SL1_pt1 168 | ah_test_sax_sax1 169 | ah_test_trumpet_Miles_DAVIS__Michel_LEGRAND_DINGO_02 170 | sb_Albums-Ballroom_Magic-09(4.0-14.0) 171 | al_Albums-Cafe_Paradiso-16(17.0-27.0) 172 | al_Albums-Latin_Jam5-01(7.9-17.9) 173 | al_Albums-Latino_Latino-05(3.4-13.4) 174 | api_2-artificial 175 | api_8-ambrielb 176 | api_RM-C002 177 | ff123_41_30sec 178 | ff123_bloodline 179 | ff123_BlueEyesExcerpt 180 | ff123_drone_short 181 | ff123_ExitMusic 182 | ff123_rushing 183 | jpb_arab60s 184 | jpb_Jaillet64 185 | jpb_jaxx 186 | jpb_violin 187 | lame_else3 188 | lame_youcantdothat 189 | sb_Media-103302(15.0-25.0) 190 | sb_Media-103307(4.0-14.0) 191 | al_Media-103611(8.2-18.2) 192 | sb_Media-104210(2.0-12.0) 193 | al_Media-105320(2.5-12.5) 194 | al_Media-105615(12.4-22.4) 195 | al_Media-105808(15.5-25.5) 196 | al_Media-105905(7.8-17.8) 197 | SoundCheck2_65_Instrumental_Acoustic_steel_strung_guitar_strummed 198 | SoundCheck2_83_The_Alan_Parsons_Project_-_Limelight 199 | mck_train13 200 | ah_development_guitar_Guitar_Licks_07-11 201 | ah_development_oud_1 202 | ah_development_percussion_cajon116_13 203 | ah_development_percussion_tambourine 204 | ah_development_piano_MOON 205 | ah_development_piano_mussorgsky 206 | ah_development_violin_my_violin1 207 | ah_test_mixtures_jazz2 208 | ah_test_ney_ne_se03 209 | ah_test_ney_ne_tmu2n06 210 | ah_test_oud_Diverse_-_03_-_Muayyer_Kurdi_Taksim 211 | ah_test_oud_ud_taksimleri_-_06_-_segah_taksim 212 | ah_test_tanburpluck_ta_1eb08 213 | ah_test_violin_my_violin3 214 | sb_Albums-Ballroom_Magic-04(6.0-16.0) 215 | sb_Albums-Latin_Jam3-02(3.0-13.0) 216 | api_RM-C027 217 | api_RM-C038 218 | ff123_Debussy 219 | ff123_fossiles 220 | ff123_gekkou-intro 221 | ff123_wait 222 | gs_mix2_0dB 223 | jpb_Jaillet73 224 | lame_Fools 225 | lame_pipes 226 | lame_spahm 227 | al_Media-104306(5.0-15.0) 228 | al_Media-104415(10.7-20.7) 229 | sb_Media-105213(13.0-23.0) 230 | al_Media-105306(21.1-31.1) 231 | al_Media-105506(11.4-21.4) 232 | sb_Media-105801(11.0-21.0) 233 | al_Media-106015(16.4-26.4) 234 | al_Media-106110(6.4-16.4) 235 | al_Media-106113(16.4-26.4) 236 | SoundCheck2_63_Instrumental_Piano 237 | SoundCheck2_78_Instrumental_Whole_drum_kit 238 | mck_train5 239 | mck_train9 240 | ah_development_guitar_Guitar_Licks_06-11 241 | ah_development_percussion_bongo1 242 | ah_development_piano_p5 243 | ah_test_guitar_guitar3 244 | ah_test_guitar_Summer_Together_110_pt1 245 | ah_test_ney_ne_icm09 246 | ah_test_ney_ne_icm13 247 | ah_test_oud_Trio_Joubran-04-Safarcut 248 | ah_test_tanburpluck_ta_1eb03 249 | ah_test_violin_42954_FreqMan_hoochie_violin_pt1 250 | ah_test_violin_my_violin2 251 | sb_Albums-Chrisanne1-08(9.0-19.0) 252 | al_Albums-Fire-03(12.4-22.4) 253 | sb_Albums-Fire-13(15.0-25.0) 254 | api_3-long_gone 255 | api_6-three 256 | ff123_BigYellow 257 | ff123_Blackwater 258 | ff123_cymbals 259 | ff123_dogies 260 | ff123_Enchantment 261 | jpb_Jaillet34 262 | jpb_Jaillet66 263 | lame_hihat 264 | sb_Media-103416(12.0-22.0) 265 | al_Media-103505(1.9-11.9) 266 | al_Media-103714(4.2-14.2) 267 | al_Media-104002(11.3-21.3) 268 | al_Media-104016(5.1-15.1) 269 | al_Media-104506(3.8-13.8) 270 | al_Media-104807(16.5-26.5) 271 | sb_Media-105810(5.0-15.0) 272 | mit_track6 273 | SoundCheck2_70_Instrumental_Flute 274 | SoundCheck2_75_Instrumental_Snare_drum 275 | SoundCheck2_77_Instrumental_Toms 276 | SoundCheck2_79_Instrumental_Violins 277 | mck_train15 278 | mck_train16 279 | vorbis_11 280 | vorbis_dr4 281 | ah_development_oud_ud_taksimleri_-_02_-_huezzam_taksim_pt1 282 | ah_development_percussion_triangle 283 | ah_development_piano_p3 284 | ah_development_piano_piano1 285 | ah_test_mixtures_classic3 286 | ah_test_mixtures_jazz3 287 | ah_test_mixtures_rock1 288 | ah_test_oud_Diverse_-_02_-_Mahur_Pesrev_(Ud) 289 | ah_test_piano_PianoMic_full_stick 290 | ah_test_sax_herbieHanckockTrack06_pt1 291 | ah_test_tanburpluck_ta_1mc08 292 | ah_test_trumpet_chetbaker 293 | sb_Albums-Ballroom_Classics4-01(3.0-13.0) 294 | al_Albums-Cafe_Paradiso-10(1.5-11.5) 295 | sb_Albums-Chrisanne2-01(16.0-26.0) 296 | al_Albums-Latin_Jam2-09(5.3-15.3) 297 | api_RM-G008 298 | api_RM-J001 299 | ff123_ATrain 300 | ff123_BachS1007 301 | ff123_TheSource 302 | gs_mix1_0dB 303 | jpb_Jaillet15 304 | jpb_Jaillet65 305 | jpb_Jaillet74 306 | lame_BlackBird 307 | lame_testsignal2 308 | lame_velvet 309 | al_Media-103605(10.0-20.0) 310 | al_Media-103919(8.0-18.0) 311 | sb_Media-104111(5.0-15.0) 312 | al_Media-104917(15.6-25.6) 313 | al_Media-105020(11.1-21.1) 314 | sb_Media-105407(6.0-16.0) 315 | al_Media-105819(8.1-18.1) 316 | mit_track7 317 | SoundCheck2_71_Instrumental_Saxophone 318 | SoundCheck2_73_Instrumental_Tambourine 319 | mck_train2 320 | mck_train4 321 | mck_train6 322 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import glob 3 | import torch 4 | from torch.utils import data 5 | from utils import onsetCNN, Dataset 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | import os 9 | import utils 10 | 11 | #Use gpu 12 | use_cuda = torch.cuda.is_available() 13 | device = torch.device("cuda:1" if use_cuda else "cpu") 14 | 15 | #evaluation tolerance and merge duration for close onsets 16 | tolerance=60e-3 #+- tolerance/2 seconds 17 | mergeDur=20e-3 18 | hop_dur=10e-3 19 | mergeDur_frame=mergeDur/hop_dur 20 | tolerance_frame=tolerance/hop_dur 21 | 22 | fold = int(sys.argv[1]) #cmd line argument 23 | 24 | #load model 25 | path_to_saved_model = 'models/saved_model_%d.pt'%fold 26 | model = onsetCNN().double().to(device) 27 | model.load_state_dict(torch.load(path_to_saved_model)) 28 | model.eval() 29 | 30 | #data 31 | datadir='/media/Sharedata/rohit/SS_onset_detection/data_pt/' 32 | #songlist=os.listdir(datadir) 33 | songlist=np.loadtxt('splits/8-fold_cv_random_%d.fold'%fold,dtype=str) 34 | labels = np.load('labels_master_test.npy').item() 35 | 36 | #loop over test songs 37 | scores=np.array([]) 38 | n_songs=len(songlist) 39 | i_song=0 40 | for song in songlist: 41 | print('%d/%d songs\n'%(i_song,n_songs)) 42 | i_song+=1 43 | 44 | odf=np.array([]) 45 | gt=np.array([]) 46 | 47 | #generate frame-wise labels serially for song 48 | n_files=len(glob.glob(os.path.join(datadir,song+'/*.pt'))) 49 | for i_file in range(n_files): 50 | x=torch.load(os.path.join(datadir,song+'/%d.pt'%i_file)).to(device) 51 | x=x.unsqueeze(0) 52 | y=model(x).squeeze().cpu().detach().numpy() 53 | odf=np.append(odf,y) 54 | gt=np.append(gt,labels[os.path.join(datadir,song+'/%d.pt'%i_file)]) 55 | 56 | #evaluate odf 57 | scores_thresh=np.array([]) 58 | #loop over different peak-picking thresholds to optimize F-score 59 | for predict_thresh in [0.3, 0.4, 0.5, 0.6, 0.7, 0.8]: 60 | odf_labels=np.zeros(len(odf)) 61 | 62 | #pick peaks 63 | odf_labels[utils.peakPicker(odf,predict_thresh)]=1. 64 | 65 | #evaluate, get #hits and #misses 66 | scores_thresh=np.append(scores_thresh,utils.eval_output(odf_labels, odf, gt, tolerance_frame, mergeDur_frame)) 67 | 68 | #accumulate hits and misses for every song 69 | if len(scores)==0: scores=np.atleast_2d(np.array(scores_thresh)) 70 | else: scores=np.vstack((scores,np.atleast_2d(np.array(scores_thresh)))) 71 | 72 | #add hits and misses over all songs (to compute testset P, R and F-score) 73 | scores=np.sum(scores,0) 74 | 75 | # Write to file 76 | fout=open('hr_fa_folds.txt','a') 77 | for item in scores: 78 | fout.write('%d\t'%item) 79 | fout.write('\n') 80 | fout.close() 81 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import glob 3 | import torch 4 | from torch.utils import data 5 | from utils import onsetCNN, Dataset 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | 9 | #function to repeat positive samples to improve data balance 10 | def balance_data(ids, labels): 11 | ids2add=[] 12 | for idi in ids: 13 | if labels[idi]==1: 14 | ids2add.append(idi) 15 | ids2add.append(idi) 16 | ids2add.append(idi) 17 | return ids2add 18 | 19 | #use GPU 20 | use_cuda = torch.cuda.is_available() 21 | device = torch.device("cuda:1")#torch.device("cuda:0" if use_cuda else "cpu") 22 | 23 | #parameters for data loader 24 | params = {'batch_size': 256,'shuffle': True,'num_workers': 6} 25 | max_epochs = 50 26 | 27 | #data 28 | datadir='/media/Sharedata/rohit/SS_onset_detection/data_pt/' 29 | songlist=np.loadtxt('songlist.txt',dtype=str) 30 | labels = np.load('labels_master.npy').item() 31 | weights = np.load('weights_master.npy').item() 32 | 33 | #model 34 | model=onsetCNN().double().to(device) 35 | criterion=torch.nn.BCELoss(reduction='none') 36 | optimizer=torch.optim.SGD(model.parameters(), lr=0.05, momentum=0.45) 37 | #optimizer=torch.optim.Adam(model.parameters(), lr=0.05) 38 | 39 | #cross-validation loop 40 | fold = int(sys.argv[1]) #cmd line argument 41 | partition = {'all':[], 'train':[], 'validation':[]} 42 | val_split = np.loadtxt('splits/8-fold_cv_random_%d.fold'%fold,dtype='str') 43 | for song in songlist: 44 | ids = glob.glob(datadir+song+'/*.pt') 45 | if song in val_split: partition['validation'].extend(ids) 46 | else: partition['train'].extend(ids) 47 | 48 | #balance data 49 | #partition['train'].extend(balance_data(partition['train'],labels)) 50 | 51 | #print data balance percentage 52 | n_ones=0. 53 | for idi in partition['train']: 54 | if labels[idi]==1.: n_ones+=1 55 | print('Fraction of positive examples: %f'%(n_ones/len(partition['train']))) 56 | 57 | #generators 58 | training_set = Dataset(partition['train'], labels, weights) 59 | training_generator = data.DataLoader(training_set, **params) 60 | 61 | validation_set = Dataset(partition['validation'], labels, weights) 62 | validation_generator = data.DataLoader(validation_set, **params) 63 | 64 | #training epochs loop 65 | train_loss_epoch=[] 66 | val_loss_epoch=[] 67 | for epoch in range(max_epochs): 68 | train_loss_epoch+=[0] 69 | val_loss_epoch+=[0] 70 | 71 | ##training 72 | n_train=0 73 | for local_batch, local_labels, local_weights in training_generator: 74 | n_train+=local_batch.shape[0] 75 | 76 | #transfer to GPU 77 | local_batch, local_labels, local_weights = local_batch.to(device), local_labels.to(device), local_weights.to(device) 78 | 79 | #update weights 80 | optimizer.zero_grad() 81 | outs = model(local_batch).squeeze() 82 | loss = criterion(outs, local_labels) 83 | loss = torch.dot(loss,local_weights) 84 | loss /= local_batch.size()[0] 85 | loss.backward() 86 | optimizer.step() 87 | train_loss_epoch[-1]+=loss.item() 88 | train_loss_epoch[-1]/=n_train 89 | 90 | ##validation 91 | n_val=0 92 | with torch.set_grad_enabled(False): 93 | for local_batch, local_labels, local_weights in validation_generator: 94 | n_val+=local_batch.shape[0] 95 | 96 | #transfer to GPU 97 | local_batch, local_labels = local_batch.to(device), local_labels.to(device) 98 | 99 | #evaluate model 100 | outs = model(local_batch).squeeze() 101 | loss = criterion(outs, local_labels).mean() 102 | val_loss_epoch[-1]+=loss.item() 103 | val_loss_epoch[-1]/=n_val 104 | 105 | #print loss in current epoch 106 | print('Epoch no: %d/%d\tTrain loss: %f\tVal loss: %f'%(epoch, max_epochs, train_loss_epoch[-1], val_loss_epoch[-1])) 107 | 108 | #update LR and momentum (only if using SGD) 109 | for param_group in optimizer.param_groups: 110 | param_group['lr'] *= 0.995 111 | if 10<=epoch<=20: param_group['momentum'] += 0.045 112 | 113 | #plot losses vs epoch 114 | plt.plot(train_loss_epoch,label='train') 115 | plt.plot(val_loss_epoch,label='val') 116 | plt.legend() 117 | plt.savefig('./plots/loss_curves_%d'%fold) 118 | plt.clf() 119 | torch.save(model.state_dict(), 'saved_model_%d.pt'%fold) 120 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.utils import data 6 | 7 | #model 8 | class onsetCNN(nn.Module): 9 | def __init__(self): 10 | super(onsetCNN, self).__init__() 11 | self.conv1 = nn.Conv2d(3, 10, (3,7)) 12 | self.pool1 = nn.MaxPool2d((3,1)) 13 | self.conv2 = nn.Conv2d(10, 20, 3) 14 | self.pool2 = nn.MaxPool2d((3,1)) 15 | self.fc1 = nn.Linear(20 * 7 * 8, 256) 16 | self.fc2 = nn.Linear(256,1) 17 | self.dout = nn.Dropout(p=0.5) 18 | 19 | def forward(self,x): 20 | y=torch.tanh(self.conv1(x)) 21 | y=self.pool1(y) 22 | y=torch.tanh(self.conv2(y)) 23 | y=self.pool2(y) 24 | y=self.dout(y.view(-1,20*7*8)) 25 | y=self.dout(torch.sigmoid(self.fc1(y))) 26 | y=torch.sigmoid(self.fc2(y)) 27 | return y 28 | 29 | #data-loader(https://stanford.edu/~shervine/blog/pytorch-how-to-generate-data-parallel) 30 | class Dataset(data.Dataset): 31 | 'Characterizes a dataset for PyTorch' 32 | def __init__(self, list_IDs, labels, weights): 33 | 'Initialization' 34 | self.labels = labels 35 | self.weights = weights 36 | self.list_IDs = list_IDs 37 | 38 | def __len__(self): 39 | 'Denotes the total number of samples' 40 | return len(self.list_IDs) 41 | 42 | def __getitem__(self, index): 43 | 'Generates one sample of data' 44 | # Select sample 45 | ID = self.list_IDs[index] 46 | 47 | # Load data and get label 48 | #X = torch.tensor(np.load(ID)) 49 | X = torch.load(ID) 50 | y = self.labels[ID]#.replace('.npy','')] 51 | w = self.weights[ID]#.replace('.npy','')] 52 | 53 | return X, y, w 54 | 55 | #peak-picking function 56 | def peakPicker(data, peakThresh): 57 | peaks=np.array([],dtype='int') 58 | for ind in range(1,len(data)-1): 59 | if ((data[ind+1] < data[ind] > data[ind-1]) & (data[ind]>peakThresh)): 60 | peaks=np.append(peaks,ind) 61 | return peaks 62 | 63 | #merge onsets if too close - retain only stronger one 64 | def merge_onsets(onsets,strengths,mergeDur): 65 | onsetLocs=np.where(onsets==1)[0] 66 | ind=1 67 | while ind