├── .gitignore ├── LICENSE ├── README.md ├── data └── generate_json.py ├── data_loader.py ├── environment.yml ├── main.py ├── methods └── base.py ├── models ├── frontend.py ├── model.py └── module.py ├── train.sh └── utils ├── evaluate.py ├── get_methods.py ├── losses.py ├── models.py └── pytorch_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | data/MSoS 6 | data/TAU_ASC 7 | data/Speech_commands_v1 8 | # C extensions 9 | *.so 10 | workspace 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 YANG XIAO(肖扬) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![LICENSE](https://img.shields.io/badge/license-MIT-green?style=flat-square)](https://github.com/y2l/meta-transfer-learning-tensorflow/blob/master/LICENSE) 2 | 3 | # ASC-CL Official Pytorch Implementation 4 | Official Pytorch Implementation for [Continual Learning For On-Device Environmental Sound Classification](https://arxiv.org/abs/2207.07429) 5 | 6 | If you have any questions on this repository or the related paper, feel free to [create an issue](https://github.com/swagshaw/ASC-CL/issues/new) or [send me an email](mailto:yxiao009+github@e.ntu.edu.sg). 7 | ## Abstract 8 | Continuously learning new classes without catastrophic forgetting is a challenging problem for on-device environmental sound classification given the restrictions on computation resources (e.g., model size, running memory). To address this issue, we propose a simple and efficient continual learning method. Our method selects the historical data for the training by measuring the per-sample classification uncertainty. Specifically, we measure the uncertainty by observing how the classification probability of data fluctuates against the parallel perturbations added to the classifier embedding. In this way, the computation cost can be significantly reduced compared with adding perturbation to the raw data. Experimental results on the DCASE 2019 Task 1 and ESC-50 dataset show that our proposed method outperforms baseline continual learning methods on classification accuracy and computational efficiency, indicating our method can efficiently and incrementally learn new classes without the catastrophic forgetting problem for on-device environmental sound classification. 9 | ## Getting Started 10 | ### Setup Environment 11 | 12 | You need to create the running environment by [Anaconda](https://www.anaconda.com/), 13 | 14 | ```bash 15 | conda env create -f environment.yml 16 | conda active asc 17 | ``` 18 | ### Results 19 | There are three types of logs during running experiments; logs, results. 20 | The log files are saved in `logs` directory, and the results which contains accuracy of each task and memory updating time are saved in `workspace` directory. 21 | ```angular2html 22 | workspace 23 | |_ logs 24 | |_ [dataset] 25 | |_.log 26 | |_ ... 27 | |_ results 28 | |_ [dataset] 29 | |_.npy 30 | |_... 31 | ``` 32 | ### Data 33 | 34 | We use the [TAU-ASC](https://zenodo.org/record/2589280#.YtJiNHbP1UE) and [ESC-50](https://github.com/karoldvl/ESC-50/archive/master.zip) dataset as the training data. 35 | You should put them into: 36 | 37 | ```bash 38 | your_project_path/data/TAU_ASC 39 | your_project_path/data/ESC-50-master 40 | ``` 41 | 42 | Then use the `./data/generate_json.py`: 43 | 44 | ```bath 45 | python ./data/generate_json.py --mode train --dpath your_project_path /data 46 | python ./data/generate_json.py --mode test --dpath your_project_path /data 47 | ``` 48 | 49 | ### Usage 50 | 51 | To run the experiments in the paper, you just run `train.sh`. 52 | For various experiments, you should know the role of each argument. 53 | 54 | - `MODE`: use CL method or not [finetune, replay] 55 | - `MODEL`: use baseline CNN model or BC-ResNet [baseline, BC-ResNet ] 56 | - `MEM_MANAGE`: Memory update method.[random, reservoir, uncertainty, prototype]. 57 | - `RND_SEED`: Random seed number 58 | - `DATASET`: Dataset name [TAU-ASC, ESC-50] 59 | - `MEM_SIZE`: Memory size: k={300, 500} 60 | - `UNCERT_MERTIC`: Perturbation methods for uncertainty [shift, noise, noisytune(ours)] 61 | 62 | ## Acknowledgements 63 | Our implementations use the source code from the following repositories and users: 64 | 65 | - [Rainbow-Keywords](https://github.com/swagshaw/Rainbow-Keywords) 66 | 67 | ## License 68 | The project is available as open source under the terms of the [MIT License](./LICENSE). 69 | -------------------------------------------------------------------------------- /data/generate_json.py: -------------------------------------------------------------------------------- 1 | """ 2 | # !/usr/bin/env python 3 | -*- coding: utf-8 -*- 4 | @Time : 2022/6/19 下午4:38 5 | @Author : Yang "Jan" Xiao 6 | @Description : generate_json 7 | """ 8 | import argparse 9 | import json 10 | import random 11 | import os 12 | 13 | import pandas as pd 14 | from soundata.datasets import tau2019uas 15 | from tqdm import tqdm 16 | 17 | os.chdir("..") 18 | print(os.getcwd()) 19 | 20 | TAU_class = ['airport', 'bus', 'shopping_mall', 'street_pedestrian', 'street_traffic', 'metro_station', 'metro', 21 | 'public_square', 'tram', 'park'] 22 | 23 | 24 | def main(): 25 | parser = argparse.ArgumentParser(description="Input optional guidance for training") 26 | parser.add_argument("--dpath", default="/home/xiaoyang/Dev/asc-continual-learning/data", type=str, 27 | help="The path of dataset") 28 | parser.add_argument("--seed", type=int, default=3, help="Random seed number.") 29 | parser.add_argument("--dataset", type=str, default="TAU-ASC", help="[TAU-ASC, ESC-50]") 30 | parser.add_argument("--n_tasks", type=int, default=5, help="The number of tasks") 31 | parser.add_argument("--n_cls_a_task", type=int, default=2, help="The number of class of each task") 32 | parser.add_argument("--n_init_cls", type=int, default=2, help="The number of classes of initial task") 33 | parser.add_argument("--exp_name", type=str, default="disjoint", help="[disjoint, blurry]") 34 | parser.add_argument("--mode", type=str, default="test", help="[train, test]") 35 | args = parser.parse_args() 36 | 37 | print(f"[0] Start to generate the {args.n_tasks} tasks of {args.dataset}.") 38 | if args.dataset == "TAU-ASC": 39 | class_list = TAU_class 40 | tau_dataset = tau2019uas.Dataset(data_home='data/TAU_ASC') 41 | clip_ids = tau_dataset.clip_ids 42 | dev_clip_ids = [id for id in clip_ids if 'development' in id] 43 | all_clips = tau_dataset.load_clips() 44 | random.seed(args.seed) 45 | random.shuffle(class_list) 46 | total_list = [] 47 | for i in range(args.n_tasks): 48 | if i == 0: 49 | t_list = [] 50 | for j in range(args.n_init_cls): 51 | t_list.append(class_list[j]) 52 | total_list.append(t_list) 53 | else: 54 | t_list = [] 55 | for j in range(args.n_cls_a_task): 56 | t_list.append((class_list[j + args.n_init_cls + (i - 1) * args.n_cls_a_task])) 57 | total_list.append(t_list) 58 | 59 | print(total_list) 60 | label_list = [] 61 | for i in range(len(total_list)): 62 | class_list = total_list[i] 63 | label_list = label_list + class_list 64 | if args.mode == 'train': 65 | collection_name = "collection/{dataset}_{mode}_{exp}_rand{rnd}_cls{n_cls}_task{iter}.json".format( 66 | dataset=args.dataset, mode='train', exp=args.exp_name, rnd=args.seed, n_cls=args.n_cls_a_task, 67 | iter=i 68 | ) 69 | clip_ids = [id for id in dev_clip_ids if all_clips[id].split == 'development.train'] 70 | else: 71 | collection_name = "collection/{dataset}_test_rand{rnd}_cls{n_cls}_task{iter}.json".format( 72 | dataset=args.dataset, rnd=args.seed, n_cls=args.n_cls_a_task, iter=i 73 | ) 74 | clip_ids = [id for id in dev_clip_ids if all_clips[id].split == 'development.test'] 75 | f = open(os.path.join(args.dpath, collection_name), 'w') 76 | class_encoding = {category: index for index, category in enumerate(label_list)} 77 | dataset_list = [] 78 | for id in clip_ids: 79 | clip = all_clips.get(id) 80 | tag = clip.tags.labels[0] 81 | if tag in class_list: 82 | dataset_list.append([id, tag, class_encoding.get(tag)]) 83 | res = [{"tag": item[1], "audio_name": item[0], "label": item[2]} for item in dataset_list] 84 | print("Task ID is {}".format(i)) 85 | print("Total samples are {}".format(len(res))) 86 | f.write(json.dumps(res)) 87 | f.close() 88 | elif args.dataset == "ESC-50": 89 | data_list = [] 90 | meta = pd.read_csv(os.path.join(args.dpath, 'ESC-50-master/meta/esc50.csv')) 91 | for test_fold_num in range(1, 6): 92 | if args.mode == 'train': 93 | data_list = meta[meta['fold'] != test_fold_num] 94 | elif args.mode == 'test': 95 | data_list = meta[meta['fold'] == test_fold_num] 96 | print(f'ESC-50 {args.mode} set using fold {test_fold_num} is creating, using sample rate {44100} Hz ...') 97 | class_list = sorted(data_list["category"].unique()) 98 | random.seed(args.seed) 99 | random.shuffle(class_list) 100 | total_list = [] 101 | for i in range(args.n_tasks): 102 | if i == 0: 103 | t_list = [] 104 | for j in range(args.n_init_cls): 105 | t_list.append(class_list[j]) 106 | total_list.append(t_list) 107 | else: 108 | t_list = [] 109 | for j in range(args.n_cls_a_task): 110 | t_list.append((class_list[j + args.n_init_cls + (i - 1) * args.n_cls_a_task])) 111 | total_list.append(t_list) 112 | 113 | print(total_list) 114 | label_list = [] 115 | for i in range(len(total_list)): 116 | class_list = total_list[i] 117 | label_list = label_list + class_list 118 | if args.mode == 'train': 119 | collection_name = "collection/{dataset}_{mode}_{exp}_rand{rnd}_cls{n_cls}" \ 120 | "_task{iter}_{test_fold_num}.json".format(dataset=args.dataset, mode='train', 121 | exp=args.exp_name, rnd=args.seed, 122 | n_cls=args.n_cls_a_task, 123 | iter=i, test_fold_num=test_fold_num 124 | ) 125 | 126 | else: 127 | collection_name = "collection/{dataset}_test_rand{rnd}_cls{n_cls}_task{iter}" \ 128 | "_{test_fold_num}.json".format(dataset=args.dataset, rnd=args.seed, 129 | n_cls=args.n_cls_a_task, iter=i, 130 | test_fold_num=test_fold_num 131 | ) 132 | f = open(os.path.join(args.dpath, collection_name), 'w') 133 | class_encoding = {category: index for index, category in enumerate(label_list)} 134 | dataset_list = [] 135 | 136 | for index in tqdm(range(len(data_list))): 137 | row = data_list.iloc[index] 138 | file_path = os.path.join(args.dpath, 'ESC-50-master', 'audio', row["filename"]) 139 | if row['category'] in class_list: 140 | dataset_list.append([file_path, row['category'], class_encoding.get(row['category'])]) 141 | res = [{"tag": item[1], "audio_name": item[0], "label": item[2]} for item in dataset_list] 142 | print("Task ID is {}".format(i)) 143 | print("Total samples are {}".format(len(res))) 144 | f.write(json.dumps(res)) 145 | f.close() 146 | 147 | else: 148 | raise NotImplementedError 149 | 150 | 151 | if __name__ == "__main__": 152 | main() 153 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List 3 | 4 | from soundata.datasets import tau2019uas 5 | from torch.utils.data import Dataset, DataLoader 6 | import numpy as np 7 | import torch 8 | import os 9 | import warnings 10 | import pandas as pd 11 | 12 | warnings.filterwarnings("ignore") 13 | import librosa 14 | 15 | logger = logging.getLogger() 16 | 17 | 18 | def load_audio(path, sr): 19 | y, _ = librosa.load(path, sr=sr) 20 | return y 21 | 22 | 23 | class ASC_Dataset(Dataset): 24 | def __init__(self, data_frame: pd.DataFrame): 25 | self.data_frame = data_frame 26 | self.dataset = tau2019uas.Dataset(data_home='data/TAU_ASC') 27 | 28 | def __len__(self): 29 | return len(self.data_frame) 30 | 31 | def __getitem__(self, index): 32 | if torch.is_tensor(index): 33 | index = index.tolist() 34 | audio_id = self.data_frame.iloc[index]["audio_name"] 35 | clip = self.dataset.clip(audio_id) 36 | 37 | waveform, sr = clip.audio 38 | waveform = np.array((waveform[0] + waveform[1]) / 2) 39 | max_length = sr * 10 40 | 41 | if len(waveform) > max_length: 42 | waveform = waveform[0:max_length] 43 | else: 44 | waveform = np.pad(waveform, (0, max_length - len(waveform)), 'constant') 45 | 46 | tag = clip.tags.labels[0] 47 | target = self.data_frame.iloc[index]["label"] 48 | target = np.eye(10)[target] 49 | 50 | data_dict = { 51 | 'audio_name': audio_id, 'waveform': waveform, 'target': target, 'tag': tag} 52 | 53 | return data_dict 54 | 55 | 56 | class ESC50_Dataset(Dataset): 57 | def __init__(self, data_frame: pd.DataFrame, sr=44100): 58 | self.data_frame = data_frame 59 | self.sr = sr 60 | 61 | def __len__(self): 62 | return len(self.data_frame) 63 | 64 | def __getitem__(self, index): 65 | if torch.is_tensor(index): 66 | index = index.tolist() 67 | audio_name = self.data_frame.iloc[index]["audio_name"] 68 | 69 | waveform = load_audio(audio_name, self.sr) 70 | tag = self.data_frame.iloc[index]["tag"] 71 | target = self.data_frame.iloc[index]["label"] 72 | target = np.eye(50)[target] 73 | data_dict = {'audio_name': audio_name, 'waveform': waveform, 'target': target, 'tag': tag} 74 | 75 | return data_dict 76 | 77 | 78 | def default_collate_fn(batch): 79 | audio_name = [data['audio_name'] for data in batch] 80 | waveform = [data['waveform'] for data in batch] 81 | target = [data['target'] for data in batch] 82 | 83 | waveform = torch.FloatTensor(waveform) 84 | target = torch.FloatTensor(target) 85 | 86 | return {'audio_name': audio_name, 'waveform': waveform, 'target': target} 87 | 88 | 89 | def get_train_datalist(args, cur_iter: int) -> List: 90 | datalist = [] 91 | if args.dataset == "ESC-50": 92 | for test_fold_num in range(1, 6): 93 | collection_name = get_train_collection_name( 94 | dataset=args.dataset, 95 | exp=args.exp_name, 96 | rnd=args.rnd_seed, 97 | n_cls=args.n_cls_a_task, 98 | iter=cur_iter, 99 | ) 100 | datalist.append(pd.read_json( 101 | os.path.join(args.data_root, f"{collection_name}_{test_fold_num}.json") 102 | ).to_dict(orient="records")) 103 | logger.info(f"[Train] Get datalist from {collection_name}_{test_fold_num}.json") 104 | else: 105 | collection_name = get_train_collection_name( 106 | dataset=args.dataset, 107 | exp=args.exp_name, 108 | rnd=args.rnd_seed, 109 | n_cls=args.n_cls_a_task, 110 | iter=cur_iter, 111 | ) 112 | 113 | datalist = pd.read_json(os.path.join(args.data_root, f"{collection_name}.json") 114 | ).to_dict(orient="records") 115 | logger.info(f"[Train] Get datalist from {collection_name}.json") 116 | 117 | return datalist 118 | 119 | 120 | def get_train_collection_name(dataset, exp, rnd, n_cls, iter): 121 | collection_name = "{dataset}_train_{exp}_rand{rnd}_cls{n_cls}_task{iter}".format( 122 | dataset=dataset, exp=exp, rnd=rnd, n_cls=n_cls, iter=iter 123 | ) 124 | return collection_name 125 | 126 | 127 | def get_test_datalist(args, exp_name: str, cur_iter: int) -> List: 128 | if exp_name is None: 129 | exp_name = args.exp_name 130 | 131 | if exp_name == "disjoint": 132 | # merge current and all previous tasks 133 | tasks = list(range(cur_iter + 1)) 134 | else: 135 | raise NotImplementedError 136 | 137 | datalist = [] 138 | if args.dataset == "ESC-50": 139 | for test_fold_num in range(1, 6): 140 | fold_list = [] 141 | for iter_ in tasks: 142 | collection_name = "{dataset}_test_rand{rnd}_cls{n_cls}_task{iter}".format( 143 | dataset=args.dataset, rnd=args.rnd_seed, n_cls=args.n_cls_a_task, iter=iter_ 144 | ) 145 | fold_list += pd.read_json( 146 | os.path.join(args.data_root, f"{collection_name}_{test_fold_num}.json") 147 | ).to_dict(orient="records") 148 | logger.info(f"[Test ] Get datalist from {collection_name}_{test_fold_num}.json") 149 | datalist.append(fold_list) 150 | else: 151 | for iter_ in tasks: 152 | collection_name = "{dataset}_test_rand{rnd}_cls{n_cls}_task{iter}".format( 153 | dataset=args.dataset, rnd=args.rnd_seed, n_cls=args.n_cls_a_task, iter=iter_ 154 | ) 155 | datalist += pd.read_json( 156 | os.path.join(args.data_root, f"{collection_name}.json") 157 | ).to_dict(orient="records") 158 | logger.info(f"[Test ] Get datalist from {collection_name}.json") 159 | 160 | return datalist 161 | 162 | 163 | def get_dataloader(data_frame, dataset, split, batch_size, num_workers=8): 164 | if dataset == 'TAU-ASC': 165 | dataset = ASC_Dataset(data_frame=data_frame) 166 | elif dataset == "ESC-50": 167 | dataset = ESC50_Dataset(data_frame=data_frame) 168 | is_train = True if split == 'train' else False 169 | 170 | return DataLoader(dataset=dataset, batch_size=batch_size, 171 | shuffle=is_train, drop_last=False, 172 | num_workers=num_workers, collate_fn=default_collate_fn) 173 | 174 | 175 | if __name__ == '__main__': 176 | pass 177 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: asc 2 | channels: 3 | - utils 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=conda_forge 8 | - _openmp_mutex=4.5=1_gnu 9 | - backcall=0.2.0=pyh9f0ad1d_0 10 | - backports=1.0=py_2 11 | - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0 12 | - blas=1.0=mkl 13 | - brotlipy=0.7.0=py37h540881e_1004 14 | - bzip2=1.0.8=h7f98852_4 15 | - ca-certificates=2021.10.8=ha878542_0 16 | - cffi=1.15.0=py37h036bc23_0 17 | - cryptography=36.0.0=py37h9ce1e76_0 18 | - cudatoolkit=11.3.1=h2bc3f7f_2 19 | - debugpy=1.5.1=py37hcd2ae1e_0 20 | - decorator=5.1.1=pyhd8ed1ab_0 21 | - entrypoints=0.4=pyhd8ed1ab_0 22 | - ffmpeg=4.3=hf484d3e_0 23 | - freetype=2.10.4=h0708190_1 24 | - gmp=6.2.1=h58526e2_0 25 | - gnutls=3.6.13=h85f3911_1 26 | - idna=3.3=pyhd8ed1ab_0 27 | - intel-openmp=2021.4.0=h06a4308_3561 28 | - ipykernel=6.8.0=py37h6531663_0 29 | - ipython=7.31.1=py37h89c1867_0 30 | - jbig=2.1=h7f98852_2003 31 | - jedi=0.18.1=py37h89c1867_0 32 | - jpeg=9d=h36c2ea0_0 33 | - jupyter_client=7.1.2=pyhd8ed1ab_0 34 | - jupyter_core=4.9.1=py37h89c1867_1 35 | - lame=3.100=h7f98852_1001 36 | - lcms2=2.12=hddcbb42_0 37 | - ld_impl_linux-64=2.36.1=hea4e1c9_2 38 | - lerc=3.0=h9c3ff4c_0 39 | - libdeflate=1.8=h7f98852_0 40 | - libffi=3.4.2=h7f98852_5 41 | - libgcc-ng=11.2.0=h1d223b6_11 42 | - libgomp=11.2.0=h1d223b6_11 43 | - libiconv=1.16=h516909a_0 44 | - libnsl=2.0.0=h7f98852_0 45 | - libpng=1.6.37=h21135ba_2 46 | - libsodium=1.0.18=h36c2ea0_1 47 | - libstdcxx-ng=11.2.0=he4da1e4_11 48 | - libtiff=4.3.0=h6f004c6_2 49 | - libuv=1.42.0=h7f98852_0 50 | - libwebp-base=1.2.1=h7f98852_0 51 | - libzlib=1.2.11=h36c2ea0_1013 52 | - lz4-c=1.9.3=h9c3ff4c_1 53 | - matplotlib-inline=0.1.3=pyhd8ed1ab_0 54 | - mkl=2021.4.0=h06a4308_640 55 | - mkl-service=2.4.0=py37h402132d_0 56 | - mkl_fft=1.3.1=py37h3e078e5_1 57 | - mkl_random=1.2.2=py37h219a48f_0 58 | - ncurses=6.2=h58526e2_4 59 | - nest-asyncio=1.5.4=pyhd8ed1ab_0 60 | - nettle=3.6=he412f7d_0 61 | - numpy-base=1.21.2=py37h79a1101_0 62 | - olefile=0.46=pyh9f0ad1d_1 63 | - openh264=2.1.1=h780b84a_0 64 | - openjpeg=2.4.0=hb52868f_1 65 | - openssl=3.0.2=h166bdaf_1 66 | - parso=0.8.3=pyhd8ed1ab_0 67 | - pexpect=4.8.0=pyh9f0ad1d_2 68 | - pickleshare=0.7.5=py_1003 69 | - pillow=8.4.0=py37h0f21c89_0 70 | - pip=21.3.1=pyhd8ed1ab_0 71 | - prompt-toolkit=3.0.26=pyha770c72_0 72 | - ptyprocess=0.7.0=pyhd3deb0d_0 73 | - pycparser=2.21=pyhd8ed1ab_0 74 | - pygments=2.11.2=pyhd8ed1ab_0 75 | - pyopenssl=22.0.0=pyhd8ed1ab_0 76 | - pysocks=1.7.1=py37h89c1867_5 77 | - python=3.7.12=hf930737_100_cpython 78 | - python-dateutil=2.8.2=pyhd8ed1ab_0 79 | - python_abi=3.7=2_cp37m 80 | - utils=1.11.0=py3.7_cuda11.3_cudnn8.2.0_0 81 | - utils-mutex=1.0=cuda 82 | - pyzmq=22.3.0=py37h336d617_1 83 | - readline=8.1=h46c0cb4_0 84 | - requests=2.27.1=pyhd8ed1ab_0 85 | - setuptools=60.5.0=py37h89c1867_0 86 | - six=1.16.0=pyh6c4a22f_0 87 | - sqlite=3.37.0=h9cd32fc_0 88 | - tk=8.6.11=h27826a3_1 89 | - torchaudio=0.11.0=py37_cu113 90 | - torchinfo=1.6.5=pyhd8ed1ab_0 91 | - torchvision=0.12.0=py37_cu113 92 | - tornado=6.1=py37h5e8e339_2 93 | - traitlets=5.1.1=pyhd8ed1ab_0 94 | - wcwidth=0.2.5=pyh9f0ad1d_2 95 | - wheel=0.37.1=pyhd8ed1ab_0 96 | - xz=5.2.5=h516909a_1 97 | - zeromq=4.3.4=h9c3ff4c_1 98 | - zlib=1.2.11=h36c2ea0_1013 99 | - zstd=1.5.1=ha95c52a_0 100 | - pip: 101 | - absl-py==1.0.0 102 | - appdirs==1.4.4 103 | - astunparse==1.6.3 104 | - attrs==21.4.0 105 | - audioread==2.1.9 106 | - cached-property==1.5.2 107 | - cachetools==5.0.0 108 | - certifi==2021.10.8 109 | - charset-normalizer==2.0.10 110 | - cycler==0.11.0 111 | - dcase-util==0.2.19 112 | - flatbuffers==2.0 113 | - fonttools==4.33.1 114 | - future==0.18.2 115 | - gast==0.5.3 116 | - google-auth==2.6.6 117 | - google-auth-oauthlib==0.4.6 118 | - google-pasta==0.2.0 119 | - grpcio==1.44.0 120 | - h5py==3.6.0 121 | - importlib-metadata==4.10.0 122 | - importlib-resources==5.4.0 123 | - jams==0.3.4 124 | - joblib==1.1.0 125 | - jsonschema==4.4.0 126 | - keras==2.8.0 127 | - keras-preprocessing==1.1.2 128 | - kiwisolver==1.4.2 129 | - libclang==14.0.1 130 | - librosa==0.8.1 131 | - llvmlite==0.36.0 132 | - markdown==3.3.6 133 | - matplotlib==3.5.1 134 | - mir-eval==0.6 135 | - numba==0.53.0 136 | - numpy==1.20.0 137 | - oauthlib==3.2.0 138 | - onnx==1.11.0 139 | - opt-einsum==3.3.0 140 | - packaging==21.3 141 | - pandas==1.3.5 142 | - pathlib==1.0.1 143 | - pooch==1.5.2 144 | - prettytable==3.2.0 145 | - protobuf==3.20.0 146 | - pyasn1==0.4.8 147 | - pyasn1-modules==0.2.8 148 | - pydot-ng==2.0.0 149 | - pyparsing==3.0.6 150 | - pyrsistent==0.18.0 151 | - python-magic==0.4.25 152 | - pytz==2021.3 153 | - pyyaml==6.0 154 | - requests-oauthlib==1.3.1 155 | - resampy==0.2.2 156 | - rsa==4.8 157 | - scikit-learn==1.0.2 158 | - scipy==1.7.3 159 | - sortedcontainers==2.4.0 160 | - soundata==0.1.1 161 | - soundfile==0.10.3.post1 162 | - tensorboard==2.8.0 163 | - tensorboard-data-server==0.6.1 164 | - tensorboard-plugin-wit==1.8.1 165 | - tensorflow==2.8.0 166 | - tensorflow-io-gcs-filesystem==0.25.0 167 | - termcolor==1.1.0 168 | - tf-estimator-nightly==2.8.0.dev2021122109 169 | - threadpoolctl==3.0.0 170 | - torch==1.11.0 171 | - torchlibrosa==0.0.9 172 | - torchsort==0.1.9 173 | - tqdm==4.62.3 174 | - typing-extensions==4.2.0 175 | - urllib3==1.26.8 176 | - validators==0.18.2 177 | - werkzeug==2.1.1 178 | - wrapt==1.14.0 179 | - zipp==3.7.0 180 | - torch-audiomentations 181 | - audiomentations -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | from collections import defaultdict 4 | import torch 5 | import os 6 | import time 7 | import numpy as np 8 | import logging as log_config 9 | from utils.losses import get_loss_func 10 | from data_loader import get_train_datalist, get_test_datalist 11 | from models.model import Baseline_CNN, BCResNet_Mod 12 | from models.frontend import Audio_Frontend 13 | from utils.get_methods import get_methods 14 | 15 | 16 | def save_model(model, optimizer, step, acc, name): 17 | save_path = os.path.join(ckpt_dir, name + '.pt') 18 | torch.save({ 19 | 'model': model.state_dict(), 20 | 'optimizer': optimizer.state_dict(), 21 | 'step': step, 22 | 'acc': acc 23 | }, save_path) 24 | 25 | 26 | if __name__ == '__main__': 27 | parser = argparse.ArgumentParser(description='Example of parser. ') 28 | # Data root. 29 | parser.add_argument("--data_root", type=str, default='/home/xiaoyang/Dev/asc-continual-learning/data/collection') 30 | parser.add_argument('--exp_name', type=str, default='disjoint') 31 | parser.add_argument('--workspace', type=str, default='workspace') 32 | parser.add_argument('--batch_size', type=int, default=32) 33 | parser.add_argument('--epoch', type=int, default=30) 34 | parser.add_argument('--lr', type=float, default=1e-3) 35 | parser.add_argument('--model_name', type=str, default='BC-ResNet') # 'baseline' | 'BC-ResNet' 36 | parser.add_argument('--dataset', type=str, default='TAU-ASC') # 'TAU-ASC' | 'ESC-50' | 37 | parser.add_argument("--mode", type=str, default="replay", help="CIL methods [finetune, replay]", ) 38 | parser.add_argument( 39 | "--mem_manage", 40 | type=str, 41 | default='prototype', 42 | help="memory management [random, uncertainty, reservoir, prototype]", 43 | ) 44 | parser.add_argument("--n_tasks", type=int, default=5, help="The number of tasks") 45 | parser.add_argument( 46 | "--n_cls_a_task", type=int, default=2, help="The number of class of each task" 47 | ) 48 | parser.add_argument( 49 | "--n_init_cls", 50 | type=int, 51 | default=2, 52 | help="The number of classes of initial task", 53 | ) 54 | parser.add_argument("--rnd_seed", type=int, default=3, help="Random seed number.") 55 | parser.add_argument( 56 | "--memory_size", type=int, default=500, help="Episodic memory size" 57 | ) 58 | # Uncertain 59 | parser.add_argument( 60 | "--uncert_metric", 61 | type=str, 62 | default="noisytune", 63 | choices=["shift", "noise", "mask", "combination", "noisytune"], 64 | help="A type of uncertainty metric", 65 | ) 66 | parser.add_argument("--metric_k", type=int, default=6, choices=[2, 4, 6], 67 | help="The number of the uncertainty metric functions") 68 | parser.add_argument("--noise_lambda", type=float, default=0.2, 69 | help="The number of the uncertainty metric functions") 70 | # Debug 71 | parser.add_argument("--debug", action="store_true", help="Turn on Debug mode") 72 | args = parser.parse_args() 73 | if args.mode == "finetune": 74 | save_path = f"{args.dataset}_{args.mode}_cls{args.n_cls_a_task}" \ 75 | f"_epoch{args.epoch}_lr{args.lr}_rnd{args.rnd_seed}" 76 | elif args.mem_manage == "uncertainty": 77 | save_path = f"{args.dataset}_{args.mode}_cls{args.n_cls_a_task}_{args.mem_manage}_{args.uncert_metric}" \ 78 | f"_{args.metric_k}_{args.noise_lambda}_epoch{args.epoch}" \ 79 | f"_lr{args.lr}_msz{args.memory_size}_rnd{args.rnd_seed}" 80 | else: 81 | save_path = f"{args.dataset}_{args.mode}_cls{args.n_cls_a_task}_{args.mem_manage}" \ 82 | f"_epoch{args.epoch}_lr{args.lr}_msz{args.memory_size}_rnd{args.rnd_seed}" 83 | 84 | # Training parameters 85 | exp_name = args.exp_name 86 | batch_size = args.batch_size 87 | epoch = args.epoch 88 | learning_rate = args.lr 89 | model_name = args.model_name 90 | dataset = args.dataset 91 | 92 | # Log file initalization 93 | ckpt_dir = os.path.join('workspace', dataset, exp_name, 'save_models') 94 | os.makedirs(ckpt_dir, exist_ok=True) 95 | ckpt_name = os.path.join(ckpt_dir, 'last.pt') 96 | ckpt_path = ckpt_name if os.path.exists(ckpt_name) else None 97 | 98 | log_dir = os.path.join('workspace', dataset, exp_name, 'logs') 99 | os.makedirs(log_dir, exist_ok=True) 100 | 101 | root_logger = log_config.getLogger() 102 | for h in root_logger.handlers: 103 | root_logger.removeHandler(h) 104 | 105 | log_config.basicConfig( 106 | level=log_config.INFO, 107 | format=' %(asctime)s - %(levelname)s - %(message)s', 108 | handlers=[ 109 | log_config.FileHandler(os.path.join(log_dir, 110 | f'{save_path}.log')), 111 | log_config.StreamHandler() 112 | ] 113 | ) 114 | 115 | logger = log_config.getLogger() 116 | 117 | # Device Setup 118 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 119 | if 'cuda' in str(device): 120 | logger.info(f'Exp name: {exp_name} | Using GPU') 121 | device = 'cuda' 122 | else: 123 | logger.info(f'Exp name: {exp_name} | Using CPU. Set --cuda flag to use GPU') 124 | device = 'cpu' 125 | logger.info(f'{args.__dict__}') 126 | # Default audio frontend Hyperparameters setup for TAU-ASC 127 | frontend_params = { 128 | 'sample_rate': 48000, 129 | 'window_size': 1024, 130 | 'hop_size': 320, 131 | 'mel_bins': 64, 132 | 'fmin': 50, 133 | 'fmax': 14000} 134 | 135 | num_class = 10 136 | if dataset == 'ESC-50': 137 | frontend_params['sample_rate'] = 44100 138 | num_class = 50 139 | frontend = Audio_Frontend(**frontend_params) 140 | 141 | # Fix the random seeds 142 | # https://hoya012.github.io/blog/reproducible_pytorch/ 143 | torch.manual_seed(args.rnd_seed) 144 | torch.backends.cudnn.deterministic = True 145 | torch.backends.cudnn.benchmark = False 146 | np.random.seed(args.rnd_seed) 147 | random.seed(args.rnd_seed) 148 | 149 | # [1] Select a CIL method 150 | logger.info(f"[1] Select a CIL method ({args.mode})") 151 | if args.mem_manage == 'uncertainty': 152 | logger.info(f"Select uncertainty measure approach ({args.uncert_metric})") 153 | loss_func = get_loss_func('clip_ce') 154 | if model_name == 'baseline': 155 | model = Baseline_CNN(num_class=num_class, frontend=frontend) 156 | elif model_name == 'BC-ResNet': 157 | model = BCResNet_Mod(num_class=num_class, frontend=frontend) 158 | else: 159 | raise Exception 160 | method = get_methods( 161 | args, loss_func, device, num_class, model 162 | ) 163 | 164 | # Incrementally training 165 | logger.info(f"[2] Incrementally training {args.n_tasks} tasks") 166 | task_records = defaultdict(list) 167 | start_time = time.time() 168 | 169 | # start to train each tasks 170 | logger.info(f'Audio frontend param:\n{frontend_params}\n') 171 | logger.info(f'Model:\n{model}\n') 172 | logger.info(f"Exp: {exp_name} | batch_size: {batch_size} | learning_rate: {learning_rate} | dataset: {dataset}") 173 | for cur_iter in range(args.n_tasks): 174 | print("\n" + "#" * 50) 175 | print(f"# Task {cur_iter} iteration") 176 | print("#" * 50 + "\n") 177 | 178 | logger.info("[2-1] Prepare a datalist for the current task") 179 | task_acc = 0.0 180 | if args.dataset == "ESC-50": 181 | cur_train_datalist = get_train_datalist(args, cur_iter) 182 | cur_test_datalist = get_test_datalist(args, args.exp_name, cur_iter) 183 | fold_acc = 0.0 184 | for test_fold in range(1, 6): 185 | logger.info(f"Set the test fold number {test_fold} of the current task") 186 | method.set_current_dataset(cur_train_datalist[test_fold - 1], cur_test_datalist[test_fold - 1]) 187 | # Increment known class for current task iteration. 188 | method.before_task(datalist=cur_train_datalist[test_fold - 1], init_opt=True) 189 | logger.info(f"[2-3] Start to train") 190 | fold_acc += method.train( 191 | n_epoch=args.epoch, 192 | batch_size=args.batch_size, 193 | n_worker=8, 194 | ) 195 | logger.info("[2-4] Update the information for the current task") 196 | method.after_task(cur_iter) 197 | task_acc = fold_acc / 5 198 | else: 199 | # get datalist 200 | cur_train_datalist = get_train_datalist(args, cur_iter) 201 | cur_test_datalist = get_test_datalist(args, args.exp_name, cur_iter) 202 | logger.info("[2-2] Set environment for the current task") 203 | method.set_current_dataset(cur_train_datalist, cur_test_datalist) 204 | # Increment known class for current task iteration. 205 | method.before_task(datalist=cur_train_datalist, init_opt=True) 206 | 207 | logger.info(f"[2-3] Start to train") 208 | task_acc = method.train( 209 | n_epoch=args.epoch, 210 | batch_size=args.batch_size, 211 | n_worker=8, 212 | ) 213 | before_update = time.time() 214 | logger.info("[2-4] Update the information for the current task") 215 | method.after_task(cur_iter) 216 | update_time = time.time() - before_update 217 | if cur_iter != 0: 218 | task_records["update_time"].append(update_time) 219 | task_records["task_acc"].append(task_acc) 220 | 221 | if cur_iter > 0: 222 | task_records["bwt_list"].append(np.mean( 223 | [task_records["task_acc"][i + 1] - task_records["task_acc"][i] for i in 224 | range(len(task_records["task_acc"]) - 1)])) 225 | logger.info("[2-5] Report task result") 226 | np.save(f"{log_dir}/{save_path}.npy", task_records) 227 | # Total time (T) 228 | duration = time.time() - start_time 229 | # Accuracy(A) 230 | A_avg = np.mean(task_records["task_acc"]) 231 | A_last = task_records["task_acc"][args.n_tasks - 1] 232 | 233 | logger.info(f"======== Summary =======") 234 | logger.info(f"Total time {duration}, Avg: {duration / args.n_tasks}s") 235 | logger.info(f'BWT: {np.mean(task_records["bwt_list"])}, std: {np.std(task_records["bwt_list"])}') 236 | logger.info(f"A_last {A_last} | A_avg {A_avg}") 237 | logger.info(f'Update time {task_records["update_time"]}') 238 | -------------------------------------------------------------------------------- /methods/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | # !/usr/bin/env python 3 | -*- coding: utf-8 -*- 4 | @Time : 2022/6/19 下午6:48 5 | @Author : Yang "Jan" Xiao 6 | @Description : base 7 | When we make a new one, we should inherit the BaseMethod class. 8 | """ 9 | import logging 10 | import random 11 | import numpy as np 12 | import pandas as pd 13 | import torch 14 | from soundata.datasets import tau2019uas 15 | from torch import optim 16 | from tqdm import tqdm 17 | 18 | from torch_audiomentations import Compose, PitchShift, Shift, AddColoredNoise 19 | from audiomentations import Compose as Normal_Compose 20 | from audiomentations import FrequencyMask 21 | from data_loader import get_dataloader, load_audio 22 | from utils.evaluate import Evaluator 23 | 24 | logger = logging.getLogger() 25 | 26 | 27 | class BaseMethod: 28 | def __init__( 29 | self, criterion, device, n_classes, model, **kwargs 30 | ): 31 | # Parameters for Dataloader 32 | self.num_learned_class = 0 33 | self.num_learning_class = kwargs["n_init_cls"] 34 | self.learned_classes = [] 35 | self.class_mean = [None] * n_classes 36 | self.exposed_classes = [] 37 | self.seen = 0 38 | self.dataset = kwargs["dataset"] 39 | 40 | # Parameters for Trainer 41 | self.patience = 7 42 | self.device = device 43 | self.criterion = criterion 44 | self.lr = kwargs["lr"] 45 | self.optimizer, self.scheduler = None, None 46 | # self.criterion = self.criterion.to(self.device) 47 | self.evaluator = Evaluator(model=model) 48 | self.counter = 0 49 | 50 | # Parameters for Model 51 | self.model_name = kwargs["model_name"] 52 | self.model = model 53 | self.model = self.model.to(self.device) 54 | 55 | # Parameters for Prototype Sampler 56 | self.feature_extractor = model 57 | self.sample_length = 48000 58 | if self.dataset == 'ESC-50': 59 | self.sample_length = 44100 60 | 61 | # Parameters for Memory Updating 62 | self.prev_streamed_list = [] 63 | self.streamed_list = [] 64 | self.test_list = [] 65 | self.memory_list = [] 66 | self.memory_size = kwargs["memory_size"] 67 | self.mem_manage = kwargs["mem_manage"] 68 | self.already_mem_update = False 69 | self.mode = kwargs["mode"] 70 | if self.mode == "finetune": 71 | self.memory_size = 0 72 | self.mem_manage = "random" 73 | self.uncert_metric = kwargs["uncert_metric"] 74 | self.metric_k = kwargs["metric_k"] 75 | self.noise_lambda = kwargs["noise_lambda"] 76 | 77 | def set_current_dataset(self, train_datalist, test_datalist): 78 | random.shuffle(train_datalist) 79 | self.prev_streamed_list = self.streamed_list 80 | self.streamed_list = train_datalist 81 | self.test_list = test_datalist 82 | 83 | def before_task(self, datalist, init_opt=True): 84 | logger.info("Apply before_task") 85 | 86 | # Confirm incoming classes 87 | incoming_classes = pd.DataFrame(datalist)["tag"].unique().tolist() 88 | self.exposed_classes = list(set(self.learned_classes + incoming_classes)) 89 | self.num_learning_class = max( 90 | len(self.exposed_classes), self.num_learning_class 91 | ) 92 | 93 | self.model.num_class = self.num_learning_class 94 | self.model = self.model.to(self.device) 95 | if init_opt: 96 | # reinitialize the optimizer and scheduler 97 | logger.info("Reset the optimizer and scheduler states") 98 | self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr, betas=(0.9, 0.999)) 99 | 100 | logger.info(f"Increasing the head of fc {self.learned_classes} -> {self.num_learning_class}") 101 | 102 | self.already_mem_update = False 103 | 104 | def after_task(self, cur_iter): 105 | logger.info("Apply after_task") 106 | self.learned_classes = self.exposed_classes 107 | self.num_learned_class = self.num_learning_class 108 | self.update_memory(cur_iter) 109 | 110 | def update_memory(self, cur_iter, num_class=None): 111 | if num_class is None: 112 | num_class = self.num_learning_class 113 | 114 | if not self.already_mem_update: 115 | logger.info(f"Update memory over {num_class} classes by {self.mem_manage}") 116 | candidates = self.streamed_list + self.memory_list 117 | if len(candidates) <= self.memory_size: 118 | self.memory_list = candidates 119 | self.seen = len(candidates) 120 | logger.warning("Candidates < Memory size") 121 | else: 122 | if self.mem_manage == "random": 123 | self.memory_list = self.rnd_sampling(candidates) 124 | elif self.mem_manage == "reservoir": 125 | self.reservoir_sampling(self.streamed_list) 126 | elif self.mem_manage == "prototype": 127 | self.memory_list = self.mean_feature_sampling( 128 | exemplars=self.memory_list, 129 | samples=self.streamed_list, 130 | num_class=num_class, 131 | ) 132 | elif self.mem_manage == "uncertainty": 133 | if cur_iter == 0: 134 | self.memory_list = self.equal_class_sampling( 135 | candidates, num_class 136 | ) 137 | else: 138 | self.memory_list = self.uncertainty_sampling( 139 | candidates, 140 | num_class=num_class, 141 | ) 142 | else: 143 | logger.error("Not implemented memory management") 144 | raise NotImplementedError 145 | 146 | assert len(self.memory_list) <= self.memory_size 147 | logger.info("Memory statistic") 148 | memory_df = pd.DataFrame(self.memory_list) 149 | if len(self.memory_list) > 0: 150 | logger.info(f"\n{memory_df.tag.value_counts(sort=True)}") 151 | # memory update happens only once per task iteratin. 152 | self.already_mem_update = True 153 | else: 154 | logger.warning(f"Already updated the memory during this iter ({cur_iter})") 155 | 156 | def get_dataloader(self, batch_size, n_worker, train_list, test_list): 157 | train_loader = get_dataloader(pd.DataFrame(train_list), self.dataset, split='train', batch_size=batch_size, 158 | num_workers=n_worker) 159 | test_loader = get_dataloader(pd.DataFrame(test_list), self.dataset, split='test', batch_size=128, 160 | num_workers=n_worker) 161 | return train_loader, test_loader 162 | 163 | def train(self, n_epoch, batch_size, n_worker): 164 | self.counter = 0 165 | train_list = self.streamed_list + self.memory_list 166 | random.shuffle(train_list) 167 | test_list = self.test_list 168 | train_loader, test_loader = self.get_dataloader( 169 | batch_size, n_worker, train_list, test_list 170 | ) 171 | logger.info(f"Streamed samples: {len(self.streamed_list)}") 172 | logger.info(f"In-memory samples: {len(self.memory_list)}") 173 | logger.info(f"Train samples: {len(train_list)}") 174 | logger.info(f"Test samples: {len(test_list)}") 175 | acc_list = [] 176 | best = {'acc': 0, 'epoch': 0} 177 | for epoch in range(n_epoch): 178 | mean_loss = 0 179 | for batch_data_dict in tqdm(train_loader): 180 | batch_data_dict['waveform'] = batch_data_dict['waveform'].to(self.device) 181 | batch_data_dict['target'] = batch_data_dict['target'].to(self.device) 182 | 183 | # Forward 184 | self.model.train() 185 | 186 | batch_output_dict = self.model(batch_data_dict['waveform'], training=True) 187 | """{'clipwise_output': (batch_size, classes_num), ...}""" 188 | batch_target_dict = {'target': batch_data_dict['target']} 189 | """{'target': (batch_size, classes_num)}""" 190 | 191 | # Loss 192 | loss = self.criterion(batch_output_dict, batch_target_dict) 193 | 194 | # Backward 195 | loss.backward() 196 | self.optimizer.step() 197 | self.optimizer.zero_grad() 198 | 199 | loss = loss.item() 200 | mean_loss += loss 201 | epoch_loss = mean_loss / len(train_loader) 202 | logger.info(f'Epoch {epoch} | Training Loss: {epoch_loss}') 203 | # Evaluate 204 | test_statistics = self.evaluator.evaluate(test_loader) 205 | ave_acc = np.mean(test_statistics['accuracy']) 206 | acc_list.append(ave_acc) 207 | logger.info(f"Epoch {epoch} | Evaluation Accuracy: {ave_acc}") 208 | 209 | if ave_acc > best['acc']: 210 | best['acc'] = ave_acc 211 | best['epoch'] = epoch 212 | logger.info(f'Best Accuracy: {ave_acc} in epoch {epoch}.') 213 | self.counter = 0 214 | else: 215 | self.counter += 1 216 | logger.info(f'EarlyStopping counter: {self.counter} out of {self.patience}.') 217 | if self.counter >= self.patience: 218 | break 219 | return best['acc'] 220 | 221 | def rnd_sampling(self, samples): 222 | random.shuffle(samples) 223 | return samples[: self.memory_size] 224 | 225 | def reservoir_sampling(self, samples): 226 | for sample in samples: 227 | if len(self.memory_list) < self.memory_size: 228 | self.memory_list += [sample] 229 | else: 230 | j = np.random.randint(0, self.seen) 231 | if j < self.memory_size: 232 | self.memory_list[j] = sample 233 | self.seen += 1 234 | 235 | def mean_feature_sampling(self, exemplars, samples, num_class): 236 | """Prototype sampling 237 | 238 | Args: 239 | features ([Tensor]): [features corresponding to the samples] 240 | samples ([Datalist]): [datalist for a class] 241 | 242 | Returns: 243 | [type]: [Sampled datalist] 244 | """ 245 | 246 | def _reduce_exemplar_sets(exemplars, mem_per_cls): 247 | if len(exemplars) == 0: 248 | return exemplars 249 | 250 | exemplar_df = pd.DataFrame(exemplars) 251 | ret = [] 252 | for y in range(self.num_learned_class): 253 | cls_df = exemplar_df[exemplar_df["label"] == y] 254 | ret += cls_df.sample(n=min(mem_per_cls, len(cls_df))).to_dict( 255 | orient="records" 256 | ) 257 | 258 | num_dups = pd.DataFrame(ret).duplicated().sum() 259 | if num_dups > 0: 260 | logger.warning(f"Duplicated samples in memory: {num_dups}") 261 | 262 | return ret 263 | 264 | mem_per_cls = self.memory_size // num_class 265 | exemplars = _reduce_exemplar_sets(exemplars, mem_per_cls) 266 | old_exemplar_df = pd.DataFrame(exemplars) 267 | 268 | new_exemplar_set = [] 269 | sample_df = pd.DataFrame(samples) 270 | for y in range(self.num_learning_class): 271 | cls_samples = [] 272 | cls_exemplars = [] 273 | if len(sample_df) != 0: 274 | cls_samples = sample_df[sample_df["label"] == y].to_dict( 275 | orient="records" 276 | ) 277 | if len(old_exemplar_df) != 0: 278 | cls_exemplars = old_exemplar_df[old_exemplar_df["label"] == y].to_dict( 279 | orient="records" 280 | ) 281 | 282 | if len(cls_exemplars) >= mem_per_cls: 283 | new_exemplar_set += cls_exemplars 284 | continue 285 | 286 | # Assign old exemplars to the samples 287 | cls_samples += cls_exemplars 288 | if len(cls_samples) <= mem_per_cls: 289 | new_exemplar_set += cls_samples 290 | continue 291 | 292 | features = [] 293 | self.feature_extractor.eval() 294 | with torch.no_grad(): 295 | for data in cls_samples: 296 | if self.dataset == 'TAU-ASC': 297 | tau_dataset = tau2019uas.Dataset(data_home='data/TAU_ASC') 298 | clip = tau_dataset.clip(data['audio_name']) 299 | waveform, _ = clip.audio 300 | waveform = np.array((waveform[0] + waveform[1]) / 2) 301 | max_length = self.sample_length * 10 302 | if len(waveform) > max_length: 303 | waveform = waveform[0:max_length] 304 | else: 305 | waveform = np.pad(waveform, (0, max_length - len(waveform)), 'constant') 306 | else: 307 | waveform = load_audio(data['audio_name'], 44100) 308 | waveform = torch.as_tensor(waveform, dtype=torch.float32) 309 | waveform = waveform.to(self.device) 310 | feature = ( 311 | self.feature_extractor(waveform.unsqueeze(0))['embedding'].detach().cpu().numpy() 312 | ) 313 | feature = feature / np.linalg.norm(feature, axis=1) # Normalize 314 | features.append(feature.squeeze()) 315 | 316 | features = np.array(features) 317 | logger.debug(f"[Prototype] features: {features.shape}") 318 | 319 | # do not replace the existing class mean 320 | if self.class_mean[y] is None: 321 | cls_mean = np.mean(features, axis=0) 322 | cls_mean /= np.linalg.norm(cls_mean) 323 | self.class_mean[y] = cls_mean 324 | else: 325 | cls_mean = self.class_mean[y] 326 | assert cls_mean.ndim == 1 327 | 328 | phi = features 329 | mu = cls_mean 330 | # select exemplars from the scratch 331 | exemplar_features = [] 332 | num_exemplars = min(mem_per_cls, len(cls_samples)) 333 | for j in range(num_exemplars): 334 | S = np.sum(exemplar_features, axis=0) 335 | mu_p = 1.0 / (j + 1) * (phi + S) 336 | mu_p = mu_p / np.linalg.norm(mu_p, axis=1, keepdims=True) 337 | 338 | dist = np.sqrt(np.sum((mu - mu_p) ** 2, axis=1)) 339 | i = np.argmin(dist) 340 | 341 | new_exemplar_set.append(cls_samples[i]) 342 | exemplar_features.append(phi[i]) 343 | 344 | # Avoid to sample the duplicated one. 345 | del cls_samples[i] 346 | phi = np.delete(phi, i, 0) 347 | 348 | return new_exemplar_set 349 | 350 | def uncertainty_sampling(self, samples, num_class): 351 | """uncertainty based sampling 352 | 353 | Args: 354 | samples ([list]): [training_list + memory_list] 355 | """ 356 | self.montecarlo(samples, uncert_metric=self.uncert_metric) 357 | 358 | sample_df = pd.DataFrame(samples) 359 | mem_per_cls = self.memory_size // num_class # kc: the number of the samples of each class 360 | 361 | ret = [] 362 | """ 363 | Sampling class by class 364 | """ 365 | for i in range(num_class): 366 | cls_df = sample_df[sample_df["label"] == i] 367 | if len(cls_df) <= mem_per_cls: 368 | ret += cls_df.to_dict(orient="records") 369 | else: 370 | jump_idx = len(cls_df) // mem_per_cls 371 | uncertain_samples = cls_df.sort_values(by="uncertainty")[::jump_idx] 372 | ret += uncertain_samples[:mem_per_cls].to_dict(orient="records") 373 | 374 | num_rest_slots = self.memory_size - len(ret) 375 | if num_rest_slots > 0: 376 | logger.warning("Fill the unused slots by breaking the equilibrium.") 377 | ret += ( 378 | sample_df[~sample_df.audio_name.isin(pd.DataFrame(ret).audio_name)] 379 | .sample(n=num_rest_slots) 380 | .to_dict(orient="records") 381 | ) 382 | 383 | num_dups = pd.DataFrame(ret).audio_name.duplicated().sum() 384 | if num_dups > 0: 385 | logger.warning(f"Duplicated samples in memory: {num_dups}") 386 | 387 | return ret 388 | 389 | def _compute_uncert(self, infer_list, infer_transform, uncert_name): 390 | batch_size = 128 391 | infer_df = pd.DataFrame(infer_list) 392 | infer_loader = get_dataloader(infer_df, self.dataset, split='test', batch_size=batch_size) 393 | 394 | self.model.eval() 395 | with torch.no_grad(): 396 | for n_batch, batch_data_dict in enumerate(infer_loader): 397 | if self.uncert_metric != "noisytune": 398 | batch_data_dict['waveform'] = infer_transform(batch_data_dict['waveform'].unsqueeze(1), 399 | self.sample_length) 400 | batch_data_dict['waveform'] = torch.as_tensor(batch_data_dict['waveform'], dtype=torch.float32) 401 | batch_data_dict['waveform'] = batch_data_dict['waveform'].squeeze() 402 | batch_data_dict['waveform'] = batch_data_dict['waveform'].to(self.device) 403 | logit = self.model(batch_data_dict['waveform']) 404 | logit = logit['clipwise_output'].detach().cpu() 405 | """{'clipwise_output': (batch_size, classes_num), ...}""" 406 | for i, cert_value in enumerate(logit): 407 | sample = infer_list[batch_size * n_batch + i] 408 | sample[uncert_name] = 1 - cert_value 409 | else: 410 | batch_data_dict['waveform'] = batch_data_dict['waveform'].to(self.device) 411 | logit = self.model(input=batch_data_dict['waveform'], training=False, add_noise=True, 412 | noise_lambda=self.noise_lambda, 413 | k=self.metric_k) 414 | logit = logit['clipwise_output'] 415 | for j in range(len(logit)): 416 | logit[j] = logit[j].detach().cpu() 417 | uncert_name = f"uncert_{str(j)}" 418 | for i, cert_value in enumerate(logit[j]): 419 | sample = infer_list[batch_size * n_batch + i] 420 | sample[uncert_name] = 1 - cert_value 421 | 422 | def montecarlo(self, candidates, uncert_metric="shift"): 423 | transform_cands = [] 424 | logger.info(f"Compute uncertainty by {uncert_metric}!") 425 | if uncert_metric == "shift": 426 | transform_cands = [PitchShift(sample_rate=self.sample_length, p=1.0), 427 | Shift(sample_rate=self.sample_length, p=1.0) 428 | ] * (self.metric_k // 2) 429 | for idx, tr in enumerate(transform_cands): 430 | _tr = Compose([tr]) 431 | self._compute_uncert(candidates, _tr, uncert_name=f"uncert_{str(idx)}") 432 | elif uncert_metric == "noise": 433 | transform_cands = [AddColoredNoise(sample_rate=self.sample_length, p=1.0)] * self.metric_k 434 | for idx, tr in enumerate(transform_cands): 435 | _tr = Compose([tr]) 436 | self._compute_uncert(candidates, _tr, uncert_name=f"uncert_{str(idx)}") 437 | elif uncert_metric == "mask": 438 | transform_cands = [TimeMask(min_band_part=0, max_band_part=0.1), 439 | FrequencyMask(min_frequency_band=0, max_frequency_band=0.1, p=1)] * (self.metric_k // 2) 440 | for idx, tr in enumerate(transform_cands): 441 | _tr = Normal_Compose([tr]) 442 | self._compute_uncert(candidates, _tr, uncert_name=f"uncert_{str(idx)}") 443 | elif uncert_metric == "combination": 444 | transform_cands = [TimeMask(min_band_part=0, max_band_part=0.1), 445 | FrequencyMask(min_frequency_band=0, max_frequency_band=0.1, p=1), 446 | AddColoredNoise(sample_rate=self.sample_length, p=1.0), 447 | AddColoredNoise(sample_rate=self.sample_length, p=1.0), 448 | PitchShift(sample_rate=self.sample_length, p=1.0), 449 | Shift(sample_rate=self.sample_length, p=1.0) 450 | ] 451 | random.shuffle(transform_cands) 452 | transform_cands = transform_cands[:self.metric_k] 453 | for idx, tr in enumerate(transform_cands): 454 | if 'audiomentations' in str(tr): 455 | _tr = Normal_Compose([tr]) 456 | else: 457 | _tr = Compose([tr]) 458 | self._compute_uncert(candidates, _tr, uncert_name=f"uncert_{str(idx)}") 459 | elif uncert_metric == "noisytune": 460 | self._compute_uncert(candidates, None, uncert_name=None) 461 | 462 | n_transforms = self.metric_k 463 | 464 | for sample in candidates: 465 | self.variance_ratio(sample, n_transforms) 466 | 467 | def variance_ratio(self, sample, cand_length): 468 | vote_counter = torch.zeros(sample["uncert_0"].size(0)) 469 | for i in range(cand_length): 470 | top_class = int(torch.argmin(sample[f"uncert_{i}"])) # uncert argmin. 471 | vote_counter[top_class] += 1 472 | assert vote_counter.sum() == cand_length 473 | sample["uncertainty"] = (1 - vote_counter.max() / cand_length).item() 474 | 475 | def equal_class_sampling(self, samples, num_class): 476 | mem_per_cls = self.memory_size // num_class 477 | sample_df = pd.DataFrame(samples) 478 | # Warning: assuming the classes were ordered following task number. 479 | ret = [] 480 | for y in range(self.num_learning_class): 481 | cls_df = sample_df[sample_df["label"] == y] 482 | ret += cls_df.sample(n=min(mem_per_cls, len(cls_df))).to_dict( 483 | orient="records" 484 | ) 485 | 486 | num_rest_slots = self.memory_size - len(ret) 487 | if num_rest_slots > 0: 488 | logger.warning("Fill the unused slots by breaking the equilibrium.") 489 | ret += ( 490 | sample_df[~sample_df.audio_name.isin(pd.DataFrame(ret).audio_name)] 491 | .sample(n=num_rest_slots) 492 | .to_dict(orient="records") 493 | ) 494 | 495 | num_dups = pd.DataFrame(ret).audio_name.duplicated().sum() 496 | if num_dups > 0: 497 | logger.warning(f"Duplicated samples in memory: {num_dups}") 498 | 499 | return ret 500 | 501 | 502 | class TimeMask: 503 | def __init__(self, min_band_part=0.0, max_band_part=0.5): 504 | self.min_band_part = min_band_part 505 | self.max_band_part = max_band_part 506 | 507 | def __call__(self, samples, sample_rate): 508 | num_samples = samples.shape[-1] 509 | t = random.randint( 510 | int(num_samples * self.min_band_part), 511 | int(num_samples * self.max_band_part), 512 | ) 513 | t0 = random.randint( 514 | 0, num_samples - t 515 | ) 516 | new_samples = samples.clone() 517 | mask = torch.zeros(t) 518 | new_samples[..., t0: t0 + t] *= mask 519 | return new_samples 520 | -------------------------------------------------------------------------------- /models/frontend.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from utils.models import init_bn 4 | from torchlibrosa.augmentation import SpecAugmentation 5 | from torchlibrosa.stft import Spectrogram, LogmelFilterBank 6 | 7 | 8 | class Audio_Frontend(nn.Module): 9 | """ 10 | Wav2Mel transformation & SpecAug data augmetation 11 | """ 12 | 13 | def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, 14 | fmax): 15 | super(Audio_Frontend, self).__init__() 16 | 17 | window = 'hann' 18 | center = True 19 | pad_mode = 'reflect' 20 | ref = 1.0 21 | amin = 1e-10 22 | top_db = None 23 | 24 | # Spectrogram extractor 25 | self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, 26 | win_length=window_size, window=window, center=center, 27 | pad_mode=pad_mode, 28 | freeze_parameters=True) 29 | 30 | # Logmel feature extractor 31 | self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, 32 | n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, 33 | top_db=top_db, 34 | freeze_parameters=True) 35 | 36 | # Spec augmenter 37 | self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, 38 | freq_drop_width=8, freq_stripes_num=2) 39 | 40 | self.bn0 = nn.BatchNorm2d(64) 41 | init_bn(self.bn0) 42 | 43 | def forward(self, input, spec_aug=False): 44 | """ 45 | Input: (batch_size, data_length) 46 | """ 47 | 48 | x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) 49 | x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) 50 | 51 | x = x.transpose(1, 3) 52 | x = self.bn0(x) 53 | x = x.transpose(1, 3) 54 | 55 | if self.training and spec_aug: 56 | x = self.spec_augmenter(x) 57 | 58 | return x 59 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from utils.models import init_layer 5 | from models.frontend import Audio_Frontend 6 | from models.module import ConvBlock3x3, TransitionBlock, BroadcastedBlock 7 | 8 | 9 | class Baseline_CNN(nn.Module): 10 | def __init__(self, num_class=10, frontend=None): 11 | super(Baseline_CNN, self).__init__() 12 | self.conv_block1 = ConvBlock3x3(in_channels=1, out_channels=16) 13 | self.conv_block2 = ConvBlock3x3(in_channels=16, out_channels=32) 14 | self.conv_block3 = ConvBlock3x3(in_channels=32, out_channels=64) 15 | self.conv_block4 = ConvBlock3x3(in_channels=64, out_channels=128) 16 | 17 | self.fc1 = nn.Linear(128, 32, bias=True) 18 | self.fc_audioset = nn.Linear(32, num_class, bias=True) 19 | 20 | self.frontend = frontend 21 | 22 | self.init_weight() 23 | 24 | def init_weight(self): 25 | init_layer(self.fc1) 26 | init_layer(self.fc_audioset) 27 | 28 | def forward(self, input): 29 | # Input: (batch_size, data_length) 30 | if self.frontend is not None: 31 | x = self.frontend(input) 32 | # Input: (batch_size, 1, T, F) 33 | else: 34 | x = input 35 | 36 | x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') 37 | x = F.dropout(x, p=0.2, training=self.training) 38 | x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') 39 | x = F.dropout(x, p=0.2, training=self.training) 40 | x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') 41 | x = F.dropout(x, p=0.2, training=self.training) 42 | x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') 43 | x = F.dropout(x, p=0.2, training=self.training) 44 | x = torch.mean(x, dim=3) 45 | 46 | (x1, _) = torch.max(x, dim=2) 47 | x2 = torch.mean(x, dim=2) 48 | x = x1 + x2 49 | x = F.dropout(x, p=0.2, training=self.training) 50 | x = F.relu_(self.fc1(x)) 51 | embedding = F.dropout(x, p=0.2, training=self.training) 52 | clipwise_output = self.fc_audioset(x) 53 | 54 | output_dict = { 55 | 'clipwise_output': clipwise_output, 56 | 'embedding': embedding} 57 | 58 | return output_dict 59 | 60 | 61 | class BCResNet_Mod(torch.nn.Module): 62 | def __init__(self, c=4, num_class=10, frontend=None, norm=False): 63 | self.lamb = 0.1 64 | super(BCResNet_Mod, self).__init__() 65 | c = 10 * c 66 | self.conv1 = nn.Conv2d(1, 2 * c, 5, stride=(2, 2), padding=(2, 2)) 67 | self.block1_1 = TransitionBlock(2 * c, c) 68 | self.block1_2 = BroadcastedBlock(c) 69 | 70 | self.block2_1 = nn.MaxPool2d(2) 71 | 72 | self.block3_1 = TransitionBlock(c, int(1.5 * c)) 73 | self.block3_2 = BroadcastedBlock(int(1.5 * c)) 74 | 75 | self.block4_1 = nn.MaxPool2d(2) 76 | 77 | self.block5_1 = TransitionBlock(int(1.5 * c), int(2 * c)) 78 | self.block5_2 = BroadcastedBlock(int(2 * c)) 79 | 80 | self.block6_1 = TransitionBlock(int(2 * c), int(2.5 * c)) 81 | self.block6_2 = BroadcastedBlock(int(2.5 * c)) 82 | self.block6_3 = BroadcastedBlock(int(2.5 * c)) 83 | 84 | self.block7_1 = nn.Conv2d(int(2.5 * c), num_class, 1) 85 | 86 | self.block8_1 = nn.AdaptiveAvgPool2d((1, 1)) 87 | self.norm = norm 88 | self.fc_audioset = nn.Linear(1, num_class, bias=True) 89 | if norm: 90 | self.one = nn.InstanceNorm2d(1) 91 | self.two = nn.InstanceNorm2d(int(1)) 92 | self.three = nn.InstanceNorm2d(int(1)) 93 | self.four = nn.InstanceNorm2d(int(1)) 94 | self.five = nn.InstanceNorm2d(int(1)) 95 | 96 | self.frontend = frontend 97 | 98 | def forward(self, input, add_noise=False, training=False, noise_lambda=0.1, k=2): 99 | 100 | if self.frontend is not None: 101 | out = self.frontend(input) 102 | # Input: (batch_size, 1, T, F) 103 | else: 104 | out = input 105 | 106 | if self.norm: 107 | out = self.lamb * out + self.one(out) 108 | out = self.conv1(out) 109 | 110 | out = self.block1_1(out) 111 | 112 | out = self.block1_2(out) 113 | if self.norm: 114 | out = self.lamb * out + self.two(out) 115 | 116 | out = self.block2_1(out) 117 | 118 | out = self.block3_1(out) 119 | out = self.block3_2(out) 120 | if self.norm: 121 | out = self.lamb * out + self.three(out) 122 | 123 | out = self.block4_1(out) 124 | 125 | out = self.block5_1(out) 126 | out = self.block5_2(out) 127 | if self.norm: 128 | out = self.lamb * out + self.four(out) 129 | 130 | out = self.block6_1(out) 131 | out = self.block6_2(out) 132 | out = self.block6_3(out) 133 | embedding = F.dropout(out, p=0.2, training=training) 134 | embedding = self.block8_1(embedding) 135 | embedding = self.block8_1(embedding) 136 | if self.norm: 137 | out = self.lamb * out + self.five(out) 138 | if not training and add_noise is True: 139 | x_hat = [] 140 | for i in range(k): 141 | feat = out 142 | noise = (torch.rand(feat.shape) - 0.5).to('cuda') * noise_lambda * torch.std(feat) 143 | feat += noise 144 | feat = self.block7_1(feat) 145 | 146 | feat = self.block8_1(feat) 147 | feat = self.block8_1(feat) 148 | 149 | clipwise_output = torch.squeeze(torch.squeeze(feat, dim=2), dim=2) 150 | x_hat.append(clipwise_output) 151 | clipwise_output = x_hat 152 | 153 | else: 154 | out = self.block7_1(out) 155 | 156 | out = self.block8_1(out) 157 | out = self.block8_1(out) 158 | 159 | clipwise_output = torch.squeeze(torch.squeeze(out, dim=2), dim=2) 160 | 161 | output_dict = { 162 | 'clipwise_output': clipwise_output, 163 | 'embedding': embedding} 164 | 165 | return output_dict 166 | 167 | 168 | if __name__ == '__main__': 169 | panns_params = { 170 | 'sample_rate': 48000, 171 | 'window_size': 1024, 172 | 'hop_size': 320, 173 | 'mel_bins': 64, 174 | 'fmin': 50, 175 | 'fmax': 14000} 176 | 177 | frontend = Audio_Frontend(**panns_params) 178 | model = BCResNet_Mod(frontend=frontend) 179 | 180 | print(model(torch.randn(32, 48000))) 181 | -------------------------------------------------------------------------------- /models/module.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from utils.models import init_layer, init_bn 4 | 5 | """addapted from https://github.com/roman-vygon/BCResNet""" 6 | 7 | 8 | class ConvBlock3x3(nn.Module): 9 | def __init__(self, in_channels, out_channels): 10 | 11 | super(ConvBlock3x3, self).__init__() 12 | 13 | self.conv1 = nn.Conv2d(in_channels=in_channels, 14 | out_channels=out_channels, 15 | kernel_size=(3, 3), stride=(1, 1), 16 | padding=(1, 1), bias=False) 17 | 18 | self.bn1 = nn.BatchNorm2d(out_channels) 19 | self.init_weight() 20 | 21 | def init_weight(self): 22 | init_layer(self.conv1) 23 | init_bn(self.bn1) 24 | 25 | def forward(self, input, pool_size=(2, 2), pool_type='avg'): 26 | 27 | x = input 28 | x = F.relu_(self.bn1(self.conv1(x))) 29 | if pool_type == 'max': 30 | x = F.max_pool2d(x, kernel_size=pool_size) 31 | elif pool_type == 'avg': 32 | x = F.avg_pool2d(x, kernel_size=pool_size) 33 | elif pool_type == 'avg+max': 34 | x1 = F.avg_pool2d(x, kernel_size=pool_size) 35 | x2 = F.max_pool2d(x, kernel_size=pool_size) 36 | x = x1 + x2 37 | else: 38 | raise Exception('Incorrect argument!') 39 | 40 | return x 41 | 42 | 43 | class SubSpectralNorm(nn.Module): 44 | def __init__(self, C, S, eps=1e-5): 45 | super(SubSpectralNorm, self).__init__() 46 | self.S = S 47 | self.eps = eps 48 | self.bn = nn.BatchNorm2d(C * S) 49 | 50 | def forward(self, x): 51 | N, C, T, F = x.size() 52 | # Changed view to reshape for quantisatoin error fix 53 | x = x.reshape(N, C * self.S, T, F // self.S) 54 | x = self.bn(x) 55 | return x.reshape(N, C, T, F) 56 | 57 | 58 | class BroadcastedBlock(nn.Module): 59 | def __init__( 60 | self, 61 | planes: int, 62 | dilation=1, 63 | stride=1, 64 | temp_pad=(0, 1), 65 | ): 66 | super(BroadcastedBlock, self).__init__() 67 | self.freq_dw_conv = nn.Conv2d(planes, planes, kernel_size=(3, 1), padding=(1, 0), groups=planes, 68 | dilation=dilation, 69 | stride=stride, bias=False) 70 | self.ssn1 = SubSpectralNorm(planes, 4) 71 | self.temp_dw_conv = nn.Conv2d(planes, planes, kernel_size=(1, 3), padding=temp_pad, groups=planes, 72 | dilation=dilation, stride=stride, bias=False) 73 | self.bn = nn.BatchNorm2d(planes) 74 | self.relu = nn.ReLU(inplace=True) 75 | self.channel_drop = nn.Dropout2d(p=0.1) 76 | self.swish = nn.SiLU() 77 | self.conv1x1 = nn.Conv2d(planes, planes, kernel_size=(1, 1), bias=False) 78 | 79 | def forward(self, x): 80 | identity = x 81 | 82 | # f2 83 | ########################## 84 | out = self.freq_dw_conv(x) 85 | out = self.ssn1(out) 86 | ########################## 87 | 88 | auxilary = out 89 | out = out.mean(2, keepdim=True) # frequency average pooling 90 | 91 | # f1 92 | ############################ 93 | out = self.temp_dw_conv(out) 94 | out = self.bn(out) 95 | out = self.swish(out) 96 | out = self.conv1x1(out) 97 | out = self.channel_drop(out) 98 | ############################ 99 | 100 | out = out + identity + auxilary 101 | out = self.relu(out) 102 | 103 | return out 104 | 105 | 106 | class TransitionBlock(nn.Module): 107 | 108 | def __init__( 109 | self, 110 | inplanes: int, 111 | planes: int, 112 | dilation=1, 113 | stride=1, 114 | temp_pad=(0, 1), 115 | ): 116 | super(TransitionBlock, self).__init__() 117 | 118 | self.freq_dw_conv = nn.Conv2d(planes, planes, kernel_size=(3, 1), padding=(1, 0), groups=planes, 119 | stride=stride, 120 | dilation=dilation, bias=False) 121 | self.ssn = SubSpectralNorm(planes, 4) 122 | # self.ssn = nn.BatchNorm2d(planes) 123 | self.temp_dw_conv = nn.Conv2d(planes, planes, kernel_size=(1, 3), padding=temp_pad, groups=planes, 124 | dilation=dilation, stride=stride, bias=False) 125 | self.bn1 = nn.BatchNorm2d(planes) 126 | self.bn2 = nn.BatchNorm2d(planes) 127 | self.relu = nn.ReLU(inplace=True) 128 | self.channel_drop = nn.Dropout2d(p=0.1) 129 | self.swish = nn.SiLU() 130 | self.conv1x1_1 = nn.Conv2d(inplanes, planes, kernel_size=(1, 1), bias=False) 131 | self.conv1x1_2 = nn.Conv2d(planes, planes, kernel_size=(1, 1), bias=False) 132 | self.flag = False 133 | 134 | def forward(self, x): 135 | # f2 136 | ############################# 137 | out = self.conv1x1_1(x) 138 | out = self.bn1(out) 139 | out = self.relu(out) 140 | out = self.freq_dw_conv(out) 141 | out = self.ssn(out) 142 | ############################# 143 | auxilary = out 144 | out = out.mean(2, keepdim=True) # frequency average pooling 145 | 146 | # f1 147 | ############################# 148 | out = self.temp_dw_conv(out) 149 | out = self.bn2(out) 150 | out = self.swish(out) 151 | out = self.conv1x1_2(out) 152 | out = self.channel_drop(out) 153 | ############################# 154 | 155 | out = auxilary + out 156 | out = self.relu(out) 157 | 158 | return out 159 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | # DCASE 2019 Task 1 2 | 3 | # Finetune 4 | python main.py --dataset TAU-ASC --mode finetune 5 | 6 | # Random 7 | python main.py --dataset TAU-ASC --mode replay --mem_manage random 8 | 9 | # Reservoir 10 | python main.py --dataset TAU-ASC --mode replay --mem_manage reservoir 11 | 12 | # Prototype 13 | python main.py --dataset TAU-ASC --mode replay --mem_manage prototype 14 | 15 | # Uncertainty shift 16 | python main.py --dataset TAU-ASC --mode replay --mem_manage uncertainty --uncert_metric shift --metric_k 2 17 | python main.py --dataset TAU-ASC --mode replay --mem_manage uncertainty --uncert_metric shift --metric_k 4 18 | python main.py --dataset TAU-ASC --mode replay --mem_manage uncertainty --uncert_metric shift --metric_k 6 19 | 20 | # Uncertainty noise 21 | python main.py --dataset TAU-ASC --mode replay --mem_manage uncertainty --uncert_metric noise --metric_k 2 22 | python main.py --dataset TAU-ASC --mode replay --mem_manage uncertainty --uncert_metric noise --metric_k 4 23 | python main.py --dataset TAU-ASC --mode replay --mem_manage uncertainty --uncert_metric noise --metric_k 6 24 | 25 | # Uncertainty++ 26 | python main.py --dataset TAU-ASC --mode replay --mem_manage uncertainty --uncert_metric noisytune --metric_k 2 27 | python main.py --dataset TAU-ASC --mode replay --mem_manage uncertainty --uncert_metric noisytune --metric_k 4 28 | python main.py --dataset TAU-ASC --mode replay --mem_manage uncertainty --uncert_metric noisytune --metric_k 6 29 | 30 | -------------------------------------------------------------------------------- /utils/evaluate.py: -------------------------------------------------------------------------------- 1 | from sklearn import metrics 2 | from sklearn.metrics import accuracy_score 3 | import numpy as np 4 | 5 | from utils.pytorch_utils import forward 6 | 7 | 8 | class Evaluator(object): 9 | def __init__(self, model): 10 | """Evaluator. 11 | 12 | Args: 13 | model: object 14 | """ 15 | self.model = model 16 | 17 | def evaluate(self, data_loader): 18 | """Forward evaluation data and calculate statistics. 19 | 20 | Args: 21 | data_loader: object 22 | 23 | Returns: 24 | statistics: dict, 25 | {'average_precision': (classes_num,), 'auc': (classes_num,)} 26 | """ 27 | 28 | # Forward 29 | output_dict = forward( 30 | model=self.model, 31 | generator=data_loader, 32 | return_target=True) 33 | 34 | clipwise_output = output_dict['clipwise_output'] # (audios_num, classes_num) 35 | target = output_dict['target'] # (audios_num, classes_num) 36 | 37 | average_precision = metrics.average_precision_score( 38 | target, clipwise_output, average=None) 39 | 40 | # auc = metrics.roc_auc_score(target, clipwise_output, average=None) 41 | 42 | target_acc = np.argmax(target, axis=1) 43 | clipwise_output_acc = np.argmax(clipwise_output, axis=1) 44 | acc = accuracy_score(target_acc, clipwise_output_acc) 45 | 46 | statistics = {'average_precision': average_precision, 'accuracy': acc} 47 | 48 | return statistics -------------------------------------------------------------------------------- /utils/get_methods.py: -------------------------------------------------------------------------------- 1 | """ 2 | # !/usr/bin/env python 3 | -*- coding: utf-8 -*- 4 | @Time : 2022/7/15 下午10:56 5 | @Author : Yang "Jan" Xiao 6 | @Description : get_methods 7 | """ 8 | import logging 9 | from methods.base import BaseMethod 10 | 11 | logger = logging.getLogger() 12 | 13 | 14 | def get_methods(args, criterion, device, n_classes, model): 15 | kwargs = vars(args) 16 | if args.mode == "finetune": 17 | method = BaseMethod( 18 | criterion=criterion, 19 | device=device, 20 | n_classes=n_classes, 21 | model=model, 22 | **kwargs, 23 | ) 24 | elif args.mode == "replay": 25 | method = BaseMethod( 26 | criterion=criterion, 27 | device=device, 28 | n_classes=n_classes, 29 | model=model, 30 | **kwargs, 31 | ) 32 | else: 33 | raise NotImplementedError( 34 | "Choose the args.mode in " 35 | "[finetune, replay]" 36 | ) 37 | logger.info(f"CIL Scenario: {args.mode}") 38 | print(f"\nn_tasks: {args.n_tasks}") 39 | print(f"n_init_cls: {args.n_init_cls}") 40 | print(f"n_cls_a_task: {args.n_cls_a_task}") 41 | print(f"total cls: {args.n_tasks * args.n_cls_a_task}") 42 | 43 | return method 44 | -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | def clip_kl(student_dict, teacher_dict): 6 | """KL divergence loss. 7 | """ 8 | kl_loss = nn.KLDivLoss(reduction="batchmean") 9 | # input should be a distribution in the log space 10 | student_logits = F.log_softmax(student_dict['clipwise_output']) 11 | teacher_logits = F.softmax(teacher_dict['clipwise_output']) 12 | return kl_loss(student_logits, teacher_logits) 13 | 14 | 15 | def clip_bce(output_dict, target_dict): 16 | """Binary crossentropy loss. 17 | """ 18 | return F.binary_cross_entropy( 19 | output_dict['clipwise_output'], target_dict['target']) 20 | 21 | 22 | def clip_ce(output_dict, target_dict): 23 | """Crossentropy loss. 24 | """ 25 | return F.cross_entropy( 26 | output_dict['clipwise_output'], target_dict['target']) 27 | 28 | 29 | def get_loss_func(loss_type): 30 | if loss_type == 'clip_bce': 31 | return clip_bce 32 | elif loss_type == 'clip_ce': 33 | return clip_ce 34 | elif loss_type == 'clip_kl': 35 | return clip_kl 36 | -------------------------------------------------------------------------------- /utils/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchlibrosa.stft import Spectrogram, LogmelFilterBank 5 | from torchlibrosa.augmentation import SpecAugmentation 6 | 7 | from utils.pytorch_utils import do_mixup, interpolate, pad_framewise_output 8 | 9 | 10 | def init_layer(layer): 11 | """Initialize a Linear or Convolutional layer. """ 12 | nn.init.xavier_uniform_(layer.weight) 13 | 14 | if hasattr(layer, 'bias'): 15 | if layer.bias is not None: 16 | layer.bias.data.fill_(0.) 17 | 18 | 19 | def init_bn(bn): 20 | """Initialize a Batchnorm layer. """ 21 | bn.bias.data.fill_(0.) 22 | bn.weight.data.fill_(1.) 23 | 24 | 25 | class ConvBlock(nn.Module): 26 | def __init__(self, in_channels, out_channels): 27 | 28 | super(ConvBlock, self).__init__() 29 | 30 | self.conv1 = nn.Conv2d(in_channels=in_channels, 31 | out_channels=out_channels, 32 | kernel_size=(3, 3), stride=(1, 1), 33 | padding=(1, 1), bias=False) 34 | 35 | self.conv2 = nn.Conv2d(in_channels=out_channels, 36 | out_channels=out_channels, 37 | kernel_size=(3, 3), stride=(1, 1), 38 | padding=(1, 1), bias=False) 39 | 40 | self.bn1 = nn.BatchNorm2d(out_channels) 41 | self.bn2 = nn.BatchNorm2d(out_channels) 42 | 43 | self.init_weight() 44 | 45 | def init_weight(self): 46 | init_layer(self.conv1) 47 | init_layer(self.conv2) 48 | init_bn(self.bn1) 49 | init_bn(self.bn2) 50 | 51 | 52 | def forward(self, input, pool_size=(2, 2), pool_type='avg'): 53 | 54 | x = input 55 | x = F.relu_(self.bn1(self.conv1(x))) 56 | x = F.relu_(self.bn2(self.conv2(x))) 57 | if pool_type == 'max': 58 | x = F.max_pool2d(x, kernel_size=pool_size) 59 | elif pool_type == 'avg': 60 | x = F.avg_pool2d(x, kernel_size=pool_size) 61 | elif pool_type == 'avg+max': 62 | x1 = F.avg_pool2d(x, kernel_size=pool_size) 63 | x2 = F.max_pool2d(x, kernel_size=pool_size) 64 | x = x1 + x2 65 | else: 66 | raise Exception('Incorrect argument!') 67 | 68 | return x 69 | 70 | 71 | class ConvBlock5x5(nn.Module): 72 | def __init__(self, in_channels, out_channels): 73 | 74 | super(ConvBlock5x5, self).__init__() 75 | 76 | self.conv1 = nn.Conv2d(in_channels=in_channels, 77 | out_channels=out_channels, 78 | kernel_size=(5, 5), stride=(1, 1), 79 | padding=(2, 2), bias=False) 80 | 81 | self.bn1 = nn.BatchNorm2d(out_channels) 82 | 83 | self.init_weight() 84 | 85 | def init_weight(self): 86 | init_layer(self.conv1) 87 | init_bn(self.bn1) 88 | 89 | 90 | def forward(self, input, pool_size=(2, 2), pool_type='avg'): 91 | 92 | x = input 93 | x = F.relu_(self.bn1(self.conv1(x))) 94 | if pool_type == 'max': 95 | x = F.max_pool2d(x, kernel_size=pool_size) 96 | elif pool_type == 'avg': 97 | x = F.avg_pool2d(x, kernel_size=pool_size) 98 | elif pool_type == 'avg+max': 99 | x1 = F.avg_pool2d(x, kernel_size=pool_size) 100 | x2 = F.max_pool2d(x, kernel_size=pool_size) 101 | x = x1 + x2 102 | else: 103 | raise Exception('Incorrect argument!') 104 | 105 | return x 106 | 107 | 108 | class AttBlock(nn.Module): 109 | def __init__(self, n_in, n_out, activation='linear', temperature=1.): 110 | super(AttBlock, self).__init__() 111 | 112 | self.activation = activation 113 | self.temperature = temperature 114 | self.att = nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=1, stride=1, padding=0, bias=True) 115 | self.cla = nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=1, stride=1, padding=0, bias=True) 116 | 117 | self.bn_att = nn.BatchNorm1d(n_out) 118 | self.init_weights() 119 | 120 | def init_weights(self): 121 | init_layer(self.att) 122 | init_layer(self.cla) 123 | init_bn(self.bn_att) 124 | 125 | def forward(self, x): 126 | # x: (n_samples, n_in, n_time) 127 | norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1) 128 | cla = self.nonlinear_transform(self.cla(x)) 129 | x = torch.sum(norm_att * cla, dim=2) 130 | return x, norm_att, cla 131 | 132 | def nonlinear_transform(self, x): 133 | if self.activation == 'linear': 134 | return x 135 | elif self.activation == 'sigmoid': 136 | return torch.sigmoid(x) 137 | 138 | 139 | class Cnn14(nn.Module): 140 | def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, 141 | fmax, classes_num): 142 | 143 | super(Cnn14, self).__init__() 144 | 145 | window = 'hann' 146 | center = True 147 | pad_mode = 'reflect' 148 | ref = 1.0 149 | amin = 1e-10 150 | top_db = None 151 | 152 | # Spectrogram extractor 153 | self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, 154 | win_length=window_size, window=window, center=center, pad_mode=pad_mode, 155 | freeze_parameters=True) 156 | 157 | # Logmel feature extractor 158 | self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, 159 | n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, 160 | freeze_parameters=True) 161 | 162 | # Spec augmenter 163 | self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, 164 | freq_drop_width=8, freq_stripes_num=2) 165 | 166 | self.bn0 = nn.BatchNorm2d(64) 167 | 168 | self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) 169 | self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) 170 | self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) 171 | self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) 172 | self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) 173 | self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) 174 | 175 | self.fc1 = nn.Linear(2048, 2048, bias=True) 176 | self.fc_audioset = nn.Linear(2048, classes_num, bias=True) 177 | 178 | self.init_weight() 179 | 180 | def init_weight(self): 181 | init_bn(self.bn0) 182 | init_layer(self.fc1) 183 | init_layer(self.fc_audioset) 184 | 185 | def forward(self, input, mixup_lambda=None): 186 | """ 187 | Input: (batch_size, data_length)""" 188 | 189 | x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) 190 | x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) 191 | 192 | x = x.transpose(1, 3) 193 | x = self.bn0(x) 194 | x = x.transpose(1, 3) 195 | 196 | if self.training: 197 | x = self.spec_augmenter(x) 198 | 199 | # Mixup on spectrogram 200 | if self.training and mixup_lambda is not None: 201 | x = do_mixup(x, mixup_lambda) 202 | 203 | x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') 204 | x = F.dropout(x, p=0.2, training=self.training) 205 | x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') 206 | x = F.dropout(x, p=0.2, training=self.training) 207 | x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') 208 | x = F.dropout(x, p=0.2, training=self.training) 209 | x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') 210 | x = F.dropout(x, p=0.2, training=self.training) 211 | x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg') 212 | x = F.dropout(x, p=0.2, training=self.training) 213 | x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg') 214 | x = F.dropout(x, p=0.2, training=self.training) 215 | tf_embed = x 216 | x = torch.mean(x, dim=3) 217 | 218 | (x1, _) = torch.max(x, dim=2) 219 | x2 = torch.mean(x, dim=2) 220 | x = x1 + x2 221 | x = F.dropout(x, p=0.5, training=self.training) 222 | x = F.relu_(self.fc1(x)) 223 | embedding = F.dropout(x, p=0.5, training=self.training) 224 | clipwise_output = torch.sigmoid(self.fc_audioset(x)) 225 | 226 | output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding, 'TF-Embed': tf_embed} 227 | 228 | return output_dict 229 | 230 | 231 | class Cnn14_no_specaug(nn.Module): 232 | def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, 233 | fmax, classes_num): 234 | 235 | super(Cnn14_no_specaug, self).__init__() 236 | 237 | window = 'hann' 238 | center = True 239 | pad_mode = 'reflect' 240 | ref = 1.0 241 | amin = 1e-10 242 | top_db = None 243 | 244 | # Spectrogram extractor 245 | self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, 246 | win_length=window_size, window=window, center=center, pad_mode=pad_mode, 247 | freeze_parameters=True) 248 | 249 | # Logmel feature extractor 250 | self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, 251 | n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, 252 | freeze_parameters=True) 253 | 254 | self.bn0 = nn.BatchNorm2d(64) 255 | 256 | self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) 257 | self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) 258 | self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) 259 | self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) 260 | self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) 261 | self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) 262 | 263 | self.fc1 = nn.Linear(2048, 2048, bias=True) 264 | self.fc_audioset = nn.Linear(2048, classes_num, bias=True) 265 | 266 | self.init_weight() 267 | 268 | def init_weight(self): 269 | init_bn(self.bn0) 270 | init_layer(self.fc1) 271 | init_layer(self.fc_audioset) 272 | 273 | def forward(self, input, mixup_lambda=None): 274 | """ 275 | Input: (batch_size, data_length)""" 276 | 277 | x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) 278 | x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) 279 | 280 | x = x.transpose(1, 3) 281 | x = self.bn0(x) 282 | x = x.transpose(1, 3) 283 | 284 | # Mixup on spectrogram 285 | if self.training and mixup_lambda is not None: 286 | x = do_mixup(x, mixup_lambda) 287 | 288 | x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') 289 | x = F.dropout(x, p=0.2, training=self.training) 290 | x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') 291 | x = F.dropout(x, p=0.2, training=self.training) 292 | x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') 293 | x = F.dropout(x, p=0.2, training=self.training) 294 | x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') 295 | x = F.dropout(x, p=0.2, training=self.training) 296 | x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg') 297 | x = F.dropout(x, p=0.2, training=self.training) 298 | x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg') 299 | x = F.dropout(x, p=0.2, training=self.training) 300 | x = torch.mean(x, dim=3) 301 | 302 | (x1, _) = torch.max(x, dim=2) 303 | x2 = torch.mean(x, dim=2) 304 | x = x1 + x2 305 | x = F.dropout(x, p=0.5, training=self.training) 306 | x = F.relu_(self.fc1(x)) 307 | embedding = F.dropout(x, p=0.5, training=self.training) 308 | clipwise_output = torch.sigmoid(self.fc_audioset(x)) 309 | 310 | output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding} 311 | 312 | return output_dict 313 | 314 | 315 | class Cnn14_no_dropout(nn.Module): 316 | def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, 317 | fmax, classes_num): 318 | 319 | super(Cnn14_no_dropout, self).__init__() 320 | 321 | window = 'hann' 322 | center = True 323 | pad_mode = 'reflect' 324 | ref = 1.0 325 | amin = 1e-10 326 | top_db = None 327 | 328 | # Spectrogram extractor 329 | self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, 330 | win_length=window_size, window=window, center=center, pad_mode=pad_mode, 331 | freeze_parameters=True) 332 | 333 | # Logmel feature extractor 334 | self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, 335 | n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, 336 | freeze_parameters=True) 337 | 338 | # Spec augmenter 339 | self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, 340 | freq_drop_width=8, freq_stripes_num=2) 341 | 342 | self.bn0 = nn.BatchNorm2d(64) 343 | 344 | self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) 345 | self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) 346 | self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) 347 | self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) 348 | self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) 349 | self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) 350 | 351 | self.fc1 = nn.Linear(2048, 2048, bias=True) 352 | self.fc_audioset = nn.Linear(2048, classes_num, bias=True) 353 | 354 | self.init_weight() 355 | 356 | def init_weight(self): 357 | init_bn(self.bn0) 358 | init_layer(self.fc1) 359 | init_layer(self.fc_audioset) 360 | 361 | def forward(self, input, mixup_lambda=None): 362 | """ 363 | Input: (batch_size, data_length)""" 364 | 365 | x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) 366 | x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) 367 | 368 | x = x.transpose(1, 3) 369 | x = self.bn0(x) 370 | x = x.transpose(1, 3) 371 | 372 | if self.training: 373 | x = self.spec_augmenter(x) 374 | 375 | # Mixup on spectrogram 376 | if self.training and mixup_lambda is not None: 377 | x = do_mixup(x, mixup_lambda) 378 | 379 | x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') 380 | x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') 381 | x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') 382 | x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') 383 | x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg') 384 | x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg') 385 | x = torch.mean(x, dim=3) 386 | 387 | (x1, _) = torch.max(x, dim=2) 388 | x2 = torch.mean(x, dim=2) 389 | x = x1 + x2 390 | x = F.dropout(x, p=0.5, training=self.training) 391 | x = F.relu_(self.fc1(x)) 392 | embedding = F.dropout(x, p=0.5, training=self.training) 393 | clipwise_output = torch.sigmoid(self.fc_audioset(x)) 394 | 395 | output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding} 396 | 397 | return output_dict 398 | 399 | 400 | class Cnn6(nn.Module): 401 | def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, 402 | fmax, classes_num): 403 | 404 | super(Cnn6, self).__init__() 405 | 406 | window = 'hann' 407 | center = True 408 | pad_mode = 'reflect' 409 | ref = 1.0 410 | amin = 1e-10 411 | top_db = None 412 | 413 | # Spectrogram extractor 414 | self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, 415 | win_length=window_size, window=window, center=center, pad_mode=pad_mode, 416 | freeze_parameters=True) 417 | 418 | # Logmel feature extractor 419 | self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, 420 | n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, 421 | freeze_parameters=True) 422 | 423 | # Spec augmenter 424 | self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, 425 | freq_drop_width=8, freq_stripes_num=2) 426 | 427 | self.bn0 = nn.BatchNorm2d(64) 428 | 429 | self.conv_block1 = ConvBlock5x5(in_channels=1, out_channels=64) 430 | self.conv_block2 = ConvBlock5x5(in_channels=64, out_channels=128) 431 | self.conv_block3 = ConvBlock5x5(in_channels=128, out_channels=256) 432 | self.conv_block4 = ConvBlock5x5(in_channels=256, out_channels=512) 433 | 434 | self.fc1 = nn.Linear(512, 512, bias=True) 435 | self.fc_audioset = nn.Linear(512, classes_num, bias=True) 436 | 437 | self.init_weight() 438 | 439 | def init_weight(self): 440 | init_bn(self.bn0) 441 | init_layer(self.fc1) 442 | init_layer(self.fc_audioset) 443 | 444 | def forward(self, input, mixup_lambda=None): 445 | """ 446 | Input: (batch_size, data_length)""" 447 | 448 | x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) 449 | x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) 450 | 451 | x = x.transpose(1, 3) 452 | x = self.bn0(x) 453 | x = x.transpose(1, 3) 454 | 455 | if self.training: 456 | x = self.spec_augmenter(x) 457 | 458 | # Mixup on spectrogram 459 | if self.training and mixup_lambda is not None: 460 | x = do_mixup(x, mixup_lambda) 461 | 462 | x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') 463 | x = F.dropout(x, p=0.2, training=self.training) 464 | x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') 465 | x = F.dropout(x, p=0.2, training=self.training) 466 | x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') 467 | x = F.dropout(x, p=0.2, training=self.training) 468 | x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') 469 | x = F.dropout(x, p=0.2, training=self.training) 470 | tf_embed = x 471 | x = torch.mean(x, dim=3) 472 | 473 | (x1, _) = torch.max(x, dim=2) 474 | x2 = torch.mean(x, dim=2) 475 | x = x1 + x2 476 | last_embed = x 477 | x = F.dropout(x, p=0.5, training=self.training) 478 | x = F.relu_(self.fc1(x)) 479 | embedding = F.dropout(x, p=0.5, training=self.training) 480 | clipwise_output = torch.sigmoid(self.fc_audioset(x)) 481 | 482 | output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding, 'TF-Embed': tf_embed, 'last_embed': last_embed} 483 | 484 | return output_dict 485 | 486 | 487 | class Cnn10(nn.Module): 488 | def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, 489 | fmax, classes_num): 490 | 491 | super(Cnn10, self).__init__() 492 | 493 | window = 'hann' 494 | center = True 495 | pad_mode = 'reflect' 496 | ref = 1.0 497 | amin = 1e-10 498 | top_db = None 499 | 500 | # Spectrogram extractor 501 | self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, 502 | win_length=window_size, window=window, center=center, pad_mode=pad_mode, 503 | freeze_parameters=True) 504 | 505 | # Logmel feature extractor 506 | self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, 507 | n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, 508 | freeze_parameters=True) 509 | 510 | # Spec augmenter 511 | self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, 512 | freq_drop_width=8, freq_stripes_num=2) 513 | 514 | self.bn0 = nn.BatchNorm2d(64) 515 | 516 | self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) 517 | self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) 518 | self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) 519 | self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) 520 | 521 | self.fc1 = nn.Linear(512, 512, bias=True) 522 | self.fc_audioset = nn.Linear(512, classes_num, bias=True) 523 | 524 | self.init_weight() 525 | 526 | def init_weight(self): 527 | init_bn(self.bn0) 528 | init_layer(self.fc1) 529 | init_layer(self.fc_audioset) 530 | 531 | def forward(self, input, mixup_lambda=None): 532 | """ 533 | Input: (batch_size, data_length)""" 534 | 535 | x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) 536 | x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) 537 | 538 | x = x.transpose(1, 3) 539 | x = self.bn0(x) 540 | x = x.transpose(1, 3) 541 | 542 | if self.training: 543 | x = self.spec_augmenter(x) 544 | 545 | # Mixup on spectrogram 546 | if self.training and mixup_lambda is not None: 547 | x = do_mixup(x, mixup_lambda) 548 | 549 | x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') 550 | x = F.dropout(x, p=0.2, training=self.training) 551 | x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') 552 | x = F.dropout(x, p=0.2, training=self.training) 553 | x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') 554 | x = F.dropout(x, p=0.2, training=self.training) 555 | x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') 556 | x = F.dropout(x, p=0.2, training=self.training) 557 | tf_embed = x 558 | x = torch.mean(x, dim=3) 559 | 560 | (x1, _) = torch.max(x, dim=2) 561 | x2 = torch.mean(x, dim=2) 562 | x = x1 + x2 563 | last_embed = x 564 | x = F.dropout(x, p=0.5, training=self.training) 565 | x = F.relu_(self.fc1(x)) 566 | embedding = F.dropout(x, p=0.5, training=self.training) 567 | clipwise_output = torch.sigmoid(self.fc_audioset(x)) 568 | 569 | output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding, 'TF-Embed': tf_embed, 'last_embed': last_embed} 570 | 571 | return output_dict 572 | 573 | 574 | def _resnet_conv3x3(in_planes, out_planes): 575 | #3x3 convolution with padding 576 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, 577 | padding=1, groups=1, bias=False, dilation=1) 578 | 579 | 580 | def _resnet_conv1x1(in_planes, out_planes): 581 | #1x1 convolution 582 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, bias=False) 583 | 584 | 585 | class _ResnetBasicBlock(nn.Module): 586 | expansion = 1 587 | 588 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 589 | base_width=64, dilation=1, norm_layer=None): 590 | super(_ResnetBasicBlock, self).__init__() 591 | if norm_layer is None: 592 | norm_layer = nn.BatchNorm2d 593 | if groups != 1 or base_width != 64: 594 | raise ValueError('_ResnetBasicBlock only supports groups=1 and base_width=64') 595 | if dilation > 1: 596 | raise NotImplementedError("Dilation > 1 not supported in _ResnetBasicBlock") 597 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 598 | 599 | self.stride = stride 600 | 601 | self.conv1 = _resnet_conv3x3(inplanes, planes) 602 | self.bn1 = norm_layer(planes) 603 | self.relu = nn.ReLU(inplace=True) 604 | self.conv2 = _resnet_conv3x3(planes, planes) 605 | self.bn2 = norm_layer(planes) 606 | self.downsample = downsample 607 | self.stride = stride 608 | 609 | self.init_weights() 610 | 611 | def init_weights(self): 612 | init_layer(self.conv1) 613 | init_bn(self.bn1) 614 | init_layer(self.conv2) 615 | init_bn(self.bn2) 616 | nn.init.constant_(self.bn2.weight, 0) 617 | 618 | def forward(self, x): 619 | identity = x 620 | 621 | if self.stride == 2: 622 | out = F.avg_pool2d(x, kernel_size=(2, 2)) 623 | else: 624 | out = x 625 | 626 | out = self.conv1(out) 627 | out = self.bn1(out) 628 | out = self.relu(out) 629 | out = F.dropout(out, p=0.1, training=self.training) 630 | 631 | out = self.conv2(out) 632 | out = self.bn2(out) 633 | 634 | if self.downsample is not None: 635 | identity = self.downsample(identity) 636 | 637 | out += identity 638 | out = self.relu(out) 639 | 640 | return out 641 | 642 | 643 | class _ResnetBottleneck(nn.Module): 644 | expansion = 4 645 | 646 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 647 | base_width=64, dilation=1, norm_layer=None): 648 | super(_ResnetBottleneck, self).__init__() 649 | if norm_layer is None: 650 | norm_layer = nn.BatchNorm2d 651 | width = int(planes * (base_width / 64.)) * groups 652 | self.stride = stride 653 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 654 | self.conv1 = _resnet_conv1x1(inplanes, width) 655 | self.bn1 = norm_layer(width) 656 | self.conv2 = _resnet_conv3x3(width, width) 657 | self.bn2 = norm_layer(width) 658 | self.conv3 = _resnet_conv1x1(width, planes * self.expansion) 659 | self.bn3 = norm_layer(planes * self.expansion) 660 | self.relu = nn.ReLU(inplace=True) 661 | self.downsample = downsample 662 | self.stride = stride 663 | 664 | self.init_weights() 665 | 666 | def init_weights(self): 667 | init_layer(self.conv1) 668 | init_bn(self.bn1) 669 | init_layer(self.conv2) 670 | init_bn(self.bn2) 671 | init_layer(self.conv3) 672 | init_bn(self.bn3) 673 | nn.init.constant_(self.bn3.weight, 0) 674 | 675 | def forward(self, x): 676 | identity = x 677 | 678 | if self.stride == 2: 679 | x = F.avg_pool2d(x, kernel_size=(2, 2)) 680 | 681 | out = self.conv1(x) 682 | out = self.bn1(out) 683 | out = self.relu(out) 684 | 685 | out = self.conv2(out) 686 | out = self.bn2(out) 687 | out = self.relu(out) 688 | out = F.dropout(out, p=0.1, training=self.training) 689 | 690 | out = self.conv3(out) 691 | out = self.bn3(out) 692 | 693 | if self.downsample is not None: 694 | identity = self.downsample(identity) 695 | 696 | out += identity 697 | out = self.relu(out) 698 | 699 | return out 700 | 701 | 702 | class _ResNet(nn.Module): 703 | def __init__(self, block, layers, zero_init_residual=False, 704 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 705 | norm_layer=None): 706 | super(_ResNet, self).__init__() 707 | 708 | if norm_layer is None: 709 | norm_layer = nn.BatchNorm2d 710 | self._norm_layer = norm_layer 711 | 712 | self.inplanes = 64 713 | self.dilation = 1 714 | if replace_stride_with_dilation is None: 715 | # each element in the tuple indicates if we should replace 716 | # the 2x2 stride with a dilated convolution instead 717 | replace_stride_with_dilation = [False, False, False] 718 | if len(replace_stride_with_dilation) != 3: 719 | raise ValueError("replace_stride_with_dilation should be None " 720 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 721 | self.groups = groups 722 | self.base_width = width_per_group 723 | 724 | self.layer1 = self._make_layer(block, 64, layers[0], stride=1) 725 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 726 | dilate=replace_stride_with_dilation[0]) 727 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 728 | dilate=replace_stride_with_dilation[1]) 729 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 730 | dilate=replace_stride_with_dilation[2]) 731 | 732 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 733 | norm_layer = self._norm_layer 734 | downsample = None 735 | previous_dilation = self.dilation 736 | if dilate: 737 | self.dilation *= stride 738 | stride = 1 739 | if stride != 1 or self.inplanes != planes * block.expansion: 740 | if stride == 1: 741 | downsample = nn.Sequential( 742 | _resnet_conv1x1(self.inplanes, planes * block.expansion), 743 | norm_layer(planes * block.expansion), 744 | ) 745 | init_layer(downsample[0]) 746 | init_bn(downsample[1]) 747 | elif stride == 2: 748 | downsample = nn.Sequential( 749 | nn.AvgPool2d(kernel_size=2), 750 | _resnet_conv1x1(self.inplanes, planes * block.expansion), 751 | norm_layer(planes * block.expansion), 752 | ) 753 | init_layer(downsample[1]) 754 | init_bn(downsample[2]) 755 | 756 | layers = [] 757 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 758 | self.base_width, previous_dilation, norm_layer)) 759 | self.inplanes = planes * block.expansion 760 | for _ in range(1, blocks): 761 | layers.append(block(self.inplanes, planes, groups=self.groups, 762 | base_width=self.base_width, dilation=self.dilation, 763 | norm_layer=norm_layer)) 764 | 765 | return nn.Sequential(*layers) 766 | 767 | def forward(self, x): 768 | x = self.layer1(x) 769 | x = self.layer2(x) 770 | x = self.layer3(x) 771 | x = self.layer4(x) 772 | 773 | return x 774 | 775 | 776 | class ResNet22(nn.Module): 777 | def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, 778 | fmax, classes_num): 779 | 780 | super(ResNet22, self).__init__() 781 | 782 | window = 'hann' 783 | center = True 784 | pad_mode = 'reflect' 785 | ref = 1.0 786 | amin = 1e-10 787 | top_db = None 788 | 789 | # Spectrogram extractor 790 | self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, 791 | win_length=window_size, window=window, center=center, pad_mode=pad_mode, 792 | freeze_parameters=True) 793 | 794 | # Logmel feature extractor 795 | self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, 796 | n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, 797 | freeze_parameters=True) 798 | 799 | # Spec augmenter 800 | self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, 801 | freq_drop_width=8, freq_stripes_num=2) 802 | 803 | self.bn0 = nn.BatchNorm2d(64) 804 | 805 | self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) 806 | # self.conv_block2 = ConvBlock(in_channels=64, out_channels=64) 807 | 808 | self.resnet = _ResNet(block=_ResnetBasicBlock, layers=[2, 2, 2, 2], zero_init_residual=True) 809 | 810 | self.conv_block_after1 = ConvBlock(in_channels=512, out_channels=2048) 811 | 812 | self.fc1 = nn.Linear(2048, 2048) 813 | self.fc_audioset = nn.Linear(2048, classes_num, bias=True) 814 | 815 | self.init_weights() 816 | 817 | def init_weights(self): 818 | init_bn(self.bn0) 819 | init_layer(self.fc1) 820 | init_layer(self.fc_audioset) 821 | 822 | 823 | def forward(self, input, mixup_lambda=None): 824 | """ 825 | Input: (batch_size, data_length)""" 826 | 827 | x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) 828 | x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) 829 | 830 | x = x.transpose(1, 3) 831 | x = self.bn0(x) 832 | x = x.transpose(1, 3) 833 | 834 | if self.training: 835 | x = self.spec_augmenter(x) 836 | 837 | # Mixup on spectrogram 838 | if self.training and mixup_lambda is not None: 839 | x = do_mixup(x, mixup_lambda) 840 | 841 | x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') 842 | x = F.dropout(x, p=0.2, training=self.training, inplace=True) 843 | x = self.resnet(x) 844 | x = F.avg_pool2d(x, kernel_size=(2, 2)) 845 | x = F.dropout(x, p=0.2, training=self.training, inplace=True) 846 | x = self.conv_block_after1(x, pool_size=(1, 1), pool_type='avg') 847 | x = F.dropout(x, p=0.2, training=self.training, inplace=True) 848 | x = torch.mean(x, dim=3) 849 | 850 | (x1, _) = torch.max(x, dim=2) 851 | x2 = torch.mean(x, dim=2) 852 | x = x1 + x2 853 | x = F.dropout(x, p=0.5, training=self.training) 854 | x = F.relu_(self.fc1(x)) 855 | embedding = F.dropout(x, p=0.5, training=self.training) 856 | clipwise_output = torch.sigmoid(self.fc_audioset(x)) 857 | 858 | output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding} 859 | 860 | return output_dict 861 | 862 | 863 | class ResNet38(nn.Module): 864 | def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, 865 | fmax, classes_num): 866 | 867 | super(ResNet38, self).__init__() 868 | 869 | window = 'hann' 870 | center = True 871 | pad_mode = 'reflect' 872 | ref = 1.0 873 | amin = 1e-10 874 | top_db = None 875 | 876 | # Spectrogram extractor 877 | self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, 878 | win_length=window_size, window=window, center=center, pad_mode=pad_mode, 879 | freeze_parameters=True) 880 | 881 | # Logmel feature extractor 882 | self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, 883 | n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, 884 | freeze_parameters=True) 885 | 886 | # Spec augmenter 887 | self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, 888 | freq_drop_width=8, freq_stripes_num=2) 889 | 890 | self.bn0 = nn.BatchNorm2d(64) 891 | 892 | self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) 893 | # self.conv_block2 = ConvBlock(in_channels=64, out_channels=64) 894 | 895 | self.resnet = _ResNet(block=_ResnetBasicBlock, layers=[3, 4, 6, 3], zero_init_residual=True) 896 | 897 | self.conv_block_after1 = ConvBlock(in_channels=512, out_channels=2048) 898 | 899 | self.fc1 = nn.Linear(2048, 2048) 900 | self.fc_audioset = nn.Linear(2048, classes_num, bias=True) 901 | 902 | self.init_weights() 903 | 904 | def init_weights(self): 905 | init_bn(self.bn0) 906 | init_layer(self.fc1) 907 | init_layer(self.fc_audioset) 908 | 909 | 910 | def forward(self, input, mixup_lambda=None): 911 | """ 912 | Input: (batch_size, data_length)""" 913 | 914 | x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) 915 | x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) 916 | 917 | x = x.transpose(1, 3) 918 | x = self.bn0(x) 919 | x = x.transpose(1, 3) 920 | 921 | if self.training: 922 | x = self.spec_augmenter(x) 923 | 924 | # Mixup on spectrogram 925 | if self.training and mixup_lambda is not None: 926 | x = do_mixup(x, mixup_lambda) 927 | 928 | x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') 929 | x = F.dropout(x, p=0.2, training=self.training, inplace=True) 930 | x = self.resnet(x) 931 | x = F.avg_pool2d(x, kernel_size=(2, 2)) 932 | x = F.dropout(x, p=0.2, training=self.training, inplace=True) 933 | x = self.conv_block_after1(x, pool_size=(1, 1), pool_type='avg') 934 | x = F.dropout(x, p=0.2, training=self.training, inplace=True) 935 | x = torch.mean(x, dim=3) 936 | 937 | (x1, _) = torch.max(x, dim=2) 938 | x2 = torch.mean(x, dim=2) 939 | x = x1 + x2 940 | x = F.dropout(x, p=0.5, training=self.training) 941 | x = F.relu_(self.fc1(x)) 942 | embedding = F.dropout(x, p=0.5, training=self.training) 943 | clipwise_output = torch.sigmoid(self.fc_audioset(x)) 944 | 945 | output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding} 946 | 947 | return output_dict 948 | 949 | 950 | class ResNet54(nn.Module): 951 | def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, 952 | fmax, classes_num): 953 | 954 | super(ResNet54, self).__init__() 955 | 956 | window = 'hann' 957 | center = True 958 | pad_mode = 'reflect' 959 | ref = 1.0 960 | amin = 1e-10 961 | top_db = None 962 | 963 | # Spectrogram extractor 964 | self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, 965 | win_length=window_size, window=window, center=center, pad_mode=pad_mode, 966 | freeze_parameters=True) 967 | 968 | # Logmel feature extractor 969 | self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, 970 | n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, 971 | freeze_parameters=True) 972 | 973 | # Spec augmenter 974 | self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, 975 | freq_drop_width=8, freq_stripes_num=2) 976 | 977 | self.bn0 = nn.BatchNorm2d(64) 978 | 979 | self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) 980 | # self.conv_block2 = ConvBlock(in_channels=64, out_channels=64) 981 | 982 | self.resnet = _ResNet(block=_ResnetBottleneck, layers=[3, 4, 6, 3], zero_init_residual=True) 983 | 984 | self.conv_block_after1 = ConvBlock(in_channels=2048, out_channels=2048) 985 | 986 | self.fc1 = nn.Linear(2048, 2048) 987 | self.fc_audioset = nn.Linear(2048, classes_num, bias=True) 988 | 989 | self.init_weights() 990 | 991 | def init_weights(self): 992 | init_bn(self.bn0) 993 | init_layer(self.fc1) 994 | init_layer(self.fc_audioset) 995 | 996 | 997 | def forward(self, input, mixup_lambda=None): 998 | """ 999 | Input: (batch_size, data_length)""" 1000 | 1001 | x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) 1002 | x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) 1003 | 1004 | x = x.transpose(1, 3) 1005 | x = self.bn0(x) 1006 | x = x.transpose(1, 3) 1007 | 1008 | if self.training: 1009 | x = self.spec_augmenter(x) 1010 | 1011 | # Mixup on spectrogram 1012 | if self.training and mixup_lambda is not None: 1013 | x = do_mixup(x, mixup_lambda) 1014 | 1015 | x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') 1016 | x = F.dropout(x, p=0.2, training=self.training, inplace=True) 1017 | x = self.resnet(x) 1018 | x = F.avg_pool2d(x, kernel_size=(2, 2)) 1019 | x = F.dropout(x, p=0.2, training=self.training, inplace=True) 1020 | x = self.conv_block_after1(x, pool_size=(1, 1), pool_type='avg') 1021 | x = F.dropout(x, p=0.2, training=self.training, inplace=True) 1022 | x = torch.mean(x, dim=3) 1023 | 1024 | (x1, _) = torch.max(x, dim=2) 1025 | x2 = torch.mean(x, dim=2) 1026 | x = x1 + x2 1027 | x = F.dropout(x, p=0.5, training=self.training) 1028 | x = F.relu_(self.fc1(x)) 1029 | embedding = F.dropout(x, p=0.5, training=self.training) 1030 | clipwise_output = torch.sigmoid(self.fc_audioset(x)) 1031 | 1032 | output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding} 1033 | 1034 | return output_dict 1035 | 1036 | 1037 | class Cnn14_emb512(nn.Module): 1038 | def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, 1039 | fmax, classes_num): 1040 | 1041 | super(Cnn14_emb512, self).__init__() 1042 | 1043 | window = 'hann' 1044 | center = True 1045 | pad_mode = 'reflect' 1046 | ref = 1.0 1047 | amin = 1e-10 1048 | top_db = None 1049 | 1050 | # Spectrogram extractor 1051 | self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, 1052 | win_length=window_size, window=window, center=center, pad_mode=pad_mode, 1053 | freeze_parameters=True) 1054 | 1055 | # Logmel feature extractor 1056 | self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, 1057 | n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, 1058 | freeze_parameters=True) 1059 | 1060 | # Spec augmenter 1061 | self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, 1062 | freq_drop_width=8, freq_stripes_num=2) 1063 | 1064 | self.bn0 = nn.BatchNorm2d(64) 1065 | 1066 | self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) 1067 | self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) 1068 | self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) 1069 | self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) 1070 | self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) 1071 | self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) 1072 | 1073 | self.fc1 = nn.Linear(2048, 512, bias=True) 1074 | self.fc_audioset = nn.Linear(512, classes_num, bias=True) 1075 | 1076 | self.init_weight() 1077 | 1078 | def init_weight(self): 1079 | init_bn(self.bn0) 1080 | init_layer(self.fc1) 1081 | init_layer(self.fc_audioset) 1082 | 1083 | def forward(self, input, mixup_lambda=None): 1084 | """ 1085 | Input: (batch_size, data_length)""" 1086 | 1087 | x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) 1088 | x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) 1089 | 1090 | x = x.transpose(1, 3) 1091 | x = self.bn0(x) 1092 | x = x.transpose(1, 3) 1093 | 1094 | if self.training: 1095 | x = self.spec_augmenter(x) 1096 | 1097 | # Mixup on spectrogram 1098 | if self.training and mixup_lambda is not None: 1099 | x = do_mixup(x, mixup_lambda) 1100 | 1101 | x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') 1102 | x = F.dropout(x, p=0.2, training=self.training) 1103 | x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') 1104 | x = F.dropout(x, p=0.2, training=self.training) 1105 | x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') 1106 | x = F.dropout(x, p=0.2, training=self.training) 1107 | x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') 1108 | x = F.dropout(x, p=0.2, training=self.training) 1109 | x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg') 1110 | x = F.dropout(x, p=0.2, training=self.training) 1111 | x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg') 1112 | x = F.dropout(x, p=0.2, training=self.training) 1113 | x = torch.mean(x, dim=3) 1114 | 1115 | (x1, _) = torch.max(x, dim=2) 1116 | x2 = torch.mean(x, dim=2) 1117 | x = x1 + x2 1118 | x = F.dropout(x, p=0.5, training=self.training) 1119 | x = F.relu_(self.fc1(x)) 1120 | embedding = F.dropout(x, p=0.5, training=self.training) 1121 | clipwise_output = torch.sigmoid(self.fc_audioset(x)) 1122 | 1123 | output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding} 1124 | 1125 | return output_dict 1126 | 1127 | 1128 | class Cnn14_emb128(nn.Module): 1129 | def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, 1130 | fmax, classes_num): 1131 | 1132 | super(Cnn14_emb128, self).__init__() 1133 | 1134 | window = 'hann' 1135 | center = True 1136 | pad_mode = 'reflect' 1137 | ref = 1.0 1138 | amin = 1e-10 1139 | top_db = None 1140 | 1141 | # Spectrogram extractor 1142 | self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, 1143 | win_length=window_size, window=window, center=center, pad_mode=pad_mode, 1144 | freeze_parameters=True) 1145 | 1146 | # Logmel feature extractor 1147 | self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, 1148 | n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, 1149 | freeze_parameters=True) 1150 | 1151 | # Spec augmenter 1152 | self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, 1153 | freq_drop_width=8, freq_stripes_num=2) 1154 | 1155 | self.bn0 = nn.BatchNorm2d(64) 1156 | 1157 | self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) 1158 | self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) 1159 | self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) 1160 | self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) 1161 | self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) 1162 | self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) 1163 | 1164 | self.fc1 = nn.Linear(2048, 128, bias=True) 1165 | self.fc_audioset = nn.Linear(128, classes_num, bias=True) 1166 | 1167 | self.init_weight() 1168 | 1169 | def init_weight(self): 1170 | init_bn(self.bn0) 1171 | init_layer(self.fc1) 1172 | init_layer(self.fc_audioset) 1173 | 1174 | def forward(self, input, mixup_lambda=None): 1175 | """ 1176 | Input: (batch_size, data_length)""" 1177 | 1178 | x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) 1179 | x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) 1180 | 1181 | x = x.transpose(1, 3) 1182 | x = self.bn0(x) 1183 | x = x.transpose(1, 3) 1184 | 1185 | if self.training: 1186 | x = self.spec_augmenter(x) 1187 | 1188 | # Mixup on spectrogram 1189 | if self.training and mixup_lambda is not None: 1190 | x = do_mixup(x, mixup_lambda) 1191 | 1192 | x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') 1193 | x = F.dropout(x, p=0.2, training=self.training) 1194 | x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') 1195 | x = F.dropout(x, p=0.2, training=self.training) 1196 | x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') 1197 | x = F.dropout(x, p=0.2, training=self.training) 1198 | x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') 1199 | x = F.dropout(x, p=0.2, training=self.training) 1200 | x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg') 1201 | x = F.dropout(x, p=0.2, training=self.training) 1202 | x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg') 1203 | x = F.dropout(x, p=0.2, training=self.training) 1204 | x = torch.mean(x, dim=3) 1205 | 1206 | (x1, _) = torch.max(x, dim=2) 1207 | x2 = torch.mean(x, dim=2) 1208 | x = x1 + x2 1209 | x = F.dropout(x, p=0.5, training=self.training) 1210 | x = F.relu_(self.fc1(x)) 1211 | embedding = F.dropout(x, p=0.5, training=self.training) 1212 | clipwise_output = torch.sigmoid(self.fc_audioset(x)) 1213 | 1214 | output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding} 1215 | 1216 | return output_dict 1217 | 1218 | 1219 | class Cnn14_emb32(nn.Module): 1220 | def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, 1221 | fmax, classes_num): 1222 | 1223 | super(Cnn14_emb32, self).__init__() 1224 | 1225 | window = 'hann' 1226 | center = True 1227 | pad_mode = 'reflect' 1228 | ref = 1.0 1229 | amin = 1e-10 1230 | top_db = None 1231 | 1232 | # Spectrogram extractor 1233 | self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, 1234 | win_length=window_size, window=window, center=center, pad_mode=pad_mode, 1235 | freeze_parameters=True) 1236 | 1237 | # Logmel feature extractor 1238 | self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, 1239 | n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, 1240 | freeze_parameters=True) 1241 | 1242 | # Spec augmenter 1243 | self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, 1244 | freq_drop_width=8, freq_stripes_num=2) 1245 | 1246 | self.bn0 = nn.BatchNorm2d(64) 1247 | 1248 | self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) 1249 | self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) 1250 | self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) 1251 | self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) 1252 | self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) 1253 | self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) 1254 | 1255 | self.fc1 = nn.Linear(2048, 32, bias=True) 1256 | self.fc_audioset = nn.Linear(32, classes_num, bias=True) 1257 | 1258 | self.init_weight() 1259 | 1260 | def init_weight(self): 1261 | init_bn(self.bn0) 1262 | init_layer(self.fc1) 1263 | init_layer(self.fc_audioset) 1264 | 1265 | def forward(self, input, mixup_lambda=None): 1266 | """ 1267 | Input: (batch_size, data_length)""" 1268 | 1269 | x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) 1270 | x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) 1271 | 1272 | x = x.transpose(1, 3) 1273 | x = self.bn0(x) 1274 | x = x.transpose(1, 3) 1275 | 1276 | if self.training: 1277 | x = self.spec_augmenter(x) 1278 | 1279 | # Mixup on spectrogram 1280 | if self.training and mixup_lambda is not None: 1281 | x = do_mixup(x, mixup_lambda) 1282 | 1283 | x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') 1284 | x = F.dropout(x, p=0.2, training=self.training) 1285 | x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') 1286 | x = F.dropout(x, p=0.2, training=self.training) 1287 | x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') 1288 | x = F.dropout(x, p=0.2, training=self.training) 1289 | x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') 1290 | x = F.dropout(x, p=0.2, training=self.training) 1291 | x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg') 1292 | x = F.dropout(x, p=0.2, training=self.training) 1293 | x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg') 1294 | x = F.dropout(x, p=0.2, training=self.training) 1295 | x = torch.mean(x, dim=3) 1296 | 1297 | (x1, _) = torch.max(x, dim=2) 1298 | x2 = torch.mean(x, dim=2) 1299 | x = x1 + x2 1300 | x = F.dropout(x, p=0.5, training=self.training) 1301 | x = F.relu_(self.fc1(x)) 1302 | embedding = F.dropout(x, p=0.5, training=self.training) 1303 | clipwise_output = torch.sigmoid(self.fc_audioset(x)) 1304 | 1305 | output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding} 1306 | 1307 | return output_dict 1308 | 1309 | 1310 | class MobileNetV1(nn.Module): 1311 | def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, 1312 | fmax, classes_num): 1313 | 1314 | super(MobileNetV1, self).__init__() 1315 | 1316 | window = 'hann' 1317 | center = True 1318 | pad_mode = 'reflect' 1319 | ref = 1.0 1320 | amin = 1e-10 1321 | top_db = None 1322 | 1323 | # Spectrogram extractor 1324 | self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, 1325 | win_length=window_size, window=window, center=center, pad_mode=pad_mode, 1326 | freeze_parameters=True) 1327 | 1328 | # Logmel feature extractor 1329 | self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, 1330 | n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, 1331 | freeze_parameters=True) 1332 | 1333 | # Spec augmenter 1334 | self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, 1335 | freq_drop_width=8, freq_stripes_num=2) 1336 | 1337 | self.bn0 = nn.BatchNorm2d(64) 1338 | 1339 | def conv_bn(inp, oup, stride): 1340 | _layers = [ 1341 | nn.Conv2d(inp, oup, 3, 1, 1, bias=False), 1342 | nn.AvgPool2d(stride), 1343 | nn.BatchNorm2d(oup), 1344 | nn.ReLU(inplace=True) 1345 | ] 1346 | _layers = nn.Sequential(*_layers) 1347 | init_layer(_layers[0]) 1348 | init_bn(_layers[2]) 1349 | return _layers 1350 | 1351 | def conv_dw(inp, oup, stride): 1352 | _layers = [ 1353 | nn.Conv2d(inp, inp, 3, 1, 1, groups=inp, bias=False), 1354 | nn.AvgPool2d(stride), 1355 | nn.BatchNorm2d(inp), 1356 | nn.ReLU(inplace=True), 1357 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 1358 | nn.BatchNorm2d(oup), 1359 | nn.ReLU(inplace=True) 1360 | ] 1361 | _layers = nn.Sequential(*_layers) 1362 | init_layer(_layers[0]) 1363 | init_bn(_layers[2]) 1364 | init_layer(_layers[4]) 1365 | init_bn(_layers[5]) 1366 | return _layers 1367 | 1368 | self.features = nn.Sequential( 1369 | conv_bn( 1, 32, 2), 1370 | conv_dw( 32, 64, 1), 1371 | conv_dw( 64, 128, 2), 1372 | conv_dw(128, 128, 1), 1373 | conv_dw(128, 256, 2), 1374 | conv_dw(256, 256, 1), 1375 | conv_dw(256, 512, 2), 1376 | conv_dw(512, 512, 1), 1377 | conv_dw(512, 512, 1), 1378 | conv_dw(512, 512, 1), 1379 | conv_dw(512, 512, 1), 1380 | conv_dw(512, 512, 1), 1381 | conv_dw(512, 1024, 2), 1382 | conv_dw(1024, 1024, 1)) 1383 | 1384 | self.fc1 = nn.Linear(1024, 1024, bias=True) 1385 | self.fc_audioset = nn.Linear(1024, classes_num, bias=True) 1386 | 1387 | self.init_weights() 1388 | 1389 | def init_weights(self): 1390 | init_bn(self.bn0) 1391 | init_layer(self.fc1) 1392 | init_layer(self.fc_audioset) 1393 | 1394 | def forward(self, input, mixup_lambda=None): 1395 | """ 1396 | Input: (batch_size, data_length)""" 1397 | 1398 | x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) 1399 | x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) 1400 | 1401 | x = x.transpose(1, 3) 1402 | x = self.bn0(x) 1403 | x = x.transpose(1, 3) 1404 | 1405 | if self.training: 1406 | x = self.spec_augmenter(x) 1407 | 1408 | # Mixup on spectrogram 1409 | if self.training and mixup_lambda is not None: 1410 | x = do_mixup(x, mixup_lambda) 1411 | 1412 | x = self.features(x) 1413 | x = torch.mean(x, dim=3) 1414 | 1415 | (x1, _) = torch.max(x, dim=2) 1416 | x2 = torch.mean(x, dim=2) 1417 | x = x1 + x2 1418 | x = F.dropout(x, p=0.5, training=self.training) 1419 | x = F.relu_(self.fc1(x)) 1420 | embedding = F.dropout(x, p=0.5, training=self.training) 1421 | clipwise_output = torch.sigmoid(self.fc_audioset(x)) 1422 | 1423 | output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding} 1424 | 1425 | return output_dict 1426 | 1427 | 1428 | class InvertedResidual(nn.Module): 1429 | def __init__(self, inp, oup, stride, expand_ratio): 1430 | super(InvertedResidual, self).__init__() 1431 | self.stride = stride 1432 | assert stride in [1, 2] 1433 | 1434 | hidden_dim = round(inp * expand_ratio) 1435 | self.use_res_connect = self.stride == 1 and inp == oup 1436 | 1437 | if expand_ratio == 1: 1438 | _layers = [ 1439 | nn.Conv2d(hidden_dim, hidden_dim, 3, 1, 1, groups=hidden_dim, bias=False), 1440 | nn.AvgPool2d(stride), 1441 | nn.BatchNorm2d(hidden_dim), 1442 | nn.ReLU6(inplace=True), 1443 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 1444 | nn.BatchNorm2d(oup) 1445 | ] 1446 | _layers = nn.Sequential(*_layers) 1447 | init_layer(_layers[0]) 1448 | init_bn(_layers[2]) 1449 | init_layer(_layers[4]) 1450 | init_bn(_layers[5]) 1451 | self.conv = _layers 1452 | else: 1453 | _layers = [ 1454 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 1455 | nn.BatchNorm2d(hidden_dim), 1456 | nn.ReLU6(inplace=True), 1457 | nn.Conv2d(hidden_dim, hidden_dim, 3, 1, 1, groups=hidden_dim, bias=False), 1458 | nn.AvgPool2d(stride), 1459 | nn.BatchNorm2d(hidden_dim), 1460 | nn.ReLU6(inplace=True), 1461 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 1462 | nn.BatchNorm2d(oup) 1463 | ] 1464 | _layers = nn.Sequential(*_layers) 1465 | init_layer(_layers[0]) 1466 | init_bn(_layers[1]) 1467 | init_layer(_layers[3]) 1468 | init_bn(_layers[5]) 1469 | init_layer(_layers[7]) 1470 | init_bn(_layers[8]) 1471 | self.conv = _layers 1472 | 1473 | def forward(self, x): 1474 | if self.use_res_connect: 1475 | return x + self.conv(x) 1476 | else: 1477 | return self.conv(x) 1478 | 1479 | 1480 | class MobileNetV2(nn.Module): 1481 | def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, 1482 | fmax, classes_num): 1483 | 1484 | super(MobileNetV2, self).__init__() 1485 | 1486 | window = 'hann' 1487 | center = True 1488 | pad_mode = 'reflect' 1489 | ref = 1.0 1490 | amin = 1e-10 1491 | top_db = None 1492 | 1493 | # Spectrogram extractor 1494 | self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, 1495 | win_length=window_size, window=window, center=center, pad_mode=pad_mode, 1496 | freeze_parameters=True) 1497 | 1498 | # Logmel feature extractor 1499 | self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, 1500 | n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, 1501 | freeze_parameters=True) 1502 | 1503 | # Spec augmenter 1504 | self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, 1505 | freq_drop_width=8, freq_stripes_num=2) 1506 | 1507 | self.bn0 = nn.BatchNorm2d(64) 1508 | 1509 | width_mult=1. 1510 | block = InvertedResidual 1511 | input_channel = 32 1512 | last_channel = 1280 1513 | interverted_residual_setting = [ 1514 | # t, c, n, s 1515 | [1, 16, 1, 1], 1516 | [6, 24, 2, 2], 1517 | [6, 32, 3, 2], 1518 | [6, 64, 4, 2], 1519 | [6, 96, 3, 2], 1520 | [6, 160, 3, 1], 1521 | [6, 320, 1, 1], 1522 | ] 1523 | 1524 | def conv_bn(inp, oup, stride): 1525 | _layers = [ 1526 | nn.Conv2d(inp, oup, 3, 1, 1, bias=False), 1527 | nn.AvgPool2d(stride), 1528 | nn.BatchNorm2d(oup), 1529 | nn.ReLU6(inplace=True) 1530 | ] 1531 | _layers = nn.Sequential(*_layers) 1532 | init_layer(_layers[0]) 1533 | init_bn(_layers[2]) 1534 | return _layers 1535 | 1536 | 1537 | def conv_1x1_bn(inp, oup): 1538 | _layers = nn.Sequential( 1539 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 1540 | nn.BatchNorm2d(oup), 1541 | nn.ReLU6(inplace=True) 1542 | ) 1543 | init_layer(_layers[0]) 1544 | init_bn(_layers[1]) 1545 | return _layers 1546 | 1547 | # building first layer 1548 | input_channel = int(input_channel * width_mult) 1549 | self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel 1550 | self.features = [conv_bn(1, input_channel, 2)] 1551 | # building inverted residual blocks 1552 | for t, c, n, s in interverted_residual_setting: 1553 | output_channel = int(c * width_mult) 1554 | for i in range(n): 1555 | if i == 0: 1556 | self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) 1557 | else: 1558 | self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) 1559 | input_channel = output_channel 1560 | # building last several layers 1561 | self.features.append(conv_1x1_bn(input_channel, self.last_channel)) 1562 | # make it nn.Sequential 1563 | self.features = nn.Sequential(*self.features) 1564 | 1565 | self.fc1 = nn.Linear(1280, 1024, bias=True) 1566 | self.fc_audioset = nn.Linear(1024, classes_num, bias=True) 1567 | 1568 | self.init_weight() 1569 | 1570 | def init_weight(self): 1571 | init_bn(self.bn0) 1572 | init_layer(self.fc1) 1573 | init_layer(self.fc_audioset) 1574 | 1575 | def forward(self, input, mixup_lambda=None): 1576 | """ 1577 | Input: (batch_size, data_length)""" 1578 | 1579 | x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) 1580 | x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) 1581 | 1582 | x = x.transpose(1, 3) 1583 | x = self.bn0(x) 1584 | x = x.transpose(1, 3) 1585 | 1586 | if self.training: 1587 | x = self.spec_augmenter(x) 1588 | 1589 | # Mixup on spectrogram 1590 | if self.training and mixup_lambda is not None: 1591 | x = do_mixup(x, mixup_lambda) 1592 | 1593 | x = self.features(x) 1594 | 1595 | x = torch.mean(x, dim=3) 1596 | 1597 | (x1, _) = torch.max(x, dim=2) 1598 | x2 = torch.mean(x, dim=2) 1599 | x = x1 + x2 1600 | # x = F.dropout(x, p=0.5, training=self.training) 1601 | x = F.relu_(self.fc1(x)) 1602 | embedding = F.dropout(x, p=0.5, training=self.training) 1603 | clipwise_output = torch.sigmoid(self.fc_audioset(x)) 1604 | 1605 | output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding} 1606 | 1607 | return output_dict 1608 | 1609 | 1610 | class LeeNetConvBlock(nn.Module): 1611 | def __init__(self, in_channels, out_channels, kernel_size, stride): 1612 | 1613 | super(LeeNetConvBlock, self).__init__() 1614 | 1615 | self.conv1 = nn.Conv1d(in_channels=in_channels, 1616 | out_channels=out_channels, 1617 | kernel_size=kernel_size, stride=stride, 1618 | padding=kernel_size // 2, bias=False) 1619 | 1620 | self.bn1 = nn.BatchNorm1d(out_channels) 1621 | 1622 | self.init_weight() 1623 | 1624 | def init_weight(self): 1625 | init_layer(self.conv1) 1626 | init_bn(self.bn1) 1627 | 1628 | def forward(self, x, pool_size=1): 1629 | x = F.relu_(self.bn1(self.conv1(x))) 1630 | if pool_size != 1: 1631 | x = F.max_pool1d(x, kernel_size=pool_size, padding=pool_size // 2) 1632 | return x 1633 | 1634 | 1635 | class LeeNet11(nn.Module): 1636 | def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, 1637 | fmax, classes_num): 1638 | 1639 | super(LeeNet11, self).__init__() 1640 | 1641 | window = 'hann' 1642 | center = True 1643 | pad_mode = 'reflect' 1644 | ref = 1.0 1645 | amin = 1e-10 1646 | top_db = None 1647 | 1648 | self.conv_block1 = LeeNetConvBlock(1, 64, 3, 3) 1649 | self.conv_block2 = LeeNetConvBlock(64, 64, 3, 1) 1650 | self.conv_block3 = LeeNetConvBlock(64, 64, 3, 1) 1651 | self.conv_block4 = LeeNetConvBlock(64, 128, 3, 1) 1652 | self.conv_block5 = LeeNetConvBlock(128, 128, 3, 1) 1653 | self.conv_block6 = LeeNetConvBlock(128, 128, 3, 1) 1654 | self.conv_block7 = LeeNetConvBlock(128, 128, 3, 1) 1655 | self.conv_block8 = LeeNetConvBlock(128, 128, 3, 1) 1656 | self.conv_block9 = LeeNetConvBlock(128, 256, 3, 1) 1657 | 1658 | 1659 | self.fc1 = nn.Linear(256, 512, bias=True) 1660 | self.fc_audioset = nn.Linear(512, classes_num, bias=True) 1661 | 1662 | self.init_weight() 1663 | 1664 | def init_weight(self): 1665 | init_layer(self.fc1) 1666 | init_layer(self.fc_audioset) 1667 | 1668 | def forward(self, input, mixup_lambda=None): 1669 | """ 1670 | Input: (batch_size, data_length)""" 1671 | 1672 | x = input[:, None, :] 1673 | 1674 | # Mixup on spectrogram 1675 | if self.training and mixup_lambda is not None: 1676 | x = do_mixup(x, mixup_lambda) 1677 | 1678 | x = self.conv_block1(x) 1679 | x = self.conv_block2(x, pool_size=3) 1680 | x = self.conv_block3(x, pool_size=3) 1681 | x = self.conv_block4(x, pool_size=3) 1682 | x = self.conv_block5(x, pool_size=3) 1683 | x = self.conv_block6(x, pool_size=3) 1684 | x = self.conv_block7(x, pool_size=3) 1685 | x = self.conv_block8(x, pool_size=3) 1686 | x = self.conv_block9(x, pool_size=3) 1687 | 1688 | (x1, _) = torch.max(x, dim=2) 1689 | x2 = torch.mean(x, dim=2) 1690 | x = x1 + x2 1691 | x = F.dropout(x, p=0.5, training=self.training) 1692 | x = F.relu_(self.fc1(x)) 1693 | embedding = F.dropout(x, p=0.5, training=self.training) 1694 | clipwise_output = torch.sigmoid(self.fc_audioset(x)) 1695 | 1696 | output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding} 1697 | 1698 | return output_dict 1699 | 1700 | 1701 | class LeeNetConvBlock2(nn.Module): 1702 | def __init__(self, in_channels, out_channels, kernel_size, stride): 1703 | 1704 | super(LeeNetConvBlock2, self).__init__() 1705 | 1706 | self.conv1 = nn.Conv1d(in_channels=in_channels, 1707 | out_channels=out_channels, 1708 | kernel_size=kernel_size, stride=stride, 1709 | padding=kernel_size // 2, bias=False) 1710 | 1711 | self.conv2 = nn.Conv1d(in_channels=out_channels, 1712 | out_channels=out_channels, 1713 | kernel_size=kernel_size, stride=1, 1714 | padding=kernel_size // 2, bias=False) 1715 | 1716 | self.bn1 = nn.BatchNorm1d(out_channels) 1717 | self.bn2 = nn.BatchNorm1d(out_channels) 1718 | 1719 | self.init_weight() 1720 | 1721 | def init_weight(self): 1722 | init_layer(self.conv1) 1723 | init_layer(self.conv2) 1724 | init_bn(self.bn1) 1725 | init_bn(self.bn2) 1726 | 1727 | def forward(self, x, pool_size=1): 1728 | x = F.relu_(self.bn1(self.conv1(x))) 1729 | x = F.relu_(self.bn2(self.conv2(x))) 1730 | if pool_size != 1: 1731 | x = F.max_pool1d(x, kernel_size=pool_size, padding=pool_size // 2) 1732 | return x 1733 | 1734 | 1735 | class LeeNet24(nn.Module): 1736 | def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, 1737 | fmax, classes_num): 1738 | 1739 | super(LeeNet24, self).__init__() 1740 | 1741 | window = 'hann' 1742 | center = True 1743 | pad_mode = 'reflect' 1744 | ref = 1.0 1745 | amin = 1e-10 1746 | top_db = None 1747 | 1748 | self.conv_block1 = LeeNetConvBlock2(1, 64, 3, 3) 1749 | self.conv_block2 = LeeNetConvBlock2(64, 96, 3, 1) 1750 | self.conv_block3 = LeeNetConvBlock2(96, 128, 3, 1) 1751 | self.conv_block4 = LeeNetConvBlock2(128, 128, 3, 1) 1752 | self.conv_block5 = LeeNetConvBlock2(128, 256, 3, 1) 1753 | self.conv_block6 = LeeNetConvBlock2(256, 256, 3, 1) 1754 | self.conv_block7 = LeeNetConvBlock2(256, 512, 3, 1) 1755 | self.conv_block8 = LeeNetConvBlock2(512, 512, 3, 1) 1756 | self.conv_block9 = LeeNetConvBlock2(512, 1024, 3, 1) 1757 | 1758 | self.fc1 = nn.Linear(1024, 1024, bias=True) 1759 | self.fc_audioset = nn.Linear(1024, classes_num, bias=True) 1760 | 1761 | self.init_weight() 1762 | 1763 | def init_weight(self): 1764 | init_layer(self.fc1) 1765 | init_layer(self.fc_audioset) 1766 | 1767 | def forward(self, input, mixup_lambda=None): 1768 | """ 1769 | Input: (batch_size, data_length)""" 1770 | 1771 | x = input[:, None, :] 1772 | 1773 | # Mixup on spectrogram 1774 | if self.training and mixup_lambda is not None: 1775 | x = do_mixup(x, mixup_lambda) 1776 | 1777 | x = self.conv_block1(x) 1778 | x = F.dropout(x, p=0.1, training=self.training) 1779 | x = self.conv_block2(x, pool_size=3) 1780 | x = F.dropout(x, p=0.1, training=self.training) 1781 | x = self.conv_block3(x, pool_size=3) 1782 | x = F.dropout(x, p=0.1, training=self.training) 1783 | x = self.conv_block4(x, pool_size=3) 1784 | x = F.dropout(x, p=0.1, training=self.training) 1785 | x = self.conv_block5(x, pool_size=3) 1786 | x = F.dropout(x, p=0.1, training=self.training) 1787 | x = self.conv_block6(x, pool_size=3) 1788 | x = F.dropout(x, p=0.1, training=self.training) 1789 | x = self.conv_block7(x, pool_size=3) 1790 | x = F.dropout(x, p=0.1, training=self.training) 1791 | x = self.conv_block8(x, pool_size=3) 1792 | x = F.dropout(x, p=0.1, training=self.training) 1793 | x = self.conv_block9(x, pool_size=1) 1794 | 1795 | (x1, _) = torch.max(x, dim=2) 1796 | x2 = torch.mean(x, dim=2) 1797 | x = x1 + x2 1798 | x = F.dropout(x, p=0.5, training=self.training) 1799 | x = F.relu_(self.fc1(x)) 1800 | embedding = F.dropout(x, p=0.5, training=self.training) 1801 | clipwise_output = torch.sigmoid(self.fc_audioset(x)) 1802 | 1803 | output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding} 1804 | 1805 | return output_dict 1806 | 1807 | 1808 | class DaiNetResBlock(nn.Module): 1809 | def __init__(self, in_channels, out_channels, kernel_size): 1810 | 1811 | super(DaiNetResBlock, self).__init__() 1812 | 1813 | self.conv1 = nn.Conv1d(in_channels=in_channels, 1814 | out_channels=out_channels, 1815 | kernel_size=kernel_size, stride=1, 1816 | padding=kernel_size // 2, bias=False) 1817 | 1818 | self.conv2 = nn.Conv1d(in_channels=out_channels, 1819 | out_channels=out_channels, 1820 | kernel_size=kernel_size, stride=1, 1821 | padding=kernel_size // 2, bias=False) 1822 | 1823 | self.conv3 = nn.Conv1d(in_channels=out_channels, 1824 | out_channels=out_channels, 1825 | kernel_size=kernel_size, stride=1, 1826 | padding=kernel_size // 2, bias=False) 1827 | 1828 | self.conv4 = nn.Conv1d(in_channels=out_channels, 1829 | out_channels=out_channels, 1830 | kernel_size=kernel_size, stride=1, 1831 | padding=kernel_size // 2, bias=False) 1832 | 1833 | self.downsample = nn.Conv1d(in_channels=in_channels, 1834 | out_channels=out_channels, 1835 | kernel_size=1, stride=1, 1836 | padding=0, bias=False) 1837 | 1838 | self.bn1 = nn.BatchNorm1d(out_channels) 1839 | self.bn2 = nn.BatchNorm1d(out_channels) 1840 | self.bn3 = nn.BatchNorm1d(out_channels) 1841 | self.bn4 = nn.BatchNorm1d(out_channels) 1842 | self.bn_downsample = nn.BatchNorm1d(out_channels) 1843 | 1844 | self.init_weight() 1845 | 1846 | def init_weight(self): 1847 | init_layer(self.conv1) 1848 | init_layer(self.conv2) 1849 | init_layer(self.conv3) 1850 | init_layer(self.conv4) 1851 | init_layer(self.downsample) 1852 | init_bn(self.bn1) 1853 | init_bn(self.bn2) 1854 | init_bn(self.bn3) 1855 | init_bn(self.bn4) 1856 | nn.init.constant_(self.bn4.weight, 0) 1857 | init_bn(self.bn_downsample) 1858 | 1859 | def forward(self, input, pool_size=1): 1860 | x = F.relu_(self.bn1(self.conv1(input))) 1861 | x = F.relu_(self.bn2(self.conv2(x))) 1862 | x = F.relu_(self.bn3(self.conv3(x))) 1863 | x = self.bn4(self.conv4(x)) 1864 | if input.shape == x.shape: 1865 | x = F.relu_(x + input) 1866 | else: 1867 | x = F.relu(x + self.bn_downsample(self.downsample(input))) 1868 | 1869 | if pool_size != 1: 1870 | x = F.max_pool1d(x, kernel_size=pool_size, padding=pool_size // 2) 1871 | return x 1872 | 1873 | 1874 | class DaiNet19(nn.Module): 1875 | def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, 1876 | fmax, classes_num): 1877 | 1878 | super(DaiNet19, self).__init__() 1879 | 1880 | window = 'hann' 1881 | center = True 1882 | pad_mode = 'reflect' 1883 | ref = 1.0 1884 | amin = 1e-10 1885 | top_db = None 1886 | 1887 | self.conv0 = nn.Conv1d(in_channels=1, out_channels=64, kernel_size=80, stride=4, padding=0, bias=False) 1888 | self.bn0 = nn.BatchNorm1d(64) 1889 | self.conv_block1 = DaiNetResBlock(64, 64, 3) 1890 | self.conv_block2 = DaiNetResBlock(64, 128, 3) 1891 | self.conv_block3 = DaiNetResBlock(128, 256, 3) 1892 | self.conv_block4 = DaiNetResBlock(256, 512, 3) 1893 | 1894 | self.fc1 = nn.Linear(512, 512, bias=True) 1895 | self.fc_audioset = nn.Linear(512, classes_num, bias=True) 1896 | 1897 | self.init_weight() 1898 | 1899 | def init_weight(self): 1900 | init_layer(self.conv0) 1901 | init_bn(self.bn0) 1902 | init_layer(self.fc1) 1903 | init_layer(self.fc_audioset) 1904 | 1905 | def forward(self, input, mixup_lambda=None): 1906 | """ 1907 | Input: (batch_size, data_length)""" 1908 | 1909 | x = input[:, None, :] 1910 | 1911 | # Mixup on spectrogram 1912 | if self.training and mixup_lambda is not None: 1913 | x = do_mixup(x, mixup_lambda) 1914 | 1915 | x = self.bn0(self.conv0(x)) 1916 | x = self.conv_block1(x) 1917 | x = F.max_pool1d(x, kernel_size=4) 1918 | x = self.conv_block2(x) 1919 | x = F.max_pool1d(x, kernel_size=4) 1920 | x = self.conv_block3(x) 1921 | x = F.max_pool1d(x, kernel_size=4) 1922 | x = self.conv_block4(x) 1923 | 1924 | (x1, _) = torch.max(x, dim=2) 1925 | x2 = torch.mean(x, dim=2) 1926 | x = x1 + x2 1927 | x = F.dropout(x, p=0.5, training=self.training) 1928 | x = F.relu_(self.fc1(x)) 1929 | embedding = F.dropout(x, p=0.5, training=self.training) 1930 | clipwise_output = torch.sigmoid(self.fc_audioset(x)) 1931 | 1932 | output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding} 1933 | 1934 | return output_dict 1935 | 1936 | 1937 | def _resnet_conv3x1_wav1d(in_planes, out_planes, dilation): 1938 | #3x3 convolution with padding 1939 | return nn.Conv1d(in_planes, out_planes, kernel_size=3, stride=1, 1940 | padding=dilation, groups=1, bias=False, dilation=dilation) 1941 | 1942 | 1943 | def _resnet_conv1x1_wav1d(in_planes, out_planes): 1944 | #1x1 convolution 1945 | return nn.Conv1d(in_planes, out_planes, kernel_size=1, stride=1, bias=False) 1946 | 1947 | 1948 | class _ResnetBasicBlockWav1d(nn.Module): 1949 | expansion = 1 1950 | 1951 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 1952 | base_width=64, dilation=1, norm_layer=None): 1953 | super(_ResnetBasicBlockWav1d, self).__init__() 1954 | if norm_layer is None: 1955 | norm_layer = nn.BatchNorm1d 1956 | if groups != 1 or base_width != 64: 1957 | raise ValueError('_ResnetBasicBlock only supports groups=1 and base_width=64') 1958 | if dilation > 1: 1959 | raise NotImplementedError("Dilation > 1 not supported in _ResnetBasicBlock") 1960 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 1961 | 1962 | self.stride = stride 1963 | 1964 | self.conv1 = _resnet_conv3x1_wav1d(inplanes, planes, dilation=1) 1965 | self.bn1 = norm_layer(planes) 1966 | self.relu = nn.ReLU(inplace=True) 1967 | self.conv2 = _resnet_conv3x1_wav1d(planes, planes, dilation=2) 1968 | self.bn2 = norm_layer(planes) 1969 | self.downsample = downsample 1970 | self.stride = stride 1971 | 1972 | self.init_weights() 1973 | 1974 | def init_weights(self): 1975 | init_layer(self.conv1) 1976 | init_bn(self.bn1) 1977 | init_layer(self.conv2) 1978 | init_bn(self.bn2) 1979 | nn.init.constant_(self.bn2.weight, 0) 1980 | 1981 | def forward(self, x): 1982 | identity = x 1983 | 1984 | if self.stride != 1: 1985 | out = F.max_pool1d(x, kernel_size=self.stride) 1986 | else: 1987 | out = x 1988 | 1989 | out = self.conv1(out) 1990 | out = self.bn1(out) 1991 | out = self.relu(out) 1992 | out = F.dropout(out, p=0.1, training=self.training) 1993 | 1994 | out = self.conv2(out) 1995 | out = self.bn2(out) 1996 | 1997 | if self.downsample is not None: 1998 | identity = self.downsample(identity) 1999 | 2000 | out += identity 2001 | out = self.relu(out) 2002 | 2003 | return out 2004 | 2005 | 2006 | class _ResNetWav1d(nn.Module): 2007 | def __init__(self, block, layers, zero_init_residual=False, 2008 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 2009 | norm_layer=None): 2010 | super(_ResNetWav1d, self).__init__() 2011 | 2012 | if norm_layer is None: 2013 | norm_layer = nn.BatchNorm1d 2014 | self._norm_layer = norm_layer 2015 | 2016 | self.inplanes = 64 2017 | self.dilation = 1 2018 | if replace_stride_with_dilation is None: 2019 | # each element in the tuple indicates if we should replace 2020 | # the 2x2 stride with a dilated convolution instead 2021 | replace_stride_with_dilation = [False, False, False] 2022 | if len(replace_stride_with_dilation) != 3: 2023 | raise ValueError("replace_stride_with_dilation should be None " 2024 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 2025 | self.groups = groups 2026 | self.base_width = width_per_group 2027 | 2028 | self.layer1 = self._make_layer(block, 64, layers[0], stride=1) 2029 | self.layer2 = self._make_layer(block, 128, layers[1], stride=4) 2030 | self.layer3 = self._make_layer(block, 256, layers[2], stride=4) 2031 | self.layer4 = self._make_layer(block, 512, layers[3], stride=4) 2032 | self.layer5 = self._make_layer(block, 1024, layers[4], stride=4) 2033 | self.layer6 = self._make_layer(block, 1024, layers[5], stride=4) 2034 | self.layer7 = self._make_layer(block, 2048, layers[6], stride=4) 2035 | 2036 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 2037 | norm_layer = self._norm_layer 2038 | downsample = None 2039 | previous_dilation = self.dilation 2040 | if dilate: 2041 | self.dilation *= stride 2042 | stride = 1 2043 | if stride != 1 or self.inplanes != planes * block.expansion: 2044 | if stride == 1: 2045 | downsample = nn.Sequential( 2046 | _resnet_conv1x1_wav1d(self.inplanes, planes * block.expansion), 2047 | norm_layer(planes * block.expansion), 2048 | ) 2049 | init_layer(downsample[0]) 2050 | init_bn(downsample[1]) 2051 | else: 2052 | downsample = nn.Sequential( 2053 | nn.AvgPool1d(kernel_size=stride), 2054 | _resnet_conv1x1_wav1d(self.inplanes, planes * block.expansion), 2055 | norm_layer(planes * block.expansion), 2056 | ) 2057 | init_layer(downsample[1]) 2058 | init_bn(downsample[2]) 2059 | 2060 | layers = [] 2061 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 2062 | self.base_width, previous_dilation, norm_layer)) 2063 | self.inplanes = planes * block.expansion 2064 | for _ in range(1, blocks): 2065 | layers.append(block(self.inplanes, planes, groups=self.groups, 2066 | base_width=self.base_width, dilation=self.dilation, 2067 | norm_layer=norm_layer)) 2068 | 2069 | return nn.Sequential(*layers) 2070 | 2071 | def forward(self, x): 2072 | 2073 | x = self.layer1(x) 2074 | x = self.layer2(x) 2075 | x = self.layer3(x) 2076 | x = self.layer4(x) 2077 | x = self.layer5(x) 2078 | x = self.layer6(x) 2079 | x = self.layer7(x) 2080 | 2081 | return x 2082 | 2083 | 2084 | class Res1dNet31(nn.Module): 2085 | def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, 2086 | fmax, classes_num): 2087 | 2088 | super(Res1dNet31, self).__init__() 2089 | 2090 | window = 'hann' 2091 | center = True 2092 | pad_mode = 'reflect' 2093 | ref = 1.0 2094 | amin = 1e-10 2095 | top_db = None 2096 | 2097 | self.conv0 = nn.Conv1d(in_channels=1, out_channels=64, kernel_size=11, stride=5, padding=5, bias=False) 2098 | self.bn0 = nn.BatchNorm1d(64) 2099 | 2100 | self.resnet = _ResNetWav1d(_ResnetBasicBlockWav1d, [2, 2, 2, 2, 2, 2, 2]) 2101 | 2102 | self.fc1 = nn.Linear(2048, 2048, bias=True) 2103 | self.fc_audioset = nn.Linear(2048, classes_num, bias=True) 2104 | 2105 | self.init_weight() 2106 | 2107 | def init_weight(self): 2108 | init_layer(self.conv0) 2109 | init_bn(self.bn0) 2110 | init_layer(self.fc1) 2111 | init_layer(self.fc_audioset) 2112 | 2113 | def forward(self, input, mixup_lambda=None): 2114 | """ 2115 | Input: (batch_size, data_length)""" 2116 | 2117 | x = input[:, None, :] 2118 | 2119 | # Mixup on spectrogram 2120 | if self.training and mixup_lambda is not None: 2121 | x = do_mixup(x, mixup_lambda) 2122 | 2123 | x = self.bn0(self.conv0(x)) 2124 | x = self.resnet(x) 2125 | 2126 | (x1, _) = torch.max(x, dim=2) 2127 | x2 = torch.mean(x, dim=2) 2128 | x = x1 + x2 2129 | x = F.dropout(x, p=0.5, training=self.training) 2130 | x = F.relu_(self.fc1(x)) 2131 | embedding = F.dropout(x, p=0.5, training=self.training) 2132 | clipwise_output = torch.sigmoid(self.fc_audioset(x)) 2133 | 2134 | output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding} 2135 | 2136 | return output_dict 2137 | 2138 | 2139 | class Res1dNet51(nn.Module): 2140 | def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, 2141 | fmax, classes_num): 2142 | 2143 | super(Res1dNet51, self).__init__() 2144 | 2145 | window = 'hann' 2146 | center = True 2147 | pad_mode = 'reflect' 2148 | ref = 1.0 2149 | amin = 1e-10 2150 | top_db = None 2151 | 2152 | self.conv0 = nn.Conv1d(in_channels=1, out_channels=64, kernel_size=11, stride=5, padding=5, bias=False) 2153 | self.bn0 = nn.BatchNorm1d(64) 2154 | 2155 | self.resnet = _ResNetWav1d(_ResnetBasicBlockWav1d, [2, 3, 4, 6, 4, 3, 2]) 2156 | 2157 | self.fc1 = nn.Linear(2048, 2048, bias=True) 2158 | self.fc_audioset = nn.Linear(2048, classes_num, bias=True) 2159 | 2160 | self.init_weight() 2161 | 2162 | def init_weight(self): 2163 | init_layer(self.conv0) 2164 | init_bn(self.bn0) 2165 | init_layer(self.fc1) 2166 | init_layer(self.fc_audioset) 2167 | 2168 | def forward(self, input, mixup_lambda=None): 2169 | """ 2170 | Input: (batch_size, data_length)""" 2171 | 2172 | x = input[:, None, :] 2173 | 2174 | # Mixup on spectrogram 2175 | if self.training and mixup_lambda is not None: 2176 | x = do_mixup(x, mixup_lambda) 2177 | 2178 | x = self.bn0(self.conv0(x)) 2179 | x = self.resnet(x) 2180 | 2181 | (x1, _) = torch.max(x, dim=2) 2182 | x2 = torch.mean(x, dim=2) 2183 | x = x1 + x2 2184 | x = F.dropout(x, p=0.5, training=self.training) 2185 | x = F.relu_(self.fc1(x)) 2186 | embedding = F.dropout(x, p=0.5, training=self.training) 2187 | clipwise_output = torch.sigmoid(self.fc_audioset(x)) 2188 | 2189 | output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding} 2190 | 2191 | return output_dict 2192 | 2193 | 2194 | class ConvPreWavBlock(nn.Module): 2195 | def __init__(self, in_channels, out_channels): 2196 | 2197 | super(ConvPreWavBlock, self).__init__() 2198 | 2199 | self.conv1 = nn.Conv1d(in_channels=in_channels, 2200 | out_channels=out_channels, 2201 | kernel_size=3, stride=1, 2202 | padding=1, bias=False) 2203 | 2204 | self.conv2 = nn.Conv1d(in_channels=out_channels, 2205 | out_channels=out_channels, 2206 | kernel_size=3, stride=1, dilation=2, 2207 | padding=2, bias=False) 2208 | 2209 | self.bn1 = nn.BatchNorm1d(out_channels) 2210 | self.bn2 = nn.BatchNorm1d(out_channels) 2211 | 2212 | self.init_weight() 2213 | 2214 | def init_weight(self): 2215 | init_layer(self.conv1) 2216 | init_layer(self.conv2) 2217 | init_bn(self.bn1) 2218 | init_bn(self.bn2) 2219 | 2220 | 2221 | def forward(self, input, pool_size): 2222 | 2223 | x = input 2224 | x = F.relu_(self.bn1(self.conv1(x))) 2225 | x = F.relu_(self.bn2(self.conv2(x))) 2226 | x = F.max_pool1d(x, kernel_size=pool_size) 2227 | 2228 | return x 2229 | 2230 | 2231 | class Wavegram_Cnn14(nn.Module): 2232 | def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, 2233 | fmax, classes_num): 2234 | 2235 | super(Wavegram_Cnn14, self).__init__() 2236 | 2237 | window = 'hann' 2238 | center = True 2239 | pad_mode = 'reflect' 2240 | ref = 1.0 2241 | amin = 1e-10 2242 | top_db = None 2243 | 2244 | self.pre_conv0 = nn.Conv1d(in_channels=1, out_channels=64, kernel_size=11, stride=5, padding=5, bias=False) 2245 | self.pre_bn0 = nn.BatchNorm1d(64) 2246 | self.pre_block1 = ConvPreWavBlock(64, 64) 2247 | self.pre_block2 = ConvPreWavBlock(64, 128) 2248 | self.pre_block3 = ConvPreWavBlock(128, 128) 2249 | self.pre_block4 = ConvBlock(in_channels=4, out_channels=64) 2250 | 2251 | # Spec augmenter 2252 | self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, 2253 | freq_drop_width=8, freq_stripes_num=2) 2254 | 2255 | self.bn0 = nn.BatchNorm2d(64) 2256 | 2257 | self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) 2258 | self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) 2259 | self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) 2260 | self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) 2261 | self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) 2262 | self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) 2263 | 2264 | self.fc1 = nn.Linear(2048, 2048, bias=True) 2265 | self.fc_audioset = nn.Linear(2048, classes_num, bias=True) 2266 | 2267 | self.init_weight() 2268 | 2269 | def init_weight(self): 2270 | init_layer(self.pre_conv0) 2271 | init_bn(self.pre_bn0) 2272 | init_bn(self.bn0) 2273 | init_layer(self.fc1) 2274 | init_layer(self.fc_audioset) 2275 | 2276 | def forward(self, input, mixup_lambda=None): 2277 | """ 2278 | Input: (batch_size, data_length)""" 2279 | 2280 | # Wavegram 2281 | a1 = F.relu_(self.pre_bn0(self.pre_conv0(input[:, None, :]))) 2282 | a1 = self.pre_block1(a1, pool_size=4) 2283 | a1 = self.pre_block2(a1, pool_size=4) 2284 | a1 = self.pre_block3(a1, pool_size=4) 2285 | a1 = a1.reshape((a1.shape[0], -1, 32, a1.shape[-1])).transpose(2, 3) 2286 | a1 = self.pre_block4(a1, pool_size=(2, 1)) 2287 | 2288 | # Mixup on spectrogram 2289 | if self.training and mixup_lambda is not None: 2290 | a1 = do_mixup(a1, mixup_lambda) 2291 | 2292 | x = a1 2293 | x = F.dropout(x, p=0.2, training=self.training) 2294 | x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') 2295 | x = F.dropout(x, p=0.2, training=self.training) 2296 | x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') 2297 | x = F.dropout(x, p=0.2, training=self.training) 2298 | x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') 2299 | x = F.dropout(x, p=0.2, training=self.training) 2300 | x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg') 2301 | x = F.dropout(x, p=0.2, training=self.training) 2302 | x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg') 2303 | x = F.dropout(x, p=0.2, training=self.training) 2304 | x = torch.mean(x, dim=3) 2305 | 2306 | (x1, _) = torch.max(x, dim=2) 2307 | x2 = torch.mean(x, dim=2) 2308 | x = x1 + x2 2309 | x = F.dropout(x, p=0.5, training=self.training) 2310 | x = F.relu_(self.fc1(x)) 2311 | embedding = F.dropout(x, p=0.5, training=self.training) 2312 | clipwise_output = torch.sigmoid(self.fc_audioset(x)) 2313 | 2314 | output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding} 2315 | 2316 | return output_dict 2317 | 2318 | 2319 | class Wavegram_Logmel_Cnn14(nn.Module): 2320 | def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, 2321 | fmax, classes_num): 2322 | 2323 | super(Wavegram_Logmel_Cnn14, self).__init__() 2324 | 2325 | window = 'hann' 2326 | center = True 2327 | pad_mode = 'reflect' 2328 | ref = 1.0 2329 | amin = 1e-10 2330 | top_db = None 2331 | 2332 | self.pre_conv0 = nn.Conv1d(in_channels=1, out_channels=64, kernel_size=11, stride=5, padding=5, bias=False) 2333 | self.pre_bn0 = nn.BatchNorm1d(64) 2334 | self.pre_block1 = ConvPreWavBlock(64, 64) 2335 | self.pre_block2 = ConvPreWavBlock(64, 128) 2336 | self.pre_block3 = ConvPreWavBlock(128, 128) 2337 | self.pre_block4 = ConvBlock(in_channels=4, out_channels=64) 2338 | 2339 | # Spectrogram extractor 2340 | self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, 2341 | win_length=window_size, window=window, center=center, pad_mode=pad_mode, 2342 | freeze_parameters=True) 2343 | 2344 | # Logmel feature extractor 2345 | self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, 2346 | n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, 2347 | freeze_parameters=True) 2348 | 2349 | # Spec augmenter 2350 | self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, 2351 | freq_drop_width=8, freq_stripes_num=2) 2352 | 2353 | self.bn0 = nn.BatchNorm2d(64) 2354 | 2355 | self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) 2356 | self.conv_block2 = ConvBlock(in_channels=128, out_channels=128) 2357 | self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) 2358 | self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) 2359 | self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) 2360 | self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) 2361 | 2362 | self.fc1 = nn.Linear(2048, 2048, bias=True) 2363 | self.fc_audioset = nn.Linear(2048, classes_num, bias=True) 2364 | 2365 | self.init_weight() 2366 | 2367 | def init_weight(self): 2368 | init_layer(self.pre_conv0) 2369 | init_bn(self.pre_bn0) 2370 | init_bn(self.bn0) 2371 | init_layer(self.fc1) 2372 | init_layer(self.fc_audioset) 2373 | 2374 | def forward(self, input, mixup_lambda=None): 2375 | """ 2376 | Input: (batch_size, data_length)""" 2377 | 2378 | # Wavegram 2379 | a1 = F.relu_(self.pre_bn0(self.pre_conv0(input[:, None, :]))) 2380 | a1 = self.pre_block1(a1, pool_size=4) 2381 | a1 = self.pre_block2(a1, pool_size=4) 2382 | a1 = self.pre_block3(a1, pool_size=4) 2383 | a1 = a1.reshape((a1.shape[0], -1, 32, a1.shape[-1])).transpose(2, 3) 2384 | a1 = self.pre_block4(a1, pool_size=(2, 1)) 2385 | 2386 | # Log mel spectrogram 2387 | x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) 2388 | x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) 2389 | 2390 | x = x.transpose(1, 3) 2391 | x = self.bn0(x) 2392 | x = x.transpose(1, 3) 2393 | 2394 | if self.training: 2395 | x = self.spec_augmenter(x) 2396 | 2397 | # Mixup on spectrogram 2398 | if self.training and mixup_lambda is not None: 2399 | x = do_mixup(x, mixup_lambda) 2400 | a1 = do_mixup(a1, mixup_lambda) 2401 | 2402 | x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') 2403 | 2404 | # Concatenate Wavegram and Log mel spectrogram along the channel dimension 2405 | x = torch.cat((x, a1), dim=1) 2406 | 2407 | x = F.dropout(x, p=0.2, training=self.training) 2408 | x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') 2409 | x = F.dropout(x, p=0.2, training=self.training) 2410 | x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') 2411 | x = F.dropout(x, p=0.2, training=self.training) 2412 | x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') 2413 | x = F.dropout(x, p=0.2, training=self.training) 2414 | x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg') 2415 | x = F.dropout(x, p=0.2, training=self.training) 2416 | x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg') 2417 | x = F.dropout(x, p=0.2, training=self.training) 2418 | x = torch.mean(x, dim=3) 2419 | 2420 | (x1, _) = torch.max(x, dim=2) 2421 | x2 = torch.mean(x, dim=2) 2422 | x = x1 + x2 2423 | x = F.dropout(x, p=0.5, training=self.training) 2424 | x = F.relu_(self.fc1(x)) 2425 | embedding = F.dropout(x, p=0.5, training=self.training) 2426 | clipwise_output = torch.sigmoid(self.fc_audioset(x)) 2427 | 2428 | output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding} 2429 | 2430 | return output_dict 2431 | 2432 | 2433 | class Wavegram_Logmel128_Cnn14(nn.Module): 2434 | def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, 2435 | fmax, classes_num): 2436 | 2437 | super(Wavegram_Logmel128_Cnn14, self).__init__() 2438 | 2439 | window = 'hann' 2440 | center = True 2441 | pad_mode = 'reflect' 2442 | ref = 1.0 2443 | amin = 1e-10 2444 | top_db = None 2445 | 2446 | self.pre_conv0 = nn.Conv1d(in_channels=1, out_channels=64, kernel_size=11, stride=5, padding=5, bias=False) 2447 | self.pre_bn0 = nn.BatchNorm1d(64) 2448 | self.pre_block1 = ConvPreWavBlock(64, 64) 2449 | self.pre_block2 = ConvPreWavBlock(64, 128) 2450 | self.pre_block3 = ConvPreWavBlock(128, 256) 2451 | self.pre_block4 = ConvBlock(in_channels=4, out_channels=64) 2452 | 2453 | # Spectrogram extractor 2454 | self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, 2455 | win_length=window_size, window=window, center=center, pad_mode=pad_mode, 2456 | freeze_parameters=True) 2457 | 2458 | # Logmel feature extractor 2459 | self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, 2460 | n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, 2461 | freeze_parameters=True) 2462 | 2463 | # Spec augmenter 2464 | self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, 2465 | freq_drop_width=16, freq_stripes_num=2) 2466 | 2467 | self.bn0 = nn.BatchNorm2d(128) 2468 | 2469 | self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) 2470 | self.conv_block2 = ConvBlock(in_channels=128, out_channels=128) 2471 | self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) 2472 | self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) 2473 | self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) 2474 | self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) 2475 | 2476 | self.fc1 = nn.Linear(2048, 2048, bias=True) 2477 | self.fc_audioset = nn.Linear(2048, classes_num, bias=True) 2478 | 2479 | self.init_weight() 2480 | 2481 | def init_weight(self): 2482 | init_layer(self.pre_conv0) 2483 | init_bn(self.pre_bn0) 2484 | init_bn(self.bn0) 2485 | init_layer(self.fc1) 2486 | init_layer(self.fc_audioset) 2487 | 2488 | def forward(self, input, mixup_lambda=None): 2489 | """ 2490 | Input: (batch_size, data_length)""" 2491 | 2492 | # Wavegram 2493 | a1 = F.relu_(self.pre_bn0(self.pre_conv0(input[:, None, :]))) 2494 | a1 = self.pre_block1(a1, pool_size=4) 2495 | a1 = self.pre_block2(a1, pool_size=4) 2496 | a1 = self.pre_block3(a1, pool_size=4) 2497 | a1 = a1.reshape((a1.shape[0], -1, 64, a1.shape[-1])).transpose(2, 3) 2498 | a1 = self.pre_block4(a1, pool_size=(2, 1)) 2499 | 2500 | # Log mel spectrogram 2501 | x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) 2502 | x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) 2503 | 2504 | x = x.transpose(1, 3) 2505 | x = self.bn0(x) 2506 | x = x.transpose(1, 3) 2507 | 2508 | if self.training: 2509 | x = self.spec_augmenter(x) 2510 | 2511 | # Mixup on spectrogram 2512 | if self.training and mixup_lambda is not None: 2513 | x = do_mixup(x, mixup_lambda) 2514 | a1 = do_mixup(a1, mixup_lambda) 2515 | 2516 | x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') 2517 | 2518 | # Concatenate Wavegram and Log mel spectrogram along the channel dimension 2519 | x = torch.cat((x, a1), dim=1) 2520 | 2521 | x = F.dropout(x, p=0.2, training=self.training) 2522 | x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') 2523 | x = F.dropout(x, p=0.2, training=self.training) 2524 | x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') 2525 | x = F.dropout(x, p=0.2, training=self.training) 2526 | x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') 2527 | x = F.dropout(x, p=0.2, training=self.training) 2528 | x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg') 2529 | x = F.dropout(x, p=0.2, training=self.training) 2530 | x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg') 2531 | x = F.dropout(x, p=0.2, training=self.training) 2532 | x = torch.mean(x, dim=3) 2533 | 2534 | (x1, _) = torch.max(x, dim=2) 2535 | x2 = torch.mean(x, dim=2) 2536 | x = x1 + x2 2537 | x = F.dropout(x, p=0.5, training=self.training) 2538 | x = F.relu_(self.fc1(x)) 2539 | embedding = F.dropout(x, p=0.5, training=self.training) 2540 | clipwise_output = torch.sigmoid(self.fc_audioset(x)) 2541 | 2542 | output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding} 2543 | 2544 | return output_dict 2545 | 2546 | 2547 | class Cnn14_16k(nn.Module): 2548 | def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num): 2549 | 2550 | super(Cnn14_16k, self).__init__() 2551 | 2552 | assert sample_rate == 16000 2553 | assert window_size == 512 2554 | assert hop_size == 160 2555 | assert mel_bins == 64 2556 | assert fmin == 50 2557 | assert fmax == 8000 2558 | 2559 | window = 'hann' 2560 | center = True 2561 | pad_mode = 'reflect' 2562 | ref = 1.0 2563 | amin = 1e-10 2564 | top_db = None 2565 | 2566 | # Spectrogram extractor 2567 | self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, 2568 | win_length=window_size, window=window, center=center, pad_mode=pad_mode, 2569 | freeze_parameters=True) 2570 | 2571 | # Logmel feature extractor 2572 | self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, 2573 | n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, 2574 | freeze_parameters=True) 2575 | 2576 | # Spec augmenter 2577 | self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, 2578 | freq_drop_width=8, freq_stripes_num=2) 2579 | 2580 | self.bn0 = nn.BatchNorm2d(64) 2581 | 2582 | self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) 2583 | self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) 2584 | self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) 2585 | self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) 2586 | self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) 2587 | self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) 2588 | 2589 | self.fc1 = nn.Linear(2048, 2048, bias=True) 2590 | self.fc_audioset = nn.Linear(2048, classes_num, bias=True) 2591 | 2592 | self.init_weight() 2593 | 2594 | def init_weight(self): 2595 | init_bn(self.bn0) 2596 | init_layer(self.fc1) 2597 | init_layer(self.fc_audioset) 2598 | 2599 | def forward(self, input, mixup_lambda=None): 2600 | """ 2601 | Input: (batch_size, data_length)""" 2602 | 2603 | x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) 2604 | x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) 2605 | 2606 | x = x.transpose(1, 3) 2607 | x = self.bn0(x) 2608 | x = x.transpose(1, 3) 2609 | 2610 | if self.training: 2611 | x = self.spec_augmenter(x) 2612 | 2613 | # Mixup on spectrogram 2614 | if self.training and mixup_lambda is not None: 2615 | x = do_mixup(x, mixup_lambda) 2616 | 2617 | x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') 2618 | x = F.dropout(x, p=0.2, training=self.training) 2619 | x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') 2620 | x = F.dropout(x, p=0.2, training=self.training) 2621 | x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') 2622 | x = F.dropout(x, p=0.2, training=self.training) 2623 | x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') 2624 | x = F.dropout(x, p=0.2, training=self.training) 2625 | x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg') 2626 | x = F.dropout(x, p=0.2, training=self.training) 2627 | x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg') 2628 | x = F.dropout(x, p=0.2, training=self.training) 2629 | x = torch.mean(x, dim=3) 2630 | 2631 | (x1, _) = torch.max(x, dim=2) 2632 | x2 = torch.mean(x, dim=2) 2633 | x = x1 + x2 2634 | x = F.dropout(x, p=0.5, training=self.training) 2635 | x = F.relu_(self.fc1(x)) 2636 | embedding = F.dropout(x, p=0.5, training=self.training) 2637 | clipwise_output = torch.sigmoid(self.fc_audioset(x)) 2638 | 2639 | output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding} 2640 | 2641 | return output_dict 2642 | 2643 | 2644 | class Cnn14_8k(nn.Module): 2645 | def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num): 2646 | 2647 | super(Cnn14_8k, self).__init__() 2648 | 2649 | assert sample_rate == 8000 2650 | assert window_size == 256 2651 | assert hop_size == 80 2652 | assert mel_bins == 64 2653 | assert fmin == 50 2654 | assert fmax == 4000 2655 | 2656 | window = 'hann' 2657 | center = True 2658 | pad_mode = 'reflect' 2659 | ref = 1.0 2660 | amin = 1e-10 2661 | top_db = None 2662 | 2663 | # Spectrogram extractor 2664 | self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, 2665 | win_length=window_size, window=window, center=center, pad_mode=pad_mode, 2666 | freeze_parameters=True) 2667 | 2668 | # Logmel feature extractor 2669 | self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, 2670 | n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, 2671 | freeze_parameters=True) 2672 | 2673 | # Spec augmenter 2674 | self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, 2675 | freq_drop_width=8, freq_stripes_num=2) 2676 | 2677 | self.bn0 = nn.BatchNorm2d(64) 2678 | 2679 | self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) 2680 | self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) 2681 | self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) 2682 | self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) 2683 | self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) 2684 | self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) 2685 | 2686 | self.fc1 = nn.Linear(2048, 2048, bias=True) 2687 | self.fc_audioset = nn.Linear(2048, classes_num, bias=True) 2688 | 2689 | self.init_weight() 2690 | 2691 | def init_weight(self): 2692 | init_bn(self.bn0) 2693 | init_layer(self.fc1) 2694 | init_layer(self.fc_audioset) 2695 | 2696 | def forward(self, input, mixup_lambda=None): 2697 | """ 2698 | Input: (batch_size, data_length)""" 2699 | 2700 | x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) 2701 | x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) 2702 | 2703 | x = x.transpose(1, 3) 2704 | x = self.bn0(x) 2705 | x = x.transpose(1, 3) 2706 | 2707 | if self.training: 2708 | x = self.spec_augmenter(x) 2709 | 2710 | # Mixup on spectrogram 2711 | if self.training and mixup_lambda is not None: 2712 | x = do_mixup(x, mixup_lambda) 2713 | 2714 | x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') 2715 | x = F.dropout(x, p=0.2, training=self.training) 2716 | x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') 2717 | x = F.dropout(x, p=0.2, training=self.training) 2718 | x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') 2719 | x = F.dropout(x, p=0.2, training=self.training) 2720 | x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') 2721 | x = F.dropout(x, p=0.2, training=self.training) 2722 | x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg') 2723 | x = F.dropout(x, p=0.2, training=self.training) 2724 | x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg') 2725 | x = F.dropout(x, p=0.2, training=self.training) 2726 | x = torch.mean(x, dim=3) 2727 | 2728 | (x1, _) = torch.max(x, dim=2) 2729 | x2 = torch.mean(x, dim=2) 2730 | x = x1 + x2 2731 | x = F.dropout(x, p=0.5, training=self.training) 2732 | x = F.relu_(self.fc1(x)) 2733 | embedding = F.dropout(x, p=0.5, training=self.training) 2734 | clipwise_output = torch.sigmoid(self.fc_audioset(x)) 2735 | 2736 | output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding} 2737 | 2738 | return output_dict 2739 | 2740 | 2741 | class Cnn14_mixup_time_domain(nn.Module): 2742 | def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, 2743 | fmax, classes_num): 2744 | 2745 | super(Cnn14_mixup_time_domain, self).__init__() 2746 | 2747 | window = 'hann' 2748 | center = True 2749 | pad_mode = 'reflect' 2750 | ref = 1.0 2751 | amin = 1e-10 2752 | top_db = None 2753 | 2754 | # Spectrogram extractor 2755 | self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, 2756 | win_length=window_size, window=window, center=center, pad_mode=pad_mode, 2757 | freeze_parameters=True) 2758 | 2759 | # Logmel feature extractor 2760 | self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, 2761 | n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, 2762 | freeze_parameters=True) 2763 | 2764 | # Spec augmenter 2765 | self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, 2766 | freq_drop_width=8, freq_stripes_num=2) 2767 | 2768 | self.bn0 = nn.BatchNorm2d(64) 2769 | 2770 | self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) 2771 | self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) 2772 | self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) 2773 | self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) 2774 | self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) 2775 | self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) 2776 | 2777 | self.fc1 = nn.Linear(2048, 2048, bias=True) 2778 | self.fc_audioset = nn.Linear(2048, classes_num, bias=True) 2779 | 2780 | self.init_weight() 2781 | 2782 | def init_weight(self): 2783 | init_bn(self.bn0) 2784 | init_layer(self.fc1) 2785 | init_layer(self.fc_audioset) 2786 | 2787 | def forward(self, input, mixup_lambda=None): 2788 | """ 2789 | Input: (batch_size, data_length)""" 2790 | 2791 | x = input 2792 | 2793 | # Mixup in time domain 2794 | if self.training and mixup_lambda is not None: 2795 | x = do_mixup(x, mixup_lambda) 2796 | 2797 | x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins) 2798 | x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) 2799 | 2800 | x = x.transpose(1, 3) 2801 | x = self.bn0(x) 2802 | x = x.transpose(1, 3) 2803 | 2804 | if self.training: 2805 | x = self.spec_augmenter(x) 2806 | 2807 | x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') 2808 | x = F.dropout(x, p=0.2, training=self.training) 2809 | x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') 2810 | x = F.dropout(x, p=0.2, training=self.training) 2811 | x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') 2812 | x = F.dropout(x, p=0.2, training=self.training) 2813 | x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') 2814 | x = F.dropout(x, p=0.2, training=self.training) 2815 | x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg') 2816 | x = F.dropout(x, p=0.2, training=self.training) 2817 | x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg') 2818 | x = F.dropout(x, p=0.2, training=self.training) 2819 | x = torch.mean(x, dim=3) 2820 | 2821 | (x1, _) = torch.max(x, dim=2) 2822 | x2 = torch.mean(x, dim=2) 2823 | x = x1 + x2 2824 | x = F.dropout(x, p=0.5, training=self.training) 2825 | x = F.relu_(self.fc1(x)) 2826 | embedding = F.dropout(x, p=0.5, training=self.training) 2827 | clipwise_output = torch.sigmoid(self.fc_audioset(x)) 2828 | 2829 | output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding} 2830 | 2831 | return output_dict 2832 | 2833 | 2834 | class Cnn14_mel32(nn.Module): 2835 | def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, 2836 | fmax, classes_num): 2837 | 2838 | super(Cnn14_mel32, self).__init__() 2839 | 2840 | window = 'hann' 2841 | center = True 2842 | pad_mode = 'reflect' 2843 | ref = 1.0 2844 | amin = 1e-10 2845 | top_db = None 2846 | 2847 | # Spectrogram extractor 2848 | self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, 2849 | win_length=window_size, window=window, center=center, pad_mode=pad_mode, 2850 | freeze_parameters=True) 2851 | 2852 | # Logmel feature extractor 2853 | self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, 2854 | n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, 2855 | freeze_parameters=True) 2856 | 2857 | # Spec augmenter 2858 | self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, 2859 | freq_drop_width=4, freq_stripes_num=2) 2860 | 2861 | self.bn0 = nn.BatchNorm2d(32) 2862 | 2863 | self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) 2864 | self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) 2865 | self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) 2866 | self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) 2867 | self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) 2868 | self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) 2869 | 2870 | self.fc1 = nn.Linear(2048, 2048, bias=True) 2871 | self.fc_audioset = nn.Linear(2048, classes_num, bias=True) 2872 | 2873 | self.init_weight() 2874 | 2875 | def init_weight(self): 2876 | init_bn(self.bn0) 2877 | init_layer(self.fc1) 2878 | init_layer(self.fc_audioset) 2879 | 2880 | def forward(self, input, mixup_lambda=None): 2881 | """ 2882 | Input: (batch_size, data_length)""" 2883 | 2884 | x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) 2885 | x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) 2886 | 2887 | x = x.transpose(1, 3) 2888 | x = self.bn0(x) 2889 | x = x.transpose(1, 3) 2890 | 2891 | if self.training: 2892 | x = self.spec_augmenter(x) 2893 | 2894 | # Mixup on spectrogram 2895 | if self.training and mixup_lambda is not None: 2896 | x = do_mixup(x, mixup_lambda) 2897 | 2898 | x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') 2899 | x = F.dropout(x, p=0.2, training=self.training) 2900 | x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') 2901 | x = F.dropout(x, p=0.2, training=self.training) 2902 | x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') 2903 | x = F.dropout(x, p=0.2, training=self.training) 2904 | x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') 2905 | x = F.dropout(x, p=0.2, training=self.training) 2906 | x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg') 2907 | x = F.dropout(x, p=0.2, training=self.training) 2908 | x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg') 2909 | x = F.dropout(x, p=0.2, training=self.training) 2910 | x = torch.mean(x, dim=3) 2911 | 2912 | (x1, _) = torch.max(x, dim=2) 2913 | x2 = torch.mean(x, dim=2) 2914 | x = x1 + x2 2915 | x = F.dropout(x, p=0.5, training=self.training) 2916 | x = F.relu_(self.fc1(x)) 2917 | embedding = F.dropout(x, p=0.5, training=self.training) 2918 | clipwise_output = torch.sigmoid(self.fc_audioset(x)) 2919 | 2920 | output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding} 2921 | 2922 | return output_dict 2923 | 2924 | 2925 | class Cnn14_mel128(nn.Module): 2926 | def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, 2927 | fmax, classes_num): 2928 | 2929 | super(Cnn14_mel128, self).__init__() 2930 | 2931 | window = 'hann' 2932 | center = True 2933 | pad_mode = 'reflect' 2934 | ref = 1.0 2935 | amin = 1e-10 2936 | top_db = None 2937 | 2938 | # Spectrogram extractor 2939 | self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, 2940 | win_length=window_size, window=window, center=center, pad_mode=pad_mode, 2941 | freeze_parameters=True) 2942 | 2943 | # Logmel feature extractor 2944 | self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, 2945 | n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, 2946 | freeze_parameters=True) 2947 | 2948 | # Spec augmenter 2949 | self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, 2950 | freq_drop_width=16, freq_stripes_num=2) 2951 | 2952 | self.bn0 = nn.BatchNorm2d(128) 2953 | 2954 | self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) 2955 | self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) 2956 | self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) 2957 | self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) 2958 | self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) 2959 | self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) 2960 | 2961 | self.fc1 = nn.Linear(2048, 2048, bias=True) 2962 | self.fc_audioset = nn.Linear(2048, classes_num, bias=True) 2963 | 2964 | self.init_weight() 2965 | 2966 | def init_weight(self): 2967 | init_bn(self.bn0) 2968 | init_layer(self.fc1) 2969 | init_layer(self.fc_audioset) 2970 | 2971 | def forward(self, input, mixup_lambda=None): 2972 | """ 2973 | Input: (batch_size, data_length)""" 2974 | 2975 | x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) 2976 | x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) 2977 | 2978 | x = x.transpose(1, 3) 2979 | x = self.bn0(x) 2980 | x = x.transpose(1, 3) 2981 | 2982 | if self.training: 2983 | x = self.spec_augmenter(x) 2984 | 2985 | # Mixup on spectrogram 2986 | if self.training and mixup_lambda is not None: 2987 | x = do_mixup(x, mixup_lambda) 2988 | 2989 | x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') 2990 | x = F.dropout(x, p=0.2, training=self.training) 2991 | x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') 2992 | x = F.dropout(x, p=0.2, training=self.training) 2993 | x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') 2994 | x = F.dropout(x, p=0.2, training=self.training) 2995 | x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') 2996 | x = F.dropout(x, p=0.2, training=self.training) 2997 | x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg') 2998 | x = F.dropout(x, p=0.2, training=self.training) 2999 | x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg') 3000 | x = F.dropout(x, p=0.2, training=self.training) 3001 | x = torch.mean(x, dim=3) 3002 | 3003 | (x1, _) = torch.max(x, dim=2) 3004 | x2 = torch.mean(x, dim=2) 3005 | x = x1 + x2 3006 | x = F.dropout(x, p=0.5, training=self.training) 3007 | x = F.relu_(self.fc1(x)) 3008 | embedding = F.dropout(x, p=0.5, training=self.training) 3009 | clipwise_output = torch.sigmoid(self.fc_audioset(x)) 3010 | 3011 | output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding} 3012 | 3013 | return output_dict 3014 | 3015 | 3016 | ############ 3017 | class Cnn14_DecisionLevelMax(nn.Module): 3018 | def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, 3019 | fmax, classes_num): 3020 | 3021 | super(Cnn14_DecisionLevelMax, self).__init__() 3022 | 3023 | window = 'hann' 3024 | center = True 3025 | pad_mode = 'reflect' 3026 | ref = 1.0 3027 | amin = 1e-10 3028 | top_db = None 3029 | self.interpolate_ratio = 32 # Downsampled ratio 3030 | 3031 | # Spectrogram extractor 3032 | self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, 3033 | win_length=window_size, window=window, center=center, pad_mode=pad_mode, 3034 | freeze_parameters=True) 3035 | 3036 | # Logmel feature extractor 3037 | self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, 3038 | n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, 3039 | freeze_parameters=True) 3040 | 3041 | # Spec augmenter 3042 | self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, 3043 | freq_drop_width=8, freq_stripes_num=2) 3044 | 3045 | self.bn0 = nn.BatchNorm2d(64) 3046 | 3047 | self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) 3048 | self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) 3049 | self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) 3050 | self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) 3051 | self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) 3052 | self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) 3053 | 3054 | self.fc1 = nn.Linear(2048, 2048, bias=True) 3055 | self.fc_audioset = nn.Linear(2048, classes_num, bias=True) 3056 | 3057 | self.init_weight() 3058 | 3059 | def init_weight(self): 3060 | init_bn(self.bn0) 3061 | init_layer(self.fc1) 3062 | init_layer(self.fc_audioset) 3063 | 3064 | def forward(self, input, mixup_lambda=None): 3065 | """ 3066 | Input: (batch_size, data_length)""" 3067 | 3068 | x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) 3069 | x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) 3070 | 3071 | frames_num = x.shape[2] 3072 | 3073 | x = x.transpose(1, 3) 3074 | x = self.bn0(x) 3075 | x = x.transpose(1, 3) 3076 | 3077 | if self.training: 3078 | x = self.spec_augmenter(x) 3079 | 3080 | # Mixup on spectrogram 3081 | if self.training and mixup_lambda is not None: 3082 | x = do_mixup(x, mixup_lambda) 3083 | 3084 | x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') 3085 | x = F.dropout(x, p=0.2, training=self.training) 3086 | x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') 3087 | x = F.dropout(x, p=0.2, training=self.training) 3088 | x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') 3089 | x = F.dropout(x, p=0.2, training=self.training) 3090 | x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') 3091 | x = F.dropout(x, p=0.2, training=self.training) 3092 | x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg') 3093 | x = F.dropout(x, p=0.2, training=self.training) 3094 | x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg') 3095 | x = F.dropout(x, p=0.2, training=self.training) 3096 | x = torch.mean(x, dim=3) 3097 | 3098 | x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) 3099 | x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) 3100 | x = x1 + x2 3101 | x = F.dropout(x, p=0.5, training=self.training) 3102 | x = x.transpose(1, 2) 3103 | x = F.relu_(self.fc1(x)) 3104 | x = F.dropout(x, p=0.5, training=self.training) 3105 | segmentwise_output = torch.sigmoid(self.fc_audioset(x)) 3106 | (clipwise_output, _) = torch.max(segmentwise_output, dim=1) 3107 | 3108 | # Get framewise output 3109 | framewise_output = interpolate(segmentwise_output, self.interpolate_ratio) 3110 | framewise_output = pad_framewise_output(framewise_output, frames_num) 3111 | 3112 | output_dict = {'framewise_output': framewise_output, 3113 | 'clipwise_output': clipwise_output} 3114 | 3115 | return output_dict 3116 | 3117 | 3118 | class Cnn14_DecisionLevelAvg(nn.Module): 3119 | def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, 3120 | fmax, classes_num): 3121 | 3122 | super(Cnn14_DecisionLevelAvg, self).__init__() 3123 | 3124 | window = 'hann' 3125 | center = True 3126 | pad_mode = 'reflect' 3127 | ref = 1.0 3128 | amin = 1e-10 3129 | top_db = None 3130 | self.interpolate_ratio = 32 # Downsampled ratio 3131 | 3132 | # Spectrogram extractor 3133 | self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, 3134 | win_length=window_size, window=window, center=center, pad_mode=pad_mode, 3135 | freeze_parameters=True) 3136 | 3137 | # Logmel feature extractor 3138 | self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, 3139 | n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, 3140 | freeze_parameters=True) 3141 | 3142 | # Spec augmenter 3143 | self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, 3144 | freq_drop_width=8, freq_stripes_num=2) 3145 | 3146 | self.bn0 = nn.BatchNorm2d(64) 3147 | 3148 | self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) 3149 | self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) 3150 | self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) 3151 | self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) 3152 | self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) 3153 | self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) 3154 | 3155 | self.fc1 = nn.Linear(2048, 2048, bias=True) 3156 | self.fc_audioset = nn.Linear(2048, classes_num, bias=True) 3157 | 3158 | self.init_weight() 3159 | 3160 | def init_weight(self): 3161 | init_bn(self.bn0) 3162 | init_layer(self.fc1) 3163 | init_layer(self.fc_audioset) 3164 | 3165 | def forward(self, input, mixup_lambda=None): 3166 | """ 3167 | Input: (batch_size, data_length)""" 3168 | 3169 | x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) 3170 | x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) 3171 | 3172 | frames_num = x.shape[2] 3173 | 3174 | x = x.transpose(1, 3) 3175 | x = self.bn0(x) 3176 | x = x.transpose(1, 3) 3177 | 3178 | if self.training: 3179 | x = self.spec_augmenter(x) 3180 | 3181 | # Mixup on spectrogram 3182 | if self.training and mixup_lambda is not None: 3183 | x = do_mixup(x, mixup_lambda) 3184 | 3185 | x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') 3186 | x = F.dropout(x, p=0.2, training=self.training) 3187 | x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') 3188 | x = F.dropout(x, p=0.2, training=self.training) 3189 | x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') 3190 | x = F.dropout(x, p=0.2, training=self.training) 3191 | x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') 3192 | x = F.dropout(x, p=0.2, training=self.training) 3193 | x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg') 3194 | x = F.dropout(x, p=0.2, training=self.training) 3195 | x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg') 3196 | x = F.dropout(x, p=0.2, training=self.training) 3197 | x = torch.mean(x, dim=3) 3198 | 3199 | x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) 3200 | x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) 3201 | x = x1 + x2 3202 | x = F.dropout(x, p=0.5, training=self.training) 3203 | x = x.transpose(1, 2) 3204 | x = F.relu_(self.fc1(x)) 3205 | x = F.dropout(x, p=0.5, training=self.training) 3206 | segmentwise_output = torch.sigmoid(self.fc_audioset(x)) 3207 | clipwise_output = torch.mean(segmentwise_output, dim=1) 3208 | 3209 | # Get framewise output 3210 | framewise_output = interpolate(segmentwise_output, self.interpolate_ratio) 3211 | framewise_output = pad_framewise_output(framewise_output, frames_num) 3212 | 3213 | # Get framewise output 3214 | framewise_output = interpolate(segmentwise_output, self.interpolate_ratio) 3215 | framewise_output = pad_framewise_output(framewise_output, frames_num) 3216 | 3217 | output_dict = {'framewise_output': framewise_output, 3218 | 'clipwise_output': clipwise_output} 3219 | 3220 | return output_dict 3221 | 3222 | 3223 | class Cnn14_DecisionLevelAtt(nn.Module): 3224 | def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, 3225 | fmax, classes_num): 3226 | 3227 | super(Cnn14_DecisionLevelAtt, self).__init__() 3228 | 3229 | window = 'hann' 3230 | center = True 3231 | pad_mode = 'reflect' 3232 | ref = 1.0 3233 | amin = 1e-10 3234 | top_db = None 3235 | self.interpolate_ratio = 32 # Downsampled ratio 3236 | 3237 | # Spectrogram extractor 3238 | self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, 3239 | win_length=window_size, window=window, center=center, pad_mode=pad_mode, 3240 | freeze_parameters=True) 3241 | 3242 | # Logmel feature extractor 3243 | self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, 3244 | n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, 3245 | freeze_parameters=True) 3246 | 3247 | # Spec augmenter 3248 | self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, 3249 | freq_drop_width=8, freq_stripes_num=2) 3250 | 3251 | self.bn0 = nn.BatchNorm2d(64) 3252 | 3253 | self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) 3254 | self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) 3255 | self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) 3256 | self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) 3257 | self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) 3258 | self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) 3259 | 3260 | self.fc1 = nn.Linear(2048, 2048, bias=True) 3261 | self.att_block = AttBlock(2048, classes_num, activation='sigmoid') 3262 | 3263 | self.init_weight() 3264 | 3265 | def init_weight(self): 3266 | init_bn(self.bn0) 3267 | init_layer(self.fc1) 3268 | 3269 | def forward(self, input, mixup_lambda=None): 3270 | """ 3271 | Input: (batch_size, data_length)""" 3272 | 3273 | x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) 3274 | x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) 3275 | 3276 | frames_num = x.shape[2] 3277 | 3278 | x = x.transpose(1, 3) 3279 | x = self.bn0(x) 3280 | x = x.transpose(1, 3) 3281 | 3282 | if self.training: 3283 | x = self.spec_augmenter(x) 3284 | 3285 | # Mixup on spectrogram 3286 | if self.training and mixup_lambda is not None: 3287 | x = do_mixup(x, mixup_lambda) 3288 | 3289 | x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') 3290 | x = F.dropout(x, p=0.2, training=self.training) 3291 | x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') 3292 | x = F.dropout(x, p=0.2, training=self.training) 3293 | x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') 3294 | x = F.dropout(x, p=0.2, training=self.training) 3295 | x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') 3296 | x = F.dropout(x, p=0.2, training=self.training) 3297 | x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg') 3298 | x = F.dropout(x, p=0.2, training=self.training) 3299 | x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg') 3300 | x = F.dropout(x, p=0.2, training=self.training) 3301 | x = torch.mean(x, dim=3) 3302 | 3303 | x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) 3304 | x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) 3305 | x = x1 + x2 3306 | x = F.dropout(x, p=0.5, training=self.training) 3307 | x = x.transpose(1, 2) 3308 | x = F.relu_(self.fc1(x)) 3309 | x = x.transpose(1, 2) 3310 | x = F.dropout(x, p=0.5, training=self.training) 3311 | (clipwise_output, _, segmentwise_output) = self.att_block(x) 3312 | segmentwise_output = segmentwise_output.transpose(1, 2) 3313 | 3314 | # Get framewise output 3315 | framewise_output = interpolate(segmentwise_output, self.interpolate_ratio) 3316 | framewise_output = pad_framewise_output(framewise_output, frames_num) 3317 | 3318 | output_dict = {'framewise_output': framewise_output, 3319 | 'clipwise_output': clipwise_output} 3320 | 3321 | return output_dict 3322 | -------------------------------------------------------------------------------- /utils/pytorch_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | import torch 4 | import torch.nn as nn 5 | from tqdm import tqdm 6 | 7 | 8 | def move_data_to_device(x, device): 9 | if 'float' in str(x.dtype): 10 | x = torch.Tensor(x) 11 | elif 'int' in str(x.dtype): 12 | x = torch.LongTensor(x) 13 | else: 14 | return x 15 | 16 | return x.to(device) 17 | 18 | 19 | def do_mixup(x, mixup_lambda): 20 | """Mixup x of even indexes (0, 2, 4, ...) with x of odd indexes 21 | (1, 3, 5, ...). 22 | 23 | Args: 24 | x: (batch_size * 2, ...) 25 | mixup_lambda: (batch_size * 2,) 26 | 27 | Returns: 28 | out: (batch_size, ...) 29 | """ 30 | out = (x[0 :: 2].transpose(0, -1) * mixup_lambda[0 :: 2] + \ 31 | x[1 :: 2].transpose(0, -1) * mixup_lambda[1 :: 2]).transpose(0, -1) 32 | return out 33 | 34 | 35 | def append_to_dict(dict, key, value): 36 | if key in dict.keys(): 37 | dict[key].append(value) 38 | else: 39 | dict[key] = [value] 40 | 41 | 42 | def forward(model, generator, return_input=False, 43 | return_target=False): 44 | """Forward data to a model. 45 | 46 | Args: 47 | model: object 48 | generator: object 49 | return_input: bool 50 | return_target: bool 51 | 52 | Returns: 53 | audio_name: (audios_num,) 54 | clipwise_output: (audios_num, classes_num) 55 | (ifexist) segmentwise_output: (audios_num, segments_num, classes_num) 56 | (ifexist) framewise_output: (audios_num, frames_num, classes_num) 57 | (optional) return_input: (audios_num, segment_samples) 58 | (optional) return_target: (audios_num, classes_num) 59 | """ 60 | output_dict = {} 61 | device = next(model.parameters()).device 62 | time1 = time.time() 63 | 64 | pbar = tqdm(generator) 65 | pbar.set_description('Evaluation starting ...') 66 | 67 | # Forward data to a model in mini-batches 68 | for n, batch_data_dict in enumerate(pbar): 69 | # print(n) 70 | batch_waveform = move_data_to_device(batch_data_dict['waveform'], device) 71 | 72 | with torch.no_grad(): 73 | model.eval() 74 | batch_output = model(batch_waveform) 75 | 76 | append_to_dict(output_dict, 'audio_name', batch_data_dict['audio_name']) 77 | 78 | append_to_dict(output_dict, 'clipwise_output', 79 | batch_output['clipwise_output'].data.cpu().numpy()) 80 | 81 | if 'segmentwise_output' in batch_output.keys(): 82 | append_to_dict(output_dict, 'segmentwise_output', 83 | batch_output['segmentwise_output'].data.cpu().numpy()) 84 | 85 | if 'framewise_output' in batch_output.keys(): 86 | append_to_dict(output_dict, 'framewise_output', 87 | batch_output['framewise_output'].data.cpu().numpy()) 88 | 89 | if return_input: 90 | append_to_dict(output_dict, 'waveform', batch_data_dict['waveform']) 91 | 92 | if return_target: 93 | if 'target' in batch_data_dict.keys(): 94 | append_to_dict(output_dict, 'target', batch_data_dict['target']) 95 | 96 | # if n % 10 == 0: 97 | # print(' --- Inference time: {:.3f} s / 10 iterations ---'.format( 98 | # time.time() - time1)) 99 | # time1 = time.time() 100 | 101 | for key in output_dict.keys(): 102 | output_dict[key] = np.concatenate(output_dict[key], axis=0) 103 | 104 | return output_dict 105 | 106 | 107 | def interpolate(x, ratio): 108 | """Interpolate data in time domain. This is used to compensate the 109 | resolution reduction in downsampling of a CNN. 110 | 111 | Args: 112 | x: (batch_size, time_steps, classes_num) 113 | ratio: int, ratio to interpolate 114 | 115 | Returns: 116 | upsampled: (batch_size, time_steps * ratio, classes_num) 117 | """ 118 | (batch_size, time_steps, classes_num) = x.shape 119 | upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1) 120 | upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num) 121 | return upsampled 122 | 123 | 124 | def pad_framewise_output(framewise_output, frames_num): 125 | """Pad framewise_output to the same length as input frames. The pad value 126 | is the same as the value of the last frame. 127 | 128 | Args: 129 | framewise_output: (batch_size, frames_num, classes_num) 130 | frames_num: int, number of frames to pad 131 | 132 | Outputs: 133 | output: (batch_size, frames_num, classes_num) 134 | """ 135 | pad = framewise_output[:, -1 :, :].repeat(1, frames_num - framewise_output.shape[1], 1) 136 | """tensor for padding""" 137 | 138 | output = torch.cat((framewise_output, pad), dim=1) 139 | """(batch_size, frames_num, classes_num)""" 140 | 141 | return output 142 | 143 | 144 | def count_parameters(model): 145 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 146 | 147 | 148 | def count_flops(model, audio_length): 149 | """Count flops. Code modified from others' implementation. 150 | """ 151 | multiply_adds = True 152 | list_conv2d=[] 153 | def conv2d_hook(self, input, output): 154 | batch_size, input_channels, input_height, input_width = input[0].size() 155 | output_channels, output_height, output_width = output[0].size() 156 | 157 | kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) * (2 if multiply_adds else 1) 158 | bias_ops = 1 if self.bias is not None else 0 159 | 160 | params = output_channels * (kernel_ops + bias_ops) 161 | flops = batch_size * params * output_height * output_width 162 | 163 | list_conv2d.append(flops) 164 | 165 | list_conv1d=[] 166 | def conv1d_hook(self, input, output): 167 | batch_size, input_channels, input_length = input[0].size() 168 | output_channels, output_length = output[0].size() 169 | 170 | kernel_ops = self.kernel_size[0] * (self.in_channels / self.groups) * (2 if multiply_adds else 1) 171 | bias_ops = 1 if self.bias is not None else 0 172 | 173 | params = output_channels * (kernel_ops + bias_ops) 174 | flops = batch_size * params * output_length 175 | 176 | list_conv1d.append(flops) 177 | 178 | list_linear=[] 179 | def linear_hook(self, input, output): 180 | batch_size = input[0].size(0) if input[0].dim() == 2 else 1 181 | 182 | weight_ops = self.weight.nelement() * (2 if multiply_adds else 1) 183 | bias_ops = self.bias.nelement() 184 | 185 | flops = batch_size * (weight_ops + bias_ops) 186 | list_linear.append(flops) 187 | 188 | list_bn=[] 189 | def bn_hook(self, input, output): 190 | list_bn.append(input[0].nelement() * 2) 191 | 192 | list_relu=[] 193 | def relu_hook(self, input, output): 194 | list_relu.append(input[0].nelement() * 2) 195 | 196 | list_pooling2d=[] 197 | def pooling2d_hook(self, input, output): 198 | batch_size, input_channels, input_height, input_width = input[0].size() 199 | output_channels, output_height, output_width = output[0].size() 200 | 201 | kernel_ops = self.kernel_size * self.kernel_size 202 | bias_ops = 0 203 | params = output_channels * (kernel_ops + bias_ops) 204 | flops = batch_size * params * output_height * output_width 205 | 206 | list_pooling2d.append(flops) 207 | 208 | list_pooling1d=[] 209 | def pooling1d_hook(self, input, output): 210 | batch_size, input_channels, input_length = input[0].size() 211 | output_channels, output_length = output[0].size() 212 | 213 | kernel_ops = self.kernel_size[0] 214 | bias_ops = 0 215 | 216 | params = output_channels * (kernel_ops + bias_ops) 217 | flops = batch_size * params * output_length 218 | 219 | list_pooling2d.append(flops) 220 | 221 | def foo(net): 222 | childrens = list(net.children()) 223 | if not childrens: 224 | if isinstance(net, nn.Conv2d): 225 | net.register_forward_hook(conv2d_hook) 226 | elif isinstance(net, nn.Conv1d): 227 | net.register_forward_hook(conv1d_hook) 228 | elif isinstance(net, nn.Linear): 229 | net.register_forward_hook(linear_hook) 230 | elif isinstance(net, nn.BatchNorm2d) or isinstance(net, nn.BatchNorm1d): 231 | net.register_forward_hook(bn_hook) 232 | elif isinstance(net, nn.ReLU): 233 | net.register_forward_hook(relu_hook) 234 | elif isinstance(net, nn.AvgPool2d) or isinstance(net, nn.MaxPool2d): 235 | net.register_forward_hook(pooling2d_hook) 236 | elif isinstance(net, nn.AvgPool1d) or isinstance(net, nn.MaxPool1d): 237 | net.register_forward_hook(pooling1d_hook) 238 | else: 239 | print('Warning: flop of module {} is not counted!'.format(net)) 240 | return 241 | for c in childrens: 242 | foo(c) 243 | 244 | # Register hook 245 | foo(model) 246 | 247 | device = device = next(model.parameters()).device 248 | input = torch.rand(1, audio_length).to(device) 249 | 250 | out = model(input) 251 | 252 | total_flops = sum(list_conv2d) + sum(list_conv1d) + sum(list_linear) + \ 253 | sum(list_bn) + sum(list_relu) + sum(list_pooling2d) + sum(list_pooling1d) 254 | 255 | return total_flops --------------------------------------------------------------------------------