├── .gitignore ├── 50_tags.txt ├── README.md ├── __init__.py ├── annot_processor.py ├── audio_processor.py ├── config.py ├── data_loader.py ├── eval_tags.py ├── main.py ├── model.py ├── requirements.txt ├── solver.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | .DS_Store 104 | 105 | # Created by https://www.gitignore.io/api/pycharm 106 | 107 | ### PyCharm ### 108 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 109 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 110 | 111 | # User-specific stuff: 112 | .idea/**/workspace.xml 113 | .idea/**/tasks.xml 114 | .idea/dictionaries 115 | 116 | # Sensitive or high-churn files: 117 | .idea/**/dataSources/ 118 | .idea/**/dataSources.ids 119 | .idea/**/dataSources.xml 120 | .idea/**/dataSources.local.xml 121 | .idea/**/sqlDataSources.xml 122 | .idea/**/dynamic.xml 123 | .idea/**/uiDesigner.xml 124 | 125 | # Gradle: 126 | .idea/**/gradle.xml 127 | .idea/**/libraries 128 | 129 | # CMake 130 | cmake-build-debug/ 131 | 132 | # Mongo Explorer plugin: 133 | .idea/**/mongoSettings.xml 134 | 135 | ## File-based project format: 136 | *.iws 137 | 138 | ## Plugin-specific files: 139 | 140 | # IntelliJ 141 | /out/ 142 | 143 | # mpeltonen/sbt-idea plugin 144 | .idea_modules/ 145 | 146 | # JIRA plugin 147 | atlassian-ide-plugin.xml 148 | 149 | # Cursive Clojure plugin 150 | .idea/replstate.xml 151 | 152 | # Ruby plugin and RubyMine 153 | /.rakeTasks 154 | 155 | # Crashlytics plugin (for Android Studio and IntelliJ) 156 | com_crashlytics_export_strings.xml 157 | crashlytics.properties 158 | crashlytics-build.properties 159 | fabric.properties 160 | 161 | ### PyCharm Patch ### 162 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 163 | 164 | *.iml 165 | modules.xml 166 | .idea/misc.xml 167 | *.ipr 168 | .idea/vcs.xml 169 | .idea/inspectionProfiles/Project_Default.xml 170 | 171 | # Sonarlint plugin 172 | .idea/sonarlint 173 | 174 | # End of https://www.gitignore.io/api/pycharm 175 | -------------------------------------------------------------------------------- /50_tags.txt: -------------------------------------------------------------------------------- 1 | guitar 2 | classic 3 | slow 4 | techno 5 | string 6 | vocal 7 | electronic 8 | drum 9 | no vocal 10 | rock 11 | fast 12 | male 13 | beat 14 | female 15 | piano 16 | ambient 17 | violin 18 | synth 19 | indian 20 | singer 21 | opera 22 | harpsichord 23 | loud 24 | quiet 25 | flute 26 | pop 27 | soft 28 | sitar 29 | solo 30 | choir 31 | new age 32 | dance 33 | weird 34 | harp 35 | heavy 36 | cello 37 | jazz 38 | country 39 | eastern 40 | bass 41 | modern 42 | no piano 43 | hard 44 | chant 45 | baroque 46 | orchestra 47 | foreign 48 | trance 49 | folk 50 | no beat 51 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Sample-level Deep CNN 2 | Pytorch implementation of [Sample-level Deep Convolutional Neural Networks for Music Auto-tagging Using Raw Waveforms](https://arxiv.org/abs/1703.01789) 3 | 4 | ### Data 5 | [MagnaTagATune Dataset](http://mirg.city.ac.uk/codeapps/the-magnatagatune-dataset) 6 | * Used tag annotations and audio data 7 | 8 | ### Model 9 | 9 1D conv layers and input sample size of 59049 (~3 seconds) 10 | 11 | ### Procedures 12 | * Fix `config.py` file 13 | * Data processing 14 | * run ` python audio_processor.py ` : audio (to read audio signal from mp3s and save as npy) 15 | * run ` python annot_processor.py ` : annotation (process redundant tags and select top N=50 tags) 16 | * this will create and save train/valid/test annotation files 17 | * Training 18 | * You can set multigpu option by listing all the available devices 19 | * Ex. ` python main.py --gpus 0 1` 20 | * Ex. ` python main.py ` will use 1 gpu if available as a default 21 | 22 | ### Tag prediction 23 | * run `python eval_tags.py --gpus 0 1 --mp3_file "path/to/mp3file/to/predict.mp3" ` 24 | 25 | ### References 26 | * [https://github.com/jongpillee/sampleCNN](https://github.com/jongpillee/sampleCNN) 27 | * [https://github.com/tae-jun/sample-cnn](https://github.com/tae-jun/sample-cnn) 28 | * [https://github.com/keunwoochoi/magnatagatune-list](https://github.com/keunwoochoi/magnatagatune-list) 29 | 30 | 31 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from . import * -------------------------------------------------------------------------------- /annot_processor.py: -------------------------------------------------------------------------------- 1 | ''' Functions for processing the MTT annotation by selecting top N tags and dividing the dataset into train/valid/test set ''' 2 | 3 | import os 4 | import pandas as pd 5 | import numpy as np 6 | import config 7 | np.random.seed(0) 8 | 9 | def _merge_redundant_tags(filename): 10 | ''' Some tags are considered to be redundant, so it seems reasonable to do some cleanup. Tag organization by https://github.com/keunwoochoi/magnatagatune-list 11 | Args : 12 | filename : path to the MTT annotation csv file 13 | Return : 14 | new_df : pandas dataframe with merged tags 15 | ''' 16 | synonyms = [['beat', 'beats'], 17 | ['chant', 'chanting'], 18 | ['choir', 'choral'], 19 | ['classic', 'clasical', 'classical'], 20 | ['drum', 'drums'], 21 | ['electronic', 'electro', 'electronica', 'electric'], 22 | ['fast', 'fast beat', 'quick'], 23 | ['female', 'female singer', 'female singing', 'female vocal', 'female vocals', 'female voice', 'woman', 'woman singing', 'women'], 24 | ['flute', 'flutes'], 25 | ['guitar', 'guitars'], 26 | ['hard', 'hard rock'], 27 | ['harpsichord', 'harpsicord'], 28 | ['heavy', 'heavy metal', 'metal'], 29 | ['horn', 'horns'], 30 | ['indian', 'india'], 31 | ['jazz', 'jazzy'], 32 | ['male', 'male singer', 'male vocal', 'male vocals', 'male voice', 'man', 'man singing', 'men'], 33 | ['no beat', 'no drums'], 34 | ['no vocal', 'no singing', 'no singer','no vocals', 'no voice', 'no voices', 'instrumental'], 35 | ['opera', 'operatic'], 36 | ['orchestra', 'orchestral'], 37 | ['quiet', 'silence'], 38 | ['singer', 'singing'], 39 | ['space', 'spacey'], 40 | ['string', 'strings'], 41 | ['synth', 'synthesizer'], 42 | ['violin', 'violins'], 43 | ['vocal', 'vocals', 'voice', 'voices'], 44 | ['weird', 'strange']] 45 | 46 | synonyms_correct = [synonyms[i][0] for i in range(len(synonyms))] 47 | synonyms_redundant = [synonyms[i][1:] for i in range(len(synonyms))] 48 | 49 | df = pd.read_csv(filename, delimiter='\t') 50 | new_df = df.copy() 51 | 52 | for i in range(len(synonyms)): 53 | for j in range(len(synonyms_redundant[i])): 54 | redundant_df = df[synonyms_redundant[i][j]] 55 | new_df[synonyms_correct[i]] = (new_df[synonyms_correct[i]] + redundant_df) > 0 56 | new_df[synonyms_correct[i]] = new_df[synonyms_correct[i]].astype(int) 57 | new_df.drop(synonyms_redundant[i][j] ,1, inplace=True) 58 | return new_df 59 | 60 | def reduce_to_N_tags(filename, base_dir, n_tops=config.NUM_TAGS, merge=True): 61 | ''' There are a lot of tags, so reduce it to top N popular tags 62 | Args : 63 | filename : path to MTT annotation csv file 64 | base_dir : path to general project directory 65 | n_tops : number of tags to reduce to 66 | merge : combine similar tags, like female vocal & female vocals & women 67 | Return : 68 | new_filename : path to the new processed csv file with reduced tags 69 | ''' 70 | if merge: 71 | df = _merge_redundant_tags(filename) 72 | else : 73 | df = pd.read_csv(filename, delimiter='\t') 74 | print (df.drop(['clip_id', 'mp3_path'], axis=1).sum(axis=0).sort_values()) 75 | topN = df.drop(['clip_id', 'mp3_path'], axis=1).sum(axis=0).sort_values().tail(n_tops).index.tolist()[::-1] 76 | print (len(topN), topN) 77 | taglist_f = open(str(n_tops) + '_tags.txt', 'w') 78 | for tag in topN: 79 | taglist_f.write(tag+'\n') 80 | taglist_f.close() 81 | 82 | # df = df[topN + ['clip_id', 'mp3_path']] 83 | df = pd.concat([df.loc[:,topN], df.loc[:,'clip_id'], df.loc[:,'mp3_path']], axis=1) 84 | 85 | # remove rows with all 0 labels 86 | df = df.loc[~(df.loc[:, topN] == 0).all(axis=1)] 87 | print (df.shape) 88 | # save new csv file 89 | new_filename = base_dir + str(n_tops) + '_tags_' + filename.split('/')[-1] 90 | df.to_csv(new_filename, sep='\t', encoding='utf-8', index=False) 91 | return new_filename 92 | 93 | def split_data(filename, base_dir, ratio=0.2): 94 | ''' Split into train/val/test and saves each set to a new file 95 | Args: 96 | filename : path to the MTT annotation csv file 97 | base_dir : path to the general project directory 98 | Return : 99 | None 100 | ''' 101 | 102 | df = pd.read_csv(filename, delimiter='\t') 103 | data_len = df.shape[0] 104 | print ("Data shape {}".format(df.shape)) 105 | 106 | test_len = int (data_len * ratio) 107 | train_valid_len = data_len - test_len 108 | valid_len = int(train_valid_len * ratio) 109 | train_len = train_valid_len - valid_len 110 | print ("Train %d, valid %d, test %d"%(train_len, valid_len, test_len)) 111 | 112 | # add headers to all files 113 | test_df = df.iloc[train_valid_len:] 114 | valid_df = df.iloc[train_len : train_valid_len] 115 | train_df = df.iloc[:train_len] 116 | 117 | # save each test, valid, train files 118 | f = filename.split('/')[-1] 119 | test_df.to_csv(base_dir + 'test_' + f, sep='\t',index=False) 120 | valid_df.to_csv(base_dir + 'valid_' + f, sep='\t',index=False) 121 | train_df.to_csv(base_dir + 'train_' + f, sep='\t', index=False) 122 | 123 | 124 | if __name__ == "__main__": 125 | new_csvfile = reduce_to_N_tags(config.ANNOT_FILE, config.BASE_DIR) 126 | split_data(new_csvfile, base_dir) 127 | 128 | -------------------------------------------------------------------------------- /audio_processor.py: -------------------------------------------------------------------------------- 1 | ''' Run this file to process raw audio ''' 2 | import os, errno 3 | import numpy as np 4 | import torch 5 | import librosa 6 | from pathlib import Path 7 | import config 8 | 9 | 10 | def save_audio_to_npy(rawfilepath, npyfilepath): 11 | ''' Save audio signal with sr=sample_rate to npy file 12 | Args : 13 | rawfilepath : path to the MTT audio files 14 | npyfilepath : path to save the numpy array audio signal 15 | Return : 16 | None 17 | ''' 18 | 19 | # make directory if not existing 20 | if not os.path.exists(npyfilepath): 21 | os.makedirs(npyfilepath) 22 | 23 | 24 | mydir = [path for path in os.listdir(rawfilepath) if path >= '0' and path <= 'f'] 25 | for path in mydir : 26 | # create directory with names '0' to 'f' if it doesn't already exist 27 | try: 28 | os.mkdir(Path(npyfilepath) / path) 29 | except OSError as e: 30 | if e.errno != errno.EEXIST: 31 | raise 32 | audios = [audio for audio in os.listdir(Path(rawfilepath) / path) if audio.split(".")[-1] == 'mp3'] 33 | for audio in audios : 34 | try: 35 | y,sr = librosa.load(audio, sr=config.SR) 36 | if len(y)/self.NUM_SAMPLES < 10: 37 | print ("There are less than 10 segments in this audio") 38 | except: 39 | print ("Cannot load audio {}".format(audio)) 40 | continue 41 | 42 | fn = audio.split(".")[0] 43 | np.save(Path(npyfilepath) / (path + '/' + fn + '.npy'), y) 44 | 45 | 46 | def get_segment_from_npy(npyfile, segment_idx): 47 | ''' Return random segment of length num_samples from the audio 48 | Args : 49 | npyfile : path to all the npy files each containing audio signals 50 | segment_idx : index of the segment to retrieve; max(segment_idx) = total_samples//num_samples 51 | Return : 52 | segment : audio signal of length num_samples 53 | ''' 54 | song = np.load(npyfile) 55 | # randidx = np.random.randint(10) 56 | try : 57 | segment = song[segment_idx * config.NUM_SAMPLES : (segment_idx+1)*config.NUM_SAMPLES] 58 | except : 59 | randidx = np.random.randint(10) 60 | get_segment_from_npy(npyfile, randidx, config.NUM_SAMPLES) 61 | return segment 62 | 63 | 64 | 65 | if __name__ =='__main__': 66 | # read audio signal and save to npy format 67 | save_audio_to_npy(config.MTT_DIR, config.AUDIO_DIR) 68 | 69 | 70 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | DATA_DIR = './data/' 2 | BASE_DIR = './data/sampleCNN-data/' # data dir for this model 3 | MTT_DIR = './data/MaganatagatuePub/mp3/' # MTT data dir 4 | AUDIO_DIR = './data/pubMagnatagatune_mp3s_to_npy/' 5 | ANNOT_FILE = './data/annotations_final.csv' 6 | LIST_OF_TAGS = './data/sampleCNN-data/50_tags.txt' 7 | 8 | DEVICE_IDS=[0,1] 9 | 10 | # audio params 11 | SR = 22050 12 | NUM_SAMPLES = 59049 13 | NUM_TAGS = 50 14 | 15 | # train params 16 | BATCH_SIZE = 64 17 | LR = 0.008 18 | DROPOUT = 0.5 19 | NUM_EPOCHS = 100 20 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import pandas as pd 4 | import numpy as np 5 | from pathlib import Path 6 | from audio_processor import get_segment_from_npy 7 | import config 8 | 9 | 10 | ''' 11 | Load Dataset (divided into train/validate/test sets) 12 | * audio data : saved as segments in npy file 13 | * labels : 50-d labels in csv file 14 | ''' 15 | 16 | class SampleLevelMTTDataset(Dataset): 17 | def __init__(self): 18 | ''' 19 | Args : 20 | csvfile : train/val/test csvfiles 21 | audio_dir : directory that contains folders 0 - f 22 | ''' 23 | self.tag_list = open(config.LIST_OF_TAGS, 'r').read().split('\n') 24 | self.audio_dir = config.AUDIO_DIR 25 | self.num_tags = config.NUM_TAGS 26 | 27 | 28 | self.set_mode('train') 29 | 30 | 31 | def set_mode(self, mode): 32 | print ("dataset mode: ", mode) 33 | if mode == 'train': 34 | self.annotation_file = Path(config.BASE_DIR) / 'train_50_tags_annotations_final.csv' 35 | 36 | elif mode == 'valid': 37 | self.annotation_file = Path(config.BASE_DIR) / 'valid_50_tags_annotations_final.csv' 38 | 39 | elif mode == 'test': 40 | self.annotation_file = Path(config.BASE_DIR) / 'test_50_tags_annotations_final.csv' 41 | 42 | 43 | self.annotations_frame = pd.read_csv(self.annotation_file, delimiter='\t') # df 44 | self.labels = self.annotations_frame.drop(['clip_id', 'mp3_path'], axis=1) 45 | 46 | 47 | # get one segment (==59049 samples) and its 50-d label 48 | def __getitem__(self, index): 49 | idx = index // 10 50 | segment_idx = index % 10 51 | 52 | mp3filename = self.annotations_frame.iloc[idx]['mp3_path'].split('.')[0]+'.npy' 53 | try : 54 | segment = get_segment_from_npy(self.audio_dir + mp3filename, segment_idx) 55 | except : 56 | new_index = index-1 if index > 0 else index +1 57 | return self.__getitem__(new_index) 58 | 59 | # build label in the order of 50_tags.txt 60 | label = np.zeros(self.num_tags) 61 | for i,tag in enumerate(self.tag_list): 62 | if tag == '': 63 | continue 64 | if self.annotations_frame[tag].iloc[idx] == 1: 65 | label[i] = 1 66 | label = torch.FloatTensor(label) 67 | entry = {'audio': segment, 'label': label} 68 | return entry 69 | 70 | def __len__(self): 71 | return self.annotations_frame.shape[0] * 10 72 | -------------------------------------------------------------------------------- /eval_tags.py: -------------------------------------------------------------------------------- 1 | ''' Functions to evaluate the tag predictions ''' 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | import pandas as pd 6 | import numpy as np 7 | import librosa 8 | import argparse 9 | from audio_processor import get_segment_from_npy 10 | from model import SampleCNN 11 | import config 12 | import utils 13 | 14 | 15 | cuda = torch.cuda.is_available() 16 | print("gpu available :", cuda) 17 | device = torch.device("cuda" if cuda else "cpu") 18 | num_gpu = torch.cuda.device_count() 19 | torch.cuda.manual_seed(5) 20 | 21 | 22 | parser= argparse.ArgumentParser() 23 | parser.add_argument('--gpus', nargs='+', type=int, default=[]) 24 | parser.add_argument('--mp3_file', type=str) 25 | args = parser.parse_args() 26 | print (args) 27 | 28 | if len(args.gpus) > 1 : 29 | multigpu = True 30 | else : 31 | multigpu = False 32 | 33 | utils.handle_multigpu(multigpu, args.gpus, num_gpu) 34 | 35 | def get_taglist(csvfile): 36 | ''' Get the human readable ordered list of tags as saved in csv file 37 | ''' 38 | df = pd.read_csv(csvfile, delimiter=',') 39 | l = list(df)[1:] 40 | l.remove('clip_id') 41 | l.remove('mp3_path') 42 | return l 43 | 44 | def load_model(model, saved_state): 45 | ''' Load the trained model 46 | Args : 47 | model : initialized model with no state 48 | saved_state : path to a specific model state 49 | Return 50 | model : trained model or None if not existing 51 | ''' 52 | if os.path.isfile(saved_state): 53 | model.load_state_dict(torch.load(saved_state)) 54 | print ("Model loaded") 55 | return model 56 | else : 57 | print ("Model not found..") 58 | return 59 | 60 | def predict_topN_tags(model, base_dir, song, N=5): 61 | ''' Predict tags for the given audio files 62 | Args : 63 | model : path to trained model 64 | song : path to the song to predict (mp3 file) 65 | N : number of top N tag predictions to see 66 | Return : 67 | predicted_tags : list of N predicted tags 68 | ''' 69 | taglist = open(config.LIST_OF_TAGS, 'r').read().split('\n') 70 | if len(taglist) != 50: 71 | print ("more than 50 tags? %d"%len(taglist), "fix..") 72 | for tag in taglist : 73 | if tag =='': 74 | taglist.remove(tag) 75 | print ("%d tags in total"%len(taglist)) 76 | taglist = np.array(taglist) 77 | 78 | print ("Evaluating %s"%song) 79 | y, sr = librosa.load(song, sr=config.SR) 80 | print ("%d samples with %d sample rate"%(len(y), sr)) 81 | 82 | # select middle 29.1secs(10 segments) and average them 83 | segments = [] 84 | num_segments = 10 85 | if len(y) < (config.NUM_SAMPLES * 10) : 86 | num_segments = y//config.NUM_SAMPLES 87 | print ("Number of segments to calculate %d"%num_segments) 88 | 89 | start_index = len(y)//2 - (config.NUM_SAMPLES*10)//2 90 | for i in range(num_segments): 91 | segments.append(y[start_index + (i*config.NUM_SAMPLES) : start_index + (i+1) * config.NUM_SAMPLES]) 92 | 93 | # predict value for each segment 94 | calculated_val = [] 95 | for segment in segments : 96 | segment = torch.FloatTensor(segment) 97 | segment = segment.view(1, segment.shape[0]).to(device) 98 | 99 | model.eval() 100 | out = model(segment) 101 | sigmoid = nn.Sigmoid() 102 | out = sigmoid(out) 103 | out = out.detach().cpu().numpy() 104 | calculated_val.append(out) 105 | 106 | # average 10 segment values 107 | calculated_val = np.array(calculated_val) 108 | print (calculated_val.shape) 109 | avg_val = np.sum(calculated_val, axis=0) /10 110 | 111 | # sort tags 112 | sorted_tags = np.argsort(avg_val)[::-1][:N] 113 | print (sorted_tags) 114 | predicted_tags = [] 115 | for idx in sorted_tags: 116 | predicted_tags.append(taglist[idx]) 117 | print (predicted_tags) 118 | 119 | 120 | if __name__ =='__main__': 121 | saved_state = 'SampleCNN-singletag.pth' 122 | samplecnn_model = SampleCNN() 123 | model = load_model(samplecnn_model, saved_state) 124 | if multigpu : 125 | model = torch.nn.DataParallel(model, device_ids=args.gpus) 126 | 127 | model.to(device) 128 | 129 | # Predict top 5 tags 130 | predict_topN_tags(model, config.BASE_DIR , args.mp3_file) 131 | 132 | 133 | 134 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import model 3 | from data_loader import SampleLevelMTTDataset 4 | import argparse 5 | import model 6 | from solver import Solver 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--gpus', nargs='+', type=int, default=[]) 10 | args = parser.parse_args() 11 | 12 | print ("gpu devices being used: ", args.gpus) 13 | 14 | def main() : 15 | 16 | dataset = SampleLevelMTTDataset() 17 | samplecnn = model.SampleCNN() 18 | 19 | # start training 20 | print ("Start training!!") 21 | mysolver = Solver(samplecnn, dataset, args) 22 | mysolver.train() 23 | 24 | print ("Finished! Hopefully..") 25 | 26 | # test it 27 | print ("Start testing...") 28 | 29 | 30 | 31 | 32 | if __name__ == '__main__': 33 | main() 34 | 35 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import config 4 | 5 | 6 | class SampleCNN(nn.Module): 7 | def __init__(self): 8 | super(SampleCNN, self).__init__() 9 | 10 | # 59049 x 1 11 | self.conv1 = nn.Sequential( 12 | nn.Conv1d(1, 128, kernel_size=3, stride=3, padding=0), 13 | nn.BatchNorm1d(128), 14 | nn.ReLU()) 15 | # 19683 x 128 16 | self.conv2 = nn.Sequential( 17 | nn.Conv1d(128, 128, kernel_size=3, stride=1, padding=1), 18 | nn.BatchNorm1d(128), 19 | nn.ReLU(), 20 | nn.MaxPool1d(3, stride=3)) 21 | # 6561 x 128 22 | self.conv3 = nn.Sequential( 23 | nn.Conv1d(128, 128, kernel_size=3, stride=1, padding=1), 24 | nn.BatchNorm1d(128), 25 | nn.ReLU(), 26 | nn.MaxPool1d(3,stride=3)) 27 | # 2187 x 128 28 | self.conv4 = nn.Sequential( 29 | nn.Conv1d(128, 256, kernel_size=3, stride=1, padding=1), 30 | nn.BatchNorm1d(256), 31 | nn.ReLU(), 32 | nn.MaxPool1d(3,stride=3)) 33 | # 729 x 256 34 | self.conv5 = nn.Sequential( 35 | nn.Conv1d(256, 256, kernel_size=3, stride=1, padding=1), 36 | nn.BatchNorm1d(256), 37 | nn.ReLU(), 38 | nn.MaxPool1d(3,stride=3)) 39 | # 243 x 256 40 | self.conv6 = nn.Sequential( 41 | nn.Conv1d(256, 256, kernel_size=3, stride=1, padding=1), 42 | nn.BatchNorm1d(256), 43 | nn.ReLU(), 44 | nn.MaxPool1d(3,stride=3), 45 | nn.Dropout(config.DROPOUT)) 46 | # 81 x 256 47 | self.conv7 = nn.Sequential( 48 | nn.Conv1d(256, 256, kernel_size=3, stride=1, padding=1), 49 | nn.BatchNorm1d(256), 50 | nn.ReLU(), 51 | nn.MaxPool1d(3,stride=3)) 52 | # 27 x 256 53 | self.conv8 = nn.Sequential( 54 | nn.Conv1d(256, 256, kernel_size=3, stride=1, padding=1), 55 | nn.BatchNorm1d(256), 56 | nn.ReLU(), 57 | nn.MaxPool1d(3,stride=3)) 58 | # 9 x 256 59 | self.conv9 = nn.Sequential( 60 | nn.Conv1d(256, 256, kernel_size=3, stride=1, padding=1), 61 | nn.BatchNorm1d(256), 62 | nn.ReLU(), 63 | nn.MaxPool1d(3,stride=3)) 64 | # 3 x 256 65 | self.conv10 = nn.Sequential( 66 | nn.Conv1d(256, 512, kernel_size=3, stride=1, padding=1), 67 | nn.BatchNorm1d(512), 68 | nn.ReLU(), 69 | nn.MaxPool1d(3,stride=3)) 70 | # 1 x 512 71 | self.conv11 = nn.Sequential( 72 | nn.Conv1d(512, 512, kernel_size=3, stride=1, padding=1), 73 | nn.BatchNorm1d(512), 74 | nn.ReLU(), 75 | nn.Dropout(config.DROPOUT)) 76 | # 1 x 512 77 | self.fc = nn.Linear(512, 50) 78 | self.activation = nn.Sigmoid() 79 | 80 | def forward(self, x): 81 | # input x : 23 x 59049 x 1 82 | # expected conv1d input : minibatch_size x num_channel x width 83 | 84 | x = x.view(x.shape[0], 1,-1) 85 | # x : 23 x 1 x 59049 86 | 87 | out = self.conv1(x) 88 | out = self.conv2(out) 89 | out = self.conv3(out) 90 | out = self.conv4(out) 91 | out = self.conv5(out) 92 | out = self.conv6(out) 93 | out = self.conv7(out) 94 | out = self.conv8(out) 95 | out = self.conv9(out) 96 | out = self.conv10(out) 97 | out = self.conv11(out) 98 | 99 | out = out.view(x.shape[0], out.size(1) * out.size(2)) 100 | logit = self.fc(out) 101 | 102 | #logit = self.activation(logit) 103 | 104 | return logit 105 | 106 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | librosa==0.5.1 2 | numpy==1.14.0 3 | pandas==0.23.0 4 | torch==0.4.0 5 | scikit_learn==0.19.1 6 | -------------------------------------------------------------------------------- /solver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from torch.utils.data import DataLoader 7 | from torch.optim import lr_scheduler 8 | import utils 9 | import config 10 | 11 | cuda = torch.cuda.is_available() 12 | kwargs = {'num_workers':1, 'pin_memory':True} if cuda else {} 13 | print ("gpu available :", cuda) 14 | device = torch.device("cuda" if cuda else "cpu") 15 | num_gpu = torch.cuda.device_count() 16 | torch.cuda.manual_seed(5) 17 | 18 | class Solver(object): 19 | def __init__(self, model, dataset, args): 20 | self.samplecnn = model 21 | self.dataset = dataset 22 | self.args = args 23 | 24 | self.curr_epoch = 0 25 | 26 | self.model_savepath = './model' 27 | if not os.path.exists(self.model_savepath): 28 | os.makedirs(self.model_savepath) 29 | 30 | # define loss function 31 | self.bce = nn.BCEWithLogitsLoss() 32 | 33 | self._initialize() 34 | self.set_mode('train') 35 | 36 | 37 | def _initialize(self): 38 | self.optimizer = torch.optim.SGD(self.samplecnn.parameters(), lr=config.LR, weight_decay=1e-6, momentum=0.9, nesterov=True) 39 | self.scheduler = lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.2, patience=2, verbose=True) 40 | 41 | 42 | # initialize cuda 43 | if len(self.args.gpus) > 1: 44 | self.multigpu = True 45 | else : 46 | self.multigpu = False 47 | 48 | utils.handle_multigpu(self.multigpu, self.args.gpus, num_gpu) 49 | 50 | 51 | if self.multigpu : 52 | self.samplecnn = nn.DataParallel(self.samplecnn, device_ids=self.args.gpus) 53 | 54 | self.samplecnn.to(device) 55 | 56 | 57 | def set_mode(self, mode): 58 | print ("solver mode : ", mode) 59 | if mode == 'train': 60 | self.samplecnn.train() 61 | self.dataset.set_mode('train') 62 | 63 | elif mode == 'valid' : 64 | self.samplecnn.eval() 65 | self.dataset.set_mode('valid') 66 | 67 | elif mode == 'test': 68 | self.samplecnn.eval() 69 | self.dataset.set_mode('test') 70 | 71 | self.dataloader = DataLoader(self.dataset, batch_size=config.BATCH_SIZE, shuffle=True, drop_last=True, **kwargs) 72 | 73 | 74 | 75 | def train(self) : 76 | # Train the network 77 | for epoch in range(config.NUM_EPOCHS): 78 | self.set_mode('train') 79 | 80 | avg_auc1 = [] 81 | avg_ap1 = [] 82 | avg_auc2 = [] 83 | avg_ap2 = [] 84 | 85 | for i, data in enumerate(self.dataloader): 86 | audio = data['audio'].to(device) 87 | label = data['label'].to(device) 88 | 89 | outputs = self.samplecnn(audio) 90 | loss = self.bce(outputs, label) 91 | self.optimizer.zero_grad() 92 | loss.backward() 93 | self.optimizer.step() 94 | 95 | if (i+1) % 10 == 0: 96 | print ("Epoch [%d/%d], Iter [%d/%d] loss : %.4f" % (epoch+1, config.NUM_EPOCHS, i+1, len(self.dataloader), loss.item())) 97 | 98 | # retrieval 99 | auc1, ap1 = utils.tagwise_aroc_ap(label.cpu().detach().numpy(), outputs.cpu().detach().numpy()) 100 | avg_auc1.append(np.mean(auc1)) 101 | avg_ap1.append(np.mean(ap1)) 102 | # annotation 103 | auc2, ap2 = utils.itemwise_aroc_ap(label.cpu().detach().numpy(), outputs.cpu().detach().numpy()) 104 | avg_auc2.append(np.mean(auc2)) 105 | avg_ap2.append(np.mean(ap2)) 106 | 107 | print ("Retrieval : AROC = %.3f, AP = %.3f / "%(np.mean(auc1), np.mean(ap1)), "Annotation : AROC = %.3f, AP = %.3f"%(np.mean(auc2), np.mean(ap2))) 108 | 109 | 110 | self.curr_epoch +=1 111 | 112 | print ("Retrieval : Average AROC = %.3f, AP = %.3f / "%(np.mean(avg_auc1), np.mean(avg_ap1)), "Annotation :Average AROC = %.3f, AP = %.3f"%(np.mean(avg_auc2), np.mean(avg_ap2))) 113 | print ('Evaluating...') 114 | eval_loss = self.eval() 115 | 116 | self.scheduler.step(eval_loss) # use the learning rate scheduler 117 | curr_lr = self.optimizer.param_groups[0]['lr'] 118 | print ('Learning rate : {}'.format(curr_lr)) 119 | if curr_lr < 1e-7: 120 | print ("Early stopping") 121 | break 122 | 123 | torch.save(self.samplecnn.module.state_dict(), self.model_savepath / self.samplecnn.module.__class__.__name__ + '_' + str(self.curr_epoch) + '.pth') 124 | 125 | 126 | # Validate the network on the val_loader (during training) or test_loader (for checking result) 127 | # During training use this function for validation data. 128 | def eval(): 129 | self.set_mode('valid') 130 | 131 | eval_loss = 0.0 132 | avg_auc1 = [] 133 | avg_ap1 = [] 134 | avg_auc2 = [] 135 | avg_ap2 = [] 136 | for i, data in enumerate(self.dataloader): 137 | audio = data['audio'].to(device) 138 | label = data['label'].to(device) 139 | 140 | outputs = self.samplecnn(audio) 141 | loss = self.bce(outputs, label) 142 | 143 | auc1, aprec1 = utils.tagwise_aroc_ap(label.cpu().detach().numpy(), outputs.cpu().detach.numpy()) 144 | avg_auc1.append(np.mean(auc1)) 145 | avg_ap1.append(np.mean(aprec1)) 146 | auc2, aprec2 = utils.itemwise_aroc_ap(label.cpu().detach.numpy(), outputs.cpu().detach.numpy()) 147 | avg_auc2.append(np.mean(auc2)) 148 | avg_ap2.append(np.mean(aprec2)) 149 | 150 | eval_loss += loss.data[0] 151 | 152 | avg_loss =eval_loss/len(val_loader) 153 | print ("Retrieval : Average AROC = %.3f, AP = %.3f / "%(np.mean(avg_auc1), np.mean(avg_ap1)), "Annotation : Average AROC = %.3f, AP = %.3f"%(np.mean(avg_auc2), np.mean(avg_ap2))) 154 | print ('Average loss: {:.4f} \n'. format(avg_loss)) 155 | return avg_loss 156 | 157 | 158 | if __name__ == '__main__': 159 | model = SampleCNN() 160 | model = model.load_state_dict(torch.load('SampleCNN.pth')) 161 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | from sklearn.metrics import roc_auc_score, average_precision_score 4 | 5 | # Evaluation functions 6 | 7 | def tagwise_aroc_ap(tags_true_binary, tags_predicted): 8 | ''' Retrieval : tag-wise (col wise) calculation ''' 9 | n_tags = tags_true_binary.shape[1] 10 | auc = [] 11 | aprec = [] 12 | 13 | for i in range(n_tags): 14 | if np.sum(tags_true_binary[:, i]) != 0: 15 | auc.append(roc_auc_score(tags_true_binary[:, i], tags_predicted[:, i])) 16 | aprec.append(average_precision_score(tags_true_binary[:, i], tags_predicted[:, i])) 17 | 18 | 19 | return auc, aprec 20 | 21 | def itemwise_aroc_ap(tags_true_binary, tags_predicted): 22 | ''' Annotation : item-wise(row wise) calculation ''' 23 | n_songs = tags_true_binary.shape[0] 24 | auc = [] 25 | aprec = [] 26 | 27 | for i in range(n_songs): 28 | if np.sum(tags_true_binary[i]) != 0: 29 | auc.append(roc_auc_score(tags_true_binary[i], tags_predicted[i])) 30 | aprec.append(average_precision_score(tags_true_binary[i], tags_predicted[i])) 31 | 32 | 33 | return auc, aprec 34 | 35 | 36 | # CUDA multigpu functions 37 | 38 | def handle_multigpu(multigpu, user_gpu_list, available_gpus): 39 | ''' Check if multigpu is going to be used correctly 40 | Args : 41 | multigpu : user preference on whether to use mult gpu or not (bool) 42 | user_gpu_list : list of user assigned GPUs 43 | available_gpus : number of gpus available on the system 44 | ''' 45 | 46 | note = '[GPU AVAILABLILITY]' 47 | 48 | if multigpu and available_gpus <= 1: 49 | print (note, "You don't have enough GPUs. Do not set any argument for --gpus") 50 | sys.exit() 51 | 52 | elif not multigpu and available_gpus > 1: 53 | print (note, "You have %d GPUs but only assigned 1. You can assign list of gpus with --gpus option to utilize multigpu functions"%available_gpus) 54 | 55 | elif len(user_gpu_list) > available_gpus: 56 | print (note, "You don't have enough GPUs. Check you system and reassign.") 57 | sys.exit() 58 | 59 | elif multigpu and available_gpus > 1 : 60 | print (note, "You assigned %d/%d available GPUs"%(len(user_gpu_list), available_gpus)) 61 | 62 | 63 | 64 | 65 | 66 | --------------------------------------------------------------------------------