├── fig ├── figure.png └── psla_poster_rs.png ├── src ├── dataloaders │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ └── audioset_dataset.cpython-37.pyc │ ├── __init__.py │ └── audioset_dataset.py ├── models │ ├── __pycache__ │ │ └── HigherModels.cpython-37.pyc │ ├── __init__.py │ ├── Models.py │ └── HigherModels.py ├── utilities │ ├── __init__.py │ ├── stats.py │ └── util.py ├── ensemble │ ├── as_ensemble.log │ ├── weight_averaging.py │ └── ensemble.py ├── gen_weight_file.py ├── label_enhancement │ ├── check_label_error.py │ ├── merge_type_1_2.py │ ├── fix_type1.py │ └── fix_type2.py ├── run.py └── traintest.py ├── .gitignore ├── requirements.txt ├── LICENSE ├── egs ├── fsd50k │ ├── run.sh │ ├── README.md │ ├── prep_fsd.py │ └── class_labels_indices.csv └── audioset │ ├── run.sh │ ├── README.md │ └── class_labels_indices.csv ├── pretrained_models └── README.md └── README.md /fig/figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuanGongND/psla/HEAD/fig/figure.png -------------------------------------------------------------------------------- /fig/psla_poster_rs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuanGongND/psla/HEAD/fig/psla_poster_rs.png -------------------------------------------------------------------------------- /src/dataloaders/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuanGongND/psla/HEAD/src/dataloaders/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/HigherModels.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuanGongND/psla/HEAD/src/models/__pycache__/HigherModels.cpython-37.pyc -------------------------------------------------------------------------------- /src/dataloaders/__pycache__/audioset_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuanGongND/psla/HEAD/src/dataloaders/__pycache__/audioset_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /src/utilities/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 10/4/20 10:16 AM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : __init__.py 7 | 8 | from .util import * 9 | from .stats import * -------------------------------------------------------------------------------- /src/dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 10/4/20 10:08 AM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : __init__.py.py 7 | 8 | from .audioset_dataset import AudiosetDataset -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 10/4/20 10:14 AM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : __init__.py.py 7 | 8 | from .HigherModels import * 9 | from .Models import * -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pth 2 | *.csv 3 | *.pdf 4 | /src/ensemble/copy_ensemble.sh 5 | /dropbox/* 6 | /src/label_enhancement/release_labelset.py 7 | .idea 8 | fig/model.png 9 | fig/psla.png 10 | fig/* 11 | !fig/psla_poster_rs.png 12 | .DS_Store 13 | !/egs/audioset/class_labels_indices.csv 14 | !/egs/fsd50k/class_labels_indices.csv 15 | *.pptx 16 | *.mp4 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | llvmlite==0.36.0 2 | matplotlib==3.4.2 3 | numba==0.53.1 4 | numpy==1.20.3 5 | scikit-learn==0.24.2 6 | scipy==1.6.3 7 | sklearn==0.0 8 | torch==1.6.0 9 | torchaudio==0.6.0 10 | torchvision==0.7.0 11 | wget==3.2 12 | zipp==3.4.1 13 | efficientnet-lite-pytorch==0.1.0 14 | efficientnet-pytorch==0.7.0 15 | -e git+https://github.com/pytorch/fairseq@bcc81f6d5291c3996c8b2472282458dead46343f#egg=fairseq -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2022, Yuan Gong 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /egs/fsd50k/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p sm 3 | #SBATCH -x sls-sm-1,sls-2080-[1,3],sls-1080-3,sls-sm-5 4 | #SBATCH -p gpu 5 | #SBATCH -x sls-titan-[0-2,9] 6 | #SBATCH --gres=gpu:4 7 | #SBATCH -c 4 8 | #SBATCH -n 1 9 | #SBATCH --mem=48000 10 | #SBATCH --job-name="psla_fsd" 11 | #SBATCH --output=./log_%j.txt 12 | 13 | set -x 14 | source ../../venv-psla/bin/activate 15 | export TORCH_HOME=./ 16 | 17 | att_head=4 18 | model=efficientnet 19 | psla=True 20 | eff_b=2 21 | batch_size=24 22 | 23 | if [ $psla == True ] 24 | then 25 | impretrain=True 26 | freqm=48 27 | timem=192 28 | mixup=0.5 29 | bal=True 30 | else 31 | impretrain=False 32 | freqm=0 33 | timem=0 34 | mixup=0 35 | bal=False 36 | fi 37 | 38 | lr=5e-4 39 | p=mean 40 | if [ $p == none ] 41 | then 42 | trpath=./datafiles/fsd50k_tr_full.json 43 | else 44 | trpath=./datafiles/fsd50k_tr_full_type1_2_${p}.json 45 | fi 46 | 47 | epoch=40 48 | wa_start=21 49 | wa_end=40 50 | lrscheduler_start=10 51 | 52 | exp_dir=./exp/demo-${model}-${eff_b}-${lr}-fsd50k-impretrain-${impretrain}-fm${freqm}-tm${timem}-mix${mixup}-bal-${bal}-b${batch_size}-le${p}-2 53 | mkdir -p $exp_dir 54 | 55 | CUDA_CACHE_DISABLE=1 python ../../src/run.py --data-train $trpath --data-val ./datafiles/fsd50k_val_full.json --data-eval ./datafiles/fsd50k_eval_full.json \ 56 | --exp-dir $exp_dir --n-print-steps 1000 --save_model True --num-workers 32 --label-csv ./class_labels_indices.csv \ 57 | --n_class 200 --n-epochs ${epoch} --batch-size ${batch_size} --lr $lr \ 58 | --model ${model} --eff_b $eff_b --impretrain ${impretrain} --att_head ${att_head} \ 59 | --freqm $freqm --timem $timem --mixup ${mixup} --bal ${bal} --lr_patience 2 \ 60 | --dataset_mean -4.6476 --dataset_std 4.5699 --target_length 3000 --noise False \ 61 | --metrics mAP --warmup True --loss BCE --lrscheduler_start ${lrscheduler_start} --lrscheduler_decay 0.5 \ 62 | --wa True --wa_start ${wa_start} --wa_end ${wa_end} -------------------------------------------------------------------------------- /src/utilities/stats.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import stats 3 | from sklearn import metrics 4 | import torch 5 | 6 | def d_prime(auc): 7 | standard_normal = stats.norm() 8 | d_prime = standard_normal.ppf(auc) * np.sqrt(2.0) 9 | return d_prime 10 | 11 | def calculate_stats(output, target): 12 | """Calculate statistics including mAP, AUC, etc. 13 | 14 | Args: 15 | output: 2d array, (samples_num, classes_num) 16 | target: 2d array, (samples_num, classes_num) 17 | 18 | Returns: 19 | stats: list of statistic of each class. 20 | """ 21 | 22 | classes_num = target.shape[-1] 23 | stats = [] 24 | 25 | # Class-wise statistics 26 | for k in range(classes_num): 27 | 28 | # Average precision 29 | avg_precision = metrics.average_precision_score( 30 | target[:, k], output[:, k], average=None) 31 | 32 | # AUC 33 | auc = metrics.roc_auc_score(target[:, k], output[:, k], average=None) 34 | 35 | # Accuracy 36 | # this is only used for single-label classification such as esc-50, not for multiple label one such as AudioSet 37 | acc = metrics.accuracy_score(np.argmax(target, 1), np.argmax(output, 1)) 38 | 39 | # Precisions, recalls 40 | (precisions, recalls, thresholds) = metrics.precision_recall_curve( 41 | target[:, k], output[:, k]) 42 | 43 | # FPR, TPR 44 | (fpr, tpr, thresholds) = metrics.roc_curve(target[:, k], output[:, k]) 45 | 46 | save_every_steps = 1000 # Sample statistics to reduce size 47 | dict = {'precisions': precisions[0::save_every_steps], 48 | 'recalls': recalls[0::save_every_steps], 49 | 'AP': avg_precision, 50 | 'fpr': fpr[0::save_every_steps], 51 | 'fnr': 1. - tpr[0::save_every_steps], 52 | 'auc': auc, 53 | 'acc': acc 54 | } 55 | stats.append(dict) 56 | 57 | return stats 58 | 59 | -------------------------------------------------------------------------------- /pretrained_models/README.md: -------------------------------------------------------------------------------- 1 | We provide full AudioSet and FSD50K pretrained models (click the mAP to download the model(s)). 2 | 3 | | | # Models |AudioSet (Eval mAP) | FSD50K (Eval mAP) | 4 | |------------------------------------------|:------:|:--------:|:------:| 5 | | Single Model | 1 |[0.440](https://www.dropbox.com/s/d1z27wj30ew5qrs/as_mdl_0.pth?dl=1) | [0.559](https://www.dropbox.com/s/stzrmfty2oyqnnj/fsd_mdl_best_single.pth?dl=1) | 6 | | Weight Averaging Model | 1 | [0.444](https://www.dropbox.com/s/ieggie0ara4x26d/as_mdl_0_wa.pth?dl=1) | [0.562](https://www.dropbox.com/s/5fvybrbulvhsish/fsd_mdl_wa.pth?dl=1) | 7 | | Ensemble (Single Run, All Checkpoints) | 30/40 |[0.453](https://www.dropbox.com/sh/jo6te8fcy1ptabw/AAAtJ9sMn93-3L0XkebzQQxIa?dl=1) | [0.573](https://www.dropbox.com/sh/gyv95m53sib36vk/AADWCgApSxtAEVU1KrnQApi3a?dl=1) | 8 | | Ensemble (3 Run with Same Setting) | 3 | [0.464](https://www.dropbox.com/sh/c83w8816vl6yhty/AADjoO9irfP1RCr-qyZMJg_-a?dl=1) | N/A | 9 | | Ensemble (All, Different Settings) | 10 | [0.474](https://www.dropbox.com/sh/ihfbxcemxamihz9/AAD9zqnUptZzyZlquqpWllDya?dl=1) | N/A | 10 | 11 | All models are EfficientNet B2 model with 4-headed attention with 13.6M parameters, trained with 16kHz audio. Load the model by using follows: 12 | 13 | ```python 14 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 15 | num_class = 527 if dataset=='audioset' else 200 16 | audio_model = models.EffNetAttention(label_dim=num_class, b=2, pretrain=False, head_num=4) 17 | audio_model = torch.nn.DataParallel(audio_model) 18 | audio_model.load_state_dict(sd, strict=False) 19 | ``` 20 | 21 | We strongly recommend to use the pretrained model with our dataloader for inference to avoid the dataloading mismatch. 22 | 23 | For ensemble experiments, uncompress the models and place in this folder: 24 | 25 | ``` 26 | pretrained_models 27 | │ README.md 28 | └───audioset 29 | │ │ as_mdl_0.pth 30 | │ │ as_mdl_1.pth 31 | └───fsd50k 32 | │ fsd_mdl_0.pth 33 | │ fsd_mdl_1.pth 34 | ``` -------------------------------------------------------------------------------- /egs/audioset/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ##SBATCH -p sm 3 | ##SBATCH -x sls-sm-1,sls-2080-[1,3],sls-1080-3,sls-sm-5 4 | #SBATCH -p gpu 5 | #SBATCH -x sls-titan-[0-2,9] 6 | #SBATCH --gres=gpu:4 7 | #SBATCH -c 4 8 | #SBATCH -n 1 9 | #SBATCH --mem=48000 10 | #SBATCH --job-name="psla_as" 11 | #SBATCH --output=./log_%j.txt 12 | 13 | set -x 14 | source ../../venv-psla/bin/activate 15 | export TORCH_HOME=./ 16 | 17 | subset=balanced 18 | att_head=4 19 | model=efficientnet 20 | psla=True 21 | eff_b=2 22 | batch_size=100 23 | 24 | if [ $psla == True ] 25 | then 26 | impretrain=True 27 | freqm=48 28 | timem=192 29 | mixup=0.5 30 | full_bal=True 31 | else 32 | impretrain=False 33 | freqm=0 34 | timem=0 35 | mixup=0 36 | full_bal=False 37 | fi 38 | 39 | if [ $subset == balanced ] 40 | then 41 | bal=False 42 | lr=1e-3 43 | p=mean 44 | # label enhanced set 45 | trpath=./datafiles/balanced_train_data_type1_2_${p}.json 46 | # original set 47 | #trpath=./datafiles/balanced_train_data.json 48 | epoch=60 49 | wa_start=41 50 | wa_end=60 51 | lrscheduler_start=35 52 | else 53 | bal=${full_bal} 54 | lr=1e-4 55 | p=None 56 | #trpath=/data/sls/scratch/yuangong/aed-pc/src/enhance_label/datafiles_local/whole_train_data_type1_2_${p}.json 57 | trpath=./datafiles/full_train_data.json 58 | epoch=30 59 | wa_start=16 60 | wa_end=30 61 | lrscheduler_start=10 62 | fi 63 | 64 | exp_dir=./exp/demo-${model}-${eff_b}-${lr}-${subset}-impretrain-${impretrain}-fm${freqm}-tm${timem}-mix${mixup}-bal-${bal}-b${batch_size}-git 65 | mkdir -p $exp_dir 66 | 67 | CUDA_CACHE_DISABLE=1 python ../../src/run.py --data-train $trpath --data-val ./datafiles/eval_data.json \ 68 | --exp-dir $exp_dir --n-print-steps 100 --save_model True --num-workers 32 --label-csv /data/sls/scratch/yuangong/audioset/utilities/class_labels_indices.csv \ 69 | --n_class 527 --n-epochs ${epoch} --batch-size ${batch_size} --lr $lr \ 70 | --model ${model} --eff_b $eff_b --impretrain ${impretrain} --att_head ${att_head} \ 71 | --freqm $freqm --timem $timem --mixup ${mixup} --bal ${bal} --lr_patience 2 \ 72 | --dataset_mean -4.6476 --dataset_std 4.5699 --target_length 1056 --noise False \ 73 | --metrics mAP --warmup True --loss BCE --lrscheduler_start ${lrscheduler_start} --lrscheduler_decay 0.5 \ 74 | --wa True --wa_start ${wa_start} --wa_end ${wa_end} \ 75 | -------------------------------------------------------------------------------- /src/ensemble/as_ensemble.log: -------------------------------------------------------------------------------- 1 | # Ensemble 3 AudioSet Models Trained with Exactly Same Setting (Best Setting), But Different Random Seeds. 2 | ---------------Ensemble Result Summary--------------- 3 | Model 0 ../../pretrained_models/audioset/as_mdl_0.pth mAP: 0.440298, AUC: 0.974047, d-prime: 2.749102 4 | Model 1 ../../pretrained_models/audioset/as_mdl_1.pth mAP: 0.439790, AUC: 0.973978, d-prime: 2.747493 5 | Model 2 ../../pretrained_models/audioset/as_mdl_2.pth mAP: 0.439322, AUC: 0.973591, d-prime: 2.738487 6 | Ensemble 3 Models mAP: 0.464112, AUC: 0.978222, d-prime: 2.854353 7 | 8 | # Ensemble 5 Top-Performance AudioSet Models. 9 | ---------------Ensemble Result Summary--------------- 10 | Model 0 ../../pretrained_models/audioset/as_mdl_0.pth mAP: 0.440298, AUC: 0.974047, d-prime: 2.749102 11 | Model 1 ../../pretrained_models/audioset/as_mdl_1.pth mAP: 0.439790, AUC: 0.973978, d-prime: 2.747493 12 | Model 2 ../../pretrained_models/audioset/as_mdl_2.pth mAP: 0.439322, AUC: 0.973591, d-prime: 2.738487 13 | Model 3 ../../pretrained_models/audioset/as_mdl_3.pth mAP: 0.440555, AUC: 0.973639, d-prime: 2.739613 14 | Model 4 ../../pretrained_models/audioset/as_mdl_4.pth mAP: 0.439713, AUC: 0.973579, d-prime: 2.738213 15 | Ensemble 5 Models mAP: 0.469050, AUC: 0.978875, d-prime: 2.872325 16 | 17 | # Ensemble All 10 AudioSet Models Presented in the PSLA Paper 18 | ---------------Ensemble Result Summary--------------- 19 | Model 0 ../pretrained_models/audioset/as_mdl_1.pth mAP: 0.440298, AUC: 0.974047, d-prime: 2.749102 20 | Model 1 ../pretrained_models/audioset/as_mdl_0.pth mAP: 0.439790, AUC: 0.973978, d-prime: 2.747493 21 | Model 2 ../pretrained_models/audioset/as_mdl_2.pth mAP: 0.439322, AUC: 0.973591, d-prime: 2.738487 22 | Model 3 ../pretrained_models/audioset/as_mdl_3.pth mAP: 0.440555, AUC: 0.973639, d-prime: 2.739613 23 | Model 4 ../pretrained_models/audioset/as_mdl_4.pth mAP: 0.439713, AUC: 0.973579, d-prime: 2.738213 24 | Model 5 ../pretrained_models/audioset/as_mdl_5.pth mAP: 0.438852, AUC: 0.973534, d-prime: 2.737183 25 | Model 6 ../pretrained_models/audioset/as_mdl_6.pth mAP: 0.394262, AUC: 0.973054, d-prime: 2.726193 26 | Model 7 ../pretrained_models/audioset/as_mdl_7.pth mAP: 0.370860, AUC: 0.961183, d-prime: 2.495504 27 | Model 8 ../pretrained_models/audioset/as_mdl_8.pth mAP: 0.426624, AUC: 0.973353, d-prime: 2.733006 28 | Model 9 ../pretrained_models/audioset/as_mdl_9.pth mAP: 0.372092, AUC: 0.970509, d-prime: 2.670498 29 | Ensemble 10 Models mAP: 0.474380, AUC: 0.981043, d-prime: 2.935611 -------------------------------------------------------------------------------- /src/gen_weight_file.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 11/17/20 3:22 PM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : gen_weight_file.py 7 | 8 | # gen sample weight = sum(label_weight) for label in all labels of the audio clip, where label_weight is the reciprocal of the total sample count of that class. 9 | # Note audioset and fsd50k are multi-label datasets 10 | 11 | import argparse 12 | import json 13 | import numpy as np 14 | import sys, os, csv 15 | 16 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 17 | parser.add_argument("--dataset", type=str, default="audioset", help="training optimizer", choices=["audioset", "fsd50k"]) 18 | parser.add_argument("--label_indices_path", type=str, default="./class_labels_indices.csv", help="the label vocabulary file.") 19 | parser.add_argument("--datafile_path", type=str, default='./datafiles/balanced_train_data.json', help="the path of data json file") 20 | 21 | def make_index_dict(label_csv): 22 | index_lookup = {} 23 | with open(label_csv, 'r') as f: 24 | csv_reader = csv.DictReader(f) 25 | line_count = 0 26 | for row in csv_reader: 27 | index_lookup[row['mid']] = row['index'] 28 | line_count += 1 29 | return index_lookup 30 | 31 | if __name__ == '__main__': 32 | args = parser.parse_args() 33 | data_path = args.datafile_path 34 | 35 | index_dict = make_index_dict(args.label_indices_path) 36 | num_class = 527 if args.dataset == 'audioset' else 200 37 | label_count = np.zeros(num_class) 38 | 39 | with open(data_path, 'r', encoding='utf8')as fp: 40 | data = json.load(fp) 41 | data = data['data'] 42 | 43 | for sample in data: 44 | sample_labels = sample['labels'].split(',') 45 | for label in sample_labels: 46 | label_idx = int(index_dict[label]) 47 | label_count[label_idx] = label_count[label_idx] + 1 48 | 49 | # the reason not using 1 is to avoid underflow for majority classes, add small value to avoid underflow 50 | label_weight = 1000.0 / (label_count + 0.01) 51 | sample_weight = np.zeros(len(data)) 52 | 53 | for i, sample in enumerate(data): 54 | sample_labels = sample['labels'].split(',') 55 | for label in sample_labels: 56 | label_idx = int(index_dict[label]) 57 | # summing up the weight of all appeared classes in the sample, note audioset is multiple-label classification 58 | sample_weight[i] += label_weight[label_idx] 59 | np.savetxt(data_path[:-5]+'_weight.csv', sample_weight, delimiter=',') 60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /src/label_enhancement/check_label_error.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 5/24/21 12:55 AM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : check_label_error.py 7 | 8 | # This is an example (male, female, and kid speech classes) showing the label error in the AudioSet. 9 | 10 | import json 11 | """ 12 | 0,/m/09x0r,"Speech" 13 | 1,/m/05zppz,"Male speech, man speaking" 14 | 2,/m/02zsn,"Female speech, woman speaking" 15 | 3,/m/0ytgt,"Child speech, kid speaking" 16 | """ 17 | 18 | def check_type1_error(json_path): 19 | total_speech_cnt = 0 20 | male_cnt= 0 21 | female_cnt = 0 22 | child_cnt = 0 23 | with open(json_path,'r',encoding='utf8')as fp: 24 | data_file = json.load(fp) 25 | data = data_file['data'] 26 | # for each sample 27 | for i, sample in enumerate(data): 28 | sample_labels = sample['labels'].split(',') 29 | if '/m/09x0r' in sample_labels: 30 | total_speech_cnt += 1 31 | if '/m/05zppz' in sample_labels: 32 | male_cnt += 1 33 | if '/m/02zsn' in sample_labels: 34 | female_cnt += 1 35 | if '/m/0ytgt' in sample_labels: 36 | child_cnt += 1 37 | print('Type-I Error:') 38 | print('There are {:d}, {:d}, {:d} samples that are labeled as male, female, and child speech (sum: {:d}), but there are {:d} samples labeled as speech in {:s}.'.format(male_cnt, female_cnt, child_cnt, (male_cnt+female_cnt+child_cnt), total_speech_cnt ,json_path)) 39 | 40 | 41 | def check_type2_error(json_path): 42 | miss_male_cnt=0 43 | miss_female_cnt = 0 44 | miss_child_cnt=0 45 | with open(json_path,'r',encoding='utf8')as fp: 46 | data_file = json.load(fp) 47 | data = data_file['data'] 48 | # for each sample 49 | for i, sample in enumerate(data): 50 | sample_labels = sample['labels'].split(',') 51 | if '/m/05zppz' in sample_labels and '/m/09x0r' not in sample_labels: 52 | miss_male_cnt += 1 53 | if '/m/02zsn' in sample_labels and '/m/09x0r' not in sample_labels: 54 | miss_female_cnt += 1 55 | if '/m/0ytgt' in sample_labels and '/m/09x0r' not in sample_labels: 56 | miss_child_cnt += 1 57 | print('Type-II Error:') 58 | print('There are {:d}, {:d}, {:d} samples that are labeled as male, female, and child speech, respectively, but are not labeled as speech in {:s}.'.format(miss_male_cnt, miss_female_cnt, miss_child_cnt, json_path)) 59 | 60 | # before label enhancement 61 | check_type1_error('../../egs/audioset/datafiles/balanced_train_data.json') 62 | check_type2_error('../../egs/audioset/datafiles/balanced_train_data.json') 63 | 64 | # after label enhancement 65 | check_type1_error('../../egs/audioset/datafiles/balanced_train_data_type1_2_mean.json') 66 | check_type2_error('../../egs/audioset/datafiles/balanced_train_data_type1_2_mean.json') -------------------------------------------------------------------------------- /src/label_enhancement/merge_type_1_2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 12/24/20 2:34 AM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : merge_type_1_2.py 7 | 8 | # merge label enhancement 1 and label enhancement 2. 9 | 10 | import json 11 | import os 12 | 13 | # count the total number of labels (one audio clip may have multiple labels) of a data json file. 14 | def count_label(js): 15 | cnt = 0 16 | with open(js,'r',encoding='utf8')as fp: 17 | data_file = json.load(fp) 18 | data = data_file['data'] 19 | for i, sample in enumerate(data): 20 | sample_labels = sample['labels'].split(',') 21 | cnt += len(sample_labels) 22 | return cnt 23 | 24 | # merge datafiles of type-1 label enhancement and type-2 label enhancement. 25 | def merge_label(js1, js2, output_path): 26 | total_label_cnt = 0 27 | with open(js1,'r',encoding='utf8')as fp: 28 | data_file = json.load(fp) 29 | data1 = data_file['data'] 30 | with open(js2, 'r', encoding='utf8')as fp: 31 | data_file = json.load(fp) 32 | data2 = data_file['data'] 33 | for i, sample in enumerate(data1): 34 | sample_labels1 = sample['labels'].split(',') 35 | sample_labels2 = data2[i]['labels'].split(',') 36 | merge_label = list(set(sample_labels1 + sample_labels2)) 37 | data1[i]['labels'] = ','.join(list(set(merge_label))) 38 | total_label_cnt += len(list(set(merge_label))) 39 | output = {'data': data1} 40 | with open(output_path, 'w') as f: 41 | json.dump(output, f, indent=1) 42 | print('Input Json file 1 has {:d} labels'.format(count_label(js1))) 43 | print('Input Json file 2 has {:d} labels'.format(count_label(js2))) 44 | print('Merged Json file has {:d} labels'.format(total_label_cnt)) 45 | 46 | if __name__ == '__main__': 47 | # 'audioset' or 'fsd50k' 48 | # for audioset, we demo label enhancement on the balanced training set 49 | dataset = 'fsd50k' 50 | # for different label modification threshold 51 | for p in ['mean', 'median', '25', '10', '5']: 52 | print('----------------Merge Type 1&2 Label Enhancement with {:s} Threshold'.format(p)) 53 | if dataset == 'fsd50k': 54 | path1 = '../../egs/' + dataset + '/datafiles/fsd50k_tr_full_type1_' + p + '.json' 55 | path2 = '../../egs/' + dataset + '/datafiles/fsd50k_tr_full_type2_' + p + '.json' 56 | out_path = '../../egs/' + dataset + '/datafiles/fsd50k_tr_full_type1_2_' + p + '.json' 57 | merge_label(path1, path2, out_path) 58 | 59 | if dataset == 'audioset': 60 | path1 = '../../egs/' + dataset + '/datafiles/balanced_train_data_type1_' + p + '.json' 61 | path2 = '../../egs/' + dataset + '/datafiles/balanced_train_data_type2_' + p + '.json' 62 | out_path = '../../egs/' + dataset + '/datafiles/balanced_train_data_type1_2_' + p + '.json' 63 | merge_label(path1, path2, out_path) 64 | 65 | # (optional) generate balanced sampling weight for each enhanced label set. 66 | os.system('python ../gen_weight_file.py --dataset {:s} --label_indices_path {:s} --datafile_path {:s}'.format( 67 | dataset, '../../egs/' + dataset + '/class_labels_indices.csv', out_path)) 68 | -------------------------------------------------------------------------------- /egs/fsd50k/README.md: -------------------------------------------------------------------------------- 1 | ## FSD50K Recipe 2 | The FSD50K recipe is in `psla/egs/fsd50k/`. Note we use 16kHz sampling rate, which is lower than the original FSD50K sampling rate to lower the computational overhead. Please make sure you have installed the dependencies in `psla/requirement.txt`. 3 | 4 | **Step 1. Download the FSD50K dataset from [the official website](https://zenodo.org/record/4060432).** 5 | 6 | **Step 2. Prepare the data.** 7 | 8 | Change the `fsd50k_path` in `egs/fsd50k/prep_fsd.py` (line 15) to your dataset path. And run: 9 | 10 | ``` 11 | cd psla/egs/fsd50k 12 | python3 prep_fsd.py 13 | ``` 14 | 15 | This should create three json files `fsd50k_tr_full.json`, `fsd50k_val_full.json`, and `fsd50k_eval_full.json` in `/egs/fsd50k/datafiles`. These will be used in training and evaluation. 16 | 17 | (Optional) **Step 3. Enhance the label of the FSD50K training set.** 18 | 19 | Download our model prediction (or you can use yours after 1st round training) from [here](https://www.dropbox.com/s/kd84fq9ygwmidvp/prediction_files.zip?dl=1). Place it in `psla/src/label_enhancement/` and uncompress it. 20 | ``` 21 | cd psla/src/label_enhancement 22 | python3 fix_type1.py 23 | python3 fix_type2.py 24 | python3 merge_type_1_2.py 25 | ``` 26 | This will automatically generate a set of new Json datafiles in `egs/fsd50k/datafiles`, e.g., `fsd50k_tr_full_type1_2_mean.json` means the datafile with enhanced label set for both Type-I and Type-II error with label modification threshold of mean. You can use these new datafiles as input of `egs/fsd50k/run.sh`, specifically, you can change `p` (label modification threshold) in `[none, mean, median, 5, 10, 25]`. 27 | 28 | If you skipped this step, please set `p` in `egs/fsd50k/run.sh` as `none`. You will get a slightly worse result. 29 | 30 | **Step 4. Run the training and evaluation.** 31 | 32 | ``` 33 | cd psla/egs/fsd50k 34 | (slurm user) sbatch run.sh 35 | (local user) ./run.sh 36 | ``` 37 | 38 | The recipe was tested on 4 GTX TITAN GPUs with 12GB memory. The entire training and evaluation takes about 15 hours. Trimming the `target_length` of `run.sh` from 3000 to 1000, and increasing the batch size and learning rate accordingly can significantly reduce the running time, but leads to just slightly worse result. 39 | 40 | **Step 5. Get the results.** 41 | 42 | The running log will present the results of 1) single model, 2) weight averaging model (i.e., averaging the weight of last few model checkpoints, this does not increase the model size and computational overhead), and 3) checkpoint ensemble models (i.e., averaging the prediction of each epoch) on both the FSD50K official validation set and evaluation set. These results are also saved in `psla/egs/fsd50k/exp/yourexpname/[best_single_,wa_,ensemble_]result.csv`. 43 | 44 | We share our training and evaluation log in `psla/egs/fsd50k/exp/`, you can expect to get a similar result as follows. We also share our entire experiment directory at this [dropbox link](https://www.dropbox.com/s/qfaeyuvtse420dn/demo-efficientnet-2-5e-4-fsd50k-impretrain-True-fm48-tm192-mix0.5-bal-True-b24-lemean-2.zip?dl=1). 45 | 46 | ``` 47 | ---------------evaluate best single model on the validation set--------------- 48 | mAP: 0.588115 49 | AUC: 0.960351 50 | ---------------evaluate best single model on the evaluation set--------------- 51 | mAP: 0.558463 52 | AUC: 0.943927 53 | ---------------evaluate weight average model on the validation set--------------- 54 | mAP: 0.587779 55 | AUC: 0.960226 56 | ---------------evaluate weight averages model on the evaluation set--------------- 57 | mAP: 0.561647 58 | AUC: 0.943910 59 | ---------------evaluate ensemble model on the validation set--------------- 60 | mAP: 0.601013 61 | AUC: 0.970726 62 | ---------------evaluate ensemble model on the evaluation set--------------- 63 | mAP: 0.572588 64 | AUC: 0.955053 65 | ``` -------------------------------------------------------------------------------- /src/ensemble/weight_averaging.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 10/28/21 1:55 PM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : weight_averaging.py 7 | 8 | import os, sys, argparse 9 | parentdir = str(os.path.abspath(os.path.join(__file__ ,"../.."))) 10 | sys.path.append(parentdir) 11 | import dataloaders 12 | import models 13 | from utilities import * 14 | from traintest import validate 15 | import numpy as np 16 | from scipy import stats 17 | import torch 18 | 19 | def get_wa_res(mdl_list, base_path, dataset='audioset'): 20 | num_class = 527 if dataset=='audioset' else 200 21 | # the 0-len(mdl_list) rows record the results of single models, the last row record the result of the ensemble model. 22 | ensemble_res = np.zeros([len(mdl_list) + 1, 3]) 23 | if os.path.exists(base_path) == False: 24 | os.mkdir(base_path) 25 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 26 | 27 | for model_idx, mdl in enumerate(mdl_list): 28 | print('-----------------------') 29 | print('now loading model {:d}: {:s}'.format(model_idx, mdl)) 30 | 31 | if model_idx == 0: 32 | sdA = torch.load(mdl, map_location=device) 33 | model_cnt = 1 34 | else: 35 | sdB = torch.load(mdl, map_location=device) 36 | for key in sdA: 37 | sdA[key] = sdA[key] + sdB[key] 38 | model_cnt += 1 39 | 40 | for key in sdA: 41 | sdA[key] = sdA[key] / float(model_cnt) 42 | 43 | sd = sdA 44 | if 'module.effnet._fc.weight' in sd.keys(): 45 | del sd['module.effnet._fc.weight'] 46 | del sd['module.effnet._fc.bias'] 47 | torch.save(sd, '../../pretrained_models/audioset/as_mdl_0_wa.pth') 48 | 49 | audio_model = models.EffNetAttention(label_dim=num_class , b=2, pretrain=False, head_num=4) 50 | audio_model = torch.nn.DataParallel(audio_model) 51 | audio_model.load_state_dict(sd, strict=True) 52 | 53 | args.exp_dir = base_path 54 | 55 | stats, _ = validate(audio_model, eval_loader, args, 'wa') 56 | mAP = np.mean([stat['AP'] for stat in stats]) 57 | mAUC = np.mean([stat['auc'] for stat in stats]) 58 | dprime = d_prime(mAUC) 59 | ensemble_res[model_idx, :] = [mAP, mAUC, dprime] 60 | print("Model {:d} {:s} mAP: {:.6f}, AUC: {:.6f}, d-prime: {:.6f}".format(model_idx, mdl, mAP, mAUC, dprime)) 61 | 62 | def d_prime(auc): 63 | standard_normal = stats.norm() 64 | d_prime = standard_normal.ppf(auc) * np.sqrt(2.0) 65 | return d_prime 66 | 67 | # dataloader settings 68 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 69 | args = parser.parse_args() 70 | 71 | dataset = 'audioset' 72 | # uncomment this line if you want test ensemble on fsd50k 73 | # dataset = 'fsd50k' 74 | 75 | args.dataset = dataset 76 | if dataset == 'audioset': 77 | args.data_eval = '../../egs/audioset/datafiles/eval_data.json' 78 | else: 79 | args.data_eval = '../../egs/fsd50k/datafiles/fsd50k_eval_full.json' 80 | args.label_csv='../../egs/' + dataset + '/class_labels_indices.csv' 81 | args.loss_fn = torch.nn.BCELoss() 82 | norm_stats = {'audioset': [-4.6476, 4.5699], 'fsd50k': [-4.6476, 4.5699]} 83 | target_length = {'audioset': 1056, 'fsd50k': 3000} 84 | batch_size = 200 if dataset=='audioset' else 48 85 | 86 | val_audio_conf = {'num_mel_bins': 128, 'target_length': target_length[args.dataset], 'freqm': 0, 'timem': 0, 'mixup': 0, 'dataset': args.dataset, 'mode':'evaluation', 'mean':norm_stats[args.dataset][0], 'std':norm_stats[args.dataset][1], 'noise': False} 87 | eval_loader = torch.utils.data.DataLoader( 88 | dataloaders.AudiosetDataset(args.data_eval, label_csv=args.label_csv, audio_conf=val_audio_conf), 89 | batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True) 90 | 91 | if dataset == 'audioset': 92 | # ensemble 3 audioset models trained with exactly same setting, but different random seeds 93 | wa_list_3 = ['../../pretrained_models/audioset/wa/audio_model.'+str(i)+'.pth' for i in range(16, 31)] 94 | get_wa_res(wa_list_3, './ensemble_as', dataset) 95 | 96 | else: 97 | pass -------------------------------------------------------------------------------- /egs/fsd50k/prep_fsd.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 4/27/21 3:21 AM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : prep_fsd.py 7 | 8 | import numpy as np 9 | import json 10 | import os 11 | 12 | # dataset downloaded from https://zenodo.org/record/4060432#.YXXR0tnMLfs 13 | # please change it to your FSD50K dataset path 14 | # the data organization might change with versioning, the code is tested early 2021 15 | fsd_path = '/data/sls/scratch/yuangong/dataset/FSD50K/' 16 | 17 | # convert all samples to 16kHZ 18 | print('Now converting all FSD50K audio to 16kHz, this may take dozens of minutes.') 19 | def get_immediate_files(a_dir): 20 | return [name for name in os.listdir(a_dir) if os.path.isfile(os.path.join(a_dir, name))] 21 | 22 | resample_cnt = 0 23 | set_list = ['dev', 'eval'] 24 | for set in set_list: 25 | basepath = fsd_path + '/FSD50K.'+ set +'_audio/' 26 | targetpath = fsd_path + '/FSD50K.'+ set +'_audio_16k/' 27 | if os.path.exists(targetpath) == False: 28 | os.mkdir(targetpath) 29 | files = get_immediate_files(basepath) 30 | for audiofile in files: 31 | os.system('sox ' + basepath + audiofile+' -r 16000 ' + targetpath + audiofile + '> /dev/null 2>&1') 32 | resample_cnt += 1 33 | if resample_cnt % 1000 == 0: 34 | print('Resampled {:d} samples.'.format(resample_cnt)) 35 | print('Resampling finished.') 36 | print('--------------------------------------------') 37 | 38 | # create json datafiles for training, validation, and evaluation set 39 | 40 | # training set and validation set are from the official 'dev' set, we use the official training and validation set split. 41 | fsdeval = fsd_path + '/FSD50K.ground_truth/dev.csv' 42 | fsdeval = np.loadtxt(fsdeval, skiprows=1, dtype=str) 43 | 44 | tr_cnt, val_cnt = 0, 0 45 | 46 | # only apply to the vocal sound data 47 | fsd_tr_data = [] 48 | fsd_val_data = [] 49 | for i in range(len(fsdeval)): 50 | try: 51 | fileid = fsdeval[i].split(',"')[0] 52 | labels = fsdeval[i].split(',"')[2][0:-1] 53 | set_info = labels.split('",')[1] 54 | except: 55 | fileid = fsdeval[i].split(',')[0] 56 | labels = fsdeval[i].split(',')[2] 57 | set_info = fsdeval[i].split(',')[3][0:-1] 58 | 59 | labels = labels.split('",')[0] 60 | label_list = labels.split(',') 61 | new_label_list = [] 62 | for label in label_list: 63 | new_label_list.append(label) 64 | new_label_list = ','.join(new_label_list) 65 | # note, all recording we use are 16kHZ. 66 | cur_dict = {"wav": fsd_path + '/FSD50K.dev_audio_16k/'+fileid+'.wav', "labels":new_label_list} 67 | 68 | if set_info == 'trai': 69 | fsd_tr_data.append(cur_dict) 70 | tr_cnt += 1 71 | elif set_info == 'va': 72 | fsd_val_data.append(cur_dict) 73 | val_cnt += 1 74 | else: 75 | raise ValueError('unrecognized set') 76 | 77 | if os.path.exists('datafiles') == False: 78 | os.mkdir('datafiles') 79 | 80 | with open('./datafiles/fsd50k_tr_full.json', 'w') as f: 81 | json.dump({'data': fsd_tr_data}, f, indent=1) 82 | print('Processed {:d} samples for the FSD50K training set.'.format(tr_cnt)) 83 | 84 | with open('./datafiles/fsd50k_val_full.json', 'w') as f: 85 | json.dump({'data': fsd_val_data}, f, indent=1) 86 | print('Processed {:d} samples for the FSD50K validation set.'.format(val_cnt)) 87 | 88 | ## process the evaluation set 89 | fsdeval = fsd_path + '/FSD50K.ground_truth/eval.csv' 90 | fsdeval = np.loadtxt(fsdeval, skiprows=1, dtype=str) 91 | 92 | cnt = 0 93 | 94 | # only apply to the vocal sound data 95 | vc_data = [] 96 | for i in range(len(fsdeval)): 97 | try: 98 | fileid = fsdeval[i].split(',"')[0] 99 | labels = fsdeval[i].split(',"')[2][0:-1] 100 | except: 101 | fileid = fsdeval[i].split(',')[0] 102 | labels = fsdeval[i].split(',')[2] 103 | 104 | label_list = labels.split(',') 105 | new_label_list = [] 106 | for label in label_list: 107 | new_label_list.append(label) 108 | 109 | if len(new_label_list) != 0: 110 | new_label_list = ','.join(new_label_list) 111 | cur_dict = {"wav": fsd_path + '/FSD50K.eval_audio_16k/'+fileid+'.wav', "labels": new_label_list} 112 | vc_data.append(cur_dict) 113 | cnt += 1 114 | 115 | with open('./datafiles/fsd50k_eval_full.json', 'w') as f: 116 | json.dump({'data': vc_data}, f, indent=1) 117 | print('Processed {:d} samples for the FSD50K evaluation set.'.format(cnt)) 118 | 119 | # generate balanced sampling weight file 120 | os.system('python ../../src/gen_weight_file.py --dataset fsd50k --label_indices_path {:s} --datafile_path {:s}'.format('./class_labels_indices.csv', './datafiles/fsd50k_tr_full.json')) 121 | 122 | # (optional) create label enhanced set. 123 | # Go to /src/label_enhancement/ 124 | 125 | 126 | -------------------------------------------------------------------------------- /src/models/Models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from .HigherModels import * 4 | from efficientnet_pytorch import EfficientNet 5 | import torchvision 6 | 7 | class ResNetAttention(nn.Module): 8 | def __init__(self, label_dim=527, pretrain=True): 9 | super(ResNetAttention, self).__init__() 10 | 11 | self.model = torchvision.models.resnet50(pretrained=False) 12 | 13 | if pretrain == False: 14 | print('ResNet50 Model Trained from Scratch (ImageNet Pretraining NOT Used).') 15 | else: 16 | print('Now Use ImageNet Pretrained ResNet50 Model.') 17 | 18 | self.model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) 19 | 20 | # remove the original ImageNet classification layers to save space. 21 | self.model.fc = torch.nn.Identity() 22 | self.model.avgpool = torch.nn.Identity() 23 | 24 | # attention pooling module 25 | self.attention = Attention( 26 | 2048, 27 | label_dim, 28 | att_activation='sigmoid', 29 | cla_activation='sigmoid') 30 | self.avgpool = nn.AvgPool2d((4, 1)) 31 | 32 | def forward(self, x): 33 | # expect input x = (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128) 34 | x = x.unsqueeze(1) 35 | x = x.transpose(2, 3) 36 | 37 | batch_size = x.shape[0] 38 | x = self.model(x) 39 | x = x.reshape([batch_size, 2048, 4, 33]) 40 | x = self.avgpool(x) 41 | x = x.transpose(2,3) 42 | out, norm_att = self.attention(x) 43 | return out 44 | 45 | class MBNet(nn.Module): 46 | def __init__(self, label_dim=527, pretrain=True): 47 | super(MBNet, self).__init__() 48 | 49 | self.model = torchvision.models.mobilenet_v2(pretrained=pretrain) 50 | 51 | self.model.features[0][0] = torch.nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) 52 | self.model.classifier = torch.nn.Linear(in_features=1280, out_features=label_dim, bias=True) 53 | 54 | def forward(self, x, nframes): 55 | # expect input x = (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128) 56 | x = x.unsqueeze(1) 57 | x = x.transpose(2, 3) 58 | 59 | out = torch.sigmoid(self.model(x)) 60 | return out 61 | 62 | 63 | class EffNetAttention(nn.Module): 64 | def __init__(self, label_dim=527, b=0, pretrain=True, head_num=4): 65 | super(EffNetAttention, self).__init__() 66 | self.middim = [1280, 1280, 1408, 1536, 1792, 2048, 2304, 2560] 67 | if pretrain == False: 68 | print('EfficientNet Model Trained from Scratch (ImageNet Pretraining NOT Used).') 69 | self.effnet = EfficientNet.from_name('efficientnet-b'+str(b), in_channels=1) 70 | else: 71 | print('Now Use ImageNet Pretrained EfficientNet-B{:d} Model.'.format(b)) 72 | self.effnet = EfficientNet.from_pretrained('efficientnet-b'+str(b), in_channels=1) 73 | # multi-head attention pooling 74 | if head_num > 1: 75 | print('Model with {:d} attention heads'.format(head_num)) 76 | self.attention = MHeadAttention( 77 | self.middim[b], 78 | label_dim, 79 | att_activation='sigmoid', 80 | cla_activation='sigmoid') 81 | # single-head attention pooling 82 | elif head_num == 1: 83 | print('Model with single attention heads') 84 | self.attention = Attention( 85 | self.middim[b], 86 | label_dim, 87 | att_activation='sigmoid', 88 | cla_activation='sigmoid') 89 | # mean pooling (no attention) 90 | elif head_num == 0: 91 | print('Model with mean pooling (NO Attention Heads)') 92 | self.attention = MeanPooling( 93 | self.middim[b], 94 | label_dim, 95 | att_activation='sigmoid', 96 | cla_activation='sigmoid') 97 | else: 98 | raise ValueError('Attention head must be integer >= 0, 0=mean pooling, 1=single-head attention, >1=multi-head attention.') 99 | 100 | self.avgpool = nn.AvgPool2d((4, 1)) 101 | #remove the original ImageNet classification layers to save space. 102 | self.effnet._fc = nn.Identity() 103 | 104 | def forward(self, x, nframes=1056): 105 | # expect input x = (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128) 106 | x = x.unsqueeze(1) 107 | x = x.transpose(2, 3) 108 | 109 | x = self.effnet.extract_features(x) 110 | x = self.avgpool(x) 111 | x = x.transpose(2,3) 112 | out, norm_att = self.attention(x) 113 | return out 114 | 115 | if __name__ == '__main__': 116 | input_tdim = 1056 117 | #ast_mdl = ResNetNewFullAttention(pretrain=False) 118 | psla_mdl = EffNetFullAttention(pretrain=False, b=0, head_num=0) 119 | # input a batch of 10 spectrogram, each with 100 time frames and 128 frequency bins 120 | test_input = torch.rand([10, input_tdim, 128]) 121 | test_output = psla_mdl(test_input) 122 | # output should be in shape [10, 527], i.e., 10 samples, each with prediction of 527 classes. 123 | print(test_output.shape) -------------------------------------------------------------------------------- /src/models/HigherModels.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | def init_layer(layer): 7 | if layer.weight.ndimension() == 4: 8 | (n_out, n_in, height, width) = layer.weight.size() 9 | n = n_in * height * width 10 | elif layer.weight.ndimension() == 2: 11 | (n_out, n) = layer.weight.size() 12 | 13 | std = math.sqrt(2. / n) 14 | scale = std * math.sqrt(3.) 15 | layer.weight.data.uniform_(-scale, scale) 16 | 17 | if layer.bias is not None: 18 | layer.bias.data.fill_(0.) 19 | 20 | def init_bn(bn): 21 | bn.weight.data.fill_(1.) 22 | 23 | class Attention(nn.Module): 24 | def __init__(self, n_in, n_out, att_activation, cla_activation): 25 | super(Attention, self).__init__() 26 | 27 | self.att_activation = att_activation 28 | self.cla_activation = cla_activation 29 | 30 | self.att = nn.Conv2d( 31 | in_channels=n_in, out_channels=n_out, kernel_size=( 32 | 1, 1), stride=( 33 | 1, 1), padding=( 34 | 0, 0), bias=True) 35 | 36 | self.cla = nn.Conv2d( 37 | in_channels=n_in, out_channels=n_out, kernel_size=( 38 | 1, 1), stride=( 39 | 1, 1), padding=( 40 | 0, 0), bias=True) 41 | 42 | self.init_weights() 43 | 44 | 45 | def init_weights(self): 46 | init_layer(self.att) 47 | init_layer(self.cla) 48 | 49 | def activate(self, x, activation): 50 | 51 | if activation == 'linear': 52 | return x 53 | 54 | elif activation == 'relu': 55 | return F.relu(x) 56 | 57 | elif activation == 'sigmoid': 58 | return torch.sigmoid(x) 59 | 60 | elif activation == 'softmax': 61 | return F.softmax(x, dim=1) 62 | 63 | def forward(self, x): 64 | """input: (samples_num, freq_bins, time_steps, 1) 65 | """ 66 | 67 | att = self.att(x) 68 | att = self.activate(att, self.att_activation) 69 | 70 | cla = self.cla(x) 71 | cla = self.activate(cla, self.cla_activation) 72 | 73 | att = att[:, :, :, 0] # (samples_num, classes_num, time_steps) 74 | cla = cla[:, :, :, 0] # (samples_num, classes_num, time_steps) 75 | 76 | epsilon = 1e-7 77 | att = torch.clamp(att, epsilon, 1. - epsilon) 78 | 79 | norm_att = att / torch.sum(att, dim=2)[:, :, None] 80 | x = torch.sum(norm_att * cla, dim=2) 81 | 82 | return x, norm_att 83 | 84 | class MeanPooling(nn.Module): 85 | def __init__(self, n_in, n_out, att_activation, cla_activation): 86 | super(MeanPooling, self).__init__() 87 | 88 | self.cla_activation = cla_activation 89 | 90 | self.cla = nn.Conv2d( 91 | in_channels=n_in, out_channels=n_out, kernel_size=( 92 | 1, 1), stride=( 93 | 1, 1), padding=( 94 | 0, 0), bias=True) 95 | 96 | self.init_weights() 97 | 98 | def init_weights(self): 99 | init_layer(self.cla) 100 | 101 | def activate(self, x, activation): 102 | return torch.sigmoid(x) 103 | 104 | def forward(self, x): 105 | """input: (samples_num, freq_bins, time_steps, 1) 106 | """ 107 | 108 | cla = self.cla(x) 109 | cla = self.activate(cla, self.cla_activation) 110 | 111 | cla = cla[:, :, :, 0] # (samples_num, classes_num, time_steps) 112 | 113 | x = torch.mean(cla, dim=2) 114 | 115 | return x, [] 116 | 117 | class MHeadAttention(nn.Module): 118 | def __init__(self, n_in, n_out, att_activation, cla_activation, head_num=4): 119 | super(MHeadAttention, self).__init__() 120 | 121 | self.head_num = head_num 122 | 123 | self.att_activation = att_activation 124 | self.cla_activation = cla_activation 125 | 126 | self.att = nn.ModuleList([]) 127 | self.cla = nn.ModuleList([]) 128 | for i in range(self.head_num): 129 | self.att.append(nn.Conv2d(in_channels=n_in, out_channels=n_out, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True)) 130 | self.cla.append(nn.Conv2d(in_channels=n_in, out_channels=n_out, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True)) 131 | 132 | self.head_weight = nn.Parameter(torch.tensor([1.0/self.head_num] * self.head_num)) 133 | 134 | def activate(self, x, activation): 135 | if activation == 'linear': 136 | return x 137 | elif activation == 'relu': 138 | return F.relu(x) 139 | elif activation == 'sigmoid': 140 | return torch.sigmoid(x) 141 | elif activation == 'softmax': 142 | return F.softmax(x, dim=1) 143 | 144 | def forward(self, x): 145 | """input: (samples_num, freq_bins, time_steps, 1) 146 | """ 147 | 148 | x_out = [] 149 | for i in range(self.head_num): 150 | att = self.att[i](x) 151 | att = self.activate(att, self.att_activation) 152 | 153 | cla = self.cla[i](x) 154 | cla = self.activate(cla, self.cla_activation) 155 | 156 | att = att[:, :, :, 0] # (samples_num, classes_num, time_steps) 157 | cla = cla[:, :, :, 0] # (samples_num, classes_num, time_steps) 158 | 159 | epsilon = 1e-7 160 | att = torch.clamp(att, epsilon, 1. - epsilon) 161 | 162 | norm_att = att / torch.sum(att, dim=2)[:, :, None] 163 | x_out.append(torch.sum(norm_att * cla, dim=2) * self.head_weight[i]) 164 | 165 | x = (torch.stack(x_out, dim=0)).sum(dim=0) 166 | 167 | return x, [] -------------------------------------------------------------------------------- /egs/audioset/README.md: -------------------------------------------------------------------------------- 1 | ## Audioset Recipe 2 | 3 | Audioset recipe is very similar with FSD50K recipe, but is a little bit more complex, you will need to prepare your data by yourself. The AudioSet recipe is in `psla/egs/audioset/`. 4 | 5 | **Step 1. Prepare the data.** 6 | 7 | Please prepare the json files (i.e., `train_data.json` and `eval_data.json`) by your self. 8 | The reason is that the raw wavefiles of Audioset is not released and you need to download them by yourself. We have put a sample json file in `psla/egs/audioset/datafiles`, please generate files in the same format (You can also refer to `psla/egs/fsd50k/prep_fsd.py`). Please keep the label code consistent with `psla/egs/audioset/class_labels_indices.csv`. 9 | 10 | Note: we use `16kHz` sampling rate for all AudioSet experiments. 11 | 12 | Once you have the json files, you will need to generate the sampling weight file for full AudioSet json file. 13 | ``` 14 | cd psla/egs/audioset 15 | python ../../src/gen_weight_file.py --dataset audioset --label_indices_path ./class_labels_indices.csv --datafile_path ./datafiles/yourdatafile.json 16 | ``` 17 | 18 | (Optional) **Step 2. Enhance the label of the balanced AudioSet training set.** 19 | 20 | If you are experimenting with Full AudioSet, you can skip this, enhanced label does not improve the performance (potentially due to the evaluation label set is noisy). 21 | 22 | If you are experimenting with Balanced AudioSet or want to enhance the labelset anyway, check our [pretrained enhanced label set](here), and change the labels in your datafile. 23 | 24 | **Step 3. Run the training and evaluation.** 25 | 26 | Change the `data-train` and `data-val` in `psla/egs/audioset/run.sh` to your datafile path. 27 | Also change `subset` to [`balanced`,`full`] for balanced and full AudioSet, respectively. 28 | 29 | ``` 30 | cd psla/egs/audioset 31 | (slurm user) sbatch run.sh 32 | (local user) ./run.sh 33 | ``` 34 | 35 | The recipe was tested on 4 GTX TITAN GPUs with 12GB memory. The entire training and evaluation takes about 12 hours for balanced AudioSet and about one week for full AudioSet. 36 | 37 | **Step 4. Get the results.** 38 | 39 | The running log will present the results of 1) single model, 2) weight averaging model (i.e., averaging the weight of last few model checkpoints, this does not increase the model size and computational overhead), and 3) checkpoint ensemble models (i.e., averaging the prediction of each epoch) on AudioSet evaluation set. These results are also saved in `psla/egs/audioset/exp/yourexpname/[best_single_,wa_,ensemble_]result.csv`. 40 | 41 | We share our training and evaluation log in `psla/egs/audioset/exp/`, you can expect to get a similar result as follows. We also share our entire experiment directory at this [dropbox link](). 42 | 43 | **Step 5. Reproducing the Ensemble Results in the PSLA paper.** 44 | 45 | In step 4, only a single model is trained and only single run checkpoint ensemble is used. To reproduce the best ensemble results (0.474 mAP) in the PSLA paper, you need to run step 4 multiple times with same or different settings. To ease this process, we provide pretrained model of all ensemble models. You can download them [here](https://www.dropbox.com/sh/ihfbxcemxamihz9/AAD9zqnUptZzyZlquqpWllDya?dl=1). Place the models in `psla/pretrained_models/audioset/`, and run `psla/src/ensemble/ensemle.py`. You can expect similar results as follows (though for AudioSet, results can be differ as both training and evaluation data can be different). 46 | 47 | ``` 48 | # Ensemble 3 AudioSet Models Trained with Exactly Same Setting (Best Setting), But Different Random Seeds. 49 | ---------------Ensemble Result Summary--------------- 50 | Model 0 ../../pretrained_models/audioset/as_mdl_0.pth mAP: 0.440298, AUC: 0.974047, d-prime: 2.749102 51 | Model 1 ../../pretrained_models/audioset/as_mdl_1.pth mAP: 0.439790, AUC: 0.973978, d-prime: 2.747493 52 | Model 2 ../../pretrained_models/audioset/as_mdl_2.pth mAP: 0.439322, AUC: 0.973591, d-prime: 2.738487 53 | Ensemble 3 Models mAP: 0.464112, AUC: 0.978222, d-prime: 2.854353 54 | 55 | # Ensemble 5 Top-Performance AudioSet Models. 56 | ---------------Ensemble Result Summary--------------- 57 | Model 0 ../../pretrained_models/audioset/as_mdl_0.pth mAP: 0.440298, AUC: 0.974047, d-prime: 2.749102 58 | Model 1 ../../pretrained_models/audioset/as_mdl_1.pth mAP: 0.439790, AUC: 0.973978, d-prime: 2.747493 59 | Model 2 ../../pretrained_models/audioset/as_mdl_2.pth mAP: 0.439322, AUC: 0.973591, d-prime: 2.738487 60 | Model 3 ../../pretrained_models/audioset/as_mdl_3.pth mAP: 0.440555, AUC: 0.973639, d-prime: 2.739613 61 | Model 4 ../../pretrained_models/audioset/as_mdl_4.pth mAP: 0.439713, AUC: 0.973579, d-prime: 2.738213 62 | Ensemble 5 Models mAP: 0.469050, AUC: 0.978875, d-prime: 2.872325 63 | 64 | # Ensemble All 10 AudioSet Models Presented in the PSLA Paper 65 | ---------------Ensemble Result Summary--------------- 66 | Model 0 ../pretrained_models/audioset/as_mdl_1.pth mAP: 0.440298, AUC: 0.974047, d-prime: 2.749102 67 | Model 1 ../pretrained_models/audioset/as_mdl_0.pth mAP: 0.439790, AUC: 0.973978, d-prime: 2.747493 68 | Model 2 ../pretrained_models/audioset/as_mdl_2.pth mAP: 0.439322, AUC: 0.973591, d-prime: 2.738487 69 | Model 3 ../pretrained_models/audioset/as_mdl_3.pth mAP: 0.440555, AUC: 0.973639, d-prime: 2.739613 70 | Model 4 ../pretrained_models/audioset/as_mdl_4.pth mAP: 0.439713, AUC: 0.973579, d-prime: 2.738213 71 | Model 5 ../pretrained_models/audioset/as_mdl_5.pth mAP: 0.438852, AUC: 0.973534, d-prime: 2.737183 72 | Model 6 ../pretrained_models/audioset/as_mdl_6.pth mAP: 0.394262, AUC: 0.973054, d-prime: 2.726193 73 | Model 7 ../pretrained_models/audioset/as_mdl_7.pth mAP: 0.370860, AUC: 0.961183, d-prime: 2.495504 74 | Model 8 ../pretrained_models/audioset/as_mdl_8.pth mAP: 0.426624, AUC: 0.973353, d-prime: 2.733006 75 | Model 9 ../pretrained_models/audioset/as_mdl_9.pth mAP: 0.372092, AUC: 0.970509, d-prime: 2.670498 76 | Ensemble 10 Models mAP: 0.474380, AUC: 0.981043, d-prime: 2.935611 77 | ``` -------------------------------------------------------------------------------- /src/ensemble/ensemble.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 6/23/21 5:38 PM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : ensemble.py 7 | 8 | # get the ensemble result 9 | 10 | import os, sys, argparse 11 | parentdir = str(os.path.abspath(os.path.join(__file__ ,"../.."))) 12 | sys.path.append(parentdir) 13 | import dataloaders 14 | import models 15 | from utilities import * 16 | from traintest import validate 17 | import numpy as np 18 | from scipy import stats 19 | import torch 20 | 21 | def get_ensemble_res(mdl_list, base_path, dataset='audioset'): 22 | num_class = 527 if dataset=='audioset' else 200 23 | # the 0-len(mdl_list) rows record the results of single models, the last row record the result of the ensemble model. 24 | ensemble_res = np.zeros([len(mdl_list) + 1, 3]) 25 | if os.path.exists(base_path) == False: 26 | os.mkdir(base_path) 27 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 28 | 29 | for model_idx, mdl in enumerate(mdl_list): 30 | print('-----------------------') 31 | print('now loading model {:d}: {:s}'.format(model_idx, mdl)) 32 | 33 | # sd = torch.load('/Users/yuan/Documents/ast/pretrained_models/audio_model_wa.pth', map_location=device) 34 | sd = torch.load(mdl, map_location=device) 35 | if 'module.effnet._fc.weight' in sd.keys(): 36 | del sd['module.effnet._fc.weight'] 37 | del sd['module.effnet._fc.bias'] 38 | torch.save(sd, mdl) 39 | audio_model = models.EffNetAttention(label_dim=num_class , b=2, pretrain=False, head_num=4) 40 | audio_model = torch.nn.DataParallel(audio_model) 41 | audio_model.load_state_dict(sd, strict=True) 42 | 43 | args.exp_dir = base_path 44 | 45 | stats, _ = validate(audio_model, eval_loader, args, model_idx) 46 | mAP = np.mean([stat['AP'] for stat in stats]) 47 | mAUC = np.mean([stat['auc'] for stat in stats]) 48 | dprime = d_prime(mAUC) 49 | ensemble_res[model_idx, :] = [mAP, mAUC, dprime] 50 | print("Model {:d} {:s} mAP: {:.6f}, AUC: {:.6f}, d-prime: {:.6f}".format(model_idx, mdl, mAP, mAUC, dprime)) 51 | 52 | # calculate the ensemble result 53 | # get the ground truth label 54 | target = np.loadtxt(base_path + '/predictions/target.csv', delimiter=',') 55 | # get the ground truth label 56 | prediction_sample = np.loadtxt(base_path + '/predictions/predictions_0.csv', delimiter=',') 57 | # allocate memory space for the ensemble prediction 58 | predictions_table = np.zeros([len(mdl_list) , prediction_sample.shape[0], prediction_sample.shape[1]]) 59 | for model_idx in range(0, len(mdl_list)): 60 | predictions_table[model_idx, :, :] = np.loadtxt(base_path + '/predictions/predictions_' + str(model_idx) + '.csv', delimiter=',') 61 | model_idx += 1 62 | 63 | ensemble_predictions = np.mean(predictions_table, axis=0) 64 | stats = calculate_stats(ensemble_predictions, target) 65 | ensemble_mAP = np.mean([stat['AP'] for stat in stats]) 66 | ensemble_mAUC = np.mean([stat['auc'] for stat in stats]) 67 | ensemble_dprime = d_prime(ensemble_mAUC) 68 | ensemble_res[-1, :] = [ensemble_mAP, ensemble_mAUC, ensemble_dprime] 69 | print('---------------Ensemble Result Summary---------------') 70 | for model_idx in range(len(mdl_list)): 71 | print("Model {:d} {:s} mAP: {:.6f}, AUC: {:.6f}, d-prime: {:.6f}".format(model_idx, mdl_list[model_idx], ensemble_res[model_idx, 0], ensemble_res[model_idx, 1], ensemble_res[model_idx, 2])) 72 | print("Ensemble {:d} Models mAP: {:.6f}, AUC: {:.6f}, d-prime: {:.6f}".format(len(mdl_list), ensemble_mAP, ensemble_mAUC, ensemble_dprime)) 73 | np.savetxt(base_path + '/ensemble_result.csv', ensemble_res, delimiter=',') 74 | 75 | def d_prime(auc): 76 | standard_normal = stats.norm() 77 | d_prime = standard_normal.ppf(auc) * np.sqrt(2.0) 78 | return d_prime 79 | 80 | # dataloader settings 81 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 82 | args = parser.parse_args() 83 | 84 | dataset = 'audioset' 85 | # uncomment this line if you want test ensemble on fsd50k 86 | # dataset = 'fsd50k' 87 | 88 | args.dataset = dataset 89 | if dataset == 'audioset': 90 | args.data_eval = '../../egs/audioset/datafiles/eval_data.json' 91 | else: 92 | args.data_eval = '../../egs/fsd50k/datafiles/fsd50k_eval_full.json' 93 | args.label_csv='../../egs/' + dataset + '/class_labels_indices.csv' 94 | args.loss_fn = torch.nn.BCELoss() 95 | norm_stats = {'audioset': [-4.6476, 4.5699], 'fsd50k': [-4.6476, 4.5699]} 96 | target_length = {'audioset': 1056, 'fsd50k': 3000} 97 | batch_size = 200 if dataset=='audioset' else 48 98 | 99 | val_audio_conf = {'num_mel_bins': 128, 'target_length': target_length[args.dataset], 'freqm': 0, 'timem': 0, 'mixup': 0, 'dataset': args.dataset, 'mode':'evaluation', 'mean':norm_stats[args.dataset][0], 'std':norm_stats[args.dataset][1], 'noise': False} 100 | eval_loader = torch.utils.data.DataLoader( 101 | dataloaders.AudiosetDataset(args.data_eval, label_csv=args.label_csv, audio_conf=val_audio_conf), 102 | batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True) 103 | 104 | if dataset == 'audioset': 105 | # ensemble 3 audioset models trained with exactly same setting, but different random seeds 106 | mdl_list_3 = ['../../pretrained_models/audioset/as_mdl_'+str(i)+'.pth' for i in range(3)] 107 | 108 | # ensemble top 5 audioset models, mAP = 109 | mdl_list_5 = ['../../pretrained_models/audioset/as_mdl_'+str(i)+'.pth' for i in range(5)] 110 | 111 | # ensemble entire 10 audioset models, mAP = 112 | mdl_list_10 = ['../../pretrained_models/audioset/as_mdl_'+str(i)+'.pth' for i in range(10)] 113 | 114 | get_ensemble_res(mdl_list_3, './ensemble_as', dataset) 115 | get_ensemble_res(mdl_list_5, './ensemble_as', dataset) 116 | get_ensemble_res(mdl_list_10, './ensemble_as', dataset) 117 | 118 | else: 119 | pass -------------------------------------------------------------------------------- /egs/fsd50k/class_labels_indices.csv: -------------------------------------------------------------------------------- 1 | index,mid,display_name 2 | 0,/m/07q2z82,Accelerating_and_revving_and_vroom 3 | 1,/m/0mkg,Accordion 4 | 2,/m/042v_gx,Acoustic_guitar 5 | 3,/m/0k5j,Aircraft 6 | 4,/m/07pp_mv,Alarm 7 | 5,/m/0jbk,Animal 8 | 6,/m/028ght,Applause 9 | 7,/m/05tny_,Bark 10 | 8,/m/0bm02,Bass_drum 11 | 9,/m/018vs,Bass_guitar 12 | 10,/m/03dnzn,Bathtub_(filling_or_washing) 13 | 11,/m/0395lw,Bell 14 | 12,/m/0199g,Bicycle 15 | 13,/m/0gy1t2s,Bicycle_bell 16 | 14,/m/015p6,Bird 17 | 15,/m/020bb7,Bird_vocalization_and_bird_call_and_bird_song 18 | 16,/m/019jd,Boat_and_Water_vehicle 19 | 17,/m/0dv3j,Boiling 20 | 18,/m/07qqyl4,Boom 21 | 19,/m/0l14_3,Bowed_string_instrument 22 | 20,/m/01kcd,Brass_instrument 23 | 21,/m/0lyf6,Breathing 24 | 22,/m/03q5_w,Burping_and_eructation 25 | 23,/m/01bjv,Bus 26 | 24,/m/07pjwq1,Buzz 27 | 25,/m/0dv5r,Camera 28 | 26,/m/0k4j,Car 29 | 27,/t/dd00134,Car_passing_by 30 | 28,/m/01yrx,Cat 31 | 29,/m/07rkbfh,Chatter 32 | 30,/m/053hz1,Cheering 33 | 31,/m/03cczk,Chewing_and_mastication 34 | 32,/m/09b5t,Chicken_and_rooster 35 | 33,/m/0ytgt,Child_speech_and_kid_speaking 36 | 34,/m/0f8s22,Chime 37 | 35,/m/07q7njn,Chink_and_clink 38 | 36,/m/07pggtn,Chirp_and_tweet 39 | 37,/m/07rgt08,Chuckle_and_chortle 40 | 38,/m/03w41f,Church_bell 41 | 39,/m/0l15bq,Clapping 42 | 40,/m/01x3z,Clock 43 | 41,/m/0242l,Coin_(dropping) 44 | 42,/m/01m2v,Computer_keyboard 45 | 43,/m/01h8n0,Conversation 46 | 44,/m/01b_21,Cough 47 | 45,/m/0239kh,Cowbell 48 | 46,/m/07qs1cx,Crack 49 | 47,/m/07pzfmf,Crackle 50 | 48,/m/0bm0k,Crash_cymbal 51 | 49,/m/09xqv,Cricket 52 | 50,/m/04s8yn,Crow 53 | 51,/m/03qtwd,Crowd 54 | 52,/t/dd00112,Crumpling_and_crinkling 55 | 53,/m/07plct2,Crushing 56 | 54,/m/0463cq4,Crying_and_sobbing 57 | 55,/m/0642b4,Cupboard_open_or_close 58 | 56,/m/023pjk,Cutlery_and_silverware 59 | 57,/m/01qbl,Cymbal 60 | 58,/m/04brg2,Dishes_and_pots_and_pans 61 | 59,/m/0bt9lr,Dog 62 | 60,/m/068hy,Domestic_animals_and_pets 63 | 61,/t/dd00071,Domestic_sounds_and_home_sounds 64 | 62,/m/02dgv,Door 65 | 63,/m/03wwcy,Doorbell 66 | 64,/m/0fqfqc,Drawer_open_or_close 67 | 65,/m/01d380,Drill 68 | 66,/m/07r5v4s,Drip 69 | 67,/m/026t6,Drum 70 | 68,/m/02hnl,Drum_kit 71 | 69,/m/02sgy,Electric_guitar 72 | 70,/m/02mk9,Engine 73 | 71,/t/dd00130,Engine_starting 74 | 72,/m/014zdl,Explosion 75 | 73,/m/02_nn,Fart 76 | 74,/t/dd00004,Female_singing 77 | 75,/m/02zsn,Female_speech_and_woman_speaking 78 | 76,/m/07p7b8y,Fill_(with_liquid) 79 | 77,/m/025_jnm,Finger_snapping 80 | 78,/m/02_41,Fire 81 | 79,/m/0g6b5,Fireworks 82 | 80,/m/0cmf2,Fixed-wing_aircraft_and_airplane 83 | 81,/m/025rv6n,Fowl 84 | 82,/m/09ld4,Frog 85 | 83,/m/0dxrf,Frying_(food) 86 | 84,/m/07s0dtb,Gasp 87 | 85,/m/07r660_,Giggle 88 | 86,/m/039jq,Glass 89 | 87,/m/0dwtp,Glockenspiel 90 | 88,/m/0mbct,Gong 91 | 89,/m/0ghcn6,Growling 92 | 90,/m/0342h,Guitar 93 | 91,/m/01dwxx,Gull_and_seagull 94 | 92,/m/032s66,Gunshot_and_gunfire 95 | 93,/m/07swgks,Gurgling 96 | 94,/m/03l9g,Hammer 97 | 95,/m/0k65p,Hands 98 | 96,/m/03qjg,Harmonica 99 | 97,/m/03m5k,Harp 100 | 98,/m/03qtq,Hi-hat 101 | 99,/m/07rjwbb,Hiss 102 | 100,/t/dd00012,Human_group_actions 103 | 101,/m/09l8g,Human_voice 104 | 102,/m/07pb8fc,Idling 105 | 103,/m/03vt0,Insect 106 | 104,/m/05148p4,Keyboard_(musical) 107 | 105,/m/03v3yw,Keys_jangling 108 | 106,/m/07r4wb8,Knock 109 | 107,/m/01j3sz,Laughter 110 | 108,/m/04k94,Liquid 111 | 109,/m/0ch8v,Livestock_and_farm_animals_and_working_animals 112 | 110,/t/dd00003,Male_singing 113 | 111,/m/05zppz,Male_speech_and_man_speaking 114 | 112,/m/0j45pbj,Mallet_percussion 115 | 113,/m/0dwsp,Marimba_and_xylophone 116 | 114,/m/02x984l,Mechanical_fan 117 | 115,/t/dd00077,Mechanisms 118 | 116,/m/07qrkrw,Meow 119 | 117,/m/0fx9l,Microwave_oven 120 | 118,/m/012f08,Motor_vehicle_(road) 121 | 119,/m/04_sv,Motorcycle 122 | 120,/m/04rlf,Music 123 | 121,/m/04szw,Musical_instrument 124 | 122,/m/05kq4,Ocean 125 | 123,/m/013y1f,Organ 126 | 124,/m/05mxj0q,Packing_tape_and_duct_tape 127 | 125,/m/0l14md,Percussion 128 | 126,/m/05r5c,Piano 129 | 127,/m/0fx80y,Plucked_string_instrument 130 | 128,/m/07prgkl,Pour 131 | 129,/m/0_ksk,Power_tool 132 | 130,/m/01m4t,Printer 133 | 131,/m/02yds9,Purr 134 | 132,/m/0ltv,Race_car_and_auto_racing 135 | 133,/m/06d_3,Rail_transport 136 | 134,/m/06mb1,Rain 137 | 135,/m/07r10fb,Raindrop 138 | 136,/m/02bm9n,Ratchet_and_pawl 139 | 137,/m/07qn4z3,Rattle 140 | 138,/m/05r5wn,Rattle_(instrument) 141 | 139,/m/09hlz4,Respiratory_sounds 142 | 140,/m/01hnzm,Ringtone 143 | 141,/m/06h7j,Run 144 | 142,/m/01b82r,Sawing 145 | 143,/m/01lsmm,Scissors 146 | 144,/m/01hgjl,Scratching_(performance_technique) 147 | 145,/m/03qc9zr,Screaming 148 | 146,/m/07q8k13,Screech 149 | 147,/m/07rn7sz,Shatter 150 | 148,/m/07p6fty,Shout 151 | 149,/m/07plz5l,Sigh 152 | 150,/m/015lz1,Singing 153 | 151,/m/0130jx,Sink_(filling_or_washing) 154 | 152,/m/03kmc9,Siren 155 | 153,/m/06_fw,Skateboard 156 | 154,/m/07rjzl8,Slam 157 | 155,/m/02y_763,Sliding_door 158 | 156,/m/06rvn,Snare_drum 159 | 157,/m/01hsr_,Sneeze 160 | 158,/m/09x0r,Speech 161 | 159,/m/0brhx,Speech_synthesizer 162 | 160,/m/07rrlb6,Splash_and_splatter 163 | 161,/m/07q6cd_,Squeak 164 | 162,/m/0j6m2,Stream 165 | 163,/m/07s0s5r,Strum 166 | 164,/m/0195fx,Subway_and_metro_and_underground 167 | 165,/m/01p970,Tabla 168 | 166,/m/07brj,Tambourine 169 | 167,/m/07qcpgn,Tap 170 | 168,/m/07qcx4z,Tearing 171 | 169,/m/07cx4,Telephone 172 | 170,/m/07qnq_y,Thump_and_thud 173 | 171,/m/0ngt1,Thunder 174 | 172,/m/0jb2l,Thunderstorm 175 | 173,/m/07qjznt,Tick 176 | 174,/m/07qjznl,Tick-tock 177 | 175,/m/01jt3m,Toilet_flush 178 | 176,/m/07k1x,Tools 179 | 177,/m/0btp2,Traffic_noise_and_roadway_noise 180 | 178,/m/07jdr,Train 181 | 179,/m/07pqc89,Trickle_and_dribble 182 | 180,/m/07r04,Truck 183 | 181,/m/07gql,Trumpet 184 | 182,/m/0c2wf,Typewriter 185 | 183,/m/0316dw,Typing 186 | 184,/m/07yv9,Vehicle 187 | 185,/m/0912c9,Vehicle_horn_and_car_horn_and_honking 188 | 186,/m/07pbtc8,Walk_and_footsteps 189 | 187,/m/0838f,Water 190 | 188,/m/02jz0l,Water_tap_and_faucet 191 | 189,/m/034srq,Waves_and_surf 192 | 190,/m/02rtxlg,Whispering 193 | 191,/m/07rqsjt,Whoosh_and_swoosh_and_swish 194 | 192,/m/01280g,Wild_animals 195 | 193,/m/03m9d0z,Wind 196 | 194,/m/026fgl,Wind_chime 197 | 195,/m/085jw,Wind_instrument_and_woodwind_instrument 198 | 196,/m/083vt,Wood 199 | 197,/m/081rb,Writing 200 | 198,/m/07sr1lc,Yell 201 | 199,/m/01s0vc,Zipper_(clothing) -------------------------------------------------------------------------------- /src/label_enhancement/fix_type1.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 12/23/20 2:02 PM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : fix_type1.py 7 | 8 | # enhance the label based on the model prediction, fixing the TYPE-I error. 9 | # TYPE-I error: an audio clip is labeled with a parent class, 10 | # but not also labeled as a child class when it does in fact 11 | # contain the audio event of the child class. 12 | 13 | import json 14 | import os 15 | 16 | import numpy as np 17 | 18 | def generate_child_dict(): 19 | with open('../utilities/ontology.json', 'r', encoding='utf8')as fp: 20 | ontology = json.load(fp) 21 | # map each class to its children classes 22 | child_dict = {} 23 | for audio_class in ontology: 24 | cur_id = audio_class['id'] 25 | # avoid abstract and discountinued class 26 | cur_restriction = audio_class['restrictions'] 27 | if cur_restriction != ['abstract']: 28 | child_dict[cur_id] = audio_class['child_ids'] 29 | return child_dict 30 | 31 | def enhance_label_type1(json_path, output_path, child_dict, labels_code_list, score_threshold, pred, dataset='audioset'): 32 | num_class = 527 if dataset == 'audioset' else 200 33 | original_label_num, fixed_label_num, fix_case_num = 0, 0, 0 34 | # these are just to track the change for analysis 35 | child_case_cnt, par_case_cnt, class_sample_cnt = [0] * num_class, [0] * num_class, [0] * num_class 36 | par_child_dict = {} 37 | with open(json_path, 'r', encoding='utf8')as fp: 38 | data_file = json.load(fp) 39 | data = data_file['data'] 40 | # for each audio sample 41 | for i, sample in enumerate(data): 42 | sample_labels = sample['labels'].split(',') 43 | new_labels = sample_labels.copy() 44 | original_label_num += len(sample_labels) 45 | # fpr each label of the audio sample 46 | for label in sample_labels: 47 | class_sample_cnt[code2idx[label]] += 1 48 | # there are some FSD50K classes not included in the AudioSet ontology, ingore them 49 | if label not in ['/m/09l8g', '/m/0bm0k', '/t/dd00012', '/m/09hlz4', '/t/dd00071'] or dataset == 'audioset': 50 | # if this label has child class(es) 51 | if child_dict[label] != None: 52 | # one label might have multiple child classes 53 | for child_label in child_dict[label]: 54 | #if the child class is in 527-class list (i.e., not abstract, not discontinued, etc) 55 | if child_label in labels_code_list: 56 | # if the child label not already in the original label set 57 | if child_label not in new_labels: 58 | # get the index of the child class 59 | child_label_idx = code2idx[child_label] 60 | # the model prediction score on the child class of this sample 61 | pred_score = pred[i, child_label_idx] 62 | # if the prediction score is higher than the threshold 63 | if pred_score > score_threshold[child_label_idx]: 64 | # add the child label 65 | new_labels.append(child_label) 66 | # below are just to track the change for analysis 67 | fix_case_num += 1 68 | par_case_cnt[code2idx[label]] += 1 69 | child_case_cnt[code2idx[child_label]] += 1 70 | if str(code2idx[label]) + '_' + str(code2idx[child_label]) not in par_child_dict: 71 | par_child_dict[str(code2idx[label]) + '_' + str(code2idx[child_label])] = 1 72 | else: 73 | par_child_dict[str(code2idx[label]) + '_' + str(code2idx[child_label])] += 1 74 | # remove repeated labels and add the new labels to the dataset 75 | data[i]['labels'] = ','.join(list(set(new_labels))) 76 | fixed_label_num += len(list(set(new_labels))) 77 | output = {'data': data} 78 | with open(output_path, 'w') as f: 79 | json.dump(output, f, indent=1) 80 | print('Added {:d} ({:.1f}%) labels to original {:d} original labels'.format((fixed_label_num-original_label_num), (fixed_label_num / original_label_num-1)*100, original_label_num)) 81 | return child_case_cnt, par_case_cnt, par_child_dict, class_sample_cnt 82 | 83 | if __name__ == '__main__': 84 | # 'audioset' or 'fsd50k' 85 | # for audioset, we demo label enhancement on the balanced training set 86 | dataset = 'fsd50k' 87 | num_class = 527 if dataset == 'audioset' else 200 88 | 89 | # generate a dict that maps each class to its children classes 90 | child_dict = generate_child_dict() 91 | 92 | # generate a dict that maps label code to index 93 | with open('../../egs/' + dataset + '/class_labels_indices.csv') as f: 94 | labels = f.readlines() 95 | 96 | labels = labels[1:] 97 | labels_code_list = [label.strip('\n').split(',')[1] for label in labels] 98 | code2idx = {labels[i].strip('\n').split(',')[1]: i for i in range(len(labels))} 99 | 100 | # the label enhancement algorithm depends on the soft model output prediction and ground truth 101 | if dataset == 'fsd50k': 102 | target_path = "./predictions_fsd/target.csv" 103 | pred_path = "./predictions_fsd/predictions.csv" 104 | else: 105 | target_path = "./predictions_as_bal/target.csv" 106 | pred_path = "./predictions_as_bal/predictions.csv" 107 | target = np.loadtxt(target_path, delimiter=',') 108 | pred = np.loadtxt(pred_path, delimiter=',') 109 | 110 | # generate label modification thresholds for each class 111 | mean_score = [np.mean(pred[np.where(target[:, i] == 1)[0], i]) for i in range(num_class)] 112 | median_score = [np.median(pred[np.where(target[:, i] == 1)[0], i]) for i in range(num_class)] 113 | twentyfivepercentile = [np.percentile(pred[np.where(target[:, i] == 1)[0], i], 25) for i in range(num_class)] 114 | tenpercentile = [np.percentile(pred[np.where(target[:, i] == 1)[0], i], 10) for i in range(num_class)] 115 | fivepercentile = [np.percentile(pred[np.where(target[:, i] == 1)[0], i], 5) for i in range(num_class)] 116 | 117 | thres_dict = {'mean': mean_score, 'median': median_score, '25': twentyfivepercentile, '10': tenpercentile, '5': fivepercentile} 118 | # 119 | for p in ['mean', 'median', '25', '10', '5']: 120 | threshold = thres_dict[p] 121 | if dataset == 'fsd50k': 122 | original_datafile = "../../egs/fsd50k/datafiles/fsd50k_tr_full.json" 123 | enhanced_datafile = "../../egs/fsd50k/datafiles/fsd50k_tr_full_type1_"+ p +".json" 124 | elif dataset == 'audioset': 125 | original_datafile = "../../egs/audioset/datafiles/balanced_train_data.json" 126 | enhanced_datafile = "../../egs/audioset/datafiles/balanced_train_data_type1_"+ p +".json" 127 | 128 | enhance_label_type1(original_datafile, enhanced_datafile, child_dict, labels_code_list, threshold, pred, dataset=dataset) 129 | # (optional) generate balanced sampling weight for each enhanced label set. 130 | os.system('python ../gen_weight_file.py --dataset {:s} --label_indices_path {:s} --datafile_path {:s}'.format(dataset, '../../egs/' + dataset + '/class_labels_indices.csv', enhanced_datafile)) -------------------------------------------------------------------------------- /src/label_enhancement/fix_type2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 12/23/20 2:02 PM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : fix_type2.py 7 | 8 | # enhance the label based on the model prediction, fixing the TYPE-II error. 9 | # Type II error: an audio clip is labeled with a child class, 10 | # but not labeled with corresponding parent classes. 11 | 12 | import json 13 | import os 14 | import numpy as np 15 | 16 | # map each class to its direct parent class 17 | def generate_parent_dict(): 18 | with open('../utilities/ontology.json', 'r', encoding='utf8')as fp: 19 | ontology = json.load(fp) 20 | # label: direct parent class, none if it is root 21 | parent_dict = {} 22 | for audio_class in ontology: 23 | cur_id = audio_class['id'] 24 | # avoid abstract and discountinued class 25 | cur_restriction = audio_class['restrictions'] 26 | if cur_restriction != ['abstract']: 27 | if cur_id not in parent_dict: 28 | parent_dict[cur_id] = None 29 | cur_child = audio_class['child_ids'] 30 | for child in cur_child: 31 | if (child not in parent_dict) or parent_dict[child] == None: 32 | parent_dict[child] = [cur_id] 33 | else: 34 | parent_dict[child].append(cur_id) 35 | return parent_dict 36 | 37 | def dfs(cur_node, par_list, parent_dict): 38 | par_list.append(cur_node) 39 | if parent_dict[cur_node] != None: 40 | for par in parent_dict[cur_node]: 41 | dfs(par, par_list, parent_dict) 42 | 43 | # do dfs search to find all parent classes 44 | def dfs_dict(parent_dict): 45 | dfs_parent_dict = {} 46 | for label in parent_dict.keys(): 47 | if parent_dict[label] != None: 48 | par_list = [] 49 | dfs(label, par_list, parent_dict) 50 | dfs_parent_dict[label] = list(set(par_list)) 51 | else: 52 | dfs_parent_dict[label] = None 53 | return dfs_parent_dict 54 | 55 | # map all classes into a background class except interest classes, for general purpose, index in interest_cla 56 | def enhance_label_type2(json_path, output_path, par_dict, labels_code_list, score_threshold, pred, dataset='audioset'): 57 | num_class = 527 if dataset == 'audioset' else 200 58 | original_label_num, fixed_label_num, fix_case_num = 0, 0, 0 59 | # these are just to track the change for analysis 60 | child_case_cnt, par_case_cnt, class_sample_cnt = [0] * num_class, [0] * num_class, [0] * num_class 61 | child_par_dict = {} 62 | with open(json_path,'r',encoding='utf8')as fp: 63 | data_file = json.load(fp) 64 | data = data_file['data'] 65 | # for each audio sample 66 | for i, sample in enumerate(data): 67 | sample_labels = sample['labels'].split(',') 68 | new_labels = sample_labels.copy() 69 | original_label_num += len(sample_labels) 70 | # for each label of the audio sample 71 | for label in sample_labels: 72 | class_sample_cnt[code2idx[label]] += 1 73 | # there are some FSD50K classes not included in the AudioSet ontology, ingore them 74 | if label not in ['/m/09l8g', '/m/0bm0k', '/t/dd00012', '/m/09hlz4', '/t/dd00071'] or dataset=='audioset': 75 | # if this label has parent class 76 | if par_dict[label] != None: 77 | # one label might have multiple parent classes 78 | for par_label in par_dict[label]: 79 | #if the parent class is in 527-class list (i.e., not abstract, not discontinued, etc) 80 | if par_label in labels_code_list: 81 | # if the parent label not already in the original label set 82 | if par_label not in new_labels: 83 | # get the index of the parent class 84 | par_label_idx = code2idx[par_label] 85 | # the model prediction score on the parent class of this sample 86 | pred_score = pred[i, par_label_idx] 87 | # if the prediction score is higher than the threshold 88 | if pred_score > score_threshold[par_label_idx]: 89 | # add the parent label 90 | new_labels.append(par_label) 91 | # below are just to track the change for analysis 92 | fix_case_num += 1 93 | child_case_cnt[code2idx[label]] += 1 94 | par_case_cnt[code2idx[par_label]] += 1 95 | if str(code2idx[label]) + '_' + str(code2idx[par_label]) not in child_par_dict: 96 | child_par_dict[str(code2idx[label]) + '_' + str(code2idx[par_label])] = 1 97 | else: 98 | child_par_dict[str(code2idx[label]) + '_' + str(code2idx[par_label])] += 1 99 | # remove repeated labels and add the new labels to the dataset 100 | data[i]['labels'] = ','.join(list(set(new_labels))) 101 | fixed_label_num += len(list(set(new_labels))) 102 | output = {'data': data} 103 | with open(output_path, 'w') as f: 104 | json.dump(output, f, indent=1) 105 | print('Added {:d} ({:.1f}%) labels to original {:d} original labels'.format((fixed_label_num-original_label_num), (fixed_label_num / original_label_num-1)*100, original_label_num)) 106 | return child_case_cnt, par_case_cnt, child_par_dict, class_sample_cnt 107 | 108 | if __name__ == '__main__': 109 | # 'audioset' or 'fsd50k' 110 | # for audioset, we demo label enhancement on the balanced training set 111 | dataset = 'fsd50k' 112 | num_class = 527 if dataset == 'audioset' else 200 113 | 114 | # map each class to ALL its parent classes 115 | par_dict = generate_parent_dict() 116 | dfs_par_dict = dfs_dict(par_dict) 117 | 118 | # generate a dict that maps label code to index 119 | with open('../../egs/' + dataset + '/class_labels_indices.csv') as f: 120 | labels = f.readlines() 121 | 122 | labels = labels[1:] 123 | labels_code_list = [label.strip('\n').split(',')[1] for label in labels] 124 | code2idx = {labels[i].strip('\n').split(',')[1]: i for i in range(len(labels))} 125 | 126 | # the label enhancement algorithm depends on the soft model output prediction and ground truth 127 | if dataset == 'fsd50k': 128 | target_path = "./predictions_fsd/target.csv" 129 | pred_path = "./predictions_fsd/predictions.csv" 130 | else: 131 | target_path = "./predictions_as_bal/target.csv" 132 | pred_path = "./predictions_as_bal/predictions.csv" 133 | target = np.loadtxt(target_path, delimiter=',') 134 | pred = np.loadtxt(pred_path, delimiter=',') 135 | 136 | # first calculate a median score for each class 137 | mean_score = [np.mean(pred[np.where(target[:, i] == 1)[0], i]) for i in range(num_class)] 138 | median_score = [np.median(pred[np.where(target[:, i] == 1)[0], i]) for i in range(num_class)] 139 | twentyfivepercentile = [np.percentile(pred[np.where(target[:, i] == 1)[0], i], 25) for i in range(num_class)] 140 | tenpercentile = [np.percentile(pred[np.where(target[:, i] == 1)[0], i], 10) for i in range(num_class)] 141 | fivepercentile = [np.percentile(pred[np.where(target[:, i] == 1)[0], i], 5) for i in range(num_class)] 142 | 143 | thres_dict = {'mean': mean_score, 'median': median_score, '25': twentyfivepercentile, '10': tenpercentile, '5': fivepercentile} 144 | for p in ['mean', 'median', '25', '10', '5']: 145 | threshold = thres_dict[p] 146 | if dataset == 'fsd50k': 147 | original_datafile = "../../egs/fsd50k/datafiles/fsd50k_tr_full.json" 148 | enhanced_datafile = "../../egs/fsd50k/datafiles/fsd50k_tr_full_type2_"+ p +".json" 149 | elif dataset == 'audioset': 150 | original_datafile = "../../egs/audioset/datafiles/balanced_train_data.json" 151 | enhanced_datafile = "../../egs/audioset/datafiles/balanced_train_data_type2_"+ p +".json" 152 | 153 | enhance_label_type2(original_datafile, enhanced_datafile, dfs_par_dict, labels_code_list, threshold, pred, dataset) 154 | # (optional) generate balanced sampling weight for each enhanced label set. 155 | os.system('python ../gen_weight_file.py --dataset {:s} --label_indices_path {:s} --datafile_path {:s}'.format(dataset, '../../egs/' + dataset + '/class_labels_indices.csv', enhanced_datafile)) 156 | -------------------------------------------------------------------------------- /src/dataloaders/audioset_dataset.py: -------------------------------------------------------------------------------- 1 | # Author: David Harwath 2 | # with some functions borrowed from https://github.com/SeanNaren/deepspeech.pytorch 3 | import csv 4 | import json 5 | import torchaudio 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional 9 | from torch.utils.data import Dataset 10 | import random 11 | 12 | def make_index_dict(label_csv): 13 | index_lookup = {} 14 | with open(label_csv, 'r') as f: 15 | csv_reader = csv.DictReader(f) 16 | line_count = 0 17 | for row in csv_reader: 18 | index_lookup[row['mid']] = row['index'] 19 | line_count += 1 20 | return index_lookup 21 | 22 | def make_name_dict(label_csv): 23 | name_lookup = {} 24 | with open(label_csv, 'r') as f: 25 | csv_reader = csv.DictReader(f) 26 | line_count = 0 27 | for row in csv_reader: 28 | name_lookup[row['index']] = row['display_name'] 29 | line_count += 1 30 | return name_lookup 31 | 32 | def lookup_list(index_list, label_csv): 33 | label_list = [] 34 | table = make_name_dict(label_csv) 35 | for item in index_list: 36 | label_list.append(table[item]) 37 | return label_list 38 | 39 | def preemphasis(signal,coeff=0.97): 40 | """perform preemphasis on the input signal. 41 | :param signal: The signal to filter. 42 | :param coeff: The preemphasis coefficient. 0 is none, default 0.97. 43 | :returns: the filtered signal. 44 | """ 45 | return np.append(signal[0],signal[1:]-coeff*signal[:-1]) 46 | 47 | class AudiosetDataset(Dataset): 48 | def __init__(self, dataset_json_file, audio_conf, label_csv=None): 49 | """ 50 | Dataset that manages audio recordings 51 | :param audio_conf: Dictionary containing the audio loading and preprocessing settings 52 | :param dataset_json_file 53 | """ 54 | self.datapath = dataset_json_file 55 | with open(dataset_json_file, 'r') as fp: 56 | data_json = json.load(fp) 57 | 58 | self.data = data_json['data'] 59 | self.audio_conf = audio_conf 60 | print('---------------the {:s} dataloader---------------'.format(self.audio_conf.get('mode'))) 61 | self.melbins = self.audio_conf.get('num_mel_bins') 62 | self.freqm = self.audio_conf.get('freqm') 63 | self.timem = self.audio_conf.get('timem') 64 | print('now using following mask: {:d} freq, {:d} time'.format(self.audio_conf.get('freqm'), self.audio_conf.get('timem'))) 65 | self.mixup = self.audio_conf.get('mixup') 66 | print('now using mix-up with rate {:f}'.format(self.mixup)) 67 | self.dataset = self.audio_conf.get('dataset') 68 | print('now process ' + self.dataset) 69 | # dataset spectrogram mean and std, used to normalize the input 70 | self.norm_mean = self.audio_conf.get('mean') 71 | self.norm_std = self.audio_conf.get('std') 72 | # skip_norm is a flag that if you want to skip normalization to compute the normalization stats using src/get_norm_stats.py, if Ture, input normalization will be skipped for correctly calculating the stats. 73 | # set it as True ONLY when you are getting the normalization stats. 74 | self.skip_norm = self.audio_conf.get('skip_norm') if self.audio_conf.get('skip_norm') else False 75 | if self.skip_norm: 76 | print('now skip normalization (use it ONLY when you are computing the normalization stats).') 77 | else: 78 | print('use dataset mean {:.3f} and std {:.3f} to normalize the input.'.format(self.norm_mean, self.norm_std)) 79 | # if add noise for data augmentation 80 | self.noise = self.audio_conf.get('noise') 81 | if self.noise == True: 82 | print('now use noise augmentation') 83 | 84 | self.index_dict = make_index_dict(label_csv) 85 | self.label_num = len(self.index_dict) 86 | print('number of classes is {:d}'.format(self.label_num)) 87 | 88 | def _wav2fbank(self, filename, filename2=None): 89 | # mixup 90 | if filename2 == None: 91 | waveform, sr = torchaudio.load(filename) 92 | waveform = waveform - waveform.mean() 93 | # mixup 94 | else: 95 | waveform1, sr = torchaudio.load(filename) 96 | waveform2, _ = torchaudio.load(filename2) 97 | 98 | waveform1 = waveform1 - waveform1.mean() 99 | waveform2 = waveform2 - waveform2.mean() 100 | 101 | if waveform1.shape[1] != waveform2.shape[1]: 102 | if waveform1.shape[1] > waveform2.shape[1]: 103 | # padding 104 | temp_wav = torch.zeros(1, waveform1.shape[1]) 105 | temp_wav[0, 0:waveform2.shape[1]] = waveform2 106 | waveform2 = temp_wav 107 | else: 108 | # cutting 109 | waveform2 = waveform2[0, 0:waveform1.shape[1]] 110 | 111 | # sample lambda from uniform distribution 112 | #mix_lambda = random.random() 113 | # sample lambda from beta distribtion 114 | mix_lambda = np.random.beta(10, 10) 115 | 116 | mix_waveform = mix_lambda * waveform1 + (1 - mix_lambda) * waveform2 117 | waveform = mix_waveform - mix_waveform.mean() 118 | 119 | fbank = torchaudio.compliance.kaldi.fbank(waveform, htk_compat=True, sample_frequency=sr, use_energy=False, 120 | window_type='hanning', num_mel_bins=self.melbins, dither=0.0, frame_shift=10) 121 | 122 | target_length = self.audio_conf.get('target_length') 123 | n_frames = fbank.shape[0] 124 | 125 | p = target_length - n_frames 126 | 127 | # cut and pad 128 | if p > 0: 129 | m = torch.nn.ZeroPad2d((0, 0, 0, p)) 130 | fbank = m(fbank) 131 | elif p < 0: 132 | fbank = fbank[0:target_length, :] 133 | 134 | if filename2 == None: 135 | return fbank, 0 136 | else: 137 | return fbank, mix_lambda 138 | 139 | def __getitem__(self, index): 140 | """ 141 | returns: image, audio, nframes 142 | where image is a FloatTensor of size (3, H, W) 143 | audio is a FloatTensor of size (N_freq, N_frames) for spectrogram, or (N_frames) for waveform 144 | nframes is an integer 145 | """ 146 | # do mix-up for this sample (controlled by the given mixup rate) 147 | if random.random() < self.mixup: 148 | datum = self.data[index] 149 | # find another sample to mix, also do balance sampling 150 | # sample the other sample from the multinomial distribution, will make the performance worse 151 | # mix_sample_idx = np.random.choice(len(self.data), p=self.sample_weight_file) 152 | # sample the other sample from the uniform distribution 153 | mix_sample_idx = random.randint(0, len(self.data)-1) 154 | mix_datum = self.data[mix_sample_idx] 155 | # get the mixed fbank 156 | fbank, mix_lambda = self._wav2fbank(datum['wav'], mix_datum['wav']) 157 | # initialize the label 158 | label_indices = np.zeros(self.label_num) 159 | # add sample 1 labels 160 | for label_str in datum['labels'].split(','): 161 | label_indices[int(self.index_dict[label_str])] += mix_lambda 162 | # add sample 2 labels 163 | for label_str in mix_datum['labels'].split(','): 164 | label_indices[int(self.index_dict[label_str])] += (1.0-mix_lambda) 165 | label_indices = torch.FloatTensor(label_indices) 166 | # if not do mixup 167 | else: 168 | datum = self.data[index] 169 | label_indices = np.zeros(self.label_num) 170 | fbank, mix_lambda = self._wav2fbank(datum['wav']) 171 | for label_str in datum['labels'].split(','): 172 | label_indices[int(self.index_dict[label_str])] = 1.0 173 | 174 | label_indices = torch.FloatTensor(label_indices) 175 | 176 | # SpecAug, not do for eval set 177 | freqm = torchaudio.transforms.FrequencyMasking(self.freqm) 178 | timem = torchaudio.transforms.TimeMasking(self.timem) 179 | fbank = torch.transpose(fbank, 0, 1) 180 | # this is just to satisfy new torchaudio version. 181 | fbank = fbank.unsqueeze(0) 182 | if self.freqm != 0: 183 | fbank = freqm(fbank) 184 | if self.timem != 0: 185 | fbank = timem(fbank) 186 | # squeeze back 187 | fbank = fbank.squeeze(0) 188 | fbank = torch.transpose(fbank, 0, 1) 189 | 190 | # normalize the input for both training and test 191 | if not self.skip_norm: 192 | fbank = (fbank - self.norm_mean) / (self.norm_std) 193 | # skip normalization the input if you are trying to get the normalization stats. 194 | else: 195 | pass 196 | 197 | if self.noise == True: 198 | fbank = fbank + torch.rand(fbank.shape[0], fbank.shape[1]) * np.random.rand() / 10 199 | fbank = torch.roll(fbank, np.random.randint(-10, 10), 0) 200 | 201 | # the output fbank shape is [time_frame_num, frequency_bins], e.g., [1024, 128] 202 | return fbank, label_indices 203 | 204 | def __len__(self): 205 | return len(self.data) -------------------------------------------------------------------------------- /src/utilities/util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pickle 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import random 7 | from collections import namedtuple 8 | 9 | def calc_recalls(S): 10 | """ 11 | Computes recall at 1, 5, and 10 given a similarity matrix S. 12 | By convention, rows of S are assumed to correspond to images and columns are captions. 13 | """ 14 | assert(S.dim() == 2) 15 | assert(S.size(0) == S.size(1)) 16 | if isinstance(S, torch.autograd.Variable): 17 | S = S.data 18 | n = S.size(0) 19 | A2I_scores, A2I_ind = S.topk(10, 0) 20 | I2A_scores, I2A_ind = S.topk(10, 1) 21 | A_r1 = AverageMeter() 22 | A_r5 = AverageMeter() 23 | A_r10 = AverageMeter() 24 | I_r1 = AverageMeter() 25 | I_r5 = AverageMeter() 26 | I_r10 = AverageMeter() 27 | for i in range(n): 28 | A_foundind = -1 29 | I_foundind = -1 30 | for ind in range(10): 31 | if A2I_ind[ind, i] == i: 32 | I_foundind = ind 33 | if I2A_ind[i, ind] == i: 34 | A_foundind = ind 35 | # do r1s 36 | if A_foundind == 0: 37 | A_r1.update(1) 38 | else: 39 | A_r1.update(0) 40 | if I_foundind == 0: 41 | I_r1.update(1) 42 | else: 43 | I_r1.update(0) 44 | # do r5s 45 | if A_foundind >= 0 and A_foundind < 5: 46 | A_r5.update(1) 47 | else: 48 | A_r5.update(0) 49 | if I_foundind >= 0 and I_foundind < 5: 50 | I_r5.update(1) 51 | else: 52 | I_r5.update(0) 53 | # do r10s 54 | if A_foundind >= 0 and A_foundind < 10: 55 | A_r10.update(1) 56 | else: 57 | A_r10.update(0) 58 | if I_foundind >= 0 and I_foundind < 10: 59 | I_r10.update(1) 60 | else: 61 | I_r10.update(0) 62 | 63 | recalls = {'A_r1':A_r1.avg, 'A_r5':A_r5.avg, 'A_r10':A_r10.avg, 64 | 'I_r1':I_r1.avg, 'I_r5':I_r5.avg, 'I_r10':I_r10.avg} 65 | #'A_meanR':A_meanR.avg, 'I_meanR':I_meanR.avg} 66 | 67 | return recalls 68 | 69 | def computeMatchmap(I, A): 70 | assert(I.dim() == 3) 71 | assert(A.dim() == 2) 72 | D = I.size(0) 73 | H = I.size(1) 74 | W = I.size(2) 75 | T = A.size(1) 76 | Ir = I.view(D, -1).t() 77 | matchmap = torch.mm(Ir, A) 78 | matchmap = matchmap.view(H, W, T) 79 | return matchmap 80 | 81 | def matchmapSim(M, simtype): 82 | assert(M.dim() == 3) 83 | if simtype == 'SISA': 84 | return M.mean() 85 | elif simtype == 'MISA': 86 | M_maxH, _ = M.max(0) 87 | M_maxHW, _ = M_maxH.max(0) 88 | return M_maxHW.mean() 89 | elif simtype == 'SIMA': 90 | M_maxT, _ = M.max(2) 91 | return M_maxT.mean() 92 | else: 93 | raise ValueError 94 | 95 | def sampled_margin_rank_loss(image_outputs, audio_outputs, nframes, margin=1., simtype='MISA'): 96 | """ 97 | Computes the triplet margin ranking loss for each anchor image/caption pair 98 | The impostor image/caption is randomly sampled from the minibatch 99 | """ 100 | assert(image_outputs.dim() == 4) 101 | assert(audio_outputs.dim() == 3) 102 | n = image_outputs.size(0) 103 | loss = torch.zeros(1, device=image_outputs.device, requires_grad=True) 104 | for i in range(n): 105 | I_imp_ind = i 106 | A_imp_ind = i 107 | while I_imp_ind == i: 108 | I_imp_ind = np.random.randint(0, n) 109 | while A_imp_ind == i: 110 | A_imp_ind = np.random.randint(0, n) 111 | nF = nframes[i] 112 | nFimp = nframes[A_imp_ind] 113 | anchorsim = matchmapSim(computeMatchmap(image_outputs[i], audio_outputs[i][:, 0:nF]), simtype) 114 | Iimpsim = matchmapSim(computeMatchmap(image_outputs[I_imp_ind], audio_outputs[i][:, 0:nF]), simtype) 115 | Aimpsim = matchmapSim(computeMatchmap(image_outputs[i], audio_outputs[A_imp_ind][:, 0:nFimp]), simtype) 116 | A2I_simdif = margin + Iimpsim - anchorsim 117 | if (A2I_simdif.data > 0).all(): 118 | loss = loss + A2I_simdif 119 | I2A_simdif = margin + Aimpsim - anchorsim 120 | if (I2A_simdif.data > 0).all(): 121 | loss = loss + I2A_simdif 122 | loss = loss / n 123 | return loss 124 | 125 | def compute_matchmap_similarity_matrix(image_outputs, audio_outputs, nframes, simtype='MISA'): 126 | """ 127 | Assumes image_outputs is a (batchsize, embedding_dim, rows, height) tensor 128 | Assumes audio_outputs is a (batchsize, embedding_dim, 1, time) tensor 129 | Returns similarity matrix S where images are rows and audios are along the columns 130 | """ 131 | assert(image_outputs.dim() == 4) 132 | assert(audio_outputs.dim() == 3) 133 | n = image_outputs.size(0) 134 | S = torch.zeros(n, n, device=image_outputs.device) 135 | for image_idx in range(n): 136 | for audio_idx in range(n): 137 | nF = max(1, nframes[audio_idx]) 138 | S[image_idx, audio_idx] = matchmapSim(computeMatchmap(image_outputs[image_idx], audio_outputs[audio_idx][:, 0:nF]), simtype) 139 | return S 140 | 141 | def compute_pooldot_similarity_matrix(image_outputs, audio_outputs, nframes): 142 | """ 143 | Assumes image_outputs is a (batchsize, embedding_dim, rows, height) tensor 144 | Assumes audio_outputs is a (batchsize, embedding_dim, 1, time) tensor 145 | Returns similarity matrix S where images are rows and audios are along the columns 146 | S[i][j] is computed as the dot product between the meanpooled embeddings of 147 | the ith image output and jth audio output 148 | """ 149 | assert(image_outputs.dim() == 4) 150 | assert(audio_outputs.dim() == 4) 151 | n = image_outputs.size(0) 152 | imagePoolfunc = nn.AdaptiveAvgPool2d((1, 1)) 153 | pooled_image_outputs = imagePoolfunc(image_outputs).squeeze(3).squeeze(2) 154 | audioPoolfunc = nn.AdaptiveAvgPool2d((1, 1)) 155 | pooled_audio_outputs_list = [] 156 | for idx in range(n): 157 | nF = max(1, nframes[idx]) 158 | pooled_audio_outputs_list.append(audioPoolfunc(audio_outputs[idx][:, :, 0:nF]).unsqueeze(0)) 159 | pooled_audio_outputs = torch.cat(pooled_audio_outputs_list).squeeze(3).squeeze(2) 160 | S = torch.mm(pooled_image_outputs, pooled_audio_outputs.t()) 161 | return S 162 | 163 | def one_imposter_index(i, N): 164 | imp_ind = random.randint(0, N - 2) 165 | if imp_ind == i: 166 | imp_ind = N - 1 167 | return imp_ind 168 | 169 | def basic_get_imposter_indices(N): 170 | imposter_idc = [] 171 | for i in range(N): 172 | # Select an imposter index for example i: 173 | imp_ind = one_imposter_index(i, N) 174 | imposter_idc.append(imp_ind) 175 | return imposter_idc 176 | 177 | def semihardneg_triplet_loss_from_S(S, margin): 178 | """ 179 | Input: Similarity matrix S as an autograd.Variable 180 | Output: The one-way triplet loss from rows of S to columns of S. Impostors are taken 181 | to be the most similar point to the anchor that is still less similar to the anchor 182 | than the positive example. 183 | You would need to run this function twice, once with S and once with S.t(), 184 | in order to compute the triplet loss in both directions. 185 | """ 186 | assert(S.dim() == 2) 187 | assert(S.size(0) == S.size(1)) 188 | N = S.size(0) 189 | loss = torch.autograd.Variable(torch.zeros(1).type(S.data.type()), requires_grad=True) 190 | # Imposter - ground truth 191 | Sdiff = S - torch.diag(S).view(-1, 1) 192 | eps = 1e-12 193 | # All examples less similar than ground truth 194 | mask = (Sdiff < -eps).type(torch.LongTensor) 195 | maskf = mask.type_as(S) 196 | # Mask out all examples >= gt with minimum similarity 197 | Sp = maskf * Sdiff + (1 - maskf) * torch.min(Sdiff).detach() 198 | # Find the index maximum similar of the remaining 199 | _, idc = Sp.max(dim=1) 200 | idc = idc.data.cpu() 201 | # Vector mask: 1 iff there exists an example < gt 202 | has_neg = (mask.sum(dim=1) > 0).data.type(torch.LongTensor) 203 | # Random imposter indices 204 | random_imp_ind = torch.LongTensor(basic_get_imposter_indices(N)) 205 | # Use hardneg if there exists an example < gt, otherwise use random imposter 206 | imp_idc = has_neg * idc + (1 - has_neg) * random_imp_ind 207 | # This could probably be vectorized too, but I haven't. 208 | for i, imp in enumerate(imp_idc): 209 | local_loss = Sdiff[i, imp] + margin 210 | if (local_loss.data > 0).all(): 211 | loss = loss + local_loss 212 | loss = loss / N 213 | return loss 214 | 215 | def sampled_triplet_loss_from_S(S, margin): 216 | """ 217 | Input: Similarity matrix S as an autograd.Variable 218 | Output: The one-way triplet loss from rows of S to columns of S. Imposters are 219 | randomly sampled from the columns of S. 220 | You would need to run this function twice, once with S and once with S.t(), 221 | in order to compute the triplet loss in both directions. 222 | """ 223 | assert(S.dim() == 2) 224 | assert(S.size(0) == S.size(1)) 225 | N = S.size(0) 226 | loss = torch.autograd.Variable(torch.zeros(1).type(S.data.type()), requires_grad=True) 227 | # Imposter - ground truth 228 | Sdiff = S - torch.diag(S).view(-1, 1) 229 | imp_ind = torch.LongTensor(basic_get_imposter_indices(N)) 230 | # This could probably be vectorized too, but I haven't. 231 | for i, imp in enumerate(imp_ind): 232 | local_loss = Sdiff[i, imp] + margin 233 | if (local_loss.data > 0).all(): 234 | loss = loss + local_loss 235 | loss = loss / N 236 | return loss 237 | 238 | class AverageMeter(object): 239 | """Computes and stores the average and current value""" 240 | def __init__(self): 241 | self.reset() 242 | 243 | def reset(self): 244 | self.val = 0 245 | self.avg = 0 246 | self.sum = 0 247 | self.count = 0 248 | 249 | def update(self, val, n=1): 250 | self.val = val 251 | self.sum += val * n 252 | self.count += n 253 | self.avg = self.sum / self.count 254 | 255 | def adjust_learning_rate(base_lr, lr_decay, optimizer, epoch): 256 | """Sets the learning rate to the initial LR decayed by 10 every lr_decay epochs""" 257 | lr = base_lr * (0.1 ** (epoch // lr_decay)) 258 | print('now learning rate changed to {:f}'.format(lr)) 259 | for param_group in optimizer.param_groups: 260 | param_group['lr'] = lr 261 | 262 | def adjust_learning_rate2(base_lr, lr_decay, optimizer, epoch): 263 | """Sets the learning rate to the initial LR decayed by 10 every lr_decay epochs""" 264 | for param_group in optimizer.param_groups: 265 | cur_lr = param_group['lr'] 266 | print('current learing rate is {:f}'.format(lr)) 267 | lr = cur_lr * 0.1 268 | print('now learning rate changed to {:f}'.format(lr)) 269 | for param_group in optimizer.param_groups: 270 | param_group['lr'] = lr 271 | 272 | 273 | def load_progress(prog_pkl, quiet=False): 274 | """ 275 | load progress pkl file 276 | Args: 277 | prog_pkl(str): path to progress pkl file 278 | Return: 279 | progress(list): 280 | epoch(int): 281 | global_step(int): 282 | best_epoch(int): 283 | best_avg_r10(float): 284 | """ 285 | def _print(msg): 286 | if not quiet: 287 | print(msg) 288 | 289 | with open(prog_pkl, "rb") as f: 290 | prog = pickle.load(f) 291 | epoch, global_step, best_epoch, best_avg_r10, _ = prog[-1] 292 | 293 | _print("\nPrevious Progress:") 294 | msg = "[%5s %7s %5s %7s %6s]" % ("epoch", "step", "best_epoch", "best_avg_r10", "time") 295 | _print(msg) 296 | return prog, epoch, global_step, best_epoch, best_avg_r10 297 | 298 | def count_parameters(model): 299 | return sum([p.numel() for p in model.parameters() if p.requires_grad]) 300 | 301 | PrenetConfig = namedtuple( 302 | 'PrenetConfig', ['input_size', 'hidden_size', 'num_layers', 'dropout']) 303 | 304 | RNNConfig = namedtuple( 305 | 'RNNConfig', 306 | ['input_size', 'hidden_size', 'num_layers', 'dropout', 'residual']) 307 | -------------------------------------------------------------------------------- /src/run.py: -------------------------------------------------------------------------------- 1 | # Yuan Gong, modified from: 2 | # Author: David Harwath 3 | import argparse 4 | import os 5 | import pickle 6 | import sys 7 | from collections import OrderedDict 8 | import time 9 | import torch 10 | import shutil 11 | basepath = os.path.dirname(os.path.dirname(sys.path[0])) 12 | sys.path.append(basepath) 13 | import dataloaders 14 | from utilities import * 15 | import models 16 | from traintest import train, validate 17 | import ast 18 | from torch.utils.data import WeightedRandomSampler 19 | import numpy as np 20 | 21 | print("I am process %s, running on %s: starting (%s)" % ( 22 | os.getpid(), os.uname()[1], time.asctime())) 23 | 24 | # I/O args 25 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 26 | parser.add_argument("--data-train", type=str, default='', help="training data json") 27 | parser.add_argument("--data-val", type=str, default='', help="validation data json") 28 | parser.add_argument("--data-eval", type=str, default=None, help="evaluation data json") 29 | parser.add_argument("--label-csv", type=str, default=os.path.join(basepath, 'utilities/class_labels_indices_coarse.csv'), help="csv with class labels") 30 | parser.add_argument("--exp-dir", type=str, default="", help="directory to dump experiments") 31 | 32 | # training and optimization args 33 | parser.add_argument("--optim", type=str, default="adam", help="training optimizer", choices=["sgd", "adam"]) 34 | parser.add_argument('-b', '--batch-size', default=60, type=int, metavar='N', help='mini-batch size (default: 100)') 35 | parser.add_argument('-w', '--num-workers', default=8, type=int, metavar='NW', help='# of workers for dataloading (default: 8)') 36 | parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, metavar='LR', help='initial learning rate') 37 | parser.add_argument('--lr-decay', default=40, type=int, metavar='LRDECAY', help='Divide the learning rate by 10 every lr_decay epochs') 38 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') 39 | parser.add_argument('--weight-decay', '--wd', default=5e-7, type=float, metavar='W', help='weight decay (default: 1e-4)') 40 | parser.add_argument("--n-epochs", type=int, default=1, help="number of maximum training epochs") 41 | parser.add_argument("--n-print-steps", type=int, default=1, help="number of steps to print statistics") 42 | 43 | # model args 44 | parser.add_argument("--model", type=str, default="efficientnet", help="audio model architecture", choices=["efficientnet", "resnet", "mbnet"]) 45 | parser.add_argument("--dataset", type=str, default="audioset", help="the dataset used", choices=["audioset", "esc50", "speechcommands"]) 46 | 47 | parser.add_argument("--dataset_mean", type=float, default=-4.6476, help="the dataset mean, used for input normalization") 48 | parser.add_argument("--dataset_std", type=float, default=4.5699, help="the dataset std, used for input normalization") 49 | parser.add_argument("--target_length", type=int, default=1056, help="the input length in frames") 50 | parser.add_argument("--noise", help='if use balance sampling', type=ast.literal_eval) 51 | parser.add_argument("--metrics", type=str, default="mAP", help="the main evaluation metrics", choices=["mAP", "acc"]) 52 | parser.add_argument("--warmup", help='if use balance sampling', type=ast.literal_eval) 53 | parser.add_argument("--loss", type=str, default="BCE", help="the loss function", choices=["BCE", "CE"]) 54 | parser.add_argument("--lrscheduler_start", type=int, default=10, help="when to start decay") 55 | parser.add_argument("--lrscheduler_decay", type=float, default=0.5, help="the learning rate decay ratio") 56 | parser.add_argument("--wa", help='if do weight averaging', type=ast.literal_eval) 57 | parser.add_argument("--wa_start", type=int, default=16, help="which epoch to start weight averaging") 58 | parser.add_argument("--wa_end", type=int, default=30, help="which epoch to end weight averaging") 59 | 60 | parser.add_argument("--n_class", type=int, default=527, help="number of classes") 61 | parser.add_argument('--save_model', help='save the model or not', type=ast.literal_eval) 62 | parser.add_argument("--eff_b", type=int, default=0, help="which efficientnet to use, the larger number, the more complex") 63 | parser.add_argument('--esc', help='If doing an ESC exp, which will have some different behabvior', type=ast.literal_eval, default='False') 64 | parser.add_argument('--impretrain', help='if use imagenet pretrained CNNs', type=ast.literal_eval, default='True') 65 | parser.add_argument('--freqm', help='frequency mask max length', type=int, default=0) 66 | parser.add_argument('--timem', help='time mask max length', type=int, default=0) 67 | parser.add_argument("--mixup", type=float, default=0, help="how many (0-1) samples need to be mixup during training") 68 | parser.add_argument("--lr_patience", type=int, default=2, help="how many epoch to wait to reduce lr if mAP doesn't improve") 69 | parser.add_argument("--att_head", type=int, default=4, help="number of attention heads") 70 | parser.add_argument('--bal', help='if use balance sampling', type=ast.literal_eval) 71 | 72 | args = parser.parse_args() 73 | 74 | audio_conf = {'num_mel_bins': 128, 'target_length': args.target_length, 'freqm': args.freqm, 75 | 'timem': args.timem, 'mixup': args.mixup, 'dataset': args.dataset, 'mode': 'train', 76 | 'mean': args.dataset_mean, 'std': args.dataset_std, 77 | 'noise': False} 78 | val_audio_conf = {'num_mel_bins': 128, 'target_length': args.target_length, 'freqm': 0, 'timem': 0, 'mixup': 0, 79 | 'dataset': args.dataset, 'mode': 'evaluation', 'mean': args.dataset_mean, 80 | 'std': args.dataset_std, 'noise': False} 81 | 82 | if args.bal == True: 83 | print('balanced sampler is being used') 84 | samples_weight = np.loadtxt(args.data_train[:-5] + '_weight.csv', delimiter=',') 85 | sampler = WeightedRandomSampler(samples_weight, len(samples_weight), replacement=True) 86 | 87 | train_loader = torch.utils.data.DataLoader( 88 | dataloaders.AudiosetDataset(args.data_train, label_csv=args.label_csv, audio_conf=audio_conf), 89 | batch_size=args.batch_size, sampler=sampler, num_workers=args.num_workers, pin_memory=False) 90 | else: 91 | print('balanced sampler is not used') 92 | train_loader = torch.utils.data.DataLoader( 93 | dataloaders.AudiosetDataset(args.data_train, label_csv=args.label_csv, audio_conf=audio_conf), 94 | batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=False) 95 | 96 | val_loader = torch.utils.data.DataLoader( 97 | dataloaders.AudiosetDataset(args.data_val, label_csv=args.label_csv, audio_conf=val_audio_conf), 98 | batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) 99 | 100 | if args.data_eval != None: 101 | eval_loader = torch.utils.data.DataLoader( 102 | dataloaders.AudiosetDataset(args.data_eval, label_csv=args.label_csv, audio_conf=val_audio_conf), 103 | batch_size=args.batch_size*2, shuffle=False, num_workers=args.num_workers, pin_memory=True) 104 | 105 | if args.model == 'efficientnet': 106 | audio_model = models.EffNetAttention(label_dim=args.n_class, b=args.eff_b, pretrain=args.impretrain, head_num=args.att_head) 107 | elif args.model == 'resnet': 108 | audio_model = models.ResNetAttention(label_dim=args.n_class, pretrain=args.impretrain) 109 | elif args.model == 'mbnet': 110 | audio_model = models.MBNet(label_dim=args.n_class, pretrain=args.effpretrain) 111 | 112 | # if you want to use a pretrained model for fine-tuning, uncomment here. 113 | # if not isinstance(audio_model, nn.DataParallel): 114 | # audio_model = nn.DataParallel(audio_model) 115 | # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 116 | # sd = torch.load('../pretrained_models/as_mdl_0.pth', map_location=device) 117 | # audio_model.load_state_dict(sd, strict=False) 118 | 119 | if not bool(args.exp_dir): 120 | print("exp_dir not specified, automatically naming one...") 121 | args.exp_dir = "exp/Data-%s/AudioModel-%s_Optim-%s_LR-%s_Epochs-%s" % ( 122 | os.path.basename(args.data_train), args.model, args.optim, 123 | args.lr, args.n_epochs) 124 | 125 | print("\nCreating experiment directory: %s" % args.exp_dir) 126 | if os.path.exists("%s/models" % args.exp_dir) == False: 127 | os.makedirs("%s/models" % args.exp_dir) 128 | with open("%s/args.pkl" % args.exp_dir, "wb") as f: 129 | pickle.dump(args, f) 130 | 131 | train(audio_model, train_loader, val_loader, args) 132 | 133 | # if the dataset has a seperate evaluation set (e.g., FSD50K), then select the model using the validation set and eval on the evaluation set. 134 | print('---------------Result Summary---------------') 135 | if args.data_eval != None: 136 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 137 | 138 | # evaluate best single model 139 | sd = torch.load(args.exp_dir + '/models/best_audio_model.pth', map_location=device) 140 | if not isinstance(audio_model, nn.DataParallel): 141 | audio_model = nn.DataParallel(audio_model) 142 | audio_model.load_state_dict(sd) 143 | print('---------------evaluate best single model on the validation set---------------') 144 | stats, _ = validate(audio_model, val_loader, args, 'best_single_valid_set') 145 | val_mAP = np.mean([stat['AP'] for stat in stats]) 146 | val_mAUC = np.mean([stat['auc'] for stat in stats]) 147 | print("mAP: {:.6f}".format(val_mAP)) 148 | print("AUC: {:.6f}".format(val_mAUC)) 149 | print('---------------evaluate best single model on the evaluation set---------------') 150 | stats, _ = validate(audio_model, eval_loader, args, 'best_single_eval_set', eval_target=True) 151 | eval_mAP = np.mean([stat['AP'] for stat in stats]) 152 | eval_mAUC = np.mean([stat['auc'] for stat in stats]) 153 | print("mAP: {:.6f}".format(eval_mAP)) 154 | print("AUC: {:.6f}".format(eval_mAUC)) 155 | np.savetxt(args.exp_dir + '/best_single_result.csv', [val_mAP, val_mAUC, eval_mAP, eval_mAUC]) 156 | 157 | # evaluate weight average model 158 | sd = torch.load(args.exp_dir + '/models/audio_model_wa.pth', map_location=device) 159 | audio_model.load_state_dict(sd) 160 | print('---------------evaluate weight average model on the validation set---------------') 161 | stats, _ = validate(audio_model, val_loader, args, 'wa_valid_set') 162 | val_mAP = np.mean([stat['AP'] for stat in stats]) 163 | val_mAUC = np.mean([stat['auc'] for stat in stats]) 164 | print("mAP: {:.6f}".format(val_mAP)) 165 | print("AUC: {:.6f}".format(val_mAUC)) 166 | print('---------------evaluate weight averages model on the evaluation set---------------') 167 | stats, _ = validate(audio_model, eval_loader, args, 'wa_eval_set') 168 | eval_mAP = np.mean([stat['AP'] for stat in stats]) 169 | eval_mAUC = np.mean([stat['auc'] for stat in stats]) 170 | print("mAP: {:.6f}".format(eval_mAP)) 171 | print("AUC: {:.6f}".format(eval_mAUC)) 172 | np.savetxt(args.exp_dir + '/wa_result.csv', [val_mAP, val_mAUC, eval_mAP, eval_mAUC]) 173 | 174 | # evaluate the ensemble results 175 | print('---------------evaluate ensemble model on the validation set---------------') 176 | # this is already done in the training process, only need to load 177 | result = np.loadtxt(args.exp_dir + '/result.csv', delimiter=',') 178 | val_mAP = result[-1, -3] 179 | val_mAUC = result[-1, -2] 180 | print("mAP: {:.6f}".format(val_mAP)) 181 | print("AUC: {:.6f}".format(val_mAUC)) 182 | print('---------------evaluate ensemble model on the evaluation set---------------') 183 | # get the prediction of each checkpoint model 184 | for epoch in range(1, args.n_epochs+1): 185 | sd = torch.load(args.exp_dir + '/models/audio_model.' + str(epoch) + '.pth', map_location=device) 186 | audio_model.load_state_dict(sd) 187 | validate(audio_model, eval_loader, args, 'eval_'+str(epoch)) 188 | # average the checkpoint prediction and calculate the results 189 | target = np.loadtxt(args.exp_dir + '/predictions/eval_target.csv', delimiter=',') 190 | ensemble_predictions = np.zeros_like(target) 191 | for epoch in range(1, args.n_epochs + 1): 192 | cur_pred = np.loadtxt(args.exp_dir + '/predictions/predictions_eval_' + str(epoch) + '.csv', delimiter=',') 193 | ensemble_predictions += cur_pred 194 | ensemble_predictions = ensemble_predictions / args.n_epochs 195 | stats = calculate_stats(ensemble_predictions, target) 196 | eval_mAP = np.mean([stat['AP'] for stat in stats]) 197 | eval_mAUC = np.mean([stat['auc'] for stat in stats]) 198 | print("mAP: {:.6f}".format(eval_mAP)) 199 | print("AUC: {:.6f}".format(eval_mAUC)) 200 | np.savetxt(args.exp_dir + '/ensemble_result.csv', [val_mAP, val_mAUC, eval_mAP, eval_mAUC]) 201 | 202 | # if the dataset only has evaluation set (no validation set), e.g., AudioSet 203 | else: 204 | # evaluate single model 205 | print('---------------evaluate best single model on the evaluation set---------------') 206 | # result is the performance of each epoch, we average the results of the last 5 epochs 207 | result = np.loadtxt(args.exp_dir + '/result.csv', delimiter=',') 208 | last_five_epoch_mean = np.mean(result[-5: , :], axis=0) 209 | eval_mAP = last_five_epoch_mean[0] 210 | eval_mAUC = last_five_epoch_mean[1] 211 | print("mAP: {:.6f}".format(eval_mAP)) 212 | print("AUC: {:.6f}".format(eval_mAUC)) 213 | np.savetxt(args.exp_dir + '/best_single_result.csv', [eval_mAP, eval_mAUC]) 214 | 215 | # evaluate weight average model 216 | print('---------------evaluate weight average model on the evaluation set---------------') 217 | # already done in training process, only need to load 218 | result = np.loadtxt(args.exp_dir + '/wa_result.csv', delimiter=',') 219 | wa_mAP = result[0] 220 | wa_mAUC = result[1] 221 | print("mAP: {:.6f}".format(wa_mAP)) 222 | print("AUC: {:.6f}".format(wa_mAUC)) 223 | np.savetxt(args.exp_dir + '/wa_result.csv', [wa_mAP, wa_mAUC]) 224 | 225 | # evaluate ensemble 226 | print('---------------evaluate ensemble model on the evaluation set---------------') 227 | # already done in training process, only need to load 228 | result = np.loadtxt(args.exp_dir + '/result.csv', delimiter=',') 229 | ensemble_mAP = result[-1, -3] 230 | ensemble_mAUC = result[-1, -2] 231 | print("mAP: {:.6f}".format(ensemble_mAP)) 232 | print("AUC: {:.6f}".format(ensemble_mAUC)) 233 | np.savetxt(args.exp_dir + '/ensemble_result.csv', [ensemble_mAP, ensemble_mAUC]) -------------------------------------------------------------------------------- /src/traintest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 6/10/21 11:00 PM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : traintest.py 7 | 8 | import sys 9 | import os 10 | import datetime 11 | sys.path.append(os.path.dirname(os.path.dirname(sys.path[0]))) 12 | from utilities import * 13 | import time 14 | import torch 15 | from torch import nn 16 | import numpy as np 17 | import pickle 18 | from torch.cuda.amp import autocast,GradScaler 19 | 20 | def train(audio_model, train_loader, test_loader, args): 21 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 22 | print('running on ' + str(device)) 23 | torch.set_grad_enabled(True) 24 | 25 | # Initialize all of the statistics we want to keep track of 26 | batch_time = AverageMeter() 27 | per_sample_time = AverageMeter() 28 | data_time = AverageMeter() 29 | per_sample_data_time = AverageMeter() 30 | loss_meter = AverageMeter() 31 | per_sample_dnn_time = AverageMeter() 32 | progress = [] 33 | # best_ensemble_mAP is checkpoint ensemble from the first epoch to the best epoch 34 | best_epoch, best_ensemble_epoch, best_mAP, best_acc, best_ensemble_mAP = 0, 0, -np.inf, -np.inf, -np.inf 35 | global_step, epoch = 0, 0 36 | start_time = time.time() 37 | exp_dir = args.exp_dir 38 | 39 | def _save_progress(): 40 | progress.append([epoch, global_step, best_epoch, best_mAP, time.time() - start_time]) 41 | with open("%s/progress.pkl" % exp_dir, "wb") as f: 42 | pickle.dump(progress, f) 43 | 44 | if not isinstance(audio_model, nn.DataParallel): 45 | audio_model = nn.DataParallel(audio_model) 46 | 47 | audio_model = audio_model.to(device) 48 | # Set up the optimizer 49 | trainables = [p for p in audio_model.parameters() if p.requires_grad] 50 | print('Total parameter number is : {:.3f} million'.format(sum(p.numel() for p in audio_model.parameters()) / 1e6)) 51 | print('Total trainable parameter number is : {:.3f} million'.format(sum(p.numel() for p in trainables) / 1e6)) 52 | optimizer = torch.optim.Adam(trainables, args.lr, weight_decay=5e-7, betas=(0.95, 0.999)) 53 | 54 | # dataset specific settings 55 | #scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=args.lr_patience, verbose=True) 56 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, list(range(args.lrscheduler_start, 1000, 5)), gamma=args.lrscheduler_decay, last_epoch=epoch - 1) 57 | main_metrics = args.metrics 58 | if args.loss == 'BCE': 59 | loss_fn = nn.BCELoss() 60 | elif args.loss == 'CE': 61 | loss_fn = nn.CrossEntropyLoss() 62 | warmup = args.warmup 63 | args.loss_fn = loss_fn 64 | print('now training with {:s}, main metrics: {:s}, loss function: {:s}, learning rate scheduler: {:s}'.format(str(args.dataset), str(main_metrics), str(loss_fn), str(scheduler))) 65 | print('The learning rate scheduler starts at {:d} epoch with decay rate of {:.3f} '.format(args.lrscheduler_start, args.lrscheduler_decay)) 66 | 67 | epoch += 1 68 | 69 | print("current #steps=%s, #epochs=%s" % (global_step, epoch)) 70 | print("start training...") 71 | result = np.zeros([args.n_epochs, 10]) 72 | audio_model.train() 73 | while epoch < args.n_epochs + 1: 74 | begin_time = time.time() 75 | end_time = time.time() 76 | audio_model.train() 77 | print('---------------') 78 | print(datetime.datetime.now()) 79 | print("current #epochs=%s, #steps=%s" % (epoch, global_step)) 80 | 81 | for i, (audio_input, labels) in enumerate(train_loader): 82 | 83 | B = audio_input.size(0) 84 | audio_input = audio_input.to(device, non_blocking=True) 85 | labels = labels.to(device, non_blocking=True) 86 | 87 | data_time.update(time.time() - end_time) 88 | per_sample_data_time.update((time.time() - end_time) / audio_input.shape[0]) 89 | dnn_start_time = time.time() 90 | 91 | # first several steps for warm-up 92 | if global_step <= 1000 and global_step % 50 == 0 and warmup == True: 93 | warm_lr = (global_step / 1000) * args.lr 94 | for param_group in optimizer.param_groups: 95 | param_group['lr'] = warm_lr 96 | print('warm-up learning rate is {:f}'.format(optimizer.param_groups[0]['lr'])) 97 | 98 | audio_output = audio_model(audio_input) 99 | if isinstance(loss_fn, torch.nn.CrossEntropyLoss): 100 | loss = loss_fn(audio_output, torch.argmax(labels.long(), axis=1)) 101 | else: 102 | epsilon = 1e-7 103 | audio_output = torch.clamp(audio_output, epsilon, 1. - epsilon) 104 | loss = loss_fn(audio_output, labels) 105 | 106 | # optimization if amp is not used 107 | optimizer.zero_grad() 108 | loss.backward() 109 | optimizer.step() 110 | 111 | # record loss 112 | loss_meter.update(loss.item(), B) 113 | batch_time.update(time.time() - end_time) 114 | per_sample_time.update((time.time() - end_time)/audio_input.shape[0]) 115 | per_sample_dnn_time.update((time.time() - dnn_start_time)/audio_input.shape[0]) 116 | 117 | print_step = global_step % args.n_print_steps == 0 118 | early_print_step = epoch == 0 and global_step % (args.n_print_steps/10) == 0 119 | print_step = print_step or early_print_step 120 | 121 | if print_step and global_step != 0: 122 | print('Epoch: [{0}][{1}/{2}]\t' 123 | 'Per Sample Total Time {per_sample_time.avg:.5f}\t' 124 | 'Per Sample Data Time {per_sample_data_time.avg:.5f}\t' 125 | 'Per Sample DNN Time {per_sample_dnn_time.avg:.5f}\t' 126 | 'Train Loss {loss_meter.avg:.4f}\t'.format( 127 | epoch, i, len(train_loader), per_sample_time=per_sample_time, per_sample_data_time=per_sample_data_time, 128 | per_sample_dnn_time=per_sample_dnn_time, loss_meter=loss_meter), flush=True) 129 | if np.isnan(loss_meter.avg): 130 | print("training diverged...") 131 | return 132 | 133 | end_time = time.time() 134 | global_step += 1 135 | 136 | print('start validation') 137 | stats, valid_loss = validate(audio_model, test_loader, args, epoch) 138 | 139 | # ensemble results 140 | ensemble_stats = validate_ensemble(args, epoch) 141 | ensemble_mAP = np.mean([stat['AP'] for stat in ensemble_stats]) 142 | ensemble_mAUC = np.mean([stat['auc'] for stat in ensemble_stats]) 143 | ensemble_acc = ensemble_stats[0]['acc'] 144 | 145 | mAP = np.mean([stat['AP'] for stat in stats]) 146 | mAUC = np.mean([stat['auc'] for stat in stats]) 147 | acc = stats[0]['acc'] 148 | 149 | middle_ps = [stat['precisions'][int(len(stat['precisions'])/2)] for stat in stats] 150 | middle_rs = [stat['recalls'][int(len(stat['recalls'])/2)] for stat in stats] 151 | average_precision = np.mean(middle_ps) 152 | average_recall = np.mean(middle_rs) 153 | 154 | if main_metrics == 'mAP': 155 | print("mAP: {:.6f}".format(mAP)) 156 | else: 157 | print("acc: {:.6f}".format(acc)) 158 | print("AUC: {:.6f}".format(mAUC)) 159 | print("Avg Precision: {:.6f}".format(average_precision)) 160 | print("Avg Recall: {:.6f}".format(average_recall)) 161 | print("d_prime: {:.6f}".format(d_prime(mAUC))) 162 | print("train_loss: {:.6f}".format(loss_meter.avg)) 163 | print("valid_loss: {:.6f}".format(valid_loss)) 164 | 165 | if main_metrics == 'mAP': 166 | result[epoch-1, :] = [mAP, mAUC, average_precision, average_recall, d_prime(mAUC), loss_meter.avg, valid_loss, ensemble_mAP, ensemble_mAUC, optimizer.param_groups[0]['lr']] 167 | else: 168 | result[epoch-1, :] = [acc, mAUC, average_precision, average_recall, d_prime(mAUC), loss_meter.avg, valid_loss, ensemble_acc, ensemble_mAUC, optimizer.param_groups[0]['lr']] 169 | np.savetxt(exp_dir + '/result.csv', result, delimiter=',') 170 | print('validation finished') 171 | 172 | if mAP > best_mAP: 173 | best_mAP = mAP 174 | if main_metrics == 'mAP': 175 | best_epoch = epoch 176 | 177 | if acc > best_acc: 178 | best_acc = acc 179 | if main_metrics == 'acc': 180 | best_epoch = epoch 181 | 182 | if ensemble_mAP > best_ensemble_mAP: 183 | best_ensemble_epoch = epoch 184 | best_ensemble_mAP = ensemble_mAP 185 | 186 | if best_epoch == epoch: 187 | torch.save(audio_model.state_dict(), "%s/models/best_audio_model.pth" % (exp_dir)) 188 | torch.save(optimizer.state_dict(), "%s/models/best_optim_state.pth" % (exp_dir)) 189 | 190 | torch.save(audio_model.state_dict(), "%s/models/audio_model.%d.pth" % (exp_dir, epoch)) 191 | if len(train_loader.dataset) > 2e5: 192 | torch.save(optimizer.state_dict(), "%s/models/optim_state.%d.pth" % (exp_dir, epoch)) 193 | 194 | scheduler.step() 195 | 196 | print('Epoch-{0} lr: {1}'.format(epoch, optimizer.param_groups[0]['lr'])) 197 | 198 | with open(exp_dir + '/stats_' + str(epoch) +'.pickle', 'wb') as handle: 199 | pickle.dump(stats, handle, protocol=pickle.HIGHEST_PROTOCOL) 200 | _save_progress() 201 | 202 | finish_time = time.time() 203 | print('epoch {:d} training time: {:.3f}'.format(epoch, finish_time-begin_time)) 204 | 205 | epoch += 1 206 | 207 | batch_time.reset() 208 | per_sample_time.reset() 209 | data_time.reset() 210 | per_sample_data_time.reset() 211 | loss_meter.reset() 212 | per_sample_dnn_time.reset() 213 | 214 | # if test weight averaging 215 | if args.wa == True: 216 | stats=validate_wa(audio_model, test_loader, args, args.wa_start, args.wa_end) 217 | mAP = np.mean([stat['AP'] for stat in stats]) 218 | mAUC = np.mean([stat['auc'] for stat in stats]) 219 | middle_ps = [stat['precisions'][int(len(stat['precisions'])/2)] for stat in stats] 220 | middle_rs = [stat['recalls'][int(len(stat['recalls'])/2)] for stat in stats] 221 | average_precision = np.mean(middle_ps) 222 | average_recall = np.mean(middle_rs) 223 | wa_result = [mAP, mAUC] 224 | print('---------------Training Finished---------------') 225 | # print('On Validation Set') 226 | # print('weighted averaged model results') 227 | # print("mAP: {:.6f}".format(mAP)) 228 | # print("AUC: {:.6f}".format(mAUC)) 229 | # print("d_prime: {:.6f}".format(d_prime(mAUC))) 230 | np.savetxt(exp_dir + '/wa_result.csv', wa_result) 231 | 232 | def validate(audio_model, val_loader, args, epoch, eval_target=False): 233 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 234 | batch_time = AverageMeter() 235 | if not isinstance(audio_model, nn.DataParallel): 236 | audio_model = nn.DataParallel(audio_model) 237 | audio_model = audio_model.to(device) 238 | # switch to evaluate mode 239 | audio_model.eval() 240 | end = time.time() 241 | A_predictions = [] 242 | A_targets = [] 243 | A_loss = [] 244 | with torch.no_grad(): 245 | for i, (audio_input, labels) in enumerate(val_loader): 246 | audio_input = audio_input.to(device) 247 | # compute output 248 | audio_output = audio_model(audio_input) 249 | predictions = audio_output.to('cpu').detach() 250 | A_predictions.append(predictions) 251 | A_targets.append(labels) 252 | # compute the loss 253 | labels = labels.to(device) 254 | epsilon = 1e-7 255 | audio_output = torch.clamp(audio_output, epsilon, 1. - epsilon) 256 | if isinstance(args.loss_fn, torch.nn.CrossEntropyLoss): 257 | loss = args.loss_fn(audio_output, torch.argmax(labels.long(), axis=1)) 258 | else: 259 | loss = args.loss_fn(audio_output, labels) 260 | A_loss.append(loss.to('cpu').detach()) 261 | batch_time.update(time.time() - end) 262 | end = time.time() 263 | audio_output = torch.cat(A_predictions) 264 | target = torch.cat(A_targets) 265 | loss = np.mean(A_loss) 266 | stats = calculate_stats(audio_output, target) 267 | # save the prediction here 268 | exp_dir = args.exp_dir 269 | if os.path.exists(exp_dir+'/predictions') == False: 270 | os.mkdir(exp_dir+'/predictions') 271 | np.savetxt(exp_dir+'/predictions/target.csv', target, delimiter=',') 272 | np.savetxt(exp_dir+'/predictions/predictions_' + str(epoch) + '.csv', audio_output, delimiter=',') 273 | # save the target for the separate eval set if there's one. 274 | if eval_target == True and os.path.exists(exp_dir+'/predictions/eval_target.csv') == False: 275 | np.savetxt(exp_dir + '/predictions/eval_target.csv', target, delimiter=',') 276 | return stats, loss 277 | 278 | def validate_ensemble(args, epoch): 279 | exp_dir = args.exp_dir 280 | target = np.loadtxt(exp_dir+'/predictions/target.csv', delimiter=',') 281 | if epoch == 1: 282 | ensemble_predictions = np.loadtxt(exp_dir + '/predictions/predictions_1.csv', delimiter=',') 283 | else: 284 | ensemble_predictions = np.loadtxt(exp_dir + '/predictions/ensemble_predictions.csv', delimiter=',') * (epoch - 1) 285 | predictions = np.loadtxt(exp_dir+'/predictions/predictions_' + str(epoch) + '.csv', delimiter=',') 286 | ensemble_predictions = ensemble_predictions + predictions 287 | # remove the prediction file to save storage space 288 | os.remove(exp_dir+'/predictions/predictions_' + str(epoch-1) + '.csv') 289 | 290 | ensemble_predictions = ensemble_predictions / epoch 291 | np.savetxt(exp_dir+'/predictions/ensemble_predictions.csv', ensemble_predictions, delimiter=',') 292 | 293 | stats = calculate_stats(ensemble_predictions, target) 294 | return stats 295 | 296 | def validate_wa(audio_model, val_loader, args, start_epoch, end_epoch): 297 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 298 | exp_dir = args.exp_dir 299 | 300 | sdA = torch.load(exp_dir + '/models/audio_model.' + str(start_epoch) + '.pth', map_location=device) 301 | 302 | model_cnt = 1 303 | for epoch in range(start_epoch, end_epoch+1): 304 | sdB = torch.load(exp_dir + '/models/audio_model.' + str(epoch) + '.pth', map_location=device) 305 | for key in sdA: 306 | sdA[key] = sdA[key] + sdB[key] 307 | model_cnt += 1 308 | 309 | # if choose not to save models of epoch, remove to save space 310 | if args.save_model == False: 311 | os.remove(exp_dir + '/models/audio_model.' + str(epoch) + '.pth') 312 | 313 | # averaging 314 | for key in sdA: 315 | sdA[key] = sdA[key] / float(model_cnt) 316 | 317 | audio_model.load_state_dict(sdA) 318 | 319 | torch.save(audio_model.state_dict(), exp_dir + '/models/audio_model_wa.pth') 320 | 321 | stats, loss = validate(audio_model, val_loader, args, 'wa') 322 | return stats -------------------------------------------------------------------------------- /egs/audioset/class_labels_indices.csv: -------------------------------------------------------------------------------- 1 | index,mid,display_name 2 | 0,/m/09x0r,"Speech" 3 | 1,/m/05zppz,"Male speech, man speaking" 4 | 2,/m/02zsn,"Female speech, woman speaking" 5 | 3,/m/0ytgt,"Child speech, kid speaking" 6 | 4,/m/01h8n0,"Conversation" 7 | 5,/m/02qldy,"Narration, monologue" 8 | 6,/m/0261r1,"Babbling" 9 | 7,/m/0brhx,"Speech synthesizer" 10 | 8,/m/07p6fty,"Shout" 11 | 9,/m/07q4ntr,"Bellow" 12 | 10,/m/07rwj3x,"Whoop" 13 | 11,/m/07sr1lc,"Yell" 14 | 12,/m/04gy_2,"Battle cry" 15 | 13,/t/dd00135,"Children shouting" 16 | 14,/m/03qc9zr,"Screaming" 17 | 15,/m/02rtxlg,"Whispering" 18 | 16,/m/01j3sz,"Laughter" 19 | 17,/t/dd00001,"Baby laughter" 20 | 18,/m/07r660_,"Giggle" 21 | 19,/m/07s04w4,"Snicker" 22 | 20,/m/07sq110,"Belly laugh" 23 | 21,/m/07rgt08,"Chuckle, chortle" 24 | 22,/m/0463cq4,"Crying, sobbing" 25 | 23,/t/dd00002,"Baby cry, infant cry" 26 | 24,/m/07qz6j3,"Whimper" 27 | 25,/m/07qw_06,"Wail, moan" 28 | 26,/m/07plz5l,"Sigh" 29 | 27,/m/015lz1,"Singing" 30 | 28,/m/0l14jd,"Choir" 31 | 29,/m/01swy6,"Yodeling" 32 | 30,/m/02bk07,"Chant" 33 | 31,/m/01c194,"Mantra" 34 | 32,/t/dd00003,"Male singing" 35 | 33,/t/dd00004,"Female singing" 36 | 34,/t/dd00005,"Child singing" 37 | 35,/t/dd00006,"Synthetic singing" 38 | 36,/m/06bxc,"Rapping" 39 | 37,/m/02fxyj,"Humming" 40 | 38,/m/07s2xch,"Groan" 41 | 39,/m/07r4k75,"Grunt" 42 | 40,/m/01w250,"Whistling" 43 | 41,/m/0lyf6,"Breathing" 44 | 42,/m/07mzm6,"Wheeze" 45 | 43,/m/01d3sd,"Snoring" 46 | 44,/m/07s0dtb,"Gasp" 47 | 45,/m/07pyy8b,"Pant" 48 | 46,/m/07q0yl5,"Snort" 49 | 47,/m/01b_21,"Cough" 50 | 48,/m/0dl9sf8,"Throat clearing" 51 | 49,/m/01hsr_,"Sneeze" 52 | 50,/m/07ppn3j,"Sniff" 53 | 51,/m/06h7j,"Run" 54 | 52,/m/07qv_x_,"Shuffle" 55 | 53,/m/07pbtc8,"Walk, footsteps" 56 | 54,/m/03cczk,"Chewing, mastication" 57 | 55,/m/07pdhp0,"Biting" 58 | 56,/m/0939n_,"Gargling" 59 | 57,/m/01g90h,"Stomach rumble" 60 | 58,/m/03q5_w,"Burping, eructation" 61 | 59,/m/02p3nc,"Hiccup" 62 | 60,/m/02_nn,"Fart" 63 | 61,/m/0k65p,"Hands" 64 | 62,/m/025_jnm,"Finger snapping" 65 | 63,/m/0l15bq,"Clapping" 66 | 64,/m/01jg02,"Heart sounds, heartbeat" 67 | 65,/m/01jg1z,"Heart murmur" 68 | 66,/m/053hz1,"Cheering" 69 | 67,/m/028ght,"Applause" 70 | 68,/m/07rkbfh,"Chatter" 71 | 69,/m/03qtwd,"Crowd" 72 | 70,/m/07qfr4h,"Hubbub, speech noise, speech babble" 73 | 71,/t/dd00013,"Children playing" 74 | 72,/m/0jbk,"Animal" 75 | 73,/m/068hy,"Domestic animals, pets" 76 | 74,/m/0bt9lr,"Dog" 77 | 75,/m/05tny_,"Bark" 78 | 76,/m/07r_k2n,"Yip" 79 | 77,/m/07qf0zm,"Howl" 80 | 78,/m/07rc7d9,"Bow-wow" 81 | 79,/m/0ghcn6,"Growling" 82 | 80,/t/dd00136,"Whimper (dog)" 83 | 81,/m/01yrx,"Cat" 84 | 82,/m/02yds9,"Purr" 85 | 83,/m/07qrkrw,"Meow" 86 | 84,/m/07rjwbb,"Hiss" 87 | 85,/m/07r81j2,"Caterwaul" 88 | 86,/m/0ch8v,"Livestock, farm animals, working animals" 89 | 87,/m/03k3r,"Horse" 90 | 88,/m/07rv9rh,"Clip-clop" 91 | 89,/m/07q5rw0,"Neigh, whinny" 92 | 90,/m/01xq0k1,"Cattle, bovinae" 93 | 91,/m/07rpkh9,"Moo" 94 | 92,/m/0239kh,"Cowbell" 95 | 93,/m/068zj,"Pig" 96 | 94,/t/dd00018,"Oink" 97 | 95,/m/03fwl,"Goat" 98 | 96,/m/07q0h5t,"Bleat" 99 | 97,/m/07bgp,"Sheep" 100 | 98,/m/025rv6n,"Fowl" 101 | 99,/m/09b5t,"Chicken, rooster" 102 | 100,/m/07st89h,"Cluck" 103 | 101,/m/07qn5dc,"Crowing, cock-a-doodle-doo" 104 | 102,/m/01rd7k,"Turkey" 105 | 103,/m/07svc2k,"Gobble" 106 | 104,/m/09ddx,"Duck" 107 | 105,/m/07qdb04,"Quack" 108 | 106,/m/0dbvp,"Goose" 109 | 107,/m/07qwf61,"Honk" 110 | 108,/m/01280g,"Wild animals" 111 | 109,/m/0cdnk,"Roaring cats (lions, tigers)" 112 | 110,/m/04cvmfc,"Roar" 113 | 111,/m/015p6,"Bird" 114 | 112,/m/020bb7,"Bird vocalization, bird call, bird song" 115 | 113,/m/07pggtn,"Chirp, tweet" 116 | 114,/m/07sx8x_,"Squawk" 117 | 115,/m/0h0rv,"Pigeon, dove" 118 | 116,/m/07r_25d,"Coo" 119 | 117,/m/04s8yn,"Crow" 120 | 118,/m/07r5c2p,"Caw" 121 | 119,/m/09d5_,"Owl" 122 | 120,/m/07r_80w,"Hoot" 123 | 121,/m/05_wcq,"Bird flight, flapping wings" 124 | 122,/m/01z5f,"Canidae, dogs, wolves" 125 | 123,/m/06hps,"Rodents, rats, mice" 126 | 124,/m/04rmv,"Mouse" 127 | 125,/m/07r4gkf,"Patter" 128 | 126,/m/03vt0,"Insect" 129 | 127,/m/09xqv,"Cricket" 130 | 128,/m/09f96,"Mosquito" 131 | 129,/m/0h2mp,"Fly, housefly" 132 | 130,/m/07pjwq1,"Buzz" 133 | 131,/m/01h3n,"Bee, wasp, etc." 134 | 132,/m/09ld4,"Frog" 135 | 133,/m/07st88b,"Croak" 136 | 134,/m/078jl,"Snake" 137 | 135,/m/07qn4z3,"Rattle" 138 | 136,/m/032n05,"Whale vocalization" 139 | 137,/m/04rlf,"Music" 140 | 138,/m/04szw,"Musical instrument" 141 | 139,/m/0fx80y,"Plucked string instrument" 142 | 140,/m/0342h,"Guitar" 143 | 141,/m/02sgy,"Electric guitar" 144 | 142,/m/018vs,"Bass guitar" 145 | 143,/m/042v_gx,"Acoustic guitar" 146 | 144,/m/06w87,"Steel guitar, slide guitar" 147 | 145,/m/01glhc,"Tapping (guitar technique)" 148 | 146,/m/07s0s5r,"Strum" 149 | 147,/m/018j2,"Banjo" 150 | 148,/m/0jtg0,"Sitar" 151 | 149,/m/04rzd,"Mandolin" 152 | 150,/m/01bns_,"Zither" 153 | 151,/m/07xzm,"Ukulele" 154 | 152,/m/05148p4,"Keyboard (musical)" 155 | 153,/m/05r5c,"Piano" 156 | 154,/m/01s0ps,"Electric piano" 157 | 155,/m/013y1f,"Organ" 158 | 156,/m/03xq_f,"Electronic organ" 159 | 157,/m/03gvt,"Hammond organ" 160 | 158,/m/0l14qv,"Synthesizer" 161 | 159,/m/01v1d8,"Sampler" 162 | 160,/m/03q5t,"Harpsichord" 163 | 161,/m/0l14md,"Percussion" 164 | 162,/m/02hnl,"Drum kit" 165 | 163,/m/0cfdd,"Drum machine" 166 | 164,/m/026t6,"Drum" 167 | 165,/m/06rvn,"Snare drum" 168 | 166,/m/03t3fj,"Rimshot" 169 | 167,/m/02k_mr,"Drum roll" 170 | 168,/m/0bm02,"Bass drum" 171 | 169,/m/011k_j,"Timpani" 172 | 170,/m/01p970,"Tabla" 173 | 171,/m/01qbl,"Cymbal" 174 | 172,/m/03qtq,"Hi-hat" 175 | 173,/m/01sm1g,"Wood block" 176 | 174,/m/07brj,"Tambourine" 177 | 175,/m/05r5wn,"Rattle (instrument)" 178 | 176,/m/0xzly,"Maraca" 179 | 177,/m/0mbct,"Gong" 180 | 178,/m/016622,"Tubular bells" 181 | 179,/m/0j45pbj,"Mallet percussion" 182 | 180,/m/0dwsp,"Marimba, xylophone" 183 | 181,/m/0dwtp,"Glockenspiel" 184 | 182,/m/0dwt5,"Vibraphone" 185 | 183,/m/0l156b,"Steelpan" 186 | 184,/m/05pd6,"Orchestra" 187 | 185,/m/01kcd,"Brass instrument" 188 | 186,/m/0319l,"French horn" 189 | 187,/m/07gql,"Trumpet" 190 | 188,/m/07c6l,"Trombone" 191 | 189,/m/0l14_3,"Bowed string instrument" 192 | 190,/m/02qmj0d,"String section" 193 | 191,/m/07y_7,"Violin, fiddle" 194 | 192,/m/0d8_n,"Pizzicato" 195 | 193,/m/01xqw,"Cello" 196 | 194,/m/02fsn,"Double bass" 197 | 195,/m/085jw,"Wind instrument, woodwind instrument" 198 | 196,/m/0l14j_,"Flute" 199 | 197,/m/06ncr,"Saxophone" 200 | 198,/m/01wy6,"Clarinet" 201 | 199,/m/03m5k,"Harp" 202 | 200,/m/0395lw,"Bell" 203 | 201,/m/03w41f,"Church bell" 204 | 202,/m/027m70_,"Jingle bell" 205 | 203,/m/0gy1t2s,"Bicycle bell" 206 | 204,/m/07n_g,"Tuning fork" 207 | 205,/m/0f8s22,"Chime" 208 | 206,/m/026fgl,"Wind chime" 209 | 207,/m/0150b9,"Change ringing (campanology)" 210 | 208,/m/03qjg,"Harmonica" 211 | 209,/m/0mkg,"Accordion" 212 | 210,/m/0192l,"Bagpipes" 213 | 211,/m/02bxd,"Didgeridoo" 214 | 212,/m/0l14l2,"Shofar" 215 | 213,/m/07kc_,"Theremin" 216 | 214,/m/0l14t7,"Singing bowl" 217 | 215,/m/01hgjl,"Scratching (performance technique)" 218 | 216,/m/064t9,"Pop music" 219 | 217,/m/0glt670,"Hip hop music" 220 | 218,/m/02cz_7,"Beatboxing" 221 | 219,/m/06by7,"Rock music" 222 | 220,/m/03lty,"Heavy metal" 223 | 221,/m/05r6t,"Punk rock" 224 | 222,/m/0dls3,"Grunge" 225 | 223,/m/0dl5d,"Progressive rock" 226 | 224,/m/07sbbz2,"Rock and roll" 227 | 225,/m/05w3f,"Psychedelic rock" 228 | 226,/m/06j6l,"Rhythm and blues" 229 | 227,/m/0gywn,"Soul music" 230 | 228,/m/06cqb,"Reggae" 231 | 229,/m/01lyv,"Country" 232 | 230,/m/015y_n,"Swing music" 233 | 231,/m/0gg8l,"Bluegrass" 234 | 232,/m/02x8m,"Funk" 235 | 233,/m/02w4v,"Folk music" 236 | 234,/m/06j64v,"Middle Eastern music" 237 | 235,/m/03_d0,"Jazz" 238 | 236,/m/026z9,"Disco" 239 | 237,/m/0ggq0m,"Classical music" 240 | 238,/m/05lls,"Opera" 241 | 239,/m/02lkt,"Electronic music" 242 | 240,/m/03mb9,"House music" 243 | 241,/m/07gxw,"Techno" 244 | 242,/m/07s72n,"Dubstep" 245 | 243,/m/0283d,"Drum and bass" 246 | 244,/m/0m0jc,"Electronica" 247 | 245,/m/08cyft,"Electronic dance music" 248 | 246,/m/0fd3y,"Ambient music" 249 | 247,/m/07lnk,"Trance music" 250 | 248,/m/0g293,"Music of Latin America" 251 | 249,/m/0ln16,"Salsa music" 252 | 250,/m/0326g,"Flamenco" 253 | 251,/m/0155w,"Blues" 254 | 252,/m/05fw6t,"Music for children" 255 | 253,/m/02v2lh,"New-age music" 256 | 254,/m/0y4f8,"Vocal music" 257 | 255,/m/0z9c,"A capella" 258 | 256,/m/0164x2,"Music of Africa" 259 | 257,/m/0145m,"Afrobeat" 260 | 258,/m/02mscn,"Christian music" 261 | 259,/m/016cjb,"Gospel music" 262 | 260,/m/028sqc,"Music of Asia" 263 | 261,/m/015vgc,"Carnatic music" 264 | 262,/m/0dq0md,"Music of Bollywood" 265 | 263,/m/06rqw,"Ska" 266 | 264,/m/02p0sh1,"Traditional music" 267 | 265,/m/05rwpb,"Independent music" 268 | 266,/m/074ft,"Song" 269 | 267,/m/025td0t,"Background music" 270 | 268,/m/02cjck,"Theme music" 271 | 269,/m/03r5q_,"Jingle (music)" 272 | 270,/m/0l14gg,"Soundtrack music" 273 | 271,/m/07pkxdp,"Lullaby" 274 | 272,/m/01z7dr,"Video game music" 275 | 273,/m/0140xf,"Christmas music" 276 | 274,/m/0ggx5q,"Dance music" 277 | 275,/m/04wptg,"Wedding music" 278 | 276,/t/dd00031,"Happy music" 279 | 277,/t/dd00032,"Funny music" 280 | 278,/t/dd00033,"Sad music" 281 | 279,/t/dd00034,"Tender music" 282 | 280,/t/dd00035,"Exciting music" 283 | 281,/t/dd00036,"Angry music" 284 | 282,/t/dd00037,"Scary music" 285 | 283,/m/03m9d0z,"Wind" 286 | 284,/m/09t49,"Rustling leaves" 287 | 285,/t/dd00092,"Wind noise (microphone)" 288 | 286,/m/0jb2l,"Thunderstorm" 289 | 287,/m/0ngt1,"Thunder" 290 | 288,/m/0838f,"Water" 291 | 289,/m/06mb1,"Rain" 292 | 290,/m/07r10fb,"Raindrop" 293 | 291,/t/dd00038,"Rain on surface" 294 | 292,/m/0j6m2,"Stream" 295 | 293,/m/0j2kx,"Waterfall" 296 | 294,/m/05kq4,"Ocean" 297 | 295,/m/034srq,"Waves, surf" 298 | 296,/m/06wzb,"Steam" 299 | 297,/m/07swgks,"Gurgling" 300 | 298,/m/02_41,"Fire" 301 | 299,/m/07pzfmf,"Crackle" 302 | 300,/m/07yv9,"Vehicle" 303 | 301,/m/019jd,"Boat, Water vehicle" 304 | 302,/m/0hsrw,"Sailboat, sailing ship" 305 | 303,/m/056ks2,"Rowboat, canoe, kayak" 306 | 304,/m/02rlv9,"Motorboat, speedboat" 307 | 305,/m/06q74,"Ship" 308 | 306,/m/012f08,"Motor vehicle (road)" 309 | 307,/m/0k4j,"Car" 310 | 308,/m/0912c9,"Vehicle horn, car horn, honking" 311 | 309,/m/07qv_d5,"Toot" 312 | 310,/m/02mfyn,"Car alarm" 313 | 311,/m/04gxbd,"Power windows, electric windows" 314 | 312,/m/07rknqz,"Skidding" 315 | 313,/m/0h9mv,"Tire squeal" 316 | 314,/t/dd00134,"Car passing by" 317 | 315,/m/0ltv,"Race car, auto racing" 318 | 316,/m/07r04,"Truck" 319 | 317,/m/0gvgw0,"Air brake" 320 | 318,/m/05x_td,"Air horn, truck horn" 321 | 319,/m/02rhddq,"Reversing beeps" 322 | 320,/m/03cl9h,"Ice cream truck, ice cream van" 323 | 321,/m/01bjv,"Bus" 324 | 322,/m/03j1ly,"Emergency vehicle" 325 | 323,/m/04qvtq,"Police car (siren)" 326 | 324,/m/012n7d,"Ambulance (siren)" 327 | 325,/m/012ndj,"Fire engine, fire truck (siren)" 328 | 326,/m/04_sv,"Motorcycle" 329 | 327,/m/0btp2,"Traffic noise, roadway noise" 330 | 328,/m/06d_3,"Rail transport" 331 | 329,/m/07jdr,"Train" 332 | 330,/m/04zmvq,"Train whistle" 333 | 331,/m/0284vy3,"Train horn" 334 | 332,/m/01g50p,"Railroad car, train wagon" 335 | 333,/t/dd00048,"Train wheels squealing" 336 | 334,/m/0195fx,"Subway, metro, underground" 337 | 335,/m/0k5j,"Aircraft" 338 | 336,/m/014yck,"Aircraft engine" 339 | 337,/m/04229,"Jet engine" 340 | 338,/m/02l6bg,"Propeller, airscrew" 341 | 339,/m/09ct_,"Helicopter" 342 | 340,/m/0cmf2,"Fixed-wing aircraft, airplane" 343 | 341,/m/0199g,"Bicycle" 344 | 342,/m/06_fw,"Skateboard" 345 | 343,/m/02mk9,"Engine" 346 | 344,/t/dd00065,"Light engine (high frequency)" 347 | 345,/m/08j51y,"Dental drill, dentist's drill" 348 | 346,/m/01yg9g,"Lawn mower" 349 | 347,/m/01j4z9,"Chainsaw" 350 | 348,/t/dd00066,"Medium engine (mid frequency)" 351 | 349,/t/dd00067,"Heavy engine (low frequency)" 352 | 350,/m/01h82_,"Engine knocking" 353 | 351,/t/dd00130,"Engine starting" 354 | 352,/m/07pb8fc,"Idling" 355 | 353,/m/07q2z82,"Accelerating, revving, vroom" 356 | 354,/m/02dgv,"Door" 357 | 355,/m/03wwcy,"Doorbell" 358 | 356,/m/07r67yg,"Ding-dong" 359 | 357,/m/02y_763,"Sliding door" 360 | 358,/m/07rjzl8,"Slam" 361 | 359,/m/07r4wb8,"Knock" 362 | 360,/m/07qcpgn,"Tap" 363 | 361,/m/07q6cd_,"Squeak" 364 | 362,/m/0642b4,"Cupboard open or close" 365 | 363,/m/0fqfqc,"Drawer open or close" 366 | 364,/m/04brg2,"Dishes, pots, and pans" 367 | 365,/m/023pjk,"Cutlery, silverware" 368 | 366,/m/07pn_8q,"Chopping (food)" 369 | 367,/m/0dxrf,"Frying (food)" 370 | 368,/m/0fx9l,"Microwave oven" 371 | 369,/m/02pjr4,"Blender" 372 | 370,/m/02jz0l,"Water tap, faucet" 373 | 371,/m/0130jx,"Sink (filling or washing)" 374 | 372,/m/03dnzn,"Bathtub (filling or washing)" 375 | 373,/m/03wvsk,"Hair dryer" 376 | 374,/m/01jt3m,"Toilet flush" 377 | 375,/m/012xff,"Toothbrush" 378 | 376,/m/04fgwm,"Electric toothbrush" 379 | 377,/m/0d31p,"Vacuum cleaner" 380 | 378,/m/01s0vc,"Zipper (clothing)" 381 | 379,/m/03v3yw,"Keys jangling" 382 | 380,/m/0242l,"Coin (dropping)" 383 | 381,/m/01lsmm,"Scissors" 384 | 382,/m/02g901,"Electric shaver, electric razor" 385 | 383,/m/05rj2,"Shuffling cards" 386 | 384,/m/0316dw,"Typing" 387 | 385,/m/0c2wf,"Typewriter" 388 | 386,/m/01m2v,"Computer keyboard" 389 | 387,/m/081rb,"Writing" 390 | 388,/m/07pp_mv,"Alarm" 391 | 389,/m/07cx4,"Telephone" 392 | 390,/m/07pp8cl,"Telephone bell ringing" 393 | 391,/m/01hnzm,"Ringtone" 394 | 392,/m/02c8p,"Telephone dialing, DTMF" 395 | 393,/m/015jpf,"Dial tone" 396 | 394,/m/01z47d,"Busy signal" 397 | 395,/m/046dlr,"Alarm clock" 398 | 396,/m/03kmc9,"Siren" 399 | 397,/m/0dgbq,"Civil defense siren" 400 | 398,/m/030rvx,"Buzzer" 401 | 399,/m/01y3hg,"Smoke detector, smoke alarm" 402 | 400,/m/0c3f7m,"Fire alarm" 403 | 401,/m/04fq5q,"Foghorn" 404 | 402,/m/0l156k,"Whistle" 405 | 403,/m/06hck5,"Steam whistle" 406 | 404,/t/dd00077,"Mechanisms" 407 | 405,/m/02bm9n,"Ratchet, pawl" 408 | 406,/m/01x3z,"Clock" 409 | 407,/m/07qjznt,"Tick" 410 | 408,/m/07qjznl,"Tick-tock" 411 | 409,/m/0l7xg,"Gears" 412 | 410,/m/05zc1,"Pulleys" 413 | 411,/m/0llzx,"Sewing machine" 414 | 412,/m/02x984l,"Mechanical fan" 415 | 413,/m/025wky1,"Air conditioning" 416 | 414,/m/024dl,"Cash register" 417 | 415,/m/01m4t,"Printer" 418 | 416,/m/0dv5r,"Camera" 419 | 417,/m/07bjf,"Single-lens reflex camera" 420 | 418,/m/07k1x,"Tools" 421 | 419,/m/03l9g,"Hammer" 422 | 420,/m/03p19w,"Jackhammer" 423 | 421,/m/01b82r,"Sawing" 424 | 422,/m/02p01q,"Filing (rasp)" 425 | 423,/m/023vsd,"Sanding" 426 | 424,/m/0_ksk,"Power tool" 427 | 425,/m/01d380,"Drill" 428 | 426,/m/014zdl,"Explosion" 429 | 427,/m/032s66,"Gunshot, gunfire" 430 | 428,/m/04zjc,"Machine gun" 431 | 429,/m/02z32qm,"Fusillade" 432 | 430,/m/0_1c,"Artillery fire" 433 | 431,/m/073cg4,"Cap gun" 434 | 432,/m/0g6b5,"Fireworks" 435 | 433,/g/122z_qxw,"Firecracker" 436 | 434,/m/07qsvvw,"Burst, pop" 437 | 435,/m/07pxg6y,"Eruption" 438 | 436,/m/07qqyl4,"Boom" 439 | 437,/m/083vt,"Wood" 440 | 438,/m/07pczhz,"Chop" 441 | 439,/m/07pl1bw,"Splinter" 442 | 440,/m/07qs1cx,"Crack" 443 | 441,/m/039jq,"Glass" 444 | 442,/m/07q7njn,"Chink, clink" 445 | 443,/m/07rn7sz,"Shatter" 446 | 444,/m/04k94,"Liquid" 447 | 445,/m/07rrlb6,"Splash, splatter" 448 | 446,/m/07p6mqd,"Slosh" 449 | 447,/m/07qlwh6,"Squish" 450 | 448,/m/07r5v4s,"Drip" 451 | 449,/m/07prgkl,"Pour" 452 | 450,/m/07pqc89,"Trickle, dribble" 453 | 451,/t/dd00088,"Gush" 454 | 452,/m/07p7b8y,"Fill (with liquid)" 455 | 453,/m/07qlf79,"Spray" 456 | 454,/m/07ptzwd,"Pump (liquid)" 457 | 455,/m/07ptfmf,"Stir" 458 | 456,/m/0dv3j,"Boiling" 459 | 457,/m/0790c,"Sonar" 460 | 458,/m/0dl83,"Arrow" 461 | 459,/m/07rqsjt,"Whoosh, swoosh, swish" 462 | 460,/m/07qnq_y,"Thump, thud" 463 | 461,/m/07rrh0c,"Thunk" 464 | 462,/m/0b_fwt,"Electronic tuner" 465 | 463,/m/02rr_,"Effects unit" 466 | 464,/m/07m2kt,"Chorus effect" 467 | 465,/m/018w8,"Basketball bounce" 468 | 466,/m/07pws3f,"Bang" 469 | 467,/m/07ryjzk,"Slap, smack" 470 | 468,/m/07rdhzs,"Whack, thwack" 471 | 469,/m/07pjjrj,"Smash, crash" 472 | 470,/m/07pc8lb,"Breaking" 473 | 471,/m/07pqn27,"Bouncing" 474 | 472,/m/07rbp7_,"Whip" 475 | 473,/m/07pyf11,"Flap" 476 | 474,/m/07qb_dv,"Scratch" 477 | 475,/m/07qv4k0,"Scrape" 478 | 476,/m/07pdjhy,"Rub" 479 | 477,/m/07s8j8t,"Roll" 480 | 478,/m/07plct2,"Crushing" 481 | 479,/t/dd00112,"Crumpling, crinkling" 482 | 480,/m/07qcx4z,"Tearing" 483 | 481,/m/02fs_r,"Beep, bleep" 484 | 482,/m/07qwdck,"Ping" 485 | 483,/m/07phxs1,"Ding" 486 | 484,/m/07rv4dm,"Clang" 487 | 485,/m/07s02z0,"Squeal" 488 | 486,/m/07qh7jl,"Creak" 489 | 487,/m/07qwyj0,"Rustle" 490 | 488,/m/07s34ls,"Whir" 491 | 489,/m/07qmpdm,"Clatter" 492 | 490,/m/07p9k1k,"Sizzle" 493 | 491,/m/07qc9xj,"Clicking" 494 | 492,/m/07rwm0c,"Clickety-clack" 495 | 493,/m/07phhsh,"Rumble" 496 | 494,/m/07qyrcz,"Plop" 497 | 495,/m/07qfgpx,"Jingle, tinkle" 498 | 496,/m/07rcgpl,"Hum" 499 | 497,/m/07p78v5,"Zing" 500 | 498,/t/dd00121,"Boing" 501 | 499,/m/07s12q4,"Crunch" 502 | 500,/m/028v0c,"Silence" 503 | 501,/m/01v_m0,"Sine wave" 504 | 502,/m/0b9m1,"Harmonic" 505 | 503,/m/0hdsk,"Chirp tone" 506 | 504,/m/0c1dj,"Sound effect" 507 | 505,/m/07pt_g0,"Pulse" 508 | 506,/t/dd00125,"Inside, small room" 509 | 507,/t/dd00126,"Inside, large room or hall" 510 | 508,/t/dd00127,"Inside, public space" 511 | 509,/t/dd00128,"Outside, urban or manmade" 512 | 510,/t/dd00129,"Outside, rural or natural" 513 | 511,/m/01b9nn,"Reverberation" 514 | 512,/m/01jnbd,"Echo" 515 | 513,/m/096m7z,"Noise" 516 | 514,/m/06_y0by,"Environmental noise" 517 | 515,/m/07rgkc5,"Static" 518 | 516,/m/06xkwv,"Mains hum" 519 | 517,/m/0g12c5,"Distortion" 520 | 518,/m/08p9q4,"Sidetone" 521 | 519,/m/07szfh9,"Cacophony" 522 | 520,/m/0chx_,"White noise" 523 | 521,/m/0cj0r,"Pink noise" 524 | 522,/m/07p_0gm,"Throbbing" 525 | 523,/m/01jwx6,"Vibration" 526 | 524,/m/07c52,"Television" 527 | 525,/m/06bz3,"Radio" 528 | 526,/m/07hvw1,"Field recording" 529 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PSLA: Improving Audio Tagging with Pretraining, Sampling, Labeling, and Aggregation 2 | - [News](#News) 3 | - [Introduction](#Introduction) 4 | - [Getting Started](#Getting-Started) 5 | - [FSD50K Recipe](#FSD50K-Recipe) 6 | - [AudioSet Recipe](#Audioset-Recipe) 7 | - [Label Enhancement](#Label-Enhancement) 8 | - [Ensemble and Weight Averaging](#Ensemble-and-Weight-Averaging) 9 | - [Pretrained Models](#Pretrained-Models) 10 | - [Pretrained Enhanced Label Sets](#Pretrained-Enhanced-Label-Sets) 11 | - [Use Pretrained Model for Audio Tagging Inference in One-Click](#Use-Pretrained-Model-for-Audio-Tagging-Inference-in-One-Click) 12 | - [Use PSLA Training Pipeline For New Models](#Use-PSLA-Training-Pipeline-For-New-Models) 13 | - [Use PSLA Training Pipeline For New Datasets and Tasks](#Use-PSLA-Training-Pipeline-For-New-Datasets-and-Tasks) 14 | - [Use Pretrained CNN+Attention Model For New Tasks](#Use-Pretrained-CNN+Attention-Model-For-New-Tasks) 15 | - [Contact](#Contact) 16 | 17 | ## News 18 | * April 2022: I will present PSLA at [13 May (Friday), 10:00 - 10:45 AM, New York Time at ICASSP 2022](https://2022.ieeeicassp.org/view_paper.php?PaperNum=9274). 19 | 20 | ## Introduction 21 | 22 |
