├── README.md
├── gen_data.py
├── gen_songlist.py
├── get_data_stats.py
├── models
├── saved_model_0.pt
├── saved_model_1.pt
├── saved_model_2.pt
├── saved_model_3.pt
├── saved_model_4.pt
├── saved_model_5.pt
├── saved_model_6.pt
└── saved_model_7.pt
├── songlist.txt
├── test.py
├── train.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # Musical onset detection using a CNN
2 |
3 | Pytorch implementation of the method described in:
4 | Schlüter, Jan, and Sebastian Böck. "Improved musical onset detection with convolutional neural networks." 2014 ieee international conference on acoustics, speech and signal processing (icassp). IEEE, 2014.
5 |
6 |
7 | ## Requirements
8 | * Pytorch
9 | * Librosa
10 | * Numpy
11 | * Matplotlib(optional)
12 |
13 | ## Dataset (used in the paper)
14 | * Can be obtained from here (until the Google drive links are alive) - "https://github.com/CPJKU/onset_db/issues/1#issuecomment-472300295"
15 |
16 | ## Usage
17 |
18 | ### Train the network
19 | 1. Run gen_songlist.py
to get the list of all songs for which there is onset annotation data available(there are some extra audios in the dataset)
20 | 2. Run get_data_stats.py
to compute the mean and standard deviation across 80 mel bands over the entire dataset
21 | 3. Run gen_data.py
to generate the 15-frame mel spectrogram chunks and frame-wise labels for all the audios
22 | 4. Run train.py
to train the network. Specify a fold number in the command line when running this script. This is used to partition the data into train and val splits using the splits data provided by the authors. The training almost exactly follows the procedure described in the paper. The weights at the end of 100 epochs get saved in the models folder.
23 |
24 | ### Evaluate the network
25 | 1. Run test.py
to evaluate on the dataset. Again, specify a fold number to get the results for that fold. Results get saved to a text file in the form of #true-postives, #false-alarms, and #ground-truth-onsets, summed over all the validation songs, for different evaluation thresholds.
26 |
27 | ### Load saved model
28 | If you wish to use the trained model on different data, utils.py
contains the model class definition (and some other helper functions). Import the model class from here and load one of the saved model state dicts from the models folder.
29 |
--------------------------------------------------------------------------------
/gen_data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import librosa
4 | import torch
5 |
6 | #function to zero pad ends of spectrogram
7 | def zeropad2d(x,n_frames):
8 | y=np.hstack((np.zeros([x.shape[0],n_frames]), x))
9 | y=np.hstack((y,np.zeros([x.shape[0],n_frames])))
10 | return y
11 |
12 | #function to create N-frame overlapping chunks of the full audio spectrogram
13 | def makechunks(x,duration):
14 | y=np.zeros([x.shape[1],x.shape[0],duration])
15 | for i_frame in range(x.shape[1]-duration):
16 | y[i_frame]=x[:,i_frame:i_frame+duration]
17 | return y
18 |
19 | #data dirs
20 | audio_dir='/media/Sharedata/rohit/SS_onset_detection/audio'
21 | onset_dir='/media/Sharedata/rohit/SS_onset_detection/onsets'
22 | save_dir='/media/Sharedata/rohit/SS_onset_detection/data_pt_test'
23 |
24 | #data stats for normalization
25 | stats=np.load('means_stds.npy')
26 | means=stats[0]
27 | stds=stats[1]
28 |
29 | #context parameters
30 | contextlen=7 #+- frames
31 | duration=2*contextlen+1
32 |
33 | #main
34 | songlist=np.loadtxt('songlist.txt',dtype=str)
35 | audio_format='.flac'
36 | labels_master={}
37 | weights_master={}
38 | filelist=[]
39 | for item in songlist:
40 | print(item)
41 | #load audio and onsets
42 | x,fs=librosa.load(os.path.join(audio_dir,item+audio_format), sr=44100)
43 | if not os.path.exists(os.path.join(onset_dir,item+'.onsets')): continue
44 | onsets=np.loadtxt(os.path.join(onset_dir,item+'.onsets'))
45 |
46 | #get mel spectrogram
47 | melgram1=librosa.feature.melspectrogram(x,sr=fs,n_fft=1024, hop_length=441,n_mels=80, fmin=27.5, fmax=16000)
48 | melgram2=librosa.feature.melspectrogram(x,sr=fs,n_fft=2048, hop_length=441,n_mels=80, fmin=27.5, fmax=16000)
49 | melgram3=librosa.feature.melspectrogram(x,sr=fs,n_fft=4096, hop_length=441,n_mels=80, fmin=27.5, fmax=16000)
50 |
51 | #log scaling
52 | melgram1=10*np.log10(1e-10+melgram1)
53 | melgram2=10*np.log10(1e-10+melgram2)
54 | melgram3=10*np.log10(1e-10+melgram3)
55 |
56 | #normalize
57 | melgram1=(melgram1-np.atleast_2d(means[0]).T)/np.atleast_2d(stds[0]).T
58 | melgram2=(melgram2-np.atleast_2d(means[1]).T)/np.atleast_2d(stds[1]).T
59 | melgram3=(melgram3-np.atleast_2d(means[2]).T)/np.atleast_2d(stds[2]).T
60 |
61 | #zero pad ends
62 | melgram1=zeropad2d(melgram1,contextlen)
63 | melgram2=zeropad2d(melgram2,contextlen)
64 | melgram3=zeropad2d(melgram3,contextlen)
65 |
66 | #make chunks
67 | melgram1_chunks=makechunks(melgram1,duration)
68 | melgram2_chunks=makechunks(melgram2,duration)
69 | melgram3_chunks=makechunks(melgram3,duration)
70 |
71 | #generate song labels
72 | hop_dur=10e-3
73 | labels=np.zeros(melgram1_chunks.shape[0])
74 | weights=np.ones(melgram1_chunks.shape[0])
75 | idxs=np.array(np.round(onsets/hop_dur),dtype=int)
76 | labels[idxs]=1
77 |
78 | #target smearing
79 | labels[idxs-1]=1
80 | labels[idxs+1]=1
81 | weights[idxs-1]=0.25
82 | weights[idxs+1]=0.25
83 |
84 | labels_dict={}
85 | weights_dict={}
86 |
87 | #save
88 | savedir=os.path.join(save_dir,item)
89 | if not os.path.exists(savedir): os.makedirs(savedir)
90 |
91 | for i_chunk in range(melgram1_chunks.shape[0]):
92 | savepath=os.path.join(savedir,str(i_chunk)+'.pt')
93 | #np.save(savepath,np.array([melgram1_chunks[i_chunk],melgram2_chunks[i_chunk],melgram3_chunks[i_chunk]]))
94 | torch.save(torch.tensor(np.array([melgram1_chunks[i_chunk], melgram2_chunks[i_chunk], melgram3_chunks[i_chunk]])), savepath)
95 | filelist.append(savepath)
96 | labels_dict[savepath]=labels[i_chunk]
97 | weights_dict[savepath]=weights[i_chunk]
98 |
99 | #append labels to master
100 | labels_master.update(labels_dict)
101 | weights_master.update(weights_dict)
102 |
103 | np.savetxt(os.path.join(savedir,'labels.txt'),labels)
104 | np.savetxt(os.path.join(savedir,'weights.txt'),weights)
105 |
106 | np.save('labels_master',labels_master)
107 | np.save('weights_master',weights_master)
108 | #np.savetxt('filelist.txt',filelist,fmt='%s')
109 |
--------------------------------------------------------------------------------
/gen_songlist.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 |
4 | songlist=np.array([])
5 | for foldfile in os.listdir('./splits/'):
6 | foldsongs=np.loadtxt('./splits/'+foldfile,dtype=str)
7 | songlist=np.append(songlist,foldsongs)
8 |
9 | np.savetxt('songlist.txt',songlist,fmt='%s')
10 |
--------------------------------------------------------------------------------
/get_data_stats.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import librosa
4 |
5 | #data dir
6 | audio_dir='/media/Sharedata/rohit/SS_onset_detection/audio'
7 |
8 | #main
9 | songlist=np.loadtxt('songlist.txt',dtype=str)
10 | i_song=0
11 |
12 | means_song=[np.array([]),np.array([]),np.array([])]
13 | stds_song=[np.array([]),np.array([]),np.array([])]
14 |
15 | for i_song in range(len(songlist)):
16 | #load audio
17 | x,fs=librosa.load(os.path.join(audio_dir,songlist[i_song]+'.flac'), sr=44100)
18 |
19 | #get mel spectrogram
20 | melgram1=librosa.feature.melspectrogram(x,sr=fs,n_fft=1024, hop_length=441,n_mels=80, fmin=27.5, fmax=16000)
21 | melgram2=librosa.feature.melspectrogram(x,sr=fs,n_fft=2048, hop_length=441,n_mels=80, fmin=27.5, fmax=16000)
22 | melgram3=librosa.feature.melspectrogram(x,sr=fs,n_fft=4096, hop_length=441,n_mels=80, fmin=27.5, fmax=16000)
23 |
24 | #log scaling
25 | melgram1=10*np.log10(1e-10+melgram1)
26 | melgram2=10*np.log10(1e-10+melgram2)
27 | melgram3=10*np.log10(1e-10+melgram3)
28 |
29 | #compute mean and std of dataset
30 | if i_song==0:
31 | means_song[0]=np.mean(melgram1,1)
32 | means_song[1]=np.mean(melgram2,1)
33 | means_song[2]=np.mean(melgram3,1)
34 |
35 | stds_song[0]=np.std(melgram1,1)
36 | stds_song[1]=np.std(melgram2,1)
37 | stds_song[2]=np.std(melgram3,1)
38 |
39 | else:
40 | means_song[0]+=np.mean(melgram1,1)
41 | means_song[1]+=np.mean(melgram2,1)
42 | means_song[2]+=np.mean(melgram3,1)
43 |
44 | stds_song[0]+=np.std(melgram1,1)
45 | stds_song[1]+=np.std(melgram2,1)
46 | stds_song[2]+=np.std(melgram3,1)
47 |
48 | means_song[0]/=i_song
49 | means_song[1]/=i_song
50 | means_song[2]/=i_song
51 |
52 | stds_song[0]/=i_song
53 | stds_song[1]/=i_song
54 | stds_song[2]/=i_song
55 |
56 | np.save('means_stds', np.array([means_song,stds_song]))
57 |
--------------------------------------------------------------------------------
/models/saved_model_0.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rohitma38/cnn-onset-detection/e79a4105a644f316f34f932fb4430c991dd6504b/models/saved_model_0.pt
--------------------------------------------------------------------------------
/models/saved_model_1.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rohitma38/cnn-onset-detection/e79a4105a644f316f34f932fb4430c991dd6504b/models/saved_model_1.pt
--------------------------------------------------------------------------------
/models/saved_model_2.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rohitma38/cnn-onset-detection/e79a4105a644f316f34f932fb4430c991dd6504b/models/saved_model_2.pt
--------------------------------------------------------------------------------
/models/saved_model_3.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rohitma38/cnn-onset-detection/e79a4105a644f316f34f932fb4430c991dd6504b/models/saved_model_3.pt
--------------------------------------------------------------------------------
/models/saved_model_4.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rohitma38/cnn-onset-detection/e79a4105a644f316f34f932fb4430c991dd6504b/models/saved_model_4.pt
--------------------------------------------------------------------------------
/models/saved_model_5.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rohitma38/cnn-onset-detection/e79a4105a644f316f34f932fb4430c991dd6504b/models/saved_model_5.pt
--------------------------------------------------------------------------------
/models/saved_model_6.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rohitma38/cnn-onset-detection/e79a4105a644f316f34f932fb4430c991dd6504b/models/saved_model_6.pt
--------------------------------------------------------------------------------
/models/saved_model_7.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rohitma38/cnn-onset-detection/e79a4105a644f316f34f932fb4430c991dd6504b/models/saved_model_7.pt
--------------------------------------------------------------------------------
/songlist.txt:
--------------------------------------------------------------------------------
1 | ah_development_guitar_2684_TexasMusicForge_Dandelion_pt1
2 | ah_development_oud_Diverse_-_01_-_Taksim_pt1
3 | ah_test_cello_03-Cello_Sonata_3__I_Allegro_ma_non_tanto_pt1
4 | ah_test_cello_14_VioloncelloTaksim_pt1
5 | ah_test_guitar_guitar2
6 | ah_test_guitar_Guitar_Licks_51-10
7 | ah_test_kemence_08_-_HicazTaksim_cut
8 | ah_test_kemence_10_huseyni_taksim_ve_cecenkizi_cut
9 | ah_test_kemence_11_RastTaksim_Kemence
10 | ah_test_ney_ne_se01
11 | ah_test_oud_ud_taksimleri_-_17_-_ussak_taksim
12 | ah_test_sax_Tubby_Hayes_-_The_Eighth_Wonder_-_11_-_Unidentified_12_Bar_Theme_pt1
13 | ah_test_trumpet_waldhorn33_-_Paloseco_pt1
14 | al_Albums-AnaBelen_Veneo-13(1.8-11.8)
15 | sb_Albums-Chrisanne3-07(3.0-13.0)
16 | sb_Albums-I_Like_It2-01(13.1-23.1)
17 | al_Albums-Latin_Jam2-03(6.1-16.1)
18 | al_Albums-Latino_Latino-09(8.8-18.8)
19 | sb_Albums-Step_By_Step-09(2.0-12.0)
20 | api_3-you_think_too_muchb
21 | api_RM-C003
22 | ff123_ItCouldBeSweet
23 | ff123_kraftwerk
24 | jpb_Jaillet70
25 | jpb_Jaillet75
26 | jpb_metheny
27 | jpb_wilco
28 | lame_t1
29 | lame_vbrtest
30 | al_Media-103917(13.5-23.5)
31 | al_Media-104218(9.3-19.3)
32 | al_Media-105404(8.0-18.0)
33 | al_Media-106115(15.9-25.9)
34 | mit_karaoke_tempo
35 | mit_track3
36 | SoundCheck2_61_Vocal_Soprano_opera
37 | mck_train10
38 | mck_train14
39 | mck_train20
40 | ah_development_oud_rast_taksim1
41 | ah_test_cello_08_-_Bach_(JS)-_Cello_Suite_-4_In_E_Flat_BWV_1010i_-_1._Preludium
42 | ah_test_guitar_Guitar_Licks_15-06
43 | ah_test_kemence_22_NevaTaksim_Kemence
44 | ah_test_mixtures_pop1
45 | ah_test_piano_Chick_Corea_-_08_-_So_in_Love
46 | ah_test_sax_20928_stephenchai_Gypsy_1_Alto_Sax_B_103bpm
47 | ah_test_tanburpluck_ta_1eb07
48 | sb_Albums-Cafe_Paradiso-05(4.3-14.3)
49 | sb_Albums-Cafe_Paradiso-07(13.1-23.1)
50 | sb_Albums-Chrisanne1-01(3.0-13.0)
51 | api_25-rujero
52 | api_2-uncle_mean
53 | api_realorgan3
54 | api_tiersen11
55 | ff123_charlies
56 | ff123_deerhunter
57 | ff123_freeup
58 | ff123_grace1
59 | ff123_Hongroise
60 | ff123_Waiting
61 | jpb_Jaillet17
62 | jpb_Jaillet21
63 | lame_fatboy
64 | al_Media-103515(19.7-29.7)
65 | al_Media-103611(0.2-10.2)
66 | al_Media-103814(13.8-23.8)
67 | al_Media-104105(15.6-25.6)
68 | al_Media-104317(12.3-22.3)
69 | al_Media-105009(20.3-30.3)
70 | al_Media-105402(9.7-19.7)
71 | sb_Media-106003(0.2-10.2)
72 | al_Media-106011(14.8-24.8)
73 | mit_track5
74 | SoundCheck2_72_Instrumental_Bongos
75 | SoundCheck2_74_Instrumental_Kick_drum
76 | SoundCheck2_84_Robin_S_-_Luv_4_Luv
77 | mck_train18
78 | mck_train7
79 | vorbis_bassdrum
80 | ah_development_guitar_Guitar_Licks_06-12
81 | ah_development_guitar_Guitar_Licks_15-05
82 | ah_development_percussion_castagnet1
83 | ah_development_percussion_conga1
84 | ah_test_cello_cello1
85 | ah_test_clarinet_44784_alikirodgers_AB_Clarinet_01_pt1
86 | ah_test_kemence_01_sultaniyghtaksim_Kemence
87 | ah_test_piano_Oscar_Peterson_-_02_-_Love_Ballade
88 | ah_test_piano_Oscar_Peterson_-_08_-_I_Love_You_Porgy
89 | ah_test_trumpet_trumpet1
90 | ah_test_violin_03__kurdilihicazkar__keman
91 | ah_test_violin_dark_eyes_viol
92 | al_Albums-Ballroom_Classics4-07(9.8-19.8)
93 | al_Albums-Latin_Jam-02(6.8-16.8)
94 | al_Albums-Latin_Jam3-11(11.5-21.5)
95 | al_Albums-Latino_Latino-05(7.8-17.8)
96 | al_Albums-Step_By_Step-07(1.1-11.1)
97 | api_15-tamerlano_act_i_track_15b
98 | api_its_alright_for_you_o
99 | ff123_BeautySlept
100 | ff123_beo1test
101 | ff123_eb_andul_short
102 | ff123_FloorEssence
103 | jpb_dido
104 | jpb_Jaillet27
105 | lame_ftb_samp
106 | lame_main_theme
107 | al_Media-103515(9.1-19.1)
108 | sb_Media-105215(12.0-22.0)
109 | sb_Media-105907(0.0-10.0)
110 | al_Media-106001(9.7-19.7)
111 | sb_Media-106103(4.0-14.0)
112 | sb_Media-106117(7.0-17.0)
113 | mit_track4
114 | SoundCheck2_60_Vocal_Tenor_opera
115 | SoundCheck2_64_Instrumental_Acoustic_steel_strung_guitar_finger_style
116 | SoundCheck2_82_Yello_-_The_race
117 | mck_train1
118 | mck_train3
119 | mck_train8
120 | ah_development_guitar_my_guitar1
121 | ah_test_clarinet_44361_debudding_Clarinet_ORTF_Stereo_Pair_NT_5_s_01
122 | ah_test_clarinet_clarinet1
123 | ah_test_mixtures_classic2
124 | ah_test_ney_ne_icm05
125 | ah_test_ney_ne_se04
126 | ah_test_piano_Chick_Corea_-_03_-_Folk_Song
127 | ah_test_sax_george_russel2_-_dimensions_pt1
128 | ah_test_tanburpluck_ta_1eb06
129 | ah_test_trumpet_1_-_Lost_In_Madrid_Part_1
130 | sb_Albums-Ballroom_Classics4-11(7.0-17.0)
131 | al_Albums-Cafe_Paradiso-09(14.9-24.9)
132 | al_Albums-Cafe_Paradiso-14(15.9-25.9)
133 | sb_Albums-Chrisanne3-02(12.0-22.0)
134 | sb_Albums-Latin_Jam-13(6.0-16.0)
135 | api_RM-C026
136 | api_RM-C036
137 | ff123_2nd_vent_clip
138 | ff123_DaFunk
139 | ff123_duel
140 | jpb_fiona
141 | jpb_Jaillet29
142 | jpb_Jaillet67
143 | jpb_PianoDebussy
144 | jpb_tabla
145 | lame_castanets
146 | lame_iron
147 | sb_Media-100608(3.0-13.0)
148 | al_Media-104717(18.2-28.2)
149 | al_Media-105016(4.1-14.1)
150 | al_Media-105614(20.9-30.9)
151 | al_Media-105913(10.6-20.6)
152 | al_Media-106109(12.4-22.4)
153 | mit_track1
154 | SoundCheck2_62_Vocal_Male_Rock_Vocal
155 | SoundCheck2_80_Instrumental_Cellos_and_violas
156 | mck_train11
157 | mck_train19
158 | unknown_violin
159 | vorbis_lalaw
160 | ah_development_guitar_Guitar_Licks_07-06
161 | ah_development_oud_8
162 | ah_development_piano_autumn
163 | ah_development_violin_01__hicaz__keman
164 | ah_development_violin_violin2
165 | ah_test_cello_03_CelloTaksimi_pt1
166 | ah_test_clarinet_my_clarinet1
167 | ah_test_clarinet_SL1_pt1
168 | ah_test_sax_sax1
169 | ah_test_trumpet_Miles_DAVIS__Michel_LEGRAND_DINGO_02
170 | sb_Albums-Ballroom_Magic-09(4.0-14.0)
171 | al_Albums-Cafe_Paradiso-16(17.0-27.0)
172 | al_Albums-Latin_Jam5-01(7.9-17.9)
173 | al_Albums-Latino_Latino-05(3.4-13.4)
174 | api_2-artificial
175 | api_8-ambrielb
176 | api_RM-C002
177 | ff123_41_30sec
178 | ff123_bloodline
179 | ff123_BlueEyesExcerpt
180 | ff123_drone_short
181 | ff123_ExitMusic
182 | ff123_rushing
183 | jpb_arab60s
184 | jpb_Jaillet64
185 | jpb_jaxx
186 | jpb_violin
187 | lame_else3
188 | lame_youcantdothat
189 | sb_Media-103302(15.0-25.0)
190 | sb_Media-103307(4.0-14.0)
191 | al_Media-103611(8.2-18.2)
192 | sb_Media-104210(2.0-12.0)
193 | al_Media-105320(2.5-12.5)
194 | al_Media-105615(12.4-22.4)
195 | al_Media-105808(15.5-25.5)
196 | al_Media-105905(7.8-17.8)
197 | SoundCheck2_65_Instrumental_Acoustic_steel_strung_guitar_strummed
198 | SoundCheck2_83_The_Alan_Parsons_Project_-_Limelight
199 | mck_train13
200 | ah_development_guitar_Guitar_Licks_07-11
201 | ah_development_oud_1
202 | ah_development_percussion_cajon116_13
203 | ah_development_percussion_tambourine
204 | ah_development_piano_MOON
205 | ah_development_piano_mussorgsky
206 | ah_development_violin_my_violin1
207 | ah_test_mixtures_jazz2
208 | ah_test_ney_ne_se03
209 | ah_test_ney_ne_tmu2n06
210 | ah_test_oud_Diverse_-_03_-_Muayyer_Kurdi_Taksim
211 | ah_test_oud_ud_taksimleri_-_06_-_segah_taksim
212 | ah_test_tanburpluck_ta_1eb08
213 | ah_test_violin_my_violin3
214 | sb_Albums-Ballroom_Magic-04(6.0-16.0)
215 | sb_Albums-Latin_Jam3-02(3.0-13.0)
216 | api_RM-C027
217 | api_RM-C038
218 | ff123_Debussy
219 | ff123_fossiles
220 | ff123_gekkou-intro
221 | ff123_wait
222 | gs_mix2_0dB
223 | jpb_Jaillet73
224 | lame_Fools
225 | lame_pipes
226 | lame_spahm
227 | al_Media-104306(5.0-15.0)
228 | al_Media-104415(10.7-20.7)
229 | sb_Media-105213(13.0-23.0)
230 | al_Media-105306(21.1-31.1)
231 | al_Media-105506(11.4-21.4)
232 | sb_Media-105801(11.0-21.0)
233 | al_Media-106015(16.4-26.4)
234 | al_Media-106110(6.4-16.4)
235 | al_Media-106113(16.4-26.4)
236 | SoundCheck2_63_Instrumental_Piano
237 | SoundCheck2_78_Instrumental_Whole_drum_kit
238 | mck_train5
239 | mck_train9
240 | ah_development_guitar_Guitar_Licks_06-11
241 | ah_development_percussion_bongo1
242 | ah_development_piano_p5
243 | ah_test_guitar_guitar3
244 | ah_test_guitar_Summer_Together_110_pt1
245 | ah_test_ney_ne_icm09
246 | ah_test_ney_ne_icm13
247 | ah_test_oud_Trio_Joubran-04-Safarcut
248 | ah_test_tanburpluck_ta_1eb03
249 | ah_test_violin_42954_FreqMan_hoochie_violin_pt1
250 | ah_test_violin_my_violin2
251 | sb_Albums-Chrisanne1-08(9.0-19.0)
252 | al_Albums-Fire-03(12.4-22.4)
253 | sb_Albums-Fire-13(15.0-25.0)
254 | api_3-long_gone
255 | api_6-three
256 | ff123_BigYellow
257 | ff123_Blackwater
258 | ff123_cymbals
259 | ff123_dogies
260 | ff123_Enchantment
261 | jpb_Jaillet34
262 | jpb_Jaillet66
263 | lame_hihat
264 | sb_Media-103416(12.0-22.0)
265 | al_Media-103505(1.9-11.9)
266 | al_Media-103714(4.2-14.2)
267 | al_Media-104002(11.3-21.3)
268 | al_Media-104016(5.1-15.1)
269 | al_Media-104506(3.8-13.8)
270 | al_Media-104807(16.5-26.5)
271 | sb_Media-105810(5.0-15.0)
272 | mit_track6
273 | SoundCheck2_70_Instrumental_Flute
274 | SoundCheck2_75_Instrumental_Snare_drum
275 | SoundCheck2_77_Instrumental_Toms
276 | SoundCheck2_79_Instrumental_Violins
277 | mck_train15
278 | mck_train16
279 | vorbis_11
280 | vorbis_dr4
281 | ah_development_oud_ud_taksimleri_-_02_-_huezzam_taksim_pt1
282 | ah_development_percussion_triangle
283 | ah_development_piano_p3
284 | ah_development_piano_piano1
285 | ah_test_mixtures_classic3
286 | ah_test_mixtures_jazz3
287 | ah_test_mixtures_rock1
288 | ah_test_oud_Diverse_-_02_-_Mahur_Pesrev_(Ud)
289 | ah_test_piano_PianoMic_full_stick
290 | ah_test_sax_herbieHanckockTrack06_pt1
291 | ah_test_tanburpluck_ta_1mc08
292 | ah_test_trumpet_chetbaker
293 | sb_Albums-Ballroom_Classics4-01(3.0-13.0)
294 | al_Albums-Cafe_Paradiso-10(1.5-11.5)
295 | sb_Albums-Chrisanne2-01(16.0-26.0)
296 | al_Albums-Latin_Jam2-09(5.3-15.3)
297 | api_RM-G008
298 | api_RM-J001
299 | ff123_ATrain
300 | ff123_BachS1007
301 | ff123_TheSource
302 | gs_mix1_0dB
303 | jpb_Jaillet15
304 | jpb_Jaillet65
305 | jpb_Jaillet74
306 | lame_BlackBird
307 | lame_testsignal2
308 | lame_velvet
309 | al_Media-103605(10.0-20.0)
310 | al_Media-103919(8.0-18.0)
311 | sb_Media-104111(5.0-15.0)
312 | al_Media-104917(15.6-25.6)
313 | al_Media-105020(11.1-21.1)
314 | sb_Media-105407(6.0-16.0)
315 | al_Media-105819(8.1-18.1)
316 | mit_track7
317 | SoundCheck2_71_Instrumental_Saxophone
318 | SoundCheck2_73_Instrumental_Tambourine
319 | mck_train2
320 | mck_train4
321 | mck_train6
322 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import glob
3 | import torch
4 | from torch.utils import data
5 | from utils import onsetCNN, Dataset
6 | import numpy as np
7 | import matplotlib.pyplot as plt
8 | import os
9 | import utils
10 |
11 | #Use gpu
12 | use_cuda = torch.cuda.is_available()
13 | device = torch.device("cuda:1" if use_cuda else "cpu")
14 |
15 | #evaluation tolerance and merge duration for close onsets
16 | tolerance=60e-3 #+- tolerance/2 seconds
17 | mergeDur=20e-3
18 | hop_dur=10e-3
19 | mergeDur_frame=mergeDur/hop_dur
20 | tolerance_frame=tolerance/hop_dur
21 |
22 | fold = int(sys.argv[1]) #cmd line argument
23 |
24 | #load model
25 | path_to_saved_model = 'models/saved_model_%d.pt'%fold
26 | model = onsetCNN().double().to(device)
27 | model.load_state_dict(torch.load(path_to_saved_model))
28 | model.eval()
29 |
30 | #data
31 | datadir='/media/Sharedata/rohit/SS_onset_detection/data_pt/'
32 | #songlist=os.listdir(datadir)
33 | songlist=np.loadtxt('splits/8-fold_cv_random_%d.fold'%fold,dtype=str)
34 | labels = np.load('labels_master_test.npy').item()
35 |
36 | #loop over test songs
37 | scores=np.array([])
38 | n_songs=len(songlist)
39 | i_song=0
40 | for song in songlist:
41 | print('%d/%d songs\n'%(i_song,n_songs))
42 | i_song+=1
43 |
44 | odf=np.array([])
45 | gt=np.array([])
46 |
47 | #generate frame-wise labels serially for song
48 | n_files=len(glob.glob(os.path.join(datadir,song+'/*.pt')))
49 | for i_file in range(n_files):
50 | x=torch.load(os.path.join(datadir,song+'/%d.pt'%i_file)).to(device)
51 | x=x.unsqueeze(0)
52 | y=model(x).squeeze().cpu().detach().numpy()
53 | odf=np.append(odf,y)
54 | gt=np.append(gt,labels[os.path.join(datadir,song+'/%d.pt'%i_file)])
55 |
56 | #evaluate odf
57 | scores_thresh=np.array([])
58 | #loop over different peak-picking thresholds to optimize F-score
59 | for predict_thresh in [0.3, 0.4, 0.5, 0.6, 0.7, 0.8]:
60 | odf_labels=np.zeros(len(odf))
61 |
62 | #pick peaks
63 | odf_labels[utils.peakPicker(odf,predict_thresh)]=1.
64 |
65 | #evaluate, get #hits and #misses
66 | scores_thresh=np.append(scores_thresh,utils.eval_output(odf_labels, odf, gt, tolerance_frame, mergeDur_frame))
67 |
68 | #accumulate hits and misses for every song
69 | if len(scores)==0: scores=np.atleast_2d(np.array(scores_thresh))
70 | else: scores=np.vstack((scores,np.atleast_2d(np.array(scores_thresh))))
71 |
72 | #add hits and misses over all songs (to compute testset P, R and F-score)
73 | scores=np.sum(scores,0)
74 |
75 | # Write to file
76 | fout=open('hr_fa_folds.txt','a')
77 | for item in scores:
78 | fout.write('%d\t'%item)
79 | fout.write('\n')
80 | fout.close()
81 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import glob
3 | import torch
4 | from torch.utils import data
5 | from utils import onsetCNN, Dataset
6 | import numpy as np
7 | import matplotlib.pyplot as plt
8 |
9 | #function to repeat positive samples to improve data balance
10 | def balance_data(ids, labels):
11 | ids2add=[]
12 | for idi in ids:
13 | if labels[idi]==1:
14 | ids2add.append(idi)
15 | ids2add.append(idi)
16 | ids2add.append(idi)
17 | return ids2add
18 |
19 | #use GPU
20 | use_cuda = torch.cuda.is_available()
21 | device = torch.device("cuda:1")#torch.device("cuda:0" if use_cuda else "cpu")
22 |
23 | #parameters for data loader
24 | params = {'batch_size': 256,'shuffle': True,'num_workers': 6}
25 | max_epochs = 50
26 |
27 | #data
28 | datadir='/media/Sharedata/rohit/SS_onset_detection/data_pt/'
29 | songlist=np.loadtxt('songlist.txt',dtype=str)
30 | labels = np.load('labels_master.npy').item()
31 | weights = np.load('weights_master.npy').item()
32 |
33 | #model
34 | model=onsetCNN().double().to(device)
35 | criterion=torch.nn.BCELoss(reduction='none')
36 | optimizer=torch.optim.SGD(model.parameters(), lr=0.05, momentum=0.45)
37 | #optimizer=torch.optim.Adam(model.parameters(), lr=0.05)
38 |
39 | #cross-validation loop
40 | fold = int(sys.argv[1]) #cmd line argument
41 | partition = {'all':[], 'train':[], 'validation':[]}
42 | val_split = np.loadtxt('splits/8-fold_cv_random_%d.fold'%fold,dtype='str')
43 | for song in songlist:
44 | ids = glob.glob(datadir+song+'/*.pt')
45 | if song in val_split: partition['validation'].extend(ids)
46 | else: partition['train'].extend(ids)
47 |
48 | #balance data
49 | #partition['train'].extend(balance_data(partition['train'],labels))
50 |
51 | #print data balance percentage
52 | n_ones=0.
53 | for idi in partition['train']:
54 | if labels[idi]==1.: n_ones+=1
55 | print('Fraction of positive examples: %f'%(n_ones/len(partition['train'])))
56 |
57 | #generators
58 | training_set = Dataset(partition['train'], labels, weights)
59 | training_generator = data.DataLoader(training_set, **params)
60 |
61 | validation_set = Dataset(partition['validation'], labels, weights)
62 | validation_generator = data.DataLoader(validation_set, **params)
63 |
64 | #training epochs loop
65 | train_loss_epoch=[]
66 | val_loss_epoch=[]
67 | for epoch in range(max_epochs):
68 | train_loss_epoch+=[0]
69 | val_loss_epoch+=[0]
70 |
71 | ##training
72 | n_train=0
73 | for local_batch, local_labels, local_weights in training_generator:
74 | n_train+=local_batch.shape[0]
75 |
76 | #transfer to GPU
77 | local_batch, local_labels, local_weights = local_batch.to(device), local_labels.to(device), local_weights.to(device)
78 |
79 | #update weights
80 | optimizer.zero_grad()
81 | outs = model(local_batch).squeeze()
82 | loss = criterion(outs, local_labels)
83 | loss = torch.dot(loss,local_weights)
84 | loss /= local_batch.size()[0]
85 | loss.backward()
86 | optimizer.step()
87 | train_loss_epoch[-1]+=loss.item()
88 | train_loss_epoch[-1]/=n_train
89 |
90 | ##validation
91 | n_val=0
92 | with torch.set_grad_enabled(False):
93 | for local_batch, local_labels, local_weights in validation_generator:
94 | n_val+=local_batch.shape[0]
95 |
96 | #transfer to GPU
97 | local_batch, local_labels = local_batch.to(device), local_labels.to(device)
98 |
99 | #evaluate model
100 | outs = model(local_batch).squeeze()
101 | loss = criterion(outs, local_labels).mean()
102 | val_loss_epoch[-1]+=loss.item()
103 | val_loss_epoch[-1]/=n_val
104 |
105 | #print loss in current epoch
106 | print('Epoch no: %d/%d\tTrain loss: %f\tVal loss: %f'%(epoch, max_epochs, train_loss_epoch[-1], val_loss_epoch[-1]))
107 |
108 | #update LR and momentum (only if using SGD)
109 | for param_group in optimizer.param_groups:
110 | param_group['lr'] *= 0.995
111 | if 10<=epoch<=20: param_group['momentum'] += 0.045
112 |
113 | #plot losses vs epoch
114 | plt.plot(train_loss_epoch,label='train')
115 | plt.plot(val_loss_epoch,label='val')
116 | plt.legend()
117 | plt.savefig('./plots/loss_curves_%d'%fold)
118 | plt.clf()
119 | torch.save(model.state_dict(), 'saved_model_%d.pt'%fold)
120 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.utils import data
6 |
7 | #model
8 | class onsetCNN(nn.Module):
9 | def __init__(self):
10 | super(onsetCNN, self).__init__()
11 | self.conv1 = nn.Conv2d(3, 10, (3,7))
12 | self.pool1 = nn.MaxPool2d((3,1))
13 | self.conv2 = nn.Conv2d(10, 20, 3)
14 | self.pool2 = nn.MaxPool2d((3,1))
15 | self.fc1 = nn.Linear(20 * 7 * 8, 256)
16 | self.fc2 = nn.Linear(256,1)
17 | self.dout = nn.Dropout(p=0.5)
18 |
19 | def forward(self,x):
20 | y=torch.tanh(self.conv1(x))
21 | y=self.pool1(y)
22 | y=torch.tanh(self.conv2(y))
23 | y=self.pool2(y)
24 | y=self.dout(y.view(-1,20*7*8))
25 | y=self.dout(torch.sigmoid(self.fc1(y)))
26 | y=torch.sigmoid(self.fc2(y))
27 | return y
28 |
29 | #data-loader(https://stanford.edu/~shervine/blog/pytorch-how-to-generate-data-parallel)
30 | class Dataset(data.Dataset):
31 | 'Characterizes a dataset for PyTorch'
32 | def __init__(self, list_IDs, labels, weights):
33 | 'Initialization'
34 | self.labels = labels
35 | self.weights = weights
36 | self.list_IDs = list_IDs
37 |
38 | def __len__(self):
39 | 'Denotes the total number of samples'
40 | return len(self.list_IDs)
41 |
42 | def __getitem__(self, index):
43 | 'Generates one sample of data'
44 | # Select sample
45 | ID = self.list_IDs[index]
46 |
47 | # Load data and get label
48 | #X = torch.tensor(np.load(ID))
49 | X = torch.load(ID)
50 | y = self.labels[ID]#.replace('.npy','')]
51 | w = self.weights[ID]#.replace('.npy','')]
52 |
53 | return X, y, w
54 |
55 | #peak-picking function
56 | def peakPicker(data, peakThresh):
57 | peaks=np.array([],dtype='int')
58 | for ind in range(1,len(data)-1):
59 | if ((data[ind+1] < data[ind] > data[ind-1]) & (data[ind]>peakThresh)):
60 | peaks=np.append(peaks,ind)
61 | return peaks
62 |
63 | #merge onsets if too close - retain only stronger one
64 | def merge_onsets(onsets,strengths,mergeDur):
65 | onsetLocs=np.where(onsets==1)[0]
66 | ind=1
67 | while ind