├── .dockerignore ├── workspace ├── trainers │ ├── __init__.py │ ├── trainer.py │ ├── trainer_mixup.py │ └── trainer_mixup_finetune.py ├── data_loaders │ ├── __init__.py │ └── 1s_dataset.py ├── notebooks │ └── README.md ├── models │ ├── __init__.py │ ├── modules │ │ ├── __init__.py │ │ ├── fusion.py │ │ ├── pytorch_utils.py │ │ └── sincnet │ │ │ └── dnn_models.py │ ├── loss.py │ ├── metric.py │ └── model.py ├── utils │ ├── __init__.py │ ├── mixup.py │ ├── spec_timeshift_transform.py │ └── util.py ├── base │ ├── __init__.py │ ├── base_model.py │ ├── base_trainer.py │ └── base_dataloader.py ├── logger │ ├── __init__.py │ ├── logger_config_mp.json │ ├── logger_config.json │ ├── logger.py │ └── visualization.py ├── scripts │ ├── unzip_dataset.sh │ ├── download_dataset.sh │ └── process_dcase_1s.py ├── configs │ └── wave_spec_fusion.json ├── parse_config.py ├── train.py └── test_dcasetask1b.py ├── model_figure.png ├── apt_requirements.txt ├── docker-compose.yaml ├── Dockerfile ├── Makefile ├── .gitignore ├── README.md └── pip_requirements.txt /.dockerignore: -------------------------------------------------------------------------------- 1 | workspace/ -------------------------------------------------------------------------------- /workspace/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /workspace/data_loaders/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /workspace/notebooks/README.md: -------------------------------------------------------------------------------- 1 | Notebooks go here! -------------------------------------------------------------------------------- /workspace/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules.fusion import * -------------------------------------------------------------------------------- /workspace/models/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .fusion import * -------------------------------------------------------------------------------- /model_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/denfed/wave-spec-fusion/HEAD/model_figure.png -------------------------------------------------------------------------------- /apt_requirements.txt: -------------------------------------------------------------------------------- 1 | python3 2 | python3-pip 3 | libsndfile1-dev 4 | libsm6 5 | libxext6 6 | libxrender-dev 7 | ffmpeg -------------------------------------------------------------------------------- /workspace/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .util import * 2 | from .mixup import * 3 | from .spec_timeshift_transform import * -------------------------------------------------------------------------------- /workspace/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_dataloader import * 2 | from .base_trainer import BaseTrainer 3 | from .base_model import * -------------------------------------------------------------------------------- /workspace/logger/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import setup_logging, get_logger 2 | from .visualization import TensorboardWriter 3 | -------------------------------------------------------------------------------- /docker-compose.yaml: -------------------------------------------------------------------------------- 1 | version: '2.3' 2 | 3 | services: 4 | dcase_2021: 5 | runtime: nvidia 6 | image: dcase_2021 7 | ipc: host 8 | build: 9 | context: . 10 | args: 11 | - UID 12 | - GID 13 | - USER_PASSWORD 14 | command: bash -c "jupyter lab --ip=0.0.0.0 --NotebookApp.token='' --NotebookApp.password='' --allow-root" 15 | ports: 16 | - "9999:8888" 17 | - "6066:6066" 18 | volumes: 19 | - ./workspace/:/home/src 20 | - /mnt/ssd/data/:/mnt/ssd/data 21 | networks: 22 | default: 23 | external: 24 | name: glados-docker 25 | -------------------------------------------------------------------------------- /workspace/scripts/unzip_dataset.sh: -------------------------------------------------------------------------------- 1 | unzip data/audio_1.zip -d data/tau_audiovisual_2021/ 2 | unzip data/audio_2.zip -d data/tau_audiovisual_2021/ 3 | unzip data/audio_3.zip -d data/tau_audiovisual_2021/ 4 | unzip data/audio_4.zip -d data/tau_audiovisual_2021/ 5 | unzip data/audio_5.zip -d data/tau_audiovisual_2021/ 6 | unzip data/audio_6.zip -d data/tau_audiovisual_2021/ 7 | unzip data/audio_7.zip -d data/tau_audiovisual_2021/ 8 | unzip data/audio_8.zip -d data/tau_audiovisual_2021/ 9 | unzip data/doc.zip -d data/tau_audiovisual_2021/ 10 | unzip data/meta.zip -d data/tau_audiovisual_2021/ 11 | unzip data/examples.zip -d data/tau_audiovisual_2021/ 12 | 13 | rm data/*.zip -------------------------------------------------------------------------------- /workspace/utils/mixup.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | 5 | 6 | def mixup_data(x, y, alpha=0.2, use_cuda=True): 7 | """x IS A DICT""" 8 | '''Returns mixed inputs, pairs of targets, and lambda''' 9 | if alpha > 0: 10 | lam = np.random.beta(alpha, alpha) 11 | else: 12 | lam = 1 13 | # batch_size = x.size()[0] 14 | batch_size = x[next(iter(x))].size()[0] 15 | if use_cuda: 16 | index = torch.randperm(batch_size).cuda() 17 | else: 18 | index = torch.randperm(batch_size) 19 | 20 | for key, value in x.items(): 21 | x[key] = lam * value + (1 - lam) * value[index, :] 22 | 23 | y_a, y_b = y, y[index] 24 | return x, y_a, y_b, lam -------------------------------------------------------------------------------- /workspace/base/base_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | from abc import abstractmethod 4 | 5 | 6 | class BaseModel(nn.Module): 7 | """ 8 | Base class for all models 9 | """ 10 | @abstractmethod 11 | def forward(self, *inputs): 12 | """ 13 | Forward pass logic 14 | 15 | :return: Model output 16 | """ 17 | raise NotImplementedError 18 | 19 | def __str__(self): 20 | """ 21 | Model prints with number of trainable parameters 22 | """ 23 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 24 | params = sum([np.prod(p.size()) for p in model_parameters]) 25 | return super().__str__() + '\nTrainable parameters: {}'.format(params) 26 | -------------------------------------------------------------------------------- /workspace/logger/logger_config_mp.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": 1, 3 | "disable_existing_loggers": false, 4 | "formatters": { 5 | "simple": {"format": "%(message)s"}, 6 | "datetime": {"format": "%(asctime)s - %(name)s - %(levelname)s - \n%(message)s"} 7 | }, 8 | "handlers": { 9 | "console": { 10 | "class": "logging.StreamHandler", 11 | "level": "DEBUG", 12 | "formatter": "simple", 13 | "stream": "ext://sys.stdout" 14 | }, 15 | "info_file_handler": { 16 | "class": "logging.handlers.RotatingFileHandler", 17 | "level": "INFO", 18 | "formatter": "datetime", 19 | "filename": "info.log", 20 | "maxBytes": 10485760, 21 | "backupCount": 20, "encoding": "utf8" 22 | } 23 | }, 24 | "loggers": { 25 | "model": { 26 | "level": "INFO", 27 | "propagate": false, 28 | "handlers": ["info_file_handler"] 29 | } 30 | }, 31 | "root": { 32 | "level": "INFO", 33 | "handlers": ["info_file_handler"] 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /workspace/logger/logger_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": 1, 3 | "disable_existing_loggers": false, 4 | "formatters": { 5 | "simple": {"format": "%(message)s"}, 6 | "datetime": {"format": "%(asctime)s - %(name)s - %(levelname)s - \n%(message)s"} 7 | }, 8 | "handlers": { 9 | "console": { 10 | "class": "logging.StreamHandler", 11 | "level": "DEBUG", 12 | "formatter": "simple", 13 | "stream": "ext://sys.stdout" 14 | }, 15 | "info_file_handler": { 16 | "class": "logging.handlers.RotatingFileHandler", 17 | "level": "INFO", 18 | "formatter": "datetime", 19 | "filename": "info.log", 20 | "maxBytes": 10485760, 21 | "backupCount": 20, "encoding": "utf8" 22 | } 23 | }, 24 | "loggers": { 25 | "model": { 26 | "level": "INFO", 27 | "propagate": false, 28 | "handlers": ["info_file_handler"] 29 | } 30 | }, 31 | "root": { 32 | "level": "DEBUG", 33 | "handlers": ["console", "info_file_handler"] 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | ARG CUDAVERSION=10.2 2 | 3 | FROM nvidia/cuda:${CUDAVERSION}-base 4 | 5 | # User Setup 6 | ARG UID 7 | ARG GID 8 | ARG USER_PASSWORD 9 | RUN adduser --disabled-password --gecos "" container_user 10 | RUN usermod -u ${UID} container_user 11 | RUN groupmod -g ${GID} container_user 12 | RUN echo container_user:${USER_PASSWORD} | chpasswd 13 | RUN usermod -aG sudo container_user 14 | 15 | RUN mkdir /home/src 16 | WORKDIR /home/src 17 | ENV HOME /home/src 18 | 19 | RUN apt update 20 | 21 | # Global Apt Dependencies 22 | COPY apt_requirements.txt $HOME/apt_requirements.txt 23 | RUN cat apt_requirements.txt | xargs apt install -y 24 | RUN rm apt_requirements.txt 25 | 26 | # Update pip3 27 | RUN pip3 install --upgrade pip==21.1 28 | 29 | # Install wandb and initialize 30 | RUN pip3 install --upgrade wandb 31 | ENV LC_ALL=C.UTF-8 32 | ENV LANG=C.UTF-8 33 | # RUN wandb login --host= 34 | 35 | # Cache pytorch so it doesn't re-download on requirements change 36 | RUN pip3 install torch==1.8.1 37 | 38 | # Global Python Dependencies 39 | COPY pip_requirements.txt $HOME/pip_requirements.txt 40 | RUN pip3 install -r pip_requirements.txt 41 | RUN rm pip_requirements.txt 42 | 43 | USER container_user 44 | -------------------------------------------------------------------------------- /workspace/logger/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import logging.config 3 | from pathlib import Path 4 | 5 | from utils import read_json 6 | 7 | 8 | def setup_logging(save_dir, root_dir='./', filename=None, 9 | log_config="logger/logger_config.json", default_level=logging.INFO): 10 | """ 11 | setup logging configuration 12 | """ 13 | log_config = Path(root_dir) / log_config 14 | if log_config.is_file(): 15 | config = read_json(log_config) 16 | # modify logging paths based on run config 17 | for _, handler in config['handlers'].items(): 18 | if 'filename' in handler: 19 | if filename is None: 20 | handler['filename'] = str(save_dir / handler['filename']) 21 | else: 22 | handler['filename'] = str(save_dir / filename) 23 | 24 | logging.config.dictConfig(config) 25 | else: 26 | print("warning: logging configuration file is not found in {}.".format(log_config)) 27 | logging.basicConfig(level=default_level) 28 | 29 | 30 | log_levels = { 31 | 0: logging.WARNING, 32 | 1: logging.INFO, 33 | 2: logging.DEBUG, 34 | } 35 | def get_logger(name, verbosity=2): 36 | assert verbosity in log_levels, \ 37 | "verbosity option {verbosity} is invalid. \ 38 | Valid options are {log_levels.keys()}." 39 | logger = logging.getLogger(name) 40 | logger.setLevel(log_levels[verbosity]) 41 | return logger 42 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | UID ?= $(shell id -u) # to set UID of container user to match host 2 | GID ?= $(shell id -g) # to set GID of container user to match host 3 | USER_PASSWORD ?= password # container user password (for sudo) 4 | PORT?=6066 # tensorboard port 5 | 6 | # Build without cache 7 | .PHONY: build-nocache 8 | build-nocache: 9 | docker-compose stop 10 | docker-compose build --no-cache --build-arg UID=$(UID) --build-arg GID=$(GID) --build-arg USER_PASSWORD=$(USER_PASSWORD) dcase_2021 11 | 12 | # Build with cache 13 | .PHONY: build 14 | build: 15 | docker-compose stop 16 | docker-compose build --build-arg UID=$(UID) --build-arg GID=$(GID) --build-arg USER_PASSWORD=$(USER_PASSWORD) dcase_2021 17 | 18 | # Start the container and the Jupyterlab environment 19 | .PHONY: run 20 | run: 21 | docker-compose stop 22 | docker-compose up -d dcase_2021 23 | 24 | # Create terminal inside container 25 | .PHONY: terminal 26 | terminal: 27 | docker-compose run dcase_2021 bash 28 | 29 | # Start tensorboard environment 30 | .PHONY: tb 31 | tb: 32 | tensorboard --logdir=saved/log/ --port=$(PORT) --bind_all --max_reload_threads 4 33 | 34 | # Start the wandb environment at port 8080 35 | .PHONY: wandb 36 | wandb: 37 | docker run --rm -d -v wandb:/vol -p 8080:8080 --name wandb-local wandb/local 38 | 39 | # Stop the wandb environment 40 | .PHONY: wandb-stop 41 | wandb-stop: 42 | docker stop wandb-local 43 | 44 | # Upgrade the wandb container 45 | .PHONY: wandb-upgrade 46 | wandb-upgrade: 47 | docker pull wandb/local 48 | -------------------------------------------------------------------------------- /workspace/scripts/download_dataset.sh: -------------------------------------------------------------------------------- 1 | curl -o data/audio_1.zip https://zenodo.org/record/4477542/files/TAU-urban-audio-visual-scenes-2021-development.audio.1.zip?download=1 2 | curl -o data/audio_2.zip https://zenodo.org/record/4477542/files/TAU-urban-audio-visual-scenes-2021-development.audio.2.zip?download=1 3 | curl -o data/audio_3.zip https://zenodo.org/record/4477542/files/TAU-urban-audio-visual-scenes-2021-development.audio.3.zip?download=1 4 | curl -o data/audio_4.zip https://zenodo.org/record/4477542/files/TAU-urban-audio-visual-scenes-2021-development.audio.4.zip?download=1 5 | curl -o data/audio_5.zip https://zenodo.org/record/4477542/files/TAU-urban-audio-visual-scenes-2021-development.audio.5.zip?download=1 6 | curl -o data/audio_6.zip https://zenodo.org/record/4477542/files/TAU-urban-audio-visual-scenes-2021-development.audio.6.zip?download=1 7 | curl -o data/audio_7.zip https://zenodo.org/record/4477542/files/TAU-urban-audio-visual-scenes-2021-development.audio.7.zip?download=1 8 | curl -o data/audio_8.zip https://zenodo.org/record/4477542/files/TAU-urban-audio-visual-scenes-2021-development.audio.8.zip?download=1 9 | curl -o data/doc.zip https://zenodo.org/record/4477542/files/TAU-urban-audio-visual-scenes-2021-development.doc.zip?download=1 10 | curl -o data/meta.zip https://zenodo.org/record/4477542/files/TAU-urban-audio-visual-scenes-2021-development.meta.zip?download=1 11 | curl -o data/examples.zip https://zenodo.org/record/4477542/files/TAU-urban-audio-visual-scenes-2021-development.examples.zip?download=1 12 | -------------------------------------------------------------------------------- /workspace/models/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn import CrossEntropyLoss, MSELoss, BCELoss, BCEWithLogitsLoss 4 | 5 | 6 | def weighted_bce_loss(output, target, weight: torch.Tensor=None): 7 | weighted_bce_loss = BCELoss(weight=weight[target.long()]) 8 | return weighted_bce_loss(output, target) 9 | 10 | 11 | # ref: https://github.com/kornia/kornia/blob/master/kornia/losses/focal.py 12 | def binary_focal_loss(output: torch.Tensor, 13 | target: torch.Tensor, 14 | alpha: float = 0.5, 15 | gamma: float = 2.0, 16 | reduction: str = 'sum', 17 | eps: float = 1e-8) -> torch.Tensor: 18 | p_t = output 19 | loss_tmp = -alpha * torch.pow(1 - p_t, gamma) * target * torch.log(p_t + eps) \ 20 | - (1 - alpha) * torch.pow(p_t, gamma) * (1 - target) * torch.log(1 - p_t + eps) 21 | 22 | if reduction == 'none': 23 | loss = loss_tmp 24 | elif reduction == 'mean': 25 | loss = torch.mean(loss_tmp) 26 | elif reduction == 'sum': 27 | loss = torch.sum(loss_tmp) 28 | 29 | return loss 30 | 31 | 32 | # model output with log_softmax 33 | def nll_loss(output, target): 34 | return F.nll_loss(output, target) 35 | 36 | 37 | def vae_loss(recon_x, mu, logvar, x, lm=1e-4): 38 | mse_loss = MSELoss() 39 | loss_recon = mse_loss(recon_x, x) 40 | kl_divergence = torch.mean(-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())) 41 | return loss_recon + lm * kl_divergence, loss_recon, kl_divergence -------------------------------------------------------------------------------- /workspace/utils/spec_timeshift_transform.py: -------------------------------------------------------------------------------- 1 | import random 2 | import warnings 3 | import librosa 4 | import numpy as np 5 | 6 | 7 | class Transform(object): 8 | def transform_data(self, data): 9 | # Mandatory to be defined by subclasses 10 | raise NotImplementedError("Abstract object") 11 | 12 | def transform_label(self, label): 13 | # Do nothing, to be changed in subclasses if needed 14 | return label 15 | 16 | def _apply_transform(self, sample_no_index): 17 | data, label = sample_no_index 18 | if type(data) is tuple: # meaning there is more than one data_input (could be duet, triplet...) 19 | data = list(data) 20 | for k in range(len(data)): 21 | data[k] = self.transform_data(data[k]) 22 | data = tuple(data) 23 | else: 24 | data = self.transform_data(data) 25 | label = self.transform_label(label) 26 | return data, label 27 | 28 | def __call__(self, sample): 29 | """Apply the transformation 30 | Args: 31 | sample: tuple, a sample defined by a DataLoad class 32 | Returns: 33 | tuple 34 | The transformed tuple 35 | """ 36 | if type(sample[1]) is int: # Means there is an index, may be another way to make it cleaner 37 | sample_data, index = sample 38 | sample_data = self._apply_transform(sample_data) 39 | sample = sample_data, index 40 | else: 41 | sample = self._apply_transform(sample) 42 | return sample 43 | 44 | 45 | class TimeShift(Transform): 46 | def __init__(self, mean=0, std=90): 47 | self.mean = mean 48 | self.std = std 49 | 50 | def __call__(self, sample): 51 | data = sample 52 | shift = int(np.random.normal(self.mean, self.std))\ 53 | 54 | if type(data) is tuple: # meaning there is more than one data_input (could be duet, triplet...) 55 | data = list(data) 56 | for k in range(len(data)): 57 | data[k] = np.roll(data[k], shift, axis=1) 58 | data = tuple(data) 59 | else: 60 | data = np.roll(data, shift, axis=1) 61 | 62 | # if len(label.shape) == 2: 63 | # label = np.roll(label, shift, axis=0) # strong label only 64 | 65 | # sample = (data, label) 66 | return data -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Workspace files 2 | workspace/.bash_history 3 | workspace/.jupyter 4 | workspace/.local 5 | workspace/.python_history 6 | 7 | wandb/ 8 | workspace/wandb/ 9 | workspace/saved/ 10 | 11 | workspace/data/ 12 | 13 | *data/ 14 | log 15 | output/ 16 | __pycache__/ 17 | saved/ 18 | 19 | # Byte-compiled / optimized / DLL files 20 | __pycache__/ 21 | *.py[cod] 22 | *$py.class 23 | 24 | .idea 25 | 26 | # C extensions 27 | *.so 28 | 29 | # Distribution / packaging 30 | .Python 31 | build/ 32 | develop-eggs/ 33 | dist/ 34 | downloads/ 35 | eggs/ 36 | .eggs/ 37 | lib/ 38 | lib64/ 39 | parts/ 40 | sdist/ 41 | var/ 42 | wheels/ 43 | pip-wheel-metadata/ 44 | share/python-wheels/ 45 | *.egg-info/ 46 | .installed.cfg 47 | *.egg 48 | MANIFEST 49 | 50 | # PyInstaller 51 | # Usually these files are written by a python script from a template 52 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 53 | *.manifest 54 | *.spec 55 | 56 | # Installer logs 57 | pip-log.txt 58 | pip-delete-this-directory.txt 59 | 60 | # Unit test / coverage reports 61 | htmlcov/ 62 | .tox/ 63 | .nox/ 64 | .coverage 65 | .coverage.* 66 | .cache 67 | nosetests.xml 68 | coverage.xml 69 | *.cover 70 | *.py,cover 71 | .hypothesis/ 72 | .pytest_cache/ 73 | 74 | # Translations 75 | *.mo 76 | *.pot 77 | 78 | # Django stuff: 79 | *.log 80 | local_settings.py 81 | db.sqlite3 82 | db.sqlite3-journal 83 | 84 | # Flask stuff: 85 | instance/ 86 | .webassets-cache 87 | 88 | # Scrapy stuff: 89 | .scrapy 90 | 91 | # Sphinx documentation 92 | docs/_build/ 93 | 94 | # PyBuilder 95 | target/ 96 | 97 | # Jupyter Notebook 98 | .ipynb_checkpoints 99 | 100 | # IPython 101 | profile_default/ 102 | ipython_config.py 103 | 104 | # pyenv 105 | .python-version 106 | 107 | # pipenv 108 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 109 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 110 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 111 | # install all needed dependencies. 112 | #Pipfile.lock 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | -------------------------------------------------------------------------------- /workspace/logger/visualization.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from datetime import datetime 3 | 4 | 5 | class TensorboardWriter(): 6 | def __init__(self, log_dir, logger, enabled): 7 | self.writer = None 8 | self.selected_module = "" 9 | 10 | if enabled: 11 | log_dir = str(log_dir) 12 | 13 | # Retrieve vizualization writer. 14 | succeeded = False 15 | for module in ["torch.utils.tensorboard", "tensorboardX"]: 16 | try: 17 | self.writer = importlib.import_module(module).SummaryWriter(log_dir) 18 | succeeded = True 19 | break 20 | except ImportError: 21 | succeeded = False 22 | self.selected_module = module 23 | 24 | if not succeeded: 25 | message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \ 26 | "this machine. Please install TensorboardX with 'pip install tensorboardx', upgrade PyTorch to " \ 27 | "version >= 1.1 to use 'torch.utils.tensorboard' or turn off the option in the 'config.json' file." 28 | logger.warning(message) 29 | 30 | self.step = 0 31 | self.mode = '' 32 | 33 | self.tb_writer_ftns = { 34 | 'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio', 35 | 'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding' 36 | } 37 | self.tag_mode_exceptions = {'add_histogram', 'add_embedding'} 38 | self.timer = datetime.now() 39 | 40 | def set_step(self, step, mode='train'): 41 | self.mode = mode 42 | self.step = step 43 | if step == 0: 44 | self.timer = datetime.now() 45 | else: 46 | duration = datetime.now() - self.timer 47 | self.add_scalar('steps_per_sec', 1 / duration.total_seconds()) 48 | self.timer = datetime.now() 49 | 50 | def __getattr__(self, name): 51 | """ 52 | If visualization is configured to use: 53 | return add_data() methods of tensorboard with additional information (step, tag) added. 54 | Otherwise: 55 | return a blank function handle that does nothing 56 | """ 57 | if name in self.tb_writer_ftns: 58 | add_data = getattr(self.writer, name, None) 59 | 60 | def wrapper(tag, data, *args, **kwargs): 61 | if add_data is not None: 62 | # add mode(train/valid) tag 63 | if name not in self.tag_mode_exceptions: 64 | tag = '{}/{}'.format(tag, self.mode) 65 | add_data(tag, data, self.step, *args, **kwargs) 66 | return wrapper 67 | else: 68 | # default action for returning methods defined in this class, set_step() for instance. 69 | try: 70 | attr = object.__getattr__(name) 71 | except AttributeError: 72 | raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name)) 73 | return attr 74 | -------------------------------------------------------------------------------- /workspace/models/modules/fusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | # class Cfgs(BaseCfgs): 6 | # def __init__(self): 7 | # super(Cfgs, self).__init__() 8 | 9 | # self.HIGH_ORDER = False 10 | # self.HIDDEN_SIZE = 2048 11 | # self.MFB_K = 5 12 | # self.MFB_O = 1000 13 | # self.LSTM_OUT_SIZE = 1024 14 | # self.DROPOUT_R = 0.1 15 | # self.I_GLIMPSES = 2 16 | # self.Q_GLIMPSES = 2 17 | 18 | 19 | class MFB(nn.Module): 20 | def __init__(self, img_feat_size, ques_feat_size, is_first=True, MFB_K=5, MFB_O=2048, DROPOUT_R=0.1): 21 | super(MFB, self).__init__() 22 | self.is_first = is_first 23 | self.MFB_K = MFB_K 24 | self.MFB_O = MFB_O 25 | self.DROPOUT_R = DROPOUT_R 26 | self.proj_i = nn.Linear(img_feat_size, self.MFB_K * self.MFB_O) 27 | self.proj_q = nn.Linear(ques_feat_size, self.MFB_K * self.MFB_O) 28 | self.dropout = nn.Dropout(self.DROPOUT_R) 29 | self.pool = nn.AvgPool1d(self.MFB_K, stride=self.MFB_K) 30 | 31 | def forward(self, img_feat, ques_feat, exp_in=1): 32 | ''' 33 | img_feat.size() -> (N, C, img_feat_size) C = 1 or 100 34 | ques_feat.size() -> (N, 1, ques_feat_size) 35 | z.size() -> (N, C, MFB_O) 36 | exp_out.size() -> (N, C, K*O) 37 | ''' 38 | batch_size = img_feat.shape[0] 39 | img_feat = self.proj_i(img_feat) # (N, C, K*O) 40 | ques_feat = self.proj_q(ques_feat) # (N, 1, K*O) 41 | 42 | exp_out = img_feat * ques_feat # (N, C, K*O) 43 | exp_out = self.dropout(exp_out) if self.is_first else self.dropout(exp_out * exp_in) # (N, C, K*O) 44 | z = self.pool(exp_out) * self.MFB_K # (N, C, O) 45 | z = torch.sqrt(F.relu(z)) - torch.sqrt(F.relu(-z)) 46 | z = F.normalize(z.view(batch_size, -1)) # (N, C*O) 47 | z = z.view(batch_size, -1, self.MFB_O) # (N, C, O) 48 | return z, exp_out 49 | 50 | 51 | class CoAtt(nn.Module): 52 | def __init__(self, __C): 53 | super(CoAtt, self).__init__() 54 | self.__C = __C 55 | 56 | img_feat_size = __C.FEAT_SIZE[__C.DATASET]['FRCN_FEAT_SIZE'][1] 57 | img_att_feat_size = img_feat_size * __C.I_GLIMPSES 58 | ques_att_feat_size = __C.LSTM_OUT_SIZE * __C.Q_GLIMPSES 59 | 60 | self.q_att = QAtt(__C) 61 | self.i_att = IAtt(__C, img_feat_size, ques_att_feat_size) 62 | 63 | if self.__C.HIGH_ORDER: # MFH 64 | self.mfh1 = MFB(__C, img_att_feat_size, ques_att_feat_size, True) 65 | self.mfh2 = MFB(__C, img_att_feat_size, ques_att_feat_size, False) 66 | else: # MFB 67 | self.mfb = MFB(__C, img_att_feat_size, ques_att_feat_size, True) 68 | 69 | def forward(self, img_feat, ques_feat): 70 | ''' 71 | img_feat.size() -> (N, C, FRCN_FEAT_SIZE) 72 | ques_feat.size() -> (N, T, LSTM_OUT_SIZE) 73 | z.size() -> MFH:(N, 2*O) / MFB:(N, O) 74 | ''' 75 | ques_feat = self.q_att(ques_feat) # (N, LSTM_OUT_SIZE*Q_GLIMPSES) 76 | fuse_feat = self.i_att(img_feat, ques_feat) # (N, FRCN_FEAT_SIZE*I_GLIMPSES) 77 | 78 | if self.__C.HIGH_ORDER: # MFH 79 | z1, exp1 = self.mfh1(fuse_feat.unsqueeze(1), ques_feat.unsqueeze(1)) # z1:(N, 1, O) exp1:(N, C, K*O) 80 | z2, _ = self.mfh2(fuse_feat.unsqueeze(1), ques_feat.unsqueeze(1), exp1) # z2:(N, 1, O) _:(N, C, K*O) 81 | z = torch.cat((z1.squeeze(1), z2.squeeze(1)), 1) # (N, 2*O) 82 | else: # MFB 83 | z, _ = self.mfb(fuse_feat.unsqueeze(1), ques_feat.unsqueeze(1)) # z:(N, 1, O) _:(N, C, K*O) 84 | z = z.squeeze(1) # (N, O) 85 | 86 | return z -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # wave-spec-fusion 2 | Code for the submitted 2021 DCASE Workshop paper: "Waveforms and Spectrograms: Enhancing Acoustic Scene Classification Using Multimodal Feature Fusion" 3 | 4 | ## Abstract 5 | 6 | Acoustic scene classification (ASC) has seen tremendous progress from the combined use of convolutional neural networks (CNNs) and signal processing strategies. In this paper, we investigate the use of two common feature representations within the audio understanding domain, the raw waveform and Mel-spectrogram, and measure their degree of complementarity when using both representations for feature fusion. We introduce a new model paradigm for acoustic scene classification by fusing features learned from Mel-spectrograms and the raw waveform from separate feature extraction branches. Our experimental results show that our proposed fusion model significantly outperforms the baseline audio-only sub-network on the DCASE 2021 Challenge Task 1B (increase of 5.7\% in accuracy and a 12.7\% reduction in loss). We further show that the learned features of raw waveforms and Mel-spectrograms are indeed complementary to each other and that there is a consistent improvement in classification performance over models trained on Mel-spectrograms or waveforms alone. 7 | 8 | ![Model Figure](model_figure.png) 9 | 10 | 11 | | Model | Accuracy % | Log Loss | # Params | 12 | | ----------- | ----------- | ----------- | ----------- | 13 | | Audio baseline | 65.1 | 1.048 | - | 14 | | Waveform sub-network | 64.79 | 1.045 | 1.0M | 15 | | Spectrogram sub-network | 66.46 | 1.072 | 1.1M | 16 | | Fusion model | **70.78** | **0.915** | 1.4M | 17 | 18 | ## Installing the Environment 19 | 20 | We use Docker! Specifically, we use Docker, Docker Compose, and Nvidia Docker to have consistent environments to work in. To use this environment, there are two options: utilize the Nvidia Docker environment, or custom install your own python virtual environment. 21 | 22 | ### Container Installation 23 | 24 | * First, you must install [Docker](https://docs.docker.com/get-docker/), [Docker Compose](https://docs.docker.com/compose/install/), and [Nvidia Docker](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html#docker). You must also have Nvidia Drivers installed on your machine. 25 | * Next, run `make build` to build the container. 26 | * Run `make run` to start the Jupyter Lab environment on `localhost:8888`. 27 | * Note: You may also run `make terminal` to open a bash terminal inside the container. 28 | 29 | ### Virtual Environment Installation 30 | 31 | In the container, we use CUDA 10.2, but a newer installation should work as well. 32 | 33 | 34 | 35 | ## Preprocessing the Dataset 36 | We provide a few scripts to automatically process the 2021 TAU Audio-Visual Scene dataset into the 1-second files as described in the DCASE 2021 Challenge Task1B. 37 | 38 | To process the dataset, run these lines in the `workspace/` directory: 39 | 40 | * `cd workspace/` 41 | * `chmod +x scripts/download_dataset.sh` 42 | * `chmod +x scripts/unzip_dataset.sh` 43 | * `./scripts/download_dataset.sh` 44 | * `./scripts/unzip_dataset.sh` 45 | 46 | Next, inside of the container, run the following script: 47 | 48 | * `python3 scripts/process_dcase_1s.py` 49 | 50 | ## Training the Fusion Model 51 | 52 | Inside the container, run the following command in the root directory of the container: 53 | 54 | * `python3 train.py -c configs/wave_spec_fusion.json -d <0,1,2,3, etc. GPU #> --run_id 'Fusion Model'` 55 | 56 | This will train the model using the configuration defined in `configs/wave_spec_fusion.json`. We utilize Weights and Biases logging, but this is optional. 57 | If you wish to use Weights and Biases, you can login using your personal api key and runs will be logged. 58 | 59 | ## Code References 60 | 61 | In this Codebase, we utilize code from the following sources: 62 | 63 | * [https://github.com/m-koichi/ConformerSED](https://github.com/m-koichi/ConformerSED) 64 | * [https://github.com/mravanelli/SincNet](https://github.com/mravanelli/SincNet) 65 | * [https://github.com/MILVLG/openvqa](https://github.com/MILVLG/openvqa) 66 | -------------------------------------------------------------------------------- /workspace/models/metric.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | 3 | import pandas as pd 4 | import numpy as np 5 | import torch 6 | from sklearn.metrics import roc_auc_score, average_precision_score 7 | 8 | smooth = 1e-6 9 | 10 | 11 | class MetricTracker: 12 | def __init__(self, keys_iter: list, keys_epoch: list, writer=None): 13 | self.writer = writer 14 | self.metrics_iter = pd.DataFrame(index=keys_iter, columns=['current', 'sum', 'square_sum', 'counts', 15 | 'mean', 'square_avg', 'std']) 16 | self.metrics_epoch = pd.DataFrame(index=keys_epoch, columns=['mean']) 17 | self.reset() 18 | 19 | def reset(self): 20 | for col in self.metrics_iter.columns: 21 | self.metrics_iter[col].values[:] = 0 22 | 23 | def iter_update(self, key, value, n=1): 24 | if self.writer is not None: 25 | self.writer.add_scalar(key, value) 26 | self.metrics_iter.at[key, 'current'] = value 27 | self.metrics_iter.at[key, 'sum'] += value * n 28 | self.metrics_iter.at[key, 'square_sum'] += value * value * n 29 | self.metrics_iter.at[key, 'counts'] += n 30 | 31 | def epoch_update(self, key, value): 32 | self.metrics_epoch.at[key, 'mean'] = value 33 | 34 | def current(self): 35 | return dict(self.metrics_iter['current']) 36 | 37 | def avg(self): 38 | for key, row in self.metrics_iter.iterrows(): 39 | self.metrics_iter.at[key, 'mean'] = row['sum'] / row['counts'] 40 | self.metrics_iter.at[key, 'square_avg'] = row['square_sum'] / row['counts'] 41 | 42 | def std(self): 43 | for key, row in self.metrics_iter.iterrows(): 44 | self.metrics_iter.at[key, 'std'] = sqrt(row['square_avg'] - row['mean']**2 + smooth) 45 | 46 | def result(self): 47 | self.avg() 48 | self.std() 49 | iter_result = self.metrics_iter[['mean', 'std']] 50 | epoch_result = self.metrics_epoch 51 | 52 | return pd.concat([iter_result, epoch_result]) 53 | 54 | 55 | def accuracy(output, target): 56 | with torch.no_grad(): 57 | pred = torch.argmax(output, dim=1) 58 | assert pred.shape[0] == len(target) 59 | correct = 0 60 | correct += torch.sum(pred == target).item() 61 | return correct / len(target) 62 | 63 | 64 | def top_k_acc(output, target, k=3): 65 | with torch.no_grad(): 66 | pred = torch.topk(output, k, dim=1)[1] 67 | assert pred.shape[0] == len(target) 68 | correct = 0 69 | for i in range(k): 70 | correct += torch.sum(pred[:, i] == target).item() 71 | return correct / len(target) 72 | 73 | 74 | def binary_accuracy(output, target): 75 | with torch.no_grad(): 76 | correct = 0 77 | correct += torch.sum(torch.abs(output - target) < 0.5).item() 78 | return correct / len(target) 79 | 80 | 81 | def AUROC(output, target): 82 | with torch.no_grad(): 83 | value = roc_auc_score(target.cpu().numpy(), output.cpu().numpy()) 84 | return value 85 | 86 | 87 | def AUPRC(output, target): 88 | with torch.no_grad(): 89 | value = average_precision_score(target.cpu().numpy(), output.cpu().numpy()) 90 | return value 91 | 92 | 93 | def noise_RoCAUC(output, target): 94 | with torch.no_grad(): 95 | value = roc_auc_score(target.cpu().numpy()[:,0], output.cpu().numpy()[:,0]) 96 | return value 97 | 98 | 99 | def noise_AP(output, target): 100 | with torch.no_grad(): 101 | value = average_precision_score(target.cpu().numpy()[:,0], output.cpu().numpy()[:,0]) 102 | return value 103 | 104 | 105 | def mean_iou_score(output, labels): 106 | ''' 107 | Compute mean IoU score over 6 classes 108 | ''' 109 | with torch.no_grad(): 110 | pred = torch.argmax(output, dim=1) 111 | pred = pred.data.cpu().numpy() 112 | labels = labels.data.cpu().numpy() 113 | mean_iou = 0 114 | for i in range(6): 115 | tp_fp = np.sum(pred == i) 116 | tp_fn = np.sum(labels == i) 117 | tp = np.sum((pred == i) * (labels == i)) 118 | iou = (tp + smooth) / (tp_fp + tp_fn - tp + smooth) 119 | mean_iou += iou / 6 120 | 121 | return mean_iou 122 | -------------------------------------------------------------------------------- /pip_requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.9.0 2 | aiohttp==3.7.4.post0 3 | aiohttp-cors==0.7.0 4 | aioredis==1.3.1 5 | anyio==2.1.0 6 | appdirs==1.4.4 7 | argon2-cffi==20.1.0 8 | asn1crypto==0.24.0 9 | async-generator==1.10 10 | async-timeout==3.0.1 11 | attrs==20.3.0 12 | audiomentations==0.16.0 13 | audioread==2.1.9 14 | Babel==2.9.0 15 | backcall==0.2.0 16 | beautifulsoup4==4.9.3 17 | bleach==3.3.0 18 | blessings==1.7 19 | blis==0.4.1 20 | Bottleneck==1.3.2 21 | cachetools==4.2.1 22 | catalogue==1.0.0 23 | certifi==2020.12.5 24 | cffi==1.14.5 25 | chardet==4.0.0 26 | click==7.1.2 27 | colorama==0.4.4 28 | colorful==0.5.4 29 | configparser==5.0.1 30 | contextvars==2.4 31 | cryptography==2.1.4 32 | cycler==0.10.0 33 | cymem==2.0.3 34 | dataclasses==0.7 35 | decorator==5.0.6 36 | defusedxml==0.6.0 37 | docker-pycreds==0.4.0 38 | entrypoints==0.3 39 | fastai==1.0.61 40 | fastprogress==0.2.3 41 | ffmpeg==1.4 42 | ffmpeg-python==0.2.0 43 | filelock==3.0.12 44 | Flask==1.1.2 45 | future==0.18.2 46 | gitdb==4.0.5 47 | GitPython==3.1.13 48 | google==3.0.0 49 | google-api-core==1.26.3 50 | google-auth==1.28.0 51 | google-auth-oauthlib==0.4.1 52 | google-cloud-bigquery==2.14.0 53 | google-cloud-bigquery-storage==2.4.0 54 | google-cloud-core==1.6.0 55 | google-crc32c==1.1.2 56 | google-resumable-media==1.2.0 57 | googleapis-common-protos==1.53.0 58 | gpustat==0.6.0 59 | grpcio==1.37.0 60 | hiredis==2.0.0 61 | idna==2.10 62 | idna-ssl==1.1.0 63 | immutables==0.15 64 | importlib-metadata==3.10.0 65 | ipykernel==5.4.3 66 | ipython==7.16.1 67 | ipython-genutils==0.2.0 68 | ipywidgets==7.6.3 69 | itsdangerous==1.1.0 70 | jedi==0.18.0 71 | Jinja2==2.11.2 72 | joblib==1.0.1 73 | json5==0.9.5 74 | jsonschema==3.2.0 75 | jupyter==1.0.0 76 | jupyter-client==6.1.11 77 | jupyter-console==6.2.0 78 | jupyter-core==4.7.1 79 | jupyter-server==1.3.0 80 | jupyterlab==3.0.0 81 | jupyterlab-pygments==0.1.2 82 | jupyterlab-server==2.2.0 83 | jupyterlab-widgets==1.0.0 84 | keyring==10.6.0 85 | keyrings.alt==3.0 86 | kiwisolver==1.2.0 87 | libcst==0.3.18 88 | librosa==0.8.0 89 | llvmlite==0.31.0 90 | Markdown==3.2.2 91 | MarkupSafe==1.1.1 92 | matplotlib==3.2.2 93 | mistune==0.8.4 94 | msgpack==1.0.2 95 | multidict==5.1.0 96 | murmurhash==1.0.2 97 | mypy-extensions==0.4.3 98 | nbclassic==0.2.6 99 | nbclient==0.5.2 100 | nbconvert==6.0.7 101 | nbformat==5.1.2 102 | nest-asyncio==1.5.1 103 | notebook==6.2.0 104 | numba==0.48.0 105 | numexpr==2.7.1 106 | numpy==1.19.5 107 | nvidia-ml-py3==7.352.0 108 | oauthlib==3.1.0 109 | opencensus==0.7.12 110 | opencensus-context==0.1.2 111 | opencv-python==4.2.0.34 112 | packaging==20.9 113 | pandas==1.0.5 114 | pandas-gbq==0.14.1 115 | pandocfilters==1.4.3 116 | parso==0.8.1 117 | pathtools==0.1.2 118 | pexpect==4.8.0 119 | pickleshare==0.7.5 120 | Pillow==7.2.0 121 | plac==1.1.3 122 | pooch==1.3.0 123 | preshed==3.0.2 124 | prometheus-client==0.10.0 125 | promise==2.3 126 | prompt-toolkit==3.0.16 127 | proto-plus==1.18.1 128 | protobuf==3.15.7 129 | psutil==5.8.0 130 | ptyprocess==0.7.0 131 | py-spy==0.3.5 132 | pyarrow==3.0.0 133 | pyasn1==0.4.8 134 | pyasn1-modules==0.2.8 135 | pycparser==2.20 136 | pycrypto==2.6.1 137 | pydata-google-auth==1.2.0 138 | Pygments==2.7.4 139 | pygobject==3.26.1 140 | pyparsing==2.4.7 141 | pyrsistent==0.17.3 142 | python-dateutil==2.8.1 143 | pytz==2021.1 144 | pyxdg==0.25 145 | PyYAML==5.4.1 146 | pyzmq==22.0.2 147 | qtconsole==5.0.2 148 | QtPy==1.9.0 149 | ray==1.0.1.post1 150 | redis==3.4.1 151 | requests==2.25.1 152 | requests-oauthlib==1.3.0 153 | resampy==0.2.2 154 | rsa==4.7.2 155 | scikit-learn==0.24.1 156 | scipy==1.5.4 157 | SecretStorage==2.3.1 158 | Send2Trash==1.5.0 159 | sentry-sdk==0.20.0 160 | shortuuid==1.0.1 161 | six==1.15.0 162 | smmap==3.0.5 163 | sniffio==1.2.0 164 | SoundFile==0.10.3.post1 165 | soupsieve==2.2.1 166 | spacy==2.3.1 167 | srsly==1.0.2 168 | subprocess32==3.5.4 169 | tabulate==0.8.9 170 | tensorboard==2.3.0 171 | tensorboard-plugin-wit==1.7.0 172 | tensorboardX==2.1 173 | terminado==0.9.2 174 | testpath==0.4.4 175 | thinc==7.4.1 176 | threadpoolctl==2.1.0 177 | torchaudio==0.8.1 178 | torchlibrosa==0.0.9 179 | torchsummary==1.5.1 180 | torchvision==0.9.1 181 | tornado==6.1 182 | tqdm==4.47.0 183 | traitlets==4.3.3 184 | typing-extensions==3.7.4.3 185 | typing-inspect==0.6.0 186 | urllib3==1.26.4 187 | wandb==0.10.18 188 | wasabi==0.7.0 189 | wcwidth==0.2.5 190 | webencodings==0.5.1 191 | Werkzeug==1.0.1 192 | widgetsnbextension==3.5.1 193 | yarl==1.6.3 194 | zipp==3.4.1 -------------------------------------------------------------------------------- /workspace/scripts/process_dcase_1s.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("../") 3 | import librosa 4 | import torch 5 | import pandas as pd 6 | import os 7 | import numpy as np 8 | from tqdm import tqdm 9 | import pathlib 10 | 11 | train = pd.read_csv("data/tau_audiovisual_2021/evaluation_setup/fold1_train.csv", sep='\t') 12 | val = pd.read_csv("data/tau_audiovisual_2021/evaluation_setup/fold1_evaluate.csv", sep='\t') 13 | 14 | 15 | new_df = pd.DataFrame(columns=["filename_wave", "filename_spec", "scene_label"]) 16 | 17 | audio_dir = "data/tau_audiovisual_2021/" 18 | spec_cachedir = "data/tau_audiovisual_2021/spectrogram_2048window_256hop_cache" 19 | wave_cachedir = "data/tau_audiovisual_2021/waveform_cache" 20 | 21 | for idx, row in tqdm(train.iterrows(), total=len(train)): 22 | audio, sr = librosa.load(os.path.join(audio_dir, row['filename_audio']), sr=48000) 23 | padded = np.zeros(480000, dtype='float32') 24 | wave = audio[:480000] 25 | padded[0:len(wave)] = wave 26 | 27 | 28 | # WAVEFORM 29 | wave = padded 30 | 31 | for idx, i in enumerate(np.split(wave, 10)): 32 | spec = librosa.feature.melspectrogram(i, n_fft=2048, hop_length=256, n_mels=128, sr=48000, fmin=0, fmax=24000) 33 | # print(spec.shape) 34 | # print(spec) 35 | # if spec.shape[1] == 501: 36 | # spec = spec[:,:-1] 37 | spec = np.log(spec) 38 | # print(spec) 39 | 40 | # z-score normalization 41 | std = spec.std() 42 | mean = spec.mean() 43 | spec = (spec - mean) / std 44 | 45 | # print(row['filename_audio']) 46 | s = row['filename_audio'].replace(".wav", "") 47 | 48 | fname_wave = os.path.join(wave_cachedir, f"{s}_{idx}.npy") 49 | fname_spec = os.path.join(spec_cachedir, f"{s}_{idx}.npy") 50 | 51 | # set paths for wave and spec 52 | pathlib.Path(os.path.join(spec_cachedir, f"{s}_{idx}.npy")).parent.mkdir(parents=True, exist_ok=True) 53 | pathlib.Path(os.path.join(wave_cachedir, f"{s}_{idx}.npy")).parent.mkdir(parents=True, exist_ok=True) 54 | 55 | # np.save(os.path.join(self.spectrogram_cachedir, item['filename_audio']).replace(".wav", ".npy"), spec) 56 | np.save(os.path.join(spec_cachedir, f"{s}_{idx}.npy"), spec) 57 | np.save(os.path.join(wave_cachedir, f"{s}_{idx}.npy"), i) 58 | 59 | new_df = new_df.append({"filename_wave": fname_wave, "filename_spec": fname_spec, "scene_label": row['scene_label']}, ignore_index=True) 60 | 61 | 62 | new_df.to_csv("data/tau_audiovisual_2021/train_1sec.csv") 63 | 64 | 65 | ### VALIDATION 66 | 67 | 68 | new_df = pd.DataFrame(columns=["filename_wave", "filename_spec", "scene_label"]) 69 | 70 | audio_dir = "data/tau_audiovisual_2021/" 71 | spec_cachedir = "data/tau_audiovisual_2021/spectrogram_2048window_256hop_cache" 72 | wave_cachedir = "data/tau_audiovisual_2021/waveform_cache" 73 | 74 | for idx, row in tqdm(val.iterrows(), total=len(val)): 75 | audio, sr = librosa.load(os.path.join(audio_dir, row['filename_audio']), sr=48000) 76 | padded = np.zeros(480000, dtype='float32') 77 | wave = audio[:480000] 78 | padded[0:len(wave)] = wave 79 | 80 | 81 | # WAVEFORM 82 | wave = padded 83 | 84 | for idx, i in enumerate(np.split(wave, 10)): 85 | spec = librosa.feature.melspectrogram(i, n_fft=2048, hop_length=256, n_mels=128, sr=48000, fmin=0, fmax=24000) 86 | # print(spec.shape) 87 | # print(spec) 88 | # if spec.shape[1] == 501: 89 | # spec = spec[:,:-1] 90 | 91 | spec = np.log(spec) 92 | # print(spec) 93 | 94 | # z-score normalization 95 | std = spec.std() 96 | mean = spec.mean() 97 | spec = (spec - mean) / std 98 | 99 | # print(row['filename_audio']) 100 | s = row['filename_audio'].replace(".wav", "") 101 | 102 | fname_wave = os.path.join(wave_cachedir, f"{s}_{idx}.npy") 103 | fname_spec = os.path.join(spec_cachedir, f"{s}_{idx}.npy") 104 | 105 | # set paths for wave and spec 106 | pathlib.Path(os.path.join(spec_cachedir, f"{s}_{idx}.npy")).parent.mkdir(parents=True, exist_ok=True) 107 | pathlib.Path(os.path.join(wave_cachedir, f"{s}_{idx}.npy")).parent.mkdir(parents=True, exist_ok=True) 108 | 109 | # np.save(os.path.join(self.spectrogram_cachedir, item['filename_audio']).replace(".wav", ".npy"), spec) 110 | np.save(os.path.join(spec_cachedir, f"{s}_{idx}.npy"), spec) 111 | np.save(os.path.join(wave_cachedir, f"{s}_{idx}.npy"), i) 112 | 113 | new_df = new_df.append({"filename_wave": fname_wave, "filename_spec": fname_spec, "scene_label": row['scene_label']}, ignore_index=True) 114 | 115 | 116 | new_df.to_csv("data/tau_audiovisual_2021/val_1sec.csv") 117 | 118 | print("Finished processing dataset to 1-second samples!") -------------------------------------------------------------------------------- /workspace/utils/util.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from itertools import repeat 4 | from collections import OrderedDict 5 | from functools import reduce 6 | from operator import getitem 7 | import wandb 8 | import torch 9 | import os 10 | 11 | 12 | def ensure_dir(dirname): 13 | dirname = Path(dirname) 14 | if not dirname.is_dir(): 15 | dirname.mkdir(parents=True, exist_ok=False) 16 | 17 | 18 | def read_json(fname): 19 | fname = Path(fname) 20 | with fname.open('rt') as handle: 21 | return json.load(handle, object_hook=OrderedDict) 22 | 23 | 24 | def write_json(content, fname): 25 | fname = Path(fname) 26 | with fname.open('wt') as handle: 27 | json.dump(content, handle, indent=4, sort_keys=False) 28 | 29 | 30 | def inf_loop(data_loader): 31 | ''' wrapper function for endless data loader. ''' 32 | for loader in repeat(data_loader): 33 | yield from loader 34 | 35 | 36 | def prepare_device(n_gpu_use): 37 | """ 38 | Setup GPU device if available, move model into configured device 39 | """ 40 | n_gpu = torch.cuda.device_count() 41 | if n_gpu_use > 0 and n_gpu == 0: 42 | print("Warning: There\'s no GPU available on this machine," 43 | "training will be performed on CPU.") 44 | n_gpu_use = 0 45 | if n_gpu_use > n_gpu: 46 | print(f"Warning: The number of GPU\'s configured to use is {n_gpu_use}, but only {n_gpu} are " 47 | "available on this machine.") 48 | n_gpu_use = n_gpu 49 | device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu') 50 | list_ids = list(range(n_gpu_use)) 51 | return device, list_ids 52 | 53 | 54 | def set_by_path(tree, keys, value): 55 | '''Set a value in a nested object in tree by sequence of keys.''' 56 | keys = keys.split(';') 57 | get_by_path(tree, keys[:-1])[keys[-1]] = value 58 | 59 | 60 | def get_by_path(tree, keys): 61 | '''Access a nested object in tree by sequence of keys.''' 62 | return reduce(getitem, keys, tree) 63 | 64 | 65 | def msg_box(msg): 66 | row = len(msg) 67 | h = ''.join(['+'] + ['-' * row] + ['+']) 68 | result = h + f"\n|{msg}|\n" + h 69 | return result 70 | 71 | def wandb_save_code(config, base_path='../'): 72 | """ 73 | Save all files associated with a run to a wandb run 74 | """ 75 | GLOBAL_FILES = ["models/loss.py", 76 | "models/metric.py", 77 | "utils/util.py", 78 | "base/base_dataloader.py", 79 | "base/base_model.py", 80 | "base/base_trainer.py", 81 | "parse_config.py"] 82 | 83 | # Save config file 84 | wandb.save(str(Path(config.run_args.config)), base_path=base_path) 85 | 86 | # Save top-level python scripts 87 | for file in os.listdir("./"): 88 | if file.endswith(".py"): 89 | wandb.save(file, base_path=base_path) 90 | 91 | # Save dataset code 92 | datasets = config.config['datasets'] 93 | for key, value in datasets['train'].items(): 94 | if 'module' in datasets['train'][key]: 95 | wandb.save(f"data_loaders{datasets['train'][key]['module']}".replace(".", "/") + ".py", base_path=base_path) 96 | for key, value in datasets['valid'].items(): 97 | if 'module' in datasets['valid'][key]: 98 | wandb.save(f"data_loaders{datasets['valid'][key]['module']}".replace(".", "/") + ".py", base_path=base_path) 99 | for key, value in datasets['test'].items(): 100 | if 'module' in datasets['test'][key]: 101 | wandb.save(f"data_loaders{datasets['test'][key]['module']}".replace(".", "/") + ".py", base_path=base_path) 102 | 103 | # Save dataloader code 104 | data_loaders = config.config['data_loaders'] 105 | for key, value in data_loaders['train'].items(): 106 | if 'module' in data_loaders['train'][key]: 107 | wandb.save(f"data_loaders{data_loaders['train'][key]['module']}".replace(".", "/") + ".py", base_path=base_path) 108 | for key, value in data_loaders['valid'].items(): 109 | if 'module' in data_loaders['train'][key]: 110 | wandb.save(f"data_loaders{data_loaders['train'][key]['module']}".replace(".", "/") + ".py", base_path=base_path) 111 | for key, value in data_loaders['test'].items(): 112 | if 'module' in data_loaders['train'][key]: 113 | wandb.save(f"data_loaders{data_loaders['train'][key]['module']}".replace(".", "/") + ".py", base_path=base_path) 114 | 115 | # Save model code 116 | models = config.config['models'] 117 | for key, value in models.items(): 118 | if 'module' in models[key]: 119 | wandb.save(f"models{models[key]['module']}".replace(".", "/") + ".py", base_path=base_path) 120 | 121 | # Save trainer 122 | wandb.save(f"trainers{config.config['trainer']['module']}".replace(".", "/") + ".py", base_path=base_path) 123 | 124 | # Save global files for all runs 125 | for file in GLOBAL_FILES: 126 | wandb.save(file, base_path=base_path) -------------------------------------------------------------------------------- /workspace/configs/wave_spec_fusion.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_gpu": 1, 3 | "root_dir": "./", 4 | "name": "dcase_task1b_1s", 5 | 6 | "wandb": { 7 | "project": "dcase_task1b", 8 | "notes": "DCASE Task1B fusion model.", 9 | "entity": "wandb" 10 | }, 11 | 12 | "datasets": { 13 | "train": { 14 | "data": { 15 | "module": ".1s_dataset", 16 | "type": "MultiModalAugmentationDataset", 17 | "kwargs": { 18 | "data_dir": "data/tau_audiovisual_2021/train_1sec.csv", 19 | "label_list": ["airport", 20 | "shopping_mall", 21 | "metro_station", 22 | "street_pedestrian", 23 | "public_square", 24 | "street_traffic", 25 | "tram", 26 | "bus", 27 | "metro", 28 | "park"], 29 | "wave_cache_dir": "data/tau_audiovisual_2021/waveform_cache/", 30 | "spec_cache_dir": "data/tau_audiovisual_2021/spectrogram_2048window_256hop_cache/", 31 | "w_shift": true, 32 | "s_shift": true 33 | } 34 | } 35 | }, 36 | "valid": { 37 | "data": { 38 | "module": ".1s_dataset", 39 | "type": "MultiModalAugmentationDataset", 40 | "kwargs": { 41 | "data_dir": "data/tau_audiovisual_2021/val_1sec.csv", 42 | "label_list": ["airport", 43 | "shopping_mall", 44 | "metro_station", 45 | "street_pedestrian", 46 | "public_square", 47 | "street_traffic", 48 | "tram", 49 | "bus", 50 | "metro", 51 | "park"], 52 | "wave_cache_dir": "data/tau_audiovisual_2021/waveform_cache/", 53 | "spec_cache_dir": "data/tau_audiovisual_2021/spectrogram_2048window_256hop_cache/" 54 | } 55 | } 56 | }, 57 | "test": { 58 | "data": { 59 | "module": ".1s_dataset", 60 | "type": "MultiModalAugmentationDataset", 61 | "kwargs": { 62 | "data_dir": "data/tau_audiovisual_2021/val_1sec.csv", 63 | "label_list": ["airport", 64 | "shopping_mall", 65 | "metro_station", 66 | "street_pedestrian", 67 | "public_square", 68 | "street_traffic", 69 | "tram", 70 | "bus", 71 | "metro", 72 | "park"], 73 | "wave_cache_dir": "data/tau_audiovisual_2021/waveform_cache/", 74 | "spec_cache_dir": "data/tau_audiovisual_2021/spectrogram_2048window_256hop_cache/" 75 | } 76 | } 77 | } 78 | }, 79 | "data_loaders": { 80 | "train": { 81 | "data": { 82 | "module": ".1s_dataset", 83 | "type": "BaseDataLoader", 84 | "kwargs": { 85 | "validation_split": 0.0, 86 | "DataLoader_kwargs": { 87 | "batch_size": 128, 88 | "shuffle": true, 89 | "num_workers": 4 90 | } 91 | } 92 | } 93 | }, 94 | "valid": { 95 | "data": { 96 | "module": ".1s_dataset", 97 | "type": "BaseDataLoader", 98 | "kwargs": { 99 | "validation_split": 0.0, 100 | "DataLoader_kwargs": { 101 | "batch_size": 128, 102 | "shuffle": false, 103 | "num_workers": 4 104 | } 105 | } 106 | } 107 | }, 108 | "test": { 109 | "data": { 110 | "module": ".1s_dataset", 111 | "type": "BaseDataLoader", 112 | "kwargs": { 113 | "validation_split": 0.0, 114 | "DataLoader_kwargs": { 115 | "batch_size": 128, 116 | "shuffle": false, 117 | "num_workers": 4 118 | } 119 | } 120 | } 121 | } 122 | }, 123 | "models": { 124 | "model": { 125 | "module": ".model", 126 | "type": "SmallMultiModalFusionClassifier", 127 | "kwargs": { 128 | "input_length": 480000, 129 | "n_bins": 128, 130 | "n_frames": 188, 131 | "num_classes": 10, 132 | "fusion_method": "sum", 133 | "parameterization": "sinc", 134 | "non_linearity": "LeakyReLU" 135 | } 136 | } 137 | }, 138 | "losses": { 139 | "loss": { 140 | "type": "CrossEntropyLoss" 141 | } 142 | }, 143 | "metrics": { 144 | "per_iteration": [], 145 | "per_epoch": ["accuracy"] 146 | }, 147 | "optimizers": { 148 | "model": { 149 | "type": "SGD", 150 | "kwargs": { 151 | "lr": 0.0001, 152 | "momentum": 0.9 153 | } 154 | } 155 | }, 156 | "lr_schedulers": { 157 | }, 158 | "trainer": { 159 | "module": ".trainer_mixup", 160 | "type": "DCASETask1BTrainerWithMixup", 161 | "kwargs": { 162 | "finetune": false, 163 | "epochs": 50, 164 | "len_epoch": null, 165 | 166 | "find_lr": true, 167 | 168 | "mixup": true, 169 | "mixup_p": 0.5, 170 | 171 | "save_period": 50, 172 | "save_the_best": true, 173 | "verbosity": 2, 174 | 175 | "monitor": "min val_loss", 176 | "early_stop": 0, 177 | 178 | "tensorboard": false 179 | } 180 | } 181 | } 182 | -------------------------------------------------------------------------------- /workspace/base/base_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from abc import abstractmethod 3 | 4 | import torch 5 | import numpy as np 6 | 7 | from base import Cross_Valid 8 | from logger import get_logger, TensorboardWriter 9 | from utils import msg_box 10 | 11 | 12 | class BaseTrainer: 13 | """ 14 | Base class for all trainers 15 | """ 16 | def __init__(self, torch_args: dict, save_dir, **kwargs): 17 | # data_loaders 18 | self.train_data_loaders = torch_args['data_loaders']['train'] 19 | self.valid_data_loaders = torch_args['data_loaders']['valid'] 20 | # models 21 | self.models = torch_args['models'] 22 | # losses 23 | self.losses = torch_args['losses'] 24 | # metrics 25 | self.metrics_iter = torch_args['metrics']['iter'] 26 | self.metrics_epoch = torch_args['metrics']['epoch'] 27 | # optimizers 28 | self.optimizers = torch_args['optimizers'] 29 | # lr_schedulers 30 | self.lr_schedulers = torch_args['lr_schedulers'] 31 | 32 | self.model_dir = save_dir['model'] 33 | # set json kwargs to self.{kwargs} 34 | for key, value in kwargs.items(): 35 | setattr(self, key, value) 36 | 37 | self.logger = get_logger('trainer', verbosity=self.verbosity) 38 | if self.early_stop <= 0 or self.early_stop is None: 39 | self.early_stop = np.inf 40 | self.start_epoch = 1 41 | 42 | # configuration to monitor model performance and save best 43 | self.num_best = 0 44 | if self.monitor == 'off': 45 | self.mnt_mode = 'off' 46 | self.mnt_best = 0 47 | else: 48 | self.mnt_mode, self.mnt_metric = self.monitor.split() 49 | assert self.mnt_mode in ['min', 'max'] 50 | self.mnt_best = np.inf if self.mnt_mode == 'min' else -np.inf 51 | 52 | # setup visualization writer instance 53 | self.writer = TensorboardWriter(save_dir['log'], self.logger, self.tensorboard) 54 | 55 | @abstractmethod 56 | def _train_epoch(self, epoch): 57 | """ 58 | Training logic for an epoch 59 | 60 | :param epoch: Current epoch number 61 | """ 62 | raise NotImplementedError 63 | 64 | def train(self): 65 | """ 66 | Full training logic 67 | """ 68 | 69 | not_improved_count = 0 70 | for epoch in range(self.start_epoch, self.epochs + 1): 71 | train_log = self._train_epoch(epoch) 72 | log_mean = train_log['mean'] 73 | 74 | # evaluate model performance according to configured metric, save best checkpoint as model_best 75 | best = False 76 | if self.mnt_mode != 'off': 77 | try: 78 | # check whether model performance strictly improved or not, according to mnt_metric 79 | improved = (self.mnt_mode == 'min' and log_mean[self.mnt_metric] < self.mnt_best) or \ 80 | (self.mnt_mode == 'max' and log_mean[self.mnt_metric] > self.mnt_best) 81 | except KeyError: 82 | self.logger.warning("Warning: Metric '{}' is not found. " 83 | "Model performance monitoring is disabled.".format(self.mnt_metric)) 84 | self.mnt_mode = 'off' 85 | improved = False 86 | 87 | if improved: 88 | self.mnt_best = log_mean[self.mnt_metric] 89 | log_best = log_mean 90 | not_improved_count = 0 91 | best = True 92 | else: 93 | not_improved_count += 1 94 | 95 | if not_improved_count > self.early_stop: 96 | self.logger.info("Validation performance didn\'t improve for {} epochs. " 97 | "Training stops.".format(self.early_stop)) 98 | break 99 | 100 | if epoch % self.save_period == 0 or best: 101 | self.logger.info("Best {}: {:.5f}".format(self.mnt_metric, self.mnt_best)) 102 | self._save_checkpoint(epoch, save_best=best) 103 | 104 | return log_best 105 | 106 | def _save_checkpoint(self, epoch, save_best=False): 107 | """ 108 | Saving checkpoints 109 | 110 | :param epoch: current epoch number 111 | :param log: logging information of the epoch 112 | :param save_best: if True, rename the saved checkpoint to 'model_best.pth' 113 | """ 114 | state = { 115 | 'epoch': epoch, 116 | 'models': {key: value.state_dict() for key, value in self.models.items()}, 117 | 'optimizers': {key: value.state_dict() for key, value in self.optimizers.items()}, 118 | 'monitor_best': self.mnt_best, 119 | } 120 | k_fold = Cross_Valid.k_fold 121 | fold_idx = Cross_Valid.fold_idx 122 | fold_prefix = f"fold_{fold_idx}_" if k_fold > 1 else '' 123 | 124 | if save_best: 125 | if self.save_the_best: 126 | filename = str(self.model_dir / f"{fold_prefix}model_best.pth") 127 | else: 128 | self.num_best += 1 129 | filename = str(self.model_dir / f"{fold_prefix}model_best{self.num_best}.pth") 130 | else: 131 | filename = str(self.model_dir / f"{fold_prefix}checkpoint-epoch{epoch}.pth") 132 | torch.save(state, filename) 133 | self.logger.info("Saving model: {} ...".format(filename)) 134 | 135 | def _resume_checkpoint(self, resume_path, finetune=False): 136 | """ 137 | Resume from saved checkpoints 138 | 139 | :param resume_path: Checkpoint path to be resumed 140 | """ 141 | resume_path = str(resume_path) 142 | self.logger.info("Loading checkpoint: {} ...".format(resume_path)) 143 | checkpoint = torch.load(resume_path) 144 | if not finetune: 145 | # resume training 146 | self.start_epoch = checkpoint['epoch'] + 1 147 | self.mnt_best = checkpoint['monitor_best'] 148 | 149 | # load each model params from checkpoint. 150 | for key, value in checkpoint['models'].items(): 151 | try: 152 | self.models[key].load_state_dict(value) 153 | except KeyError: 154 | print("models not match, can not resume.") 155 | 156 | # load each optimizer from checkpoint. 157 | for key, value in checkpoint['optimizers'].items(): 158 | try: 159 | self.optimizers[key].load_state_dict(value) 160 | except KeyError: 161 | print("optimizers not match, can not resume.") 162 | 163 | self.logger.info("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch)) 164 | -------------------------------------------------------------------------------- /workspace/parse_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import logging 4 | from pathlib import Path 5 | from functools import partial 6 | from datetime import datetime 7 | import importlib 8 | 9 | from logger import setup_logging 10 | from utils import ensure_dir, read_json, write_json, set_by_path, get_by_path 11 | 12 | 13 | class ConfigParser: 14 | def __init__(self, run_args, modification=None): 15 | """ 16 | class to parse configuration json file. Handles hyperparameters for training, 17 | initializations of modules, checkpoint saving and logging module. 18 | :param run_args: Dict, running arguments including resume, mode, run_id, log_name. 19 | - config: String, path to the config file. 20 | - resume: String, path to the checkpoint being loaded. 21 | - mode: String, 'train', 'test' or 'inference'. 22 | - run_id: Unique Identifier for training processes. Used to save checkpoints and training log. 23 | Timestamp is being used as default 24 | - log_name: Change info.log into .log. 25 | :param modification: Dict {keychain: value}, specifying position values to be replaced from config dict. 26 | """ 27 | # run_args 28 | self.run_args = run_args 29 | # load config file and apply modification 30 | config_json = run_args.config 31 | config = read_json(Path(config_json)) 32 | self._config = _update_config(config, modification) 33 | self.resume = Path(run_args.resume) if run_args.resume is not None else None 34 | self.mode = run_args.mode 35 | log_name = run_args.log_name 36 | 37 | self.root_dir = self.config['root_dir'] 38 | run_id = run_args.run_id 39 | 40 | save_name = {'train': 'saved/', 'test': 'output/'} 41 | save_dir = Path(self.root_dir) / save_name[self.mode] 42 | if run_id is None: # use timestamp as default run-id 43 | run_id = datetime.now().strftime(r'%m%d_%H%M%S') 44 | exp_dir = save_dir / self.config['name'] / run_id 45 | 46 | dirs = {'train': ['log', 'model', 'metrics_best'], 'test': ['log', 'metric', 'fig']} 47 | self.save_dir = dict() 48 | for dir_name in dirs[self.mode]: 49 | dir_path = exp_dir / dir_name 50 | ensure_dir(dir_path) 51 | self.save_dir[dir_name] = dir_path 52 | 53 | log_config = {} 54 | if self.mode == 'train': 55 | fold_idx = self.config['trainer'].get('fold_idx', 0) 56 | if fold_idx > 0: 57 | # multiprocessing is enabled. 58 | log_config.update({'log_config': 'logger/logger_config_mp.json'}) 59 | if fold_idx <= 1: 60 | # backup config file to the experiment dirctory 61 | write_json(self.config, exp_dir / os.path.basename(config_json)) 62 | 63 | # configure logging module 64 | setup_logging(self.save_dir['log'], root_dir=self.root_dir, filename=log_name, **log_config) 65 | 66 | @classmethod 67 | def from_args(cls, parser, options=''): 68 | """ 69 | Initialize this class from some cli arguments. Used in train, test. 70 | """ 71 | args = parser.parse_args() 72 | 73 | msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example." 74 | assert args.config is not None, msg_no_cfg 75 | 76 | if args.device is not None: 77 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device 78 | 79 | modification = None 80 | for group in parser._action_groups: 81 | if group.title == 'mod_args': 82 | # parse custom cli options into dictionary 83 | # modification = {opt.target: getattr(args, _get_opt_name(opt.flags)) for opt in options} 84 | modification = {} 85 | for opt in options: 86 | if isinstance(opt.target, list): 87 | for target in opt.target: 88 | modification[target] = getattr(args, _get_opt_name(opt.flags)) 89 | else: 90 | modification[opt.target] = getattr(args, _get_opt_name(opt.flags)) 91 | else: 92 | group_dict = {g.dest: getattr(args, g.dest, None) for g in group._group_actions} 93 | arg_group = argparse.Namespace(**group_dict) 94 | if group.title == 'run_args': 95 | run_args = arg_group 96 | elif group.title == 'test_args': 97 | cls.test_args = arg_group 98 | 99 | return cls(run_args, modification) 100 | 101 | @staticmethod 102 | def _update_kwargs(_config, kwargs): 103 | try: 104 | _kwargs = dict(_config['kwargs']) 105 | except KeyError: # In case no arguments are specified 106 | _kwargs = dict() 107 | assert all([k not in _kwargs for k in kwargs]), 'Overwriting kwargs given in config file is not allowed' 108 | _kwargs.update(kwargs) 109 | return _kwargs 110 | 111 | def init_obj(self, keys, module, *args, **kwargs): 112 | """ 113 | Returns an object or a function, which is specified in config[keys[0]]...[keys[-1]]. 114 | In config[keys[0]]...[keys[-1]], 115 | 'is_ftn': If True, return a function. If False, return an object. 116 | 'module': The module of each instance. 117 | 'type': Class name. 118 | 'kwargs': Keyword arguments for the class initialization. 119 | keys is the list of config entries. 120 | module is the package module. 121 | Additional *args and **kwargs would be forwarded to obj() 122 | Usage: `objects = config.init_obj(['A', 'B', 'C'], module, a, b=1)` 123 | """ 124 | obj_config = get_by_path(self, keys) 125 | try: 126 | module_name = obj_config['module'] 127 | module_obj = importlib.import_module(module_name, package=module) 128 | except KeyError: # In case no 'module' is specified 129 | module_obj = module 130 | class_name = obj_config['type'] 131 | obj = getattr(module_obj, class_name) 132 | kwargs_obj = self._update_kwargs(obj_config, kwargs) 133 | 134 | if obj_config.get('is_ftn', False): 135 | return partial(obj, *args, **kwargs_obj) 136 | return obj(*args, **kwargs_obj) 137 | 138 | def __getitem__(self, name): 139 | """Access items like ordinary dict.""" 140 | return self.config[name] 141 | 142 | # read-only attributes 143 | @property 144 | def config(self): 145 | return self._config 146 | 147 | 148 | # helper functions to update config dict with custom cli options 149 | def _update_config(config, modification): 150 | if modification is None: 151 | return config 152 | 153 | for key, value in modification.items(): 154 | if value is not None: 155 | set_by_path(config, key, value) 156 | return config 157 | 158 | 159 | def _get_opt_name(flags): 160 | for flg in flags: 161 | if flg.startswith('--'): 162 | return flg.replace('--', '') 163 | return flags[0].replace('--', '') 164 | -------------------------------------------------------------------------------- /workspace/base/base_dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pandas as pd 4 | import numpy as np 5 | from torch.utils.data import DataLoader 6 | from torch.utils.data.sampler import SubsetRandomSampler 7 | import torch 8 | from logger import get_logger 9 | from utils import msg_box 10 | from torch.utils.data.dataloader import default_collate 11 | 12 | 13 | class BaseDataLoader(DataLoader): 14 | """ 15 | Split one dataset into train data_loader and valid data_loader 16 | """ 17 | logger = get_logger('data_loader') 18 | 19 | def __init__(self, dataset, validation_split=0.0, 20 | DataLoader_kwargs=None, do_transform=False): 21 | self.dataset = dataset 22 | self.n_samples = len(dataset) 23 | self.split = validation_split 24 | self.init_kwargs = DataLoader_kwargs if DataLoader_kwargs is not None else {} 25 | 26 | if Cross_Valid.k_fold > 1: 27 | fold_msg = msg_box(f"Fold {Cross_Valid.fold_idx}") 28 | self.logger.info(fold_msg) 29 | split_idx = dataset.get_split_idx(Cross_Valid.fold_idx - 1) 30 | train_sampler, valid_sampler = self._get_sampler(*split_idx) 31 | if do_transform: 32 | dataset.transform(split_idx) 33 | super().__init__(dataset, sampler=train_sampler, **self.init_kwargs) 34 | self.valid_loader = DataLoader(dataset, sampler=valid_sampler, **self.init_kwargs) 35 | else: 36 | if validation_split > 0.0: 37 | split_idx = self._split_sampler() 38 | train_sampler, valid_sampler = self._get_sampler(*split_idx) 39 | if do_transform: 40 | dataset.transform(split_idx) 41 | super().__init__(dataset, sampler=train_sampler, **self.init_kwargs) 42 | self.valid_loader = DataLoader(dataset, sampler=valid_sampler, **self.init_kwargs) 43 | else: 44 | super().__init__(self.dataset, **self.init_kwargs) 45 | self.valid_loader = None 46 | 47 | def _get_sampler(self, train_idx, valid_idx): 48 | train_sampler = SubsetRandomSampler(train_idx) 49 | valid_sampler = SubsetRandomSampler(valid_idx) 50 | 51 | # turn off shuffle option which is mutually exclusive with sampler 52 | self.init_kwargs['shuffle'] = False 53 | self.n_samples = len(train_idx) 54 | 55 | return train_sampler, valid_sampler 56 | 57 | def _split_sampler(self): 58 | idx_full = np.arange(self.n_samples) 59 | 60 | np.random.seed(0) 61 | np.random.shuffle(idx_full) 62 | 63 | if isinstance(self.split, int): 64 | assert self.split > 0 65 | assert self.split < self.n_samples, \ 66 | "validation set size is configured to be larger than entire dataset." 67 | len_valid = self.split 68 | else: 69 | len_valid = int(self.n_samples * self.split) 70 | 71 | train_idx, valid_idx = idx_full[len_valid:], idx_full[:len_valid] 72 | 73 | return (train_idx, valid_idx) 74 | 75 | 76 | class Cross_Valid: 77 | @classmethod 78 | def create_CV(cls, k_fold=1, fold_idx=0): 79 | cls.k_fold = k_fold 80 | cls.fold_idx = 1 if fold_idx == 0 else fold_idx 81 | return cls() 82 | 83 | @classmethod 84 | def next_fold(cls): 85 | cls.fold_idx += 1 86 | 87 | 88 | class MultiDatasetDataLoader(DataLoader): 89 | 90 | def __init__(self, datasets, dataset_batches=None, DataLoader_kwargs=None): 91 | dss = [i for i in datasets.values()] 92 | batches = [i for i in dataset_batches.values()] 93 | self.init_kwargs = DataLoader_kwargs if DataLoader_kwargs is not None else {} 94 | 95 | dataset = torch.utils.data.ConcatDataset(dss) 96 | 97 | samplers = [torch.utils.data.RandomSampler(i) for i in dss] 98 | 99 | batch_sampler = ConcatDatasetBatchSampler(samplers, batches) 100 | 101 | super().__init__(dataset, batch_sampler=batch_sampler, collate_fn= lambda b: dict_split_collate(b, dataset_batches), **self.init_kwargs) 102 | 103 | 104 | def dict_split_collate(batch, dataset_batches): 105 | final_data = {} 106 | final_target = {} 107 | offset=0 108 | 109 | for key, value in dataset_batches.items(): 110 | collated = default_collate(batch[offset:offset+value]) 111 | 112 | if isinstance(collated, list): 113 | final_data[key] = collated[0] 114 | final_target[key] = collated[1] 115 | else: 116 | final_data[key] = collated 117 | offset += value 118 | 119 | return final_data, final_target 120 | 121 | from torch.utils.data import Sampler 122 | import numpy as np 123 | 124 | 125 | class ConcatDatasetBatchSampler(Sampler): 126 | """This sampler is built to work with a standard Pytorch ConcatDataset. 127 | From SpeechBrain dataio see https://github.com/speechbrain/ 128 | It is used to retrieve elements from the different concatenated datasets placing them in the same batch 129 | with proportion specified by batch_sizes, e.g 8, 16 means each batch will 130 | be of 24 elements with the first 8 belonging to the first dataset in ConcatDataset 131 | object and the last 16 to the second. 132 | More than two datasets are supported, in that case you need to provide 3 batch 133 | sizes. 134 | Note 135 | ---- 136 | Batched are drawn from the datasets till the one with smallest length is exhausted. 137 | Thus number of examples in your training epoch is dictated by the dataset 138 | whose length is the smallest. 139 | Arguments 140 | --------- 141 | samplers : int 142 | The base seed to use for the random number generator. It is recommended 143 | to use a value which has a good mix of 0 and 1 bits. 144 | batch_sizes: list 145 | Batch sizes. 146 | epoch : int 147 | The epoch to start at. 148 | """ 149 | 150 | def __init__(self, samplers, batch_sizes: (tuple, list), epoch=0) -> None: 151 | 152 | if not isinstance(samplers, (list, tuple)): 153 | raise ValueError( 154 | "samplers should be a list or tuple of Pytorch Samplers, " 155 | "but got samplers={}".format(batch_sizes) 156 | ) 157 | 158 | if not isinstance(batch_sizes, (list, tuple)): 159 | raise ValueError( 160 | "batch_sizes should be a list or tuple of integers, " 161 | "but got batch_sizes={}".format(batch_sizes) 162 | ) 163 | 164 | if not len(batch_sizes) == len(samplers): 165 | raise ValueError("batch_sizes and samplers should be have same length") 166 | 167 | self.batch_sizes = batch_sizes 168 | self.samplers = samplers 169 | self.offsets = [0] + np.cumsum([len(x) for x in self.samplers]).tolist()[:-1] 170 | 171 | self.epoch = epoch 172 | self.set_epoch(self.epoch) 173 | 174 | def _iter_one_dataset(self, c_batch_size, c_sampler, c_offset): 175 | batch = [] 176 | for idx in c_sampler: 177 | batch.append(c_offset + idx) 178 | if len(batch) == c_batch_size: 179 | yield batch 180 | 181 | def set_epoch(self, epoch): 182 | if hasattr(self.samplers[0], "epoch"): 183 | for s in self.samplers: 184 | s.set_epoch(epoch) 185 | 186 | def __iter__(self): 187 | 188 | iterators = [iter(i) for i in self.samplers] 189 | tot_batch = [] 190 | 191 | for b_num in range(len(self)): 192 | for samp_idx in range(len(self.samplers)): 193 | c_batch = [] 194 | while len(c_batch) < self.batch_sizes[samp_idx]: 195 | c_batch.append(self.offsets[samp_idx] + next(iterators[samp_idx])) 196 | tot_batch.extend(c_batch) 197 | yield tot_batch 198 | tot_batch = [] 199 | 200 | def __len__(self): 201 | 202 | min_len = float("inf") 203 | for idx, sampler in enumerate(self.samplers): 204 | c_len = (len(sampler)) // self.batch_sizes[idx] 205 | 206 | min_len = min(c_len, min_len) 207 | return min_len -------------------------------------------------------------------------------- /workspace/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import collections 4 | 5 | import torch 6 | from sklearn.utils.class_weight import compute_class_weight 7 | 8 | from base import Cross_Valid 9 | from logger import get_logger 10 | import models.loss as module_loss 11 | import models.metric as module_metric 12 | from parse_config import ConfigParser 13 | from utils import ensure_dir, prepare_device, get_by_path, msg_box, wandb_save_code 14 | import numpy as np 15 | import wandb 16 | 17 | os.environ['NUMEXPR_MAX_THREADS'] = '8' 18 | os.environ['NUMEXPR_NUM_THREADS'] = '8' 19 | import numexpr as ne 20 | 21 | torch.manual_seed(0) 22 | np.random.seed(0) 23 | import random 24 | random.seed(0) 25 | 26 | torch.backends.cudnn.benchmark = False 27 | torch.backends.cudnn.deterministic = True 28 | 29 | # print(os.environ) 30 | # print(f"{torch.get_num_threads()} CORES") 31 | 32 | import warnings 33 | warnings.filterwarnings('ignore') 34 | 35 | def main(config): 36 | if config.config.get('wandb') is not None: 37 | # Initialize wandb if defined in config 38 | wandb.init(project=config['wandb']['project'], 39 | notes=config['wandb']['notes'], 40 | entity=config['wandb']['entity'], 41 | config=config) 42 | 43 | # Update run name iff a custom name is flagged at runtime 44 | if hasattr(config.run_args, 'run_id'): 45 | wandb.run.name = config.run_args.run_id 46 | 47 | # Save wandb code 48 | wandb_save_code(config) 49 | 50 | 51 | k_fold = config['trainer'].get('k_fold', 1) 52 | fold_idx = config['trainer'].get('fold_idx', 0) 53 | 54 | if fold_idx > 0: 55 | # do on fold_idx, which is for multiprocessing cross validation 56 | # if multiprocessing, turn off debug logging to avoid messing up stdout 57 | config['trainer']['kwargs']['verbosity'] = 1 58 | verbosity = 1 59 | k_loop = 1 60 | else: 61 | # do full cross validation in single thread 62 | verbosity = 2 63 | k_loop = k_fold 64 | 65 | logger = get_logger('train', verbosity=verbosity) 66 | train_msg = msg_box("TRAIN") 67 | logger.debug(train_msg) 68 | 69 | # setup GPU device if available, move model into configured device 70 | device, device_ids = prepare_device(config['n_gpu']) 71 | 72 | # datasets 73 | train_datasets = dict() 74 | valid_datasets = dict() 75 | ## train 76 | keys = ['datasets', 'train'] 77 | for name in get_by_path(config, keys): 78 | train_datasets[name] = config.init_obj([*keys, name], 'data_loaders') 79 | ## valid 80 | valid_exist = False 81 | keys = ['datasets', 'valid'] 82 | for name in get_by_path(config, keys): 83 | valid_exist = True 84 | valid_datasets[name] = config.init_obj([*keys, name], 'data_loaders') 85 | 86 | # losses 87 | losses = dict() 88 | for name in config['losses']: 89 | kwargs = {} 90 | # TODO 91 | if config['losses'][name].get('balanced', False): 92 | target = train_datasets['data'].y_train 93 | weight = compute_class_weight(class_weight='balanced', 94 | classes=target.unique(), 95 | y=target) 96 | weight = torch.FloatTensor(weight).to(device) 97 | kwargs.update(pos_weight=weight[1]) 98 | losses[name] = config.init_obj(['losses', name], module_loss, **kwargs) 99 | 100 | # metrics 101 | metrics_iter = [getattr(module_metric, met) for met in config['metrics']['per_iteration']] 102 | metrics_epoch = [getattr(module_metric, met) for met in config['metrics']['per_epoch']] 103 | 104 | # unchanged objects in each fold 105 | torch_args = {'datasets': {'train': train_datasets, 'valid': valid_datasets}, 106 | 'losses': losses, 107 | 'metrics': {'iter': metrics_iter, 'epoch': metrics_epoch}} 108 | 109 | if k_fold > 1: # cross validation enabled 110 | train_datasets['data'].split_cv_indexes(k_fold) 111 | Cross_Valid.create_CV(k_fold, fold_idx) 112 | 113 | for k in range(k_loop): 114 | # data_loaders 115 | train_data_loaders = dict() 116 | valid_data_loaders = dict() 117 | ## train 118 | keys = ['data_loaders', 'train'] 119 | for name in get_by_path(config, keys): 120 | ### Concat dataset 121 | if get_by_path(config, keys)[name]['type'] == "MultiDatasetDataLoader": 122 | train_data_loaders[name] = config.init_obj([*keys, name], 'data_loaders', train_datasets) 123 | else: 124 | dataset = train_datasets[name] 125 | train_data_loaders[name] = config.init_obj([*keys, name], 'data_loaders', dataset) 126 | 127 | if not valid_exist: 128 | valid_data_loaders[name] = train_data_loaders[name].valid_loader 129 | ## valid 130 | keys = ['data_loaders', 'valid'] 131 | for name in get_by_path(config, keys): 132 | dataset = valid_datasets[name] 133 | valid_data_loaders[name] = config.init_obj([*keys, name], 'data_loaders', dataset) 134 | 135 | # models 136 | models = dict() 137 | logger_model = get_logger('model', verbosity=1) 138 | for name in config['models']: 139 | model = config.init_obj(['models', name], 'models') 140 | logger_model.info(model) 141 | logger.info(model) 142 | model = model.to(device) 143 | if len(device_ids) > 1: 144 | model = torch.nn.DataParallel(model, device_ids=device_ids) 145 | models[name] = model 146 | 147 | # optimizers 148 | optimizers = dict() 149 | for name in config['optimizers']: 150 | trainable_params = filter(lambda p: p.requires_grad, models[name].parameters()) 151 | optimizers[name] = config.init_obj(['optimizers', name], torch.optim, trainable_params) 152 | 153 | # learning rate schedulers 154 | lr_schedulers = dict() 155 | for name in config['lr_schedulers']: 156 | lr_schedulers[name] = config.init_obj(['lr_schedulers', name], 157 | torch.optim.lr_scheduler, optimizers[name]) 158 | 159 | # update objects for each fold 160 | update_args = {'data_loaders': {'train': train_data_loaders, 'valid': valid_data_loaders}, 161 | 'models': models, 162 | 'optimizers': optimizers, 163 | 'lr_schedulers': lr_schedulers} 164 | torch_args.update(update_args) 165 | if k_fold > 1: 166 | torch_args['fold_idx'] = Cross_Valid.fold_idx 167 | 168 | trainer = config.init_obj(['trainer'], 'trainers', torch_args, 169 | config.save_dir, config.resume, device) 170 | log_best = trainer.train() 171 | 172 | # cross validation 173 | if k_fold > 1: 174 | idx = Cross_Valid.fold_idx 175 | save_path = config.save_dir['metrics_best'] / f"fold_{idx}.pkl" 176 | log_best.to_pickle(save_path) 177 | Cross_Valid.next_fold() 178 | else: 179 | msg = msg_box("result") 180 | logger.info(f"{msg}\n{log_best}") 181 | 182 | 183 | if __name__ == '__main__': 184 | args = argparse.ArgumentParser(description='training') 185 | run_args = args.add_argument_group('run_args') 186 | run_args.add_argument('-c', '--config', default="configs/config.json", type=str) 187 | run_args.add_argument('-d', '--device', default=None, type=str) 188 | run_args.add_argument('-r', '--resume', default=None, type=str) 189 | run_args.add_argument('--mode', default='train', type=str) 190 | run_args.add_argument('--run_id', default=None, type=str) 191 | run_args.add_argument('--log_name', default=None, type=str) 192 | 193 | # custom cli options to modify configuration from default values given in json file. 194 | mod_args = args.add_argument_group('mod_args') 195 | CustomArgs = collections.namedtuple('CustomArgs', "flags type target") 196 | options = [ 197 | CustomArgs(['--fold_idx'], type=int, target="trainer;fold_idx"), # fold_idx > 0 means multiprocessing is enabled 198 | CustomArgs(['--num_workers'], type=int, target="data_loaders;train;data;kwargs;DataLoader_kwargs;num_workers"), 199 | CustomArgs(['--lr', '--learning_rate'], type=float, target="optimizers;model;args;lr"), 200 | CustomArgs(['--bs', '--batch_size'], type=int, 201 | target="data_loaders;train;data;args;DataLoader_kwargs;batch_size"), 202 | CustomArgs(['--tp', '--transform_p'], type=float, target="datasets;train;data;kwargs;transform_p"), 203 | CustomArgs(['--epochs'], type=int, target=["trainer;kwargs;epochs","trainer;kwargs;save_period"]) 204 | ] 205 | for opt in options: 206 | mod_args.add_argument(*opt.flags, default=None, type=opt.type) 207 | 208 | cfg = ConfigParser.from_args(args, options) 209 | main(cfg) 210 | -------------------------------------------------------------------------------- /workspace/models/modules/pytorch_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | def move_data_to_device(x, device): 8 | if 'float' in str(x.dtype): 9 | x = torch.Tensor(x) 10 | elif 'int' in str(x.dtype): 11 | x = torch.LongTensor(x) 12 | else: 13 | return x 14 | 15 | return x.to(device) 16 | 17 | 18 | def do_mixup(x, mixup_lambda): 19 | """Mixup x of even indexes (0, 2, 4, ...) with x of odd indexes 20 | (1, 3, 5, ...). 21 | Args: 22 | x: (batch_size * 2, ...) 23 | mixup_lambda: (batch_size * 2,) 24 | Returns: 25 | out: (batch_size, ...) 26 | """ 27 | out = (x[0 :: 2].transpose(0, -1) * mixup_lambda[0 :: 2] + \ 28 | x[1 :: 2].transpose(0, -1) * mixup_lambda[1 :: 2]).transpose(0, -1) 29 | return out 30 | 31 | 32 | def append_to_dict(dict, key, value): 33 | if key in dict.keys(): 34 | dict[key].append(value) 35 | else: 36 | dict[key] = [value] 37 | 38 | 39 | def forward(model, generator, return_input=False, 40 | return_target=False): 41 | """Forward data to a model. 42 | 43 | Args: 44 | model: object 45 | generator: object 46 | return_input: bool 47 | return_target: bool 48 | Returns: 49 | audio_name: (audios_num,) 50 | clipwise_output: (audios_num, classes_num) 51 | (ifexist) segmentwise_output: (audios_num, segments_num, classes_num) 52 | (ifexist) framewise_output: (audios_num, frames_num, classes_num) 53 | (optional) return_input: (audios_num, segment_samples) 54 | (optional) return_target: (audios_num, classes_num) 55 | """ 56 | output_dict = {} 57 | device = next(model.parameters()).device 58 | time1 = time.time() 59 | 60 | # Forward data to a model in mini-batches 61 | for n, batch_data_dict in enumerate(generator): 62 | print(n) 63 | batch_waveform = move_data_to_device(batch_data_dict['waveform'], device) 64 | 65 | with torch.no_grad(): 66 | model.eval() 67 | batch_output = model(batch_waveform) 68 | 69 | append_to_dict(output_dict, 'audio_name', batch_data_dict['audio_name']) 70 | 71 | append_to_dict(output_dict, 'clipwise_output', 72 | batch_output['clipwise_output'].data.cpu().numpy()) 73 | 74 | if 'segmentwise_output' in batch_output.keys(): 75 | append_to_dict(output_dict, 'segmentwise_output', 76 | batch_output['segmentwise_output'].data.cpu().numpy()) 77 | 78 | if 'framewise_output' in batch_output.keys(): 79 | append_to_dict(output_dict, 'framewise_output', 80 | batch_output['framewise_output'].data.cpu().numpy()) 81 | 82 | if return_input: 83 | append_to_dict(output_dict, 'waveform', batch_data_dict['waveform']) 84 | 85 | if return_target: 86 | if 'target' in batch_data_dict.keys(): 87 | append_to_dict(output_dict, 'target', batch_data_dict['target']) 88 | 89 | if n % 10 == 0: 90 | print(' --- Inference time: {:.3f} s / 10 iterations ---'.format( 91 | time.time() - time1)) 92 | time1 = time.time() 93 | 94 | for key in output_dict.keys(): 95 | output_dict[key] = np.concatenate(output_dict[key], axis=0) 96 | 97 | return output_dict 98 | 99 | 100 | def interpolate(x, ratio): 101 | """Interpolate data in time domain. This is used to compensate the 102 | resolution reduction in downsampling of a CNN. 103 | 104 | Args: 105 | x: (batch_size, time_steps, classes_num) 106 | ratio: int, ratio to interpolate 107 | Returns: 108 | upsampled: (batch_size, time_steps * ratio, classes_num) 109 | """ 110 | (batch_size, time_steps, classes_num) = x.shape 111 | upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1) 112 | upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num) 113 | return upsampled 114 | 115 | 116 | def pad_framewise_output(framewise_output, frames_num): 117 | """Pad framewise_output to the same length as input frames. The pad value 118 | is the same as the value of the last frame. 119 | Args: 120 | framewise_output: (batch_size, frames_num, classes_num) 121 | frames_num: int, number of frames to pad 122 | Outputs: 123 | output: (batch_size, frames_num, classes_num) 124 | """ 125 | pad = framewise_output[:, -1 :, :].repeat(1, frames_num - framewise_output.shape[1], 1) 126 | """tensor for padding""" 127 | 128 | output = torch.cat((framewise_output, pad), dim=1) 129 | """(batch_size, frames_num, classes_num)""" 130 | 131 | return output 132 | 133 | 134 | def count_parameters(model): 135 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 136 | 137 | 138 | def count_flops(model, audio_length): 139 | """Count flops. Code modified from others' implementation. 140 | """ 141 | multiply_adds = True 142 | list_conv2d=[] 143 | def conv2d_hook(self, input, output): 144 | batch_size, input_channels, input_height, input_width = input[0].size() 145 | output_channels, output_height, output_width = output[0].size() 146 | 147 | kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) * (2 if multiply_adds else 1) 148 | bias_ops = 1 if self.bias is not None else 0 149 | 150 | params = output_channels * (kernel_ops + bias_ops) 151 | flops = batch_size * params * output_height * output_width 152 | 153 | list_conv2d.append(flops) 154 | 155 | list_conv1d=[] 156 | def conv1d_hook(self, input, output): 157 | batch_size, input_channels, input_length = input[0].size() 158 | output_channels, output_length = output[0].size() 159 | 160 | kernel_ops = self.kernel_size[0] * (self.in_channels / self.groups) * (2 if multiply_adds else 1) 161 | bias_ops = 1 if self.bias is not None else 0 162 | 163 | params = output_channels * (kernel_ops + bias_ops) 164 | flops = batch_size * params * output_length 165 | 166 | list_conv1d.append(flops) 167 | 168 | list_linear=[] 169 | def linear_hook(self, input, output): 170 | batch_size = input[0].size(0) if input[0].dim() == 2 else 1 171 | 172 | weight_ops = self.weight.nelement() * (2 if multiply_adds else 1) 173 | bias_ops = self.bias.nelement() 174 | 175 | flops = batch_size * (weight_ops + bias_ops) 176 | list_linear.append(flops) 177 | 178 | list_bn=[] 179 | def bn_hook(self, input, output): 180 | list_bn.append(input[0].nelement() * 2) 181 | 182 | list_relu=[] 183 | def relu_hook(self, input, output): 184 | list_relu.append(input[0].nelement() * 2) 185 | 186 | list_pooling2d=[] 187 | def pooling2d_hook(self, input, output): 188 | batch_size, input_channels, input_height, input_width = input[0].size() 189 | output_channels, output_height, output_width = output[0].size() 190 | 191 | kernel_ops = self.kernel_size * self.kernel_size 192 | bias_ops = 0 193 | params = output_channels * (kernel_ops + bias_ops) 194 | flops = batch_size * params * output_height * output_width 195 | 196 | list_pooling2d.append(flops) 197 | 198 | list_pooling1d=[] 199 | def pooling1d_hook(self, input, output): 200 | batch_size, input_channels, input_length = input[0].size() 201 | output_channels, output_length = output[0].size() 202 | 203 | kernel_ops = self.kernel_size[0] 204 | bias_ops = 0 205 | 206 | params = output_channels * (kernel_ops + bias_ops) 207 | flops = batch_size * params * output_length 208 | 209 | list_pooling2d.append(flops) 210 | 211 | def foo(net): 212 | childrens = list(net.children()) 213 | if not childrens: 214 | if isinstance(net, nn.Conv2d): 215 | net.register_forward_hook(conv2d_hook) 216 | elif isinstance(net, nn.Conv1d): 217 | net.register_forward_hook(conv1d_hook) 218 | elif isinstance(net, nn.Linear): 219 | net.register_forward_hook(linear_hook) 220 | elif isinstance(net, nn.BatchNorm2d) or isinstance(net, nn.BatchNorm1d): 221 | net.register_forward_hook(bn_hook) 222 | elif isinstance(net, nn.ReLU): 223 | net.register_forward_hook(relu_hook) 224 | elif isinstance(net, nn.AvgPool2d) or isinstance(net, nn.MaxPool2d): 225 | net.register_forward_hook(pooling2d_hook) 226 | elif isinstance(net, nn.AvgPool1d) or isinstance(net, nn.MaxPool1d): 227 | net.register_forward_hook(pooling1d_hook) 228 | else: 229 | print('Warning: flop of module {} is not counted!'.format(net)) 230 | return 231 | for c in childrens: 232 | foo(c) 233 | 234 | # Register hook 235 | foo(model) 236 | 237 | device = device = next(model.parameters()).device 238 | input = torch.rand(1, audio_length).to(device) 239 | 240 | out = model(input) 241 | 242 | total_flops = sum(list_conv2d) + sum(list_conv1d) + sum(list_linear) + \ 243 | sum(list_bn) + sum(list_relu) + sum(list_pooling2d) + sum(list_pooling1d) 244 | 245 | return total_flops -------------------------------------------------------------------------------- /workspace/test_dcasetask1b.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import collections 4 | from operator import itemgetter 5 | import torch 6 | import torch.nn as nn 7 | from torchvision.utils import make_grid, save_image 8 | import pandas as pd 9 | from sklearn.manifold import TSNE 10 | from sklearn.utils.class_weight import compute_class_weight 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | from tqdm import tqdm 14 | from sklearn.metrics import confusion_matrix 15 | 16 | from base import Cross_Valid 17 | from logger import get_logger 18 | import models.loss as module_loss 19 | import models.metric as module_metric 20 | from models.metric import MetricTracker 21 | from parse_config import ConfigParser 22 | from utils import ensure_dir, prepare_device, get_by_path, msg_box 23 | from sklearn.metrics import log_loss 24 | 25 | os.environ['NUMEXPR_MAX_THREADS'] = '8' 26 | os.environ['NUMEXPR_NUM_THREADS'] = '8' 27 | import numexpr as ne 28 | 29 | # fix random seeds for reproducibility 30 | SEED = 123 31 | torch.manual_seed(SEED) 32 | torch.backends.cudnn.deterministic = True 33 | torch.backends.cudnn.benchmark = False 34 | np.random.seed(SEED) 35 | 36 | 37 | def main(config): 38 | k_fold = config['trainer'].get('k_fold', 1) 39 | fold_idx = config['trainer'].get('fold_idx', 0) 40 | Cross_Valid.create_CV(k_fold, fold_idx) 41 | 42 | logger = get_logger('test') 43 | test_msg = msg_box("TEST") 44 | logger.debug(test_msg) 45 | 46 | # datasets 47 | test_datasets = dict() 48 | keys = ['datasets', 'test'] 49 | for name in get_by_path(config, keys): 50 | test_datasets[name] = config.init_obj([*keys, name], 'data_loaders') 51 | 52 | # data_loaders 53 | test_data_loaders = dict() 54 | keys = ['data_loaders', 'test'] 55 | for name in get_by_path(config, keys): 56 | dataset = test_datasets[name] 57 | do_transform = get_by_path(config, [*keys, name]).get('do_transform', False) 58 | if do_transform: 59 | dataset.transform() 60 | test_data_loaders[name] = config.init_obj([*keys, name], 'data_loaders', dataset) 61 | 62 | # prepare model for testing 63 | device, device_ids = prepare_device(config['n_gpu']) 64 | 65 | Cross_Valid.create_CV(k_fold, fold_idx) 66 | for fold_idx in range(1, k_fold + 1): 67 | # models 68 | if k_fold > 1: 69 | fold_prefix = f'fold_{fold_idx}_' 70 | dirname = os.path.dirname(config.resume) 71 | basename = os.path.basename(config.resume) 72 | resume = os.path.join(dirname, fold_prefix + basename) 73 | else: 74 | resume = config.resume 75 | logger.info(f"Loading model: {resume} ...") 76 | checkpoint = torch.load(resume) 77 | models = dict() 78 | logger_model = get_logger('model', verbosity=0) 79 | for name in config['models']: 80 | model = config.init_obj(['models', name], 'models') 81 | logger_model.info(model) 82 | state_dict = checkpoint['models'][name] 83 | if config['n_gpu'] > 1: 84 | model = torch.nn.DataParallel(model) 85 | model.load_state_dict(state_dict) 86 | model = model.to(device) 87 | model.eval() 88 | models[name] = model 89 | 90 | # losses 91 | kwargs = {} 92 | # TODO 93 | if config['losses']['loss'].get('balanced', False): 94 | target = test_datasets['data'].y_test 95 | weight = compute_class_weight(class_weight='balanced', 96 | classes=target.unique(), 97 | y=target) 98 | weight = torch.FloatTensor(weight).to(device) 99 | kwargs.update(pos_weight=weight[1]) 100 | loss_fn = config.init_obj(['losses', 'loss'], module_loss, **kwargs) 101 | 102 | # metrics 103 | metrics_iter = [getattr(module_metric, met) for met in config['metrics']['per_iteration']] 104 | metrics_epoch = [getattr(module_metric, met) for met in config['metrics']['per_epoch']] 105 | keys_loss = ['loss'] 106 | keys_iter = [m.__name__ for m in metrics_iter] 107 | keys_epoch = [m.__name__ for m in metrics_epoch] 108 | test_metrics = MetricTracker(keys_loss + keys_iter, keys_epoch) 109 | 110 | with torch.no_grad(): 111 | print("testing...") 112 | model = models['model'] 113 | testloader = test_data_loaders['data'] 114 | if len(metrics_epoch) > 0: 115 | outputs = torch.FloatTensor().to(device) 116 | targets = torch.FloatTensor().to(device) 117 | for batch_idx, (data, target) in tqdm(enumerate(testloader), total=len(testloader)): 118 | if isinstance(data, dict): 119 | data = {k: v.to(device) for k, v in data.items()} 120 | else: 121 | data = data.to(device) 122 | 123 | target = target.to(device) 124 | output = model(data) 125 | if len(metrics_epoch) > 0: 126 | outputs = torch.cat((outputs, output)) 127 | targets = torch.cat((targets, target)) 128 | 129 | # 130 | # save sample images, or do something with output here 131 | # 132 | 133 | # computing loss, metrics on test set 134 | loss = loss_fn(output, target) 135 | test_metrics.iter_update('loss', loss.item()) 136 | for met in metrics_iter: 137 | test_metrics.iter_update(met.__name__, met(output, target)) 138 | 139 | for met in metrics_epoch: 140 | test_metrics.epoch_update(met.__name__, met(outputs, targets)) 141 | 142 | print(outputs.cpu().numpy().shape) 143 | 144 | outputs_log = nn.LogSoftmax(dim=1)(outputs) 145 | outputs = nn.Softmax(dim=1)(outputs) 146 | 147 | preds = torch.argmax(outputs, dim=1) 148 | 149 | cm = confusion_matrix(targets.cpu().numpy(), preds.cpu().numpy()) 150 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] 151 | cm.diagonal() 152 | values = [np.round(i,decimals=3) for i in list(cm.diagonal())] 153 | print(values) 154 | print(config['datasets']['test']['data']['kwargs']['label_list']) 155 | label_list = config['datasets']['test']['data']['kwargs']['label_list'] 156 | 157 | print(outputs.cpu().numpy()) 158 | print(targets.cpu().numpy()) 159 | print("outputs shape", outputs.shape) 160 | print("sklearn log loss: ", log_loss(targets.cpu().numpy(), outputs.cpu().numpy())) 161 | print("pytorch log loss: ", nn.NLLLoss(reduction='none')(outputs_log.cpu(), targets.cpu().long())) 162 | print(nn.NLLLoss(reduction='none')(outputs_log.cpu(), targets.cpu().long()).shape) 163 | print(torch.sum(nn.NLLLoss(reduction='none')(outputs_log.cpu(), targets.cpu().long()))) 164 | print(torch.sum(nn.NLLLoss(reduction='none')(outputs_log.cpu(), targets.cpu().long()))/36450) 165 | nll_loss = nn.NLLLoss(reduction='none')(outputs_log.cpu(), targets.cpu().long()) 166 | 167 | 168 | logloss_class_wise = {} 169 | y_true_list = targets.cpu().numpy().tolist() 170 | prob = outputs.cpu().numpy().tolist() 171 | for scene_label in label_list: 172 | scene_number = label_list.index(str(scene_label)) 173 | print("scene_number: ", scene_number) 174 | index_list = [] 175 | for i, e in enumerate(y_true_list): 176 | if e == scene_number: 177 | index_list.append(i) 178 | logloss_class_wise[scene_label] = 0.0 179 | for x in index_list: 180 | # print(x) 181 | logloss_class_wise[scene_label] += nll_loss[x] 182 | logloss_class_wise[scene_label] /= len(index_list) 183 | # T = list(itemgetter(*index_list)(y_true_list)) 184 | # print(len(prob), len(prob[0])) 185 | # P = list(itemgetter(*index_list)(prob))[:, scene_number] 186 | # logloss_class_wise[scene_label] = log_loss(y_true=T, y_pred=P, labels=list(range(len(keys)))) 187 | 188 | print(logloss_class_wise) 189 | 190 | test_log = test_metrics.result() 191 | logger.info(test_log) 192 | # cross validation is enabled 193 | if k_fold > 1: 194 | log_mean = test_log['mean'] 195 | idx = Cross_Valid.fold_idx 196 | save_path = config.save_dir['metric'] / f"fold_{idx}.pkl" 197 | log_mean.to_pickle(save_path) 198 | Cross_Valid.next_fold() 199 | 200 | 201 | if __name__ == '__main__': 202 | args = argparse.ArgumentParser(description='testing') 203 | run_args = args.add_argument_group('run_args') 204 | run_args.add_argument('-c', '--config', default="configs/examples/mnist.json", type=str) 205 | run_args.add_argument('-r', '--resume', default=None, type=str) 206 | run_args.add_argument('-d', '--device', default=None, type=str) 207 | run_args.add_argument('--mode', default='test', type=str) 208 | run_args.add_argument('--run_id', default=None, type=str) 209 | run_args.add_argument('--log_name', default=None, type=str) 210 | 211 | # custom cli options to modify configuration from default values given in json file. 212 | mod_args = args.add_argument_group('mod_args') 213 | CustomArgs = collections.namedtuple('CustomArgs', "flags default type target") 214 | options = [ 215 | ] 216 | for opt in options: 217 | mod_args.add_argument(*opt.flags, default=opt.default, type=opt.type) 218 | 219 | # additional arguments for testing 220 | test_args = args.add_argument_group('test_args') 221 | test_args.add_argument('--output_path', default=None, type=str) 222 | 223 | cfg = ConfigParser.from_args(args, options) 224 | main(cfg) 225 | -------------------------------------------------------------------------------- /workspace/data_loaders/1s_dataset.py: -------------------------------------------------------------------------------- 1 | from torchvision import datasets 2 | import torchvision.transforms as tv_transforms 3 | from base import BaseDataLoader 4 | import torch.utils.data as data 5 | import librosa 6 | import os 7 | import pandas as pd 8 | import torch 9 | import numpy as np 10 | import pathlib 11 | from audiomentations import Compose, AddGaussianNoise, TimeStretch, PitchShift, Shift, TimeMask 12 | from utils.spec_timeshift_transform import TimeShift 13 | 14 | def load_dataframe(path): 15 | df = pd.read_csv(path) 16 | return df.to_dict('records') 17 | 18 | def load_numpy(path): 19 | try: 20 | data = np.load(path) 21 | except Exception as e: 22 | print(e) 23 | return None 24 | return data 25 | 26 | 27 | class SpectrogramAugmentationDataset(data.Dataset): 28 | """ 29 | 30 | """ 31 | 32 | def __init__(self, data_dir, label_list, cache_dir, t_shift=False): 33 | self.data_arr = load_dataframe(data_dir) 34 | self.data_dir = data_dir 35 | self.label_list = label_list 36 | self.spectrogram_cachedir = cache_dir 37 | 38 | transforms = [] 39 | if t_shift: 40 | transforms.append(TimeShift()) 41 | 42 | if len(transforms)==0: 43 | self.transform = None 44 | else: 45 | self.transform = tv_transforms.Compose(transforms) 46 | 47 | def __len__(self): 48 | return len(self.data_arr) 49 | 50 | def __getitem__(self, idx): 51 | item = self.data_arr[idx] 52 | scene_label = item['scene_label'] 53 | scene_encoded = self.label_list.index(scene_label) 54 | 55 | spec = load_numpy(item['filename_spec']) 56 | 57 | # if os.path.isfile(os.path.join(self.spectrogram_cachedir, item['filename_audio']).replace(".wav", ".npy")): 58 | # spec = load_numpy(os.path.join(self.spectrogram_cachedir, item['filename_audio']).replace(".wav", ".npy")) 59 | # else: 60 | # y, sr = librosa.load(os.path.join("/mnt/ssd/data/tau_audiovisual_2021/", item['filename_audio']), sr=48000) 61 | # spec = librosa.feature.melspectrogram(y, n_fft=2048, hop_length=960, n_mels=256, sr=48000, fmin=0, fmax=22050) 62 | # if spec.shape[1] == 501: 63 | # spec = spec[:,:-1] 64 | 65 | # spec = np.log(spec) 66 | 67 | # # z-score normalization 68 | # std = spec.std() 69 | # mean = spec.mean() 70 | # spec = (spec - mean) / std 71 | 72 | # pathlib.Path(os.path.join(self.spectrogram_cachedir, item['filename_audio']).replace(".wav", ".npy")).parent.mkdir(parents=True, exist_ok=True) 73 | 74 | # np.save(os.path.join(self.spectrogram_cachedir, item['filename_audio']).replace(".wav", ".npy"), spec) 75 | 76 | if self.transform is not None: 77 | spec = self.transform(spec) 78 | 79 | # return spec, scene_encoded 80 | return {"spec": spec}, scene_encoded 81 | 82 | 83 | class WaveformAugmentationDataset(data.Dataset): 84 | """ 85 | 86 | """ 87 | 88 | def __init__(self, data_dir, label_list, cache_dir, 89 | t_gaussian_noise=False, 90 | t_time_stretch=False, 91 | t_pitch_shift=False, 92 | t_shift=False, 93 | t_time_mask=False): 94 | self.data_arr = load_dataframe(data_dir) 95 | self.data_dir = data_dir 96 | self.label_list = label_list 97 | self.waveform_cachedir = cache_dir 98 | 99 | transforms = [] 100 | if t_gaussian_noise: 101 | transforms.append(AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=1.0)) 102 | if t_time_stretch: 103 | transforms.append(TimeStretch(min_rate=0.8, max_rate=1.25, p=1.0)) 104 | if t_pitch_shift: 105 | transforms.append(PitchShift(min_semitones=-4, max_semitones=4, p=1.0)) 106 | if t_shift: 107 | transforms.append(Shift(min_fraction=-0.5, max_fraction=0.5, p=1.0)) 108 | if t_time_mask: 109 | transforms.append(TimeMask(min_band_part=0.0, max_band_part=0.5, p=1.0)) 110 | 111 | if len(transforms)==0: 112 | self.transform = None 113 | else: 114 | self.transform = Compose(transforms) 115 | 116 | def __len__(self): 117 | return len(self.data_arr) 118 | 119 | def __getitem__(self, idx): 120 | item = self.data_arr[idx] 121 | scene_label = item['scene_label'] 122 | scene_encoded = self.label_list.index(scene_label) 123 | 124 | wave = load_numpy(item['filename_wave']) 125 | 126 | # if os.path.isfile(os.path.join(self.waveform_cachedir, item['filename_audio']).replace(".wav", ".npy")): 127 | # wave = load_numpy(os.path.join(self.waveform_cachedir, item['filename_audio']).replace(".wav", ".npy")) 128 | # else: 129 | # y, sr = librosa.load(os.path.join("/mnt/ssd/data/tau_audiovisual_2021/", item['filename_audio']), sr=48000) 130 | 131 | # padded = np.zeros(480000, dtype='float32') 132 | # wave = y[:480000] 133 | # padded[0:len(wave)] = wave 134 | # wave = padded 135 | 136 | # pathlib.Path(os.path.join(self.waveform_cachedir, item['filename_audio']).replace(".wav", ".npy")).parent.mkdir(parents=True, exist_ok=True) 137 | 138 | # np.save(os.path.join(self.waveform_cachedir, item['filename_audio']).replace(".wav", ".npy"), wave) 139 | 140 | if self.transform is not None: 141 | # wave = np.expand_dims(wave, axis=0) 142 | wave = self.transform(wave, sample_rate=48000) 143 | # wave = wave.squeeze(0) 144 | 145 | return {"wave": wave}, scene_encoded 146 | 147 | 148 | 149 | class MultiModalAugmentationDataset(data.Dataset): 150 | """ 151 | 152 | """ 153 | 154 | def __init__(self, data_dir, label_list, spec_cache_dir, wave_cache_dir, w_shift=False, s_shift=False): 155 | self.data_arr = load_dataframe(data_dir) 156 | self.data_dir = data_dir 157 | self.label_list = label_list 158 | self.w_shift = w_shift 159 | self.s_shift = s_shift 160 | self.spectrogram_cachedir = spec_cache_dir 161 | self.waveform_cachedir = wave_cache_dir 162 | 163 | 164 | spec_transforms = [] 165 | if s_shift: 166 | spec_transforms.append(TimeShift()) 167 | 168 | if len(spec_transforms)==0: 169 | self.spec_transform = None 170 | else: 171 | self.spec_transform = tv_transforms.Compose(spec_transforms) 172 | 173 | wave_transforms = [] 174 | if w_shift: 175 | wave_transforms.append(Shift(min_fraction=-0.5, max_fraction=0.5, p=1.0)) 176 | 177 | if len(wave_transforms)==0: 178 | self.wave_transform = None 179 | else: 180 | self.wave_transform = Compose(wave_transforms) 181 | 182 | def __len__(self): 183 | return len(self.data_arr) 184 | 185 | def _get_spec(self, filename): 186 | # if os.path.isfile(os.path.join(self.spectrogram_cachedir, filename).replace(".wav", ".npy")): 187 | # spec = load_numpy(os.path.join(self.spectrogram_cachedir, filename).replace(".wav", ".npy")) 188 | # else: 189 | # y, sr = librosa.load(os.path.join("/mnt/ssd/data/tau_audiovisual_2021/", filename), sr=48000) 190 | # spec = librosa.feature.melspectrogram(y, n_fft=2048, hop_length=960, n_mels=40, sr=48000, fmin=0, fmax=22050) 191 | # if spec.shape[1] == 501: 192 | # spec = spec[:,:-1] 193 | 194 | # spec = np.log(spec) 195 | 196 | # pathlib.Path(os.path.join(self.spectrogram_cachedir, filename).replace(".wav", ".npy")).parent.mkdir(parents=True, exist_ok=True) 197 | 198 | # np.save(os.path.join(self.spectrogram_cachedir, filename).replace(".wav", ".npy"), spec) 199 | return load_numpy(filename) 200 | 201 | def _get_wave(self, filename): 202 | # if os.path.isfile(os.path.join(self.waveform_cachedir, filename).replace(".wav", ".npy")): 203 | # wave = load_numpy(os.path.join(self.waveform_cachedir, filename).replace(".wav", ".npy")) 204 | # else: 205 | # y, sr = librosa.load(os.path.join("/mnt/ssd/data/tau_audiovisual_2021/", filename), sr=48000) 206 | 207 | # padded = np.zeros(480000, dtype='float32') 208 | # wave = y[:480000] 209 | # padded[0:len(wave)] = wave 210 | # wave = padded 211 | 212 | # pathlib.Path(os.path.join(self.waveform_cachedir, filename).replace(".wav", ".npy")).parent.mkdir(parents=True, exist_ok=True) 213 | 214 | # np.save(os.path.join(self.waveform_cachedir, filename).replace(".wav", ".npy"), wave) 215 | return load_numpy(filename) 216 | 217 | def __getitem__(self, idx): 218 | item = self.data_arr[idx] 219 | scene_label = item['scene_label'] 220 | scene_encoded = self.label_list.index(scene_label) 221 | 222 | spec = self._get_spec(item['filename_spec']) 223 | wave = self._get_wave(item['filename_wave']) 224 | 225 | # Add transforms 226 | if self.wave_transform is not None: 227 | wave = self.wave_transform(wave, sample_rate=48000) 228 | 229 | if self.spec_transform is not None: 230 | spec = self.spec_transform(spec) 231 | 232 | # return spec, scene_encoded 233 | return {"spec": spec, "wave": wave}, scene_encoded 234 | 235 | 236 | class Task1BEvaluationDataset(data.Dataset): 237 | """ 238 | 239 | """ 240 | 241 | def __init__(self, data_dir, label_list): 242 | self.data_arr = pd.read_csv(data_dir, sep="\t").to_dict('records') 243 | self.data_dir = data_dir 244 | self.label_list = label_list 245 | self.audio_dir = "/mnt/ssd/data/tau_audiovisual_evaluation" 246 | 247 | def __len__(self): 248 | return len(self.data_arr) 249 | 250 | def __getitem__(self, idx): 251 | item = self.data_arr[idx] 252 | item_filepath = item['filename_audio'] 253 | item_filename = item['filename_audio'].replace("audio/", "") 254 | 255 | 256 | audio, sr = librosa.load(os.path.join(self.audio_dir, item_filepath), sr=48000) 257 | padded = np.zeros(48000, dtype='float32') 258 | wave = audio[:48000] 259 | padded[0:len(wave)] = wave 260 | 261 | spec = librosa.feature.melspectrogram(padded, n_fft=2048, hop_length=256, n_mels=128, sr=48000, fmin=0, fmax=24000) 262 | 263 | spec = np.log(spec) 264 | # print(spec) 265 | 266 | # z-score normalization 267 | std = spec.std() 268 | mean = spec.mean() 269 | spec = (spec - mean) / std 270 | 271 | 272 | # return spec, scene_encoded 273 | return {"spec": spec, "wave": padded}, item_filename -------------------------------------------------------------------------------- /workspace/trainers/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | from torchvision.utils import make_grid 8 | 9 | from base import BaseTrainer 10 | from models.metric import MetricTracker 11 | from utils import inf_loop 12 | import copy 13 | import wandb 14 | 15 | 16 | class DCASETask1BTrainer(BaseTrainer): 17 | """ 18 | Trainer class 19 | """ 20 | def __init__(self, torch_args: dict, save_dir, resume, device, **kwargs): 21 | self.device = device 22 | super().__init__(torch_args, save_dir, **kwargs) 23 | 24 | if resume is not None: 25 | self._resume_checkpoint(resume, finetune=self.finetune) 26 | 27 | # data_loaders 28 | self.do_validation = self.valid_data_loaders['data'] is not None 29 | if self.len_epoch is None: 30 | # epoch-based training 31 | self.len_epoch = len(self.train_data_loaders['data']) 32 | else: 33 | # iteration-based training 34 | self.train_data_loaders['data'] = inf_loop(self.train_data_loaders['data']) 35 | self.log_step = int(np.sqrt(self.train_data_loaders['data'].batch_size)) 36 | 37 | # losses 38 | self.criterion = self.losses['loss'] 39 | 40 | # metrics 41 | keys_loss = ['loss'] 42 | keys_iter = [m.__name__ for m in self.metrics_iter] 43 | keys_epoch = [m.__name__ for m in self.metrics_epoch] 44 | self.train_metrics = MetricTracker(keys_loss + keys_iter, keys_epoch, writer=self.writer) 45 | self.valid_metrics = MetricTracker(keys_loss + keys_iter, keys_epoch, writer=self.writer) 46 | 47 | # init wandb model watch 48 | wandb.watch(self.models['model']) 49 | 50 | # wandb.save("trainers/cr_trainer.py", base_path="../") 51 | 52 | # init history logging of best/worst metrics 53 | self.best_val_metrics = {} 54 | for index, row in self.valid_metrics.metrics_epoch.iterrows(): 55 | self.best_val_metrics['val_'+index] = {'min': 99999999, 'max': -99999999} 56 | 57 | ### Calculate the learning rate 58 | if kwargs.get('find_lr'): 59 | print('Finding optimal LR...') 60 | self.backup_model= copy.deepcopy(self.models['model'].state_dict()) 61 | self.backup_opt = copy.deepcopy(self.optimizers['model'].state_dict()) 62 | max_lr = np.median([self._find_lr() for i in range(7)]) 63 | 64 | self.lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(self.optimizers['model'], 65 | max_lr=max_lr, 66 | steps_per_epoch=len(self.train_data_loaders['data']), 67 | epochs=self.epochs) 68 | print(f'Max LR: {max_lr}') 69 | wandb.run.summary['OneCycle Max LR'] = max_lr 70 | 71 | # learning rate schedulers 72 | self.do_lr_scheduling = kwargs.get('find_lr') 73 | # self.lr_scheduler = self.lr_schedulers['model'] 74 | 75 | 76 | 77 | 78 | 79 | 80 | def _train_epoch(self, epoch): 81 | """ 82 | Training logic for an epoch 83 | 84 | :param epoch: Integer, current training epoch. 85 | :return: A log that contains average loss and metric in this epoch. 86 | """ 87 | start = time.time() 88 | self.models['model'].train() 89 | self.train_metrics.reset() 90 | if len(self.metrics_epoch) > 0: 91 | outputs = torch.FloatTensor().to(self.device) 92 | targets = torch.FloatTensor().to(self.device) 93 | for batch_idx, (data, target) in enumerate(self.train_data_loaders['data']): 94 | if isinstance(data, dict): 95 | data = {k: v.to(self.device) for k, v in data.items()} 96 | else: 97 | data = data.to(self.device) 98 | 99 | target = target.to(self.device) 100 | 101 | self.optimizers['model'].zero_grad() 102 | output = self.models['model'](data) 103 | if len(self.metrics_epoch) > 0: 104 | outputs = torch.cat((outputs, output)) 105 | targets = torch.cat((targets, target)) 106 | loss = self.criterion(output, target) 107 | loss.backward() 108 | self.optimizers['model'].step() 109 | 110 | if self.do_lr_scheduling: 111 | self.lr_scheduler.step() 112 | 113 | # log loss and lr 114 | wandb.log({'loss': loss}) 115 | wandb.log({'learning_rate': self.optimizers['model'].param_groups[0]['lr']}) 116 | 117 | self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) 118 | self.train_metrics.iter_update('loss', loss.item()) 119 | for met in self.metrics_iter: 120 | self.train_metrics.iter_update(met.__name__, met(output, target)) 121 | 122 | if batch_idx % self.log_step == 0: 123 | epoch_debug = f"Train Epoch: {epoch} {self._progress(batch_idx)} " 124 | current_metrics = self.train_metrics.current() 125 | metrics_debug = ", ".join(f"{key}: {value:.6f}" for key, value in current_metrics.items()) 126 | self.logger.debug(epoch_debug + metrics_debug) 127 | # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) 128 | 129 | if batch_idx == self.len_epoch: 130 | break 131 | 132 | for met in self.metrics_epoch: 133 | self.train_metrics.epoch_update(met.__name__, met(outputs, targets)) 134 | # log the training metrics 135 | wandb.log({met.__name__: met(outputs, targets)}) 136 | 137 | train_log = self.train_metrics.result() 138 | 139 | wandb.run.summary['epochs_trained'] = epoch 140 | 141 | if self.do_validation: 142 | valid_log = self._valid_epoch(epoch) 143 | valid_log.set_index('val_' + valid_log.index.astype(str), inplace=True) 144 | 145 | # log the validation metrics 146 | wandb.log({str(index) : row['mean'] for index, row in valid_log.iterrows()}) 147 | 148 | # update best/worst metric results 149 | for index, row in valid_log.iterrows(): 150 | if index in self.best_val_metrics: 151 | if row['mean'] < self.best_val_metrics[index]['min']: 152 | self.best_val_metrics[index]['min'] = row['mean'] 153 | if row['mean'] > self.best_val_metrics[index]['max']: 154 | self.best_val_metrics[index]['max'] = row['mean'] 155 | 156 | wandb.run.summary['lowest_'+index] = self.best_val_metrics[index]['min'] 157 | wandb.run.summary['highest_'+index] = self.best_val_metrics[index]['max'] 158 | 159 | # if self.do_lr_scheduling: 160 | # self.lr_scheduler.step() 161 | 162 | log = pd.concat([train_log, valid_log]) 163 | end = time.time() 164 | ty_res = time.gmtime(end - start) 165 | res = time.strftime("%H hours, %M minutes, %S seconds", ty_res) 166 | epoch_log = {'epochs': epoch, 167 | 'iterations': self.len_epoch * epoch, 168 | 'Runtime': res} 169 | epoch_info = ', '.join(f"{key}: {value}" for key, value in epoch_log.items()) 170 | logger_info = f"{epoch_info}\n{log}" 171 | self.logger.info(logger_info) 172 | 173 | return log 174 | 175 | def _valid_epoch(self, epoch): 176 | """ 177 | Validate after training an epoch 178 | 179 | :param epoch: Integer, current training epoch. 180 | :return: A log that contains information about validation 181 | """ 182 | self.models['model'].eval() 183 | self.valid_metrics.reset() 184 | with torch.no_grad(): 185 | if len(self.metrics_epoch) > 0: 186 | outputs = torch.FloatTensor().to(self.device) 187 | targets = torch.FloatTensor().to(self.device) 188 | for batch_idx, (data, target) in enumerate(self.valid_data_loaders['data']): 189 | if isinstance(data, dict): 190 | data = {k: v.to(self.device) for k, v in data.items()} 191 | else: 192 | data = data.to(self.device) 193 | 194 | target = target.to(self.device) 195 | 196 | output = self.models['model'](data) 197 | loss = self.criterion(output, target) 198 | if len(self.metrics_epoch) > 0: 199 | outputs = torch.cat((outputs, output)) 200 | targets = torch.cat((targets, target)) 201 | 202 | self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx, 'valid') 203 | self.valid_metrics.iter_update('loss', loss.item()) 204 | for met in self.metrics_iter: 205 | self.valid_metrics.iter_update(met.__name__, met(output, target)) 206 | # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) 207 | 208 | for met in self.metrics_epoch: 209 | self.valid_metrics.epoch_update(met.__name__, met(outputs, targets)) 210 | 211 | # # add histogram of model parameters to the tensorboard 212 | # for name, param in self.models['model'].named_parameters(): 213 | # self.writer.add_histogram(name, param, bins='auto') 214 | 215 | valid_log = self.valid_metrics.result() 216 | 217 | return valid_log 218 | 219 | def _find_lr(self): 220 | lrs = np.logspace(-7, 2, base=10, num=50) 221 | losses = [] 222 | 223 | lr_idx = 0 224 | 225 | self.models['model'].train() 226 | 227 | while lr_idx < len(lrs): 228 | for batch_idx, (data, target) in enumerate(self.train_data_loaders['data']): 229 | if lr_idx == len(lrs): 230 | break 231 | 232 | lr = lrs[lr_idx] 233 | self.optimizers['model'].param_groups[0]['lr'] = lr 234 | 235 | if isinstance(data, dict): 236 | data = {k: v.to(self.device) for k, v in data.items()} 237 | else: 238 | data = data.to(self.device) 239 | 240 | target = target.to(self.device) 241 | 242 | self.optimizers['model'].zero_grad() 243 | output = self.models['model'](data) 244 | loss = self.criterion(output, target) 245 | loss.backward() 246 | self.optimizers['model'].step() 247 | 248 | losses += [loss.item()] 249 | 250 | lr_idx += 1 251 | 252 | best_idx = np.argmin(losses) 253 | best_loss = losses[best_idx] 254 | best_lr = lrs[best_idx] 255 | rec_lr = best_lr / 10.0 256 | 257 | self.logger.debug(f'Best LR: {best_lr} Recommended 1-Cycle max_lr: {rec_lr}') 258 | 259 | # self.models['model'].eval() 260 | self.models['model'].load_state_dict(self.backup_model) 261 | self.optimizers['model'].load_state_dict(self.backup_opt) 262 | 263 | return rec_lr 264 | 265 | def _progress(self, batch_idx): 266 | ratio = '[{}/{} ({:.0f}%)]' 267 | return ratio.format(batch_idx, self.len_epoch, 100.0 * batch_idx / self.len_epoch) 268 | -------------------------------------------------------------------------------- /workspace/trainers/trainer_mixup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import random 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | from torchvision.utils import make_grid 8 | 9 | from base import BaseTrainer 10 | from models.metric import MetricTracker 11 | from utils import inf_loop, mixup_data 12 | import copy 13 | import wandb 14 | 15 | 16 | class DCASETask1BTrainerWithMixup(BaseTrainer): 17 | """ 18 | Trainer class 19 | """ 20 | def __init__(self, torch_args: dict, save_dir, resume, device, **kwargs): 21 | self.device = device 22 | super().__init__(torch_args, save_dir, **kwargs) 23 | 24 | if resume is not None: 25 | self._resume_checkpoint(resume, finetune=self.finetune) 26 | 27 | # data_loaders 28 | self.do_validation = self.valid_data_loaders['data'] is not None 29 | if self.len_epoch is None: 30 | # epoch-based training 31 | self.len_epoch = len(self.train_data_loaders['data']) 32 | else: 33 | # iteration-based training 34 | self.train_data_loaders['data'] = inf_loop(self.train_data_loaders['data']) 35 | self.log_step = int(np.sqrt(self.train_data_loaders['data'].batch_size)) 36 | 37 | # losses 38 | self.criterion = self.losses['loss'] 39 | 40 | # metrics 41 | keys_loss = ['loss'] 42 | keys_iter = [m.__name__ for m in self.metrics_iter] 43 | keys_epoch = [m.__name__ for m in self.metrics_epoch] 44 | self.train_metrics = MetricTracker(keys_loss + keys_iter, keys_epoch, writer=self.writer) 45 | self.valid_metrics = MetricTracker(keys_loss + keys_iter, keys_epoch, writer=self.writer) 46 | 47 | # check whether mixup is being used 48 | self.mixup = kwargs.get('mixup', False) 49 | self.probability_mixup = kwargs.get('mixup_p', False) 50 | 51 | # init wandb model watch 52 | wandb.watch(self.models['model']) 53 | 54 | # init history logging of best/worst metrics 55 | self.best_val_metrics = {} 56 | self.best_val_metrics['val_loss'] = {'min': 99999999, 'max': -99999999} 57 | for index, row in self.valid_metrics.metrics_epoch.iterrows(): 58 | self.best_val_metrics['val_'+index] = {'min': 99999999, 'max': -99999999} 59 | 60 | ### Calculate the learning rate 61 | if kwargs.get('find_lr'): 62 | print('Finding optimal LR...') 63 | self.backup_model= copy.deepcopy(self.models['model'].state_dict()) 64 | self.backup_opt = copy.deepcopy(self.optimizers['model'].state_dict()) 65 | if kwargs.get('max_lr'): 66 | max_lr = kwargs['max_lr'] 67 | else: 68 | max_lr = np.median([self._find_lr() for i in range(7)]) 69 | 70 | self.lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(self.optimizers['model'], 71 | max_lr=max_lr, 72 | steps_per_epoch=len(self.train_data_loaders['data']), 73 | epochs=self.epochs) 74 | print(f'Max LR: {max_lr}') 75 | wandb.run.summary['OneCycle Max LR'] = max_lr 76 | 77 | # learning rate schedulers 78 | self.do_lr_scheduling = kwargs.get('find_lr') 79 | 80 | 81 | def mixup_criterion(self, pred, y_a, y_b, lam): 82 | return lam * self.criterion(pred, y_a) + (1 - lam) * self.criterion(pred, y_b) 83 | 84 | 85 | def _train_epoch(self, epoch): 86 | """ 87 | Training logic for an epoch 88 | 89 | :param epoch: Integer, current training epoch. 90 | :return: A log that contains average loss and metric in this epoch. 91 | """ 92 | start = time.time() 93 | self.models['model'].train() 94 | self.train_metrics.reset() 95 | if len(self.metrics_epoch) > 0: 96 | outputs = torch.FloatTensor().to(self.device) 97 | targets = torch.FloatTensor().to(self.device) 98 | for batch_idx, (data, target) in enumerate(self.train_data_loaders['data']): 99 | if isinstance(data, dict): 100 | data = {k: v.to(self.device) for k, v in data.items()} 101 | else: 102 | data = data.to(self.device) 103 | 104 | target = target.to(self.device) 105 | 106 | # mixup 107 | instance_mixup = 0 108 | if self.mixup: 109 | if self.probability_mixup is not None: 110 | if (1 - self.probability_mixup) > random.random(): 111 | data_mixup, targets_a, targets_b, lam = mixup_data(data, target, 0.2, True) 112 | # data = {"data": data_mixup} 113 | instance_mixup = 1 114 | else: 115 | instance_mixup = 0 116 | else: 117 | data_mixup, targets_a, targets_b, lam = mixup_data(data, target, 0.2, True) 118 | # data = {"data": data_mixup} 119 | instance_mixup = 1 120 | 121 | self.optimizers['model'].zero_grad() 122 | output = self.models['model'](data) 123 | if len(self.metrics_epoch) > 0: 124 | outputs = torch.cat((outputs, output)) 125 | targets = torch.cat((targets, target)) 126 | 127 | if self.mixup and instance_mixup: 128 | loss = self.mixup_criterion(output, targets_a, targets_b, lam) 129 | else: 130 | loss = self.criterion(output, target) 131 | 132 | loss.backward() 133 | self.optimizers['model'].step() 134 | 135 | if self.do_lr_scheduling: 136 | self.lr_scheduler.step() 137 | 138 | # log loss and lr 139 | wandb.log({'loss': loss}) 140 | wandb.log({'learning_rate': self.optimizers['model'].param_groups[0]['lr']}) 141 | 142 | self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) 143 | self.train_metrics.iter_update('loss', loss.item()) 144 | for met in self.metrics_iter: 145 | self.train_metrics.iter_update(met.__name__, met(output, target)) 146 | 147 | if batch_idx % self.log_step == 0: 148 | epoch_debug = f"Train Epoch: {epoch} {self._progress(batch_idx)} " 149 | current_metrics = self.train_metrics.current() 150 | metrics_debug = ", ".join(f"{key}: {value:.6f}" for key, value in current_metrics.items()) 151 | self.logger.debug(epoch_debug + metrics_debug) 152 | # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) 153 | 154 | if batch_idx == self.len_epoch: 155 | break 156 | 157 | for met in self.metrics_epoch: 158 | self.train_metrics.epoch_update(met.__name__, met(outputs, targets)) 159 | # log the training metrics 160 | wandb.log({met.__name__: met(outputs, targets)}) 161 | 162 | train_log = self.train_metrics.result() 163 | 164 | wandb.run.summary['epochs_trained'] = epoch 165 | 166 | if self.do_validation: 167 | valid_log = self._valid_epoch(epoch) 168 | valid_log.set_index('val_' + valid_log.index.astype(str), inplace=True) 169 | 170 | # log the validation metrics 171 | wandb.log({str(index) : row['mean'] for index, row in valid_log.iterrows()}) 172 | 173 | # update best/worst metric results 174 | for index, row in valid_log.iterrows(): 175 | if index == 'val_loss': 176 | if row['mean'] < self.best_val_metrics['val_loss']['min']: 177 | self.best_val_metrics['val_loss']['min'] = row['mean'] 178 | if row['mean'] > self.best_val_metrics['val_loss']['max']: 179 | self.best_val_metrics['val_loss']['max'] = row['mean'] 180 | 181 | wandb.run.summary['lowest_val_loss'] = self.best_val_metrics['val_loss']['min'] 182 | wandb.run.summary['highest_val_loss'] = self.best_val_metrics['val_loss']['max'] 183 | 184 | if index in self.best_val_metrics: 185 | if row['mean'] < self.best_val_metrics[index]['min']: 186 | self.best_val_metrics[index]['min'] = row['mean'] 187 | if row['mean'] > self.best_val_metrics[index]['max']: 188 | self.best_val_metrics[index]['max'] = row['mean'] 189 | 190 | wandb.run.summary['lowest_'+index] = self.best_val_metrics[index]['min'] 191 | wandb.run.summary['highest_'+index] = self.best_val_metrics[index]['max'] 192 | 193 | # if self.do_lr_scheduling: 194 | # self.lr_scheduler.step() 195 | 196 | log = pd.concat([train_log, valid_log]) 197 | end = time.time() 198 | ty_res = time.gmtime(end - start) 199 | res = time.strftime("%H hours, %M minutes, %S seconds", ty_res) 200 | epoch_log = {'epochs': epoch, 201 | 'iterations': self.len_epoch * epoch, 202 | 'Runtime': res} 203 | epoch_info = ', '.join(f"{key}: {value}" for key, value in epoch_log.items()) 204 | logger_info = f"{epoch_info}\n{log}" 205 | self.logger.info(logger_info) 206 | 207 | return log 208 | 209 | def _valid_epoch(self, epoch): 210 | """ 211 | Validate after training an epoch 212 | 213 | :param epoch: Integer, current training epoch. 214 | :return: A log that contains information about validation 215 | """ 216 | self.models['model'].eval() 217 | self.valid_metrics.reset() 218 | with torch.no_grad(): 219 | if len(self.metrics_epoch) > 0: 220 | outputs = torch.FloatTensor().to(self.device) 221 | targets = torch.FloatTensor().to(self.device) 222 | for batch_idx, (data, target) in enumerate(self.valid_data_loaders['data']): 223 | if isinstance(data, dict): 224 | data = {k: v.to(self.device) for k, v in data.items()} 225 | else: 226 | data = data.to(self.device) 227 | 228 | target = target.to(self.device) 229 | 230 | output = self.models['model'](data) 231 | loss = self.criterion(output, target) 232 | if len(self.metrics_epoch) > 0: 233 | outputs = torch.cat((outputs, output)) 234 | targets = torch.cat((targets, target)) 235 | 236 | self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx, 'valid') 237 | self.valid_metrics.iter_update('loss', loss.item()) 238 | for met in self.metrics_iter: 239 | self.valid_metrics.iter_update(met.__name__, met(output, target)) 240 | # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) 241 | 242 | for met in self.metrics_epoch: 243 | self.valid_metrics.epoch_update(met.__name__, met(outputs, targets)) 244 | 245 | # # add histogram of model parameters to the tensorboard 246 | # for name, param in self.models['model'].named_parameters(): 247 | # self.writer.add_histogram(name, param, bins='auto') 248 | 249 | valid_log = self.valid_metrics.result() 250 | 251 | return valid_log 252 | 253 | def _find_lr(self): 254 | lrs = np.logspace(-7, 2, base=10, num=50) 255 | losses = [] 256 | 257 | lr_idx = 0 258 | 259 | self.models['model'].train() 260 | 261 | while lr_idx < len(lrs): 262 | for batch_idx, (data, target) in enumerate(self.train_data_loaders['data']): 263 | if lr_idx == len(lrs): 264 | break 265 | 266 | lr = lrs[lr_idx] 267 | self.optimizers['model'].param_groups[0]['lr'] = lr 268 | 269 | if isinstance(data, dict): 270 | data = {k: v.to(self.device) for k, v in data.items()} 271 | else: 272 | data = data.to(self.device) 273 | 274 | target = target.to(self.device) 275 | 276 | self.optimizers['model'].zero_grad() 277 | output = self.models['model'](data) 278 | loss = self.criterion(output, target) 279 | loss.backward() 280 | self.optimizers['model'].step() 281 | 282 | losses += [loss.item()] 283 | 284 | lr_idx += 1 285 | 286 | best_idx = np.argmin(losses) 287 | best_loss = losses[best_idx] 288 | best_lr = lrs[best_idx] 289 | rec_lr = best_lr / 10.0 290 | 291 | self.logger.debug(f'Best LR: {best_lr} Recommended 1-Cycle max_lr: {rec_lr}') 292 | 293 | # self.models['model'].eval() 294 | self.models['model'].load_state_dict(self.backup_model) 295 | self.optimizers['model'].load_state_dict(self.backup_opt) 296 | 297 | return rec_lr 298 | 299 | def _progress(self, batch_idx): 300 | ratio = '[{}/{} ({:.0f}%)]' 301 | return ratio.format(batch_idx, self.len_epoch, 100.0 * batch_idx / self.len_epoch) 302 | -------------------------------------------------------------------------------- /workspace/trainers/trainer_mixup_finetune.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import random 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | from torchvision.utils import make_grid 8 | 9 | from base import BaseTrainer 10 | from models.metric import MetricTracker 11 | from utils import inf_loop, mixup_data 12 | import copy 13 | import wandb 14 | 15 | 16 | class DCASETask1BTrainerWithMixup(BaseTrainer): 17 | """ 18 | Trainer class 19 | """ 20 | def __init__(self, torch_args: dict, save_dir, resume, device, **kwargs): 21 | self.device = device 22 | super().__init__(torch_args, save_dir, **kwargs) 23 | 24 | if resume is not None: 25 | self._resume_checkpoint(resume, finetune=self.finetune) 26 | 27 | # data_loaders 28 | self.do_validation = self.valid_data_loaders['data'] is not None 29 | if self.len_epoch is None: 30 | # epoch-based training 31 | self.len_epoch = len(self.train_data_loaders['data']) 32 | else: 33 | # iteration-based training 34 | self.train_data_loaders['data'] = inf_loop(self.train_data_loaders['data']) 35 | self.log_step = int(np.sqrt(self.train_data_loaders['data'].batch_size)) 36 | 37 | # losses 38 | self.criterion = self.losses['loss'] 39 | 40 | # metrics 41 | keys_loss = ['loss'] 42 | keys_iter = [m.__name__ for m in self.metrics_iter] 43 | keys_epoch = [m.__name__ for m in self.metrics_epoch] 44 | self.train_metrics = MetricTracker(keys_loss + keys_iter, keys_epoch, writer=self.writer) 45 | self.valid_metrics = MetricTracker(keys_loss + keys_iter, keys_epoch, writer=self.writer) 46 | 47 | # check whether mixup is being used 48 | self.mixup = kwargs.get('mixup', False) 49 | self.probability_mixup = kwargs.get('mixup_p', False) 50 | 51 | # init wandb model watch 52 | wandb.watch(self.models['model']) 53 | 54 | # init history logging of best/worst metrics 55 | self.best_val_metrics = {} 56 | self.best_val_metrics['val_loss'] = {'min': 99999999, 'max': -99999999} 57 | for index, row in self.valid_metrics.metrics_epoch.iterrows(): 58 | self.best_val_metrics['val_'+index] = {'min': 99999999, 'max': -99999999} 59 | 60 | ### Load model weights and set grad to false for feature extractors 61 | spec_weights = torch.load("saved/dcase_task1b_1sec_spec/Spec Static LR/model/model_best.pth")['models']['model'] 62 | wave_weights = torch.load("saved/dcase_task1b_1sec_wave/Wave Static LR/model/model_best.pth")['models']['model'] 63 | 64 | from collections import OrderedDict 65 | new_spec = OrderedDict() 66 | 67 | for key, value in spec_weights.items(): 68 | if "tail" not in key: 69 | new_key = "spec." + key 70 | new_spec[new_key] = value 71 | 72 | new_wave = OrderedDict() 73 | for key, value in wave_weights.items(): 74 | if "tail" not in key: 75 | new_key = "wave." + key 76 | new_wave[new_key] = value 77 | 78 | self.models['model'].load_state_dict(new_spec, strict=False) 79 | self.models['model'].load_state_dict(new_wave, strict=False) 80 | 81 | for name, param in self.models['model'].named_parameters(): 82 | if "wave" in name or "spec" in name: 83 | param.requires_grad = False 84 | self.logger.debug("Set weights of feature extractors and set requires_grad to false.") 85 | 86 | ### Calculate the learning rate 87 | if kwargs.get('find_lr'): 88 | print('Finding optimal LR...') 89 | self.backup_model= copy.deepcopy(self.models['model'].state_dict()) 90 | self.backup_opt = copy.deepcopy(self.optimizers['model'].state_dict()) 91 | if kwargs.get('max_lr'): 92 | max_lr = kwargs['max_lr'] 93 | else: 94 | max_lr = np.median([self._find_lr() for i in range(7)]) 95 | 96 | self.lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(self.optimizers['model'], 97 | max_lr=max_lr, 98 | steps_per_epoch=len(self.train_data_loaders['data']), 99 | epochs=self.epochs) 100 | print(f'Max LR: {max_lr}') 101 | wandb.run.summary['OneCycle Max LR'] = max_lr 102 | 103 | # learning rate schedulers 104 | self.do_lr_scheduling = kwargs.get('find_lr') 105 | 106 | 107 | def mixup_criterion(self, pred, y_a, y_b, lam): 108 | return lam * self.criterion(pred, y_a) + (1 - lam) * self.criterion(pred, y_b) 109 | 110 | 111 | def _train_epoch(self, epoch): 112 | """ 113 | Training logic for an epoch 114 | 115 | :param epoch: Integer, current training epoch. 116 | :return: A log that contains average loss and metric in this epoch. 117 | """ 118 | start = time.time() 119 | self.models['model'].train() 120 | self.train_metrics.reset() 121 | if len(self.metrics_epoch) > 0: 122 | outputs = torch.FloatTensor().to(self.device) 123 | targets = torch.FloatTensor().to(self.device) 124 | for batch_idx, (data, target) in enumerate(self.train_data_loaders['data']): 125 | if isinstance(data, dict): 126 | data = {k: v.to(self.device) for k, v in data.items()} 127 | else: 128 | data = data.to(self.device) 129 | 130 | target = target.to(self.device) 131 | 132 | # mixup 133 | instance_mixup = 0 134 | if self.mixup: 135 | if self.probability_mixup is not None: 136 | if (1 - self.probability_mixup) > random.random(): 137 | data_mixup, targets_a, targets_b, lam = mixup_data(data, target, 0.2, True) 138 | # data = {"data": data_mixup} 139 | instance_mixup = 1 140 | else: 141 | instance_mixup = 0 142 | else: 143 | data_mixup, targets_a, targets_b, lam = mixup_data(data, target, 0.2, True) 144 | # data = {"data": data_mixup} 145 | instance_mixup = 1 146 | 147 | self.optimizers['model'].zero_grad() 148 | output = self.models['model'](data) 149 | if len(self.metrics_epoch) > 0: 150 | outputs = torch.cat((outputs, output)) 151 | targets = torch.cat((targets, target)) 152 | 153 | if self.mixup and instance_mixup: 154 | loss = self.mixup_criterion(output, targets_a, targets_b, lam) 155 | else: 156 | loss = self.criterion(output, target) 157 | 158 | loss.backward() 159 | self.optimizers['model'].step() 160 | 161 | if self.do_lr_scheduling: 162 | self.lr_scheduler.step() 163 | 164 | # log loss and lr 165 | wandb.log({'loss': loss}) 166 | wandb.log({'learning_rate': self.optimizers['model'].param_groups[0]['lr']}) 167 | 168 | self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) 169 | self.train_metrics.iter_update('loss', loss.item()) 170 | for met in self.metrics_iter: 171 | self.train_metrics.iter_update(met.__name__, met(output, target)) 172 | 173 | if batch_idx % self.log_step == 0: 174 | epoch_debug = f"Train Epoch: {epoch} {self._progress(batch_idx)} " 175 | current_metrics = self.train_metrics.current() 176 | metrics_debug = ", ".join(f"{key}: {value:.6f}" for key, value in current_metrics.items()) 177 | self.logger.debug(epoch_debug + metrics_debug) 178 | # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) 179 | 180 | if batch_idx == self.len_epoch: 181 | break 182 | 183 | for met in self.metrics_epoch: 184 | self.train_metrics.epoch_update(met.__name__, met(outputs, targets)) 185 | # log the training metrics 186 | wandb.log({met.__name__: met(outputs, targets)}) 187 | 188 | train_log = self.train_metrics.result() 189 | 190 | wandb.run.summary['epochs_trained'] = epoch 191 | 192 | if self.do_validation: 193 | valid_log = self._valid_epoch(epoch) 194 | valid_log.set_index('val_' + valid_log.index.astype(str), inplace=True) 195 | 196 | # log the validation metrics 197 | wandb.log({str(index) : row['mean'] for index, row in valid_log.iterrows()}) 198 | 199 | # update best/worst metric results 200 | for index, row in valid_log.iterrows(): 201 | if index == 'val_loss': 202 | if row['mean'] < self.best_val_metrics['val_loss']['min']: 203 | self.best_val_metrics['val_loss']['min'] = row['mean'] 204 | if row['mean'] > self.best_val_metrics['val_loss']['max']: 205 | self.best_val_metrics['val_loss']['max'] = row['mean'] 206 | 207 | wandb.run.summary['lowest_val_loss'] = self.best_val_metrics['val_loss']['min'] 208 | wandb.run.summary['highest_val_loss'] = self.best_val_metrics['val_loss']['max'] 209 | 210 | if index in self.best_val_metrics: 211 | if row['mean'] < self.best_val_metrics[index]['min']: 212 | self.best_val_metrics[index]['min'] = row['mean'] 213 | if row['mean'] > self.best_val_metrics[index]['max']: 214 | self.best_val_metrics[index]['max'] = row['mean'] 215 | 216 | wandb.run.summary['lowest_'+index] = self.best_val_metrics[index]['min'] 217 | wandb.run.summary['highest_'+index] = self.best_val_metrics[index]['max'] 218 | 219 | # if self.do_lr_scheduling: 220 | # self.lr_scheduler.step() 221 | 222 | log = pd.concat([train_log, valid_log]) 223 | end = time.time() 224 | ty_res = time.gmtime(end - start) 225 | res = time.strftime("%H hours, %M minutes, %S seconds", ty_res) 226 | epoch_log = {'epochs': epoch, 227 | 'iterations': self.len_epoch * epoch, 228 | 'Runtime': res} 229 | epoch_info = ', '.join(f"{key}: {value}" for key, value in epoch_log.items()) 230 | logger_info = f"{epoch_info}\n{log}" 231 | self.logger.info(logger_info) 232 | 233 | return log 234 | 235 | def _valid_epoch(self, epoch): 236 | """ 237 | Validate after training an epoch 238 | 239 | :param epoch: Integer, current training epoch. 240 | :return: A log that contains information about validation 241 | """ 242 | self.models['model'].eval() 243 | self.valid_metrics.reset() 244 | with torch.no_grad(): 245 | if len(self.metrics_epoch) > 0: 246 | outputs = torch.FloatTensor().to(self.device) 247 | targets = torch.FloatTensor().to(self.device) 248 | for batch_idx, (data, target) in enumerate(self.valid_data_loaders['data']): 249 | if isinstance(data, dict): 250 | data = {k: v.to(self.device) for k, v in data.items()} 251 | else: 252 | data = data.to(self.device) 253 | 254 | target = target.to(self.device) 255 | 256 | output = self.models['model'](data) 257 | loss = self.criterion(output, target) 258 | if len(self.metrics_epoch) > 0: 259 | outputs = torch.cat((outputs, output)) 260 | targets = torch.cat((targets, target)) 261 | 262 | self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx, 'valid') 263 | self.valid_metrics.iter_update('loss', loss.item()) 264 | for met in self.metrics_iter: 265 | self.valid_metrics.iter_update(met.__name__, met(output, target)) 266 | # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) 267 | 268 | for met in self.metrics_epoch: 269 | self.valid_metrics.epoch_update(met.__name__, met(outputs, targets)) 270 | 271 | # # add histogram of model parameters to the tensorboard 272 | # for name, param in self.models['model'].named_parameters(): 273 | # self.writer.add_histogram(name, param, bins='auto') 274 | 275 | valid_log = self.valid_metrics.result() 276 | 277 | return valid_log 278 | 279 | def _find_lr(self): 280 | lrs = np.logspace(-7, 2, base=10, num=50) 281 | losses = [] 282 | 283 | lr_idx = 0 284 | 285 | self.models['model'].train() 286 | 287 | while lr_idx < len(lrs): 288 | for batch_idx, (data, target) in enumerate(self.train_data_loaders['data']): 289 | if lr_idx == len(lrs): 290 | break 291 | 292 | lr = lrs[lr_idx] 293 | self.optimizers['model'].param_groups[0]['lr'] = lr 294 | 295 | if isinstance(data, dict): 296 | data = {k: v.to(self.device) for k, v in data.items()} 297 | else: 298 | data = data.to(self.device) 299 | 300 | target = target.to(self.device) 301 | 302 | self.optimizers['model'].zero_grad() 303 | output = self.models['model'](data) 304 | loss = self.criterion(output, target) 305 | loss.backward() 306 | self.optimizers['model'].step() 307 | 308 | losses += [loss.item()] 309 | 310 | lr_idx += 1 311 | 312 | best_idx = np.argmin(losses) 313 | best_loss = losses[best_idx] 314 | best_lr = lrs[best_idx] 315 | rec_lr = best_lr / 10.0 316 | 317 | self.logger.debug(f'Best LR: {best_lr} Recommended 1-Cycle max_lr: {rec_lr}') 318 | 319 | # self.models['model'].eval() 320 | self.models['model'].load_state_dict(self.backup_model) 321 | self.optimizers['model'].load_state_dict(self.backup_opt) 322 | 323 | return rec_lr 324 | 325 | def _progress(self, batch_idx): 326 | ratio = '[{}/{} ({:.0f}%)]' 327 | return ratio.format(batch_idx, self.len_epoch, 100.0 * batch_idx / self.len_epoch) 328 | -------------------------------------------------------------------------------- /workspace/models/modules/sincnet/dnn_models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | import sys 6 | from torch.autograd import Variable 7 | import math 8 | 9 | def flip(x, dim): 10 | xsize = x.size() 11 | dim = x.dim() + dim if dim < 0 else dim 12 | x = x.contiguous() 13 | x = x.view(-1, *xsize[dim:]) 14 | x = x.view(x.size(0), x.size(1), -1)[:, getattr(torch.arange(x.size(1)-1, 15 | -1, -1), ('cpu','cuda')[x.is_cuda])().long(), :] 16 | return x.view(xsize) 17 | 18 | 19 | def sinc(band,t_right): 20 | y_right= torch.sin(2*math.pi*band*t_right)/(2*math.pi*band*t_right) 21 | y_left= flip(y_right,0) 22 | 23 | y=torch.cat([y_left,Variable(torch.ones(1)).cuda(),y_right]) 24 | 25 | return y 26 | 27 | 28 | class SincConv_fast(nn.Module): 29 | """Sinc-based convolution 30 | Parameters 31 | ---------- 32 | in_channels : `int` 33 | Number of input channels. Must be 1. 34 | out_channels : `int` 35 | Number of filters. 36 | kernel_size : `int` 37 | Filter length. 38 | sample_rate : `int`, optional 39 | Sample rate. Defaults to 16000. 40 | Usage 41 | ----- 42 | See `torch.nn.Conv1d` 43 | Reference 44 | --------- 45 | Mirco Ravanelli, Yoshua Bengio, 46 | "Speaker Recognition from raw waveform with SincNet". 47 | https://arxiv.org/abs/1808.00158 48 | """ 49 | 50 | @staticmethod 51 | def to_mel(hz): 52 | return 2595 * np.log10(1 + hz / 700) 53 | 54 | @staticmethod 55 | def to_hz(mel): 56 | return 700 * (10 ** (mel / 2595) - 1) 57 | 58 | def __init__(self, out_channels, kernel_size, sample_rate=16000, in_channels=1, 59 | stride=1, padding=0, dilation=1, bias=False, groups=1, min_low_hz=50, min_band_hz=50): 60 | 61 | super(SincConv_fast,self).__init__() 62 | 63 | if in_channels != 1: 64 | #msg = (f'SincConv only support one input channel ' 65 | # f'(here, in_channels = {in_channels:d}).') 66 | msg = "SincConv only support one input channel (here, in_channels = {%i})" % (in_channels) 67 | raise ValueError(msg) 68 | 69 | self.out_channels = out_channels 70 | self.kernel_size = kernel_size 71 | 72 | # Forcing the filters to be odd (i.e, perfectly symmetrics) 73 | if kernel_size%2==0: 74 | self.kernel_size=self.kernel_size+1 75 | 76 | self.stride = stride 77 | self.padding = padding 78 | self.dilation = dilation 79 | 80 | if bias: 81 | raise ValueError('SincConv does not support bias.') 82 | if groups > 1: 83 | raise ValueError('SincConv does not support groups.') 84 | 85 | self.sample_rate = sample_rate 86 | self.min_low_hz = min_low_hz 87 | self.min_band_hz = min_band_hz 88 | 89 | # initialize filterbanks such that they are equally spaced in Mel scale 90 | low_hz = 30 91 | high_hz = self.sample_rate / 2 - (self.min_low_hz + self.min_band_hz) 92 | 93 | mel = np.linspace(self.to_mel(low_hz), 94 | self.to_mel(high_hz), 95 | self.out_channels + 1) 96 | hz = self.to_hz(mel) 97 | 98 | 99 | # filter lower frequency (out_channels, 1) 100 | self.low_hz_ = nn.Parameter(torch.Tensor(hz[:-1]).view(-1, 1)) 101 | 102 | # filter frequency band (out_channels, 1) 103 | self.band_hz_ = nn.Parameter(torch.Tensor(np.diff(hz)).view(-1, 1)) 104 | 105 | # Hamming window 106 | #self.window_ = torch.hamming_window(self.kernel_size) 107 | n_lin=torch.linspace(0, (self.kernel_size/2)-1, steps=int((self.kernel_size/2))) # computing only half of the window 108 | self.window_=0.54-0.46*torch.cos(2*math.pi*n_lin/self.kernel_size); 109 | 110 | 111 | # (1, kernel_size/2) 112 | n = (self.kernel_size - 1) / 2.0 113 | self.n_ = 2*math.pi*torch.arange(-n, 0).view(1, -1) / self.sample_rate # Due to symmetry, I only need half of the time axes 114 | 115 | 116 | 117 | 118 | def forward(self, waveforms): 119 | """ 120 | Parameters 121 | ---------- 122 | waveforms : `torch.Tensor` (batch_size, 1, n_samples) 123 | Batch of waveforms. 124 | Returns 125 | ------- 126 | features : `torch.Tensor` (batch_size, out_channels, n_samples_out) 127 | Batch of sinc filters activations. 128 | """ 129 | 130 | self.n_ = self.n_.to(waveforms.device) 131 | 132 | self.window_ = self.window_.to(waveforms.device) 133 | 134 | low = self.min_low_hz + torch.abs(self.low_hz_) 135 | 136 | high = torch.clamp(low + self.min_band_hz + torch.abs(self.band_hz_),self.min_low_hz,self.sample_rate/2) 137 | band=(high-low)[:,0] 138 | 139 | f_times_t_low = torch.matmul(low, self.n_) 140 | f_times_t_high = torch.matmul(high, self.n_) 141 | 142 | band_pass_left=((torch.sin(f_times_t_high)-torch.sin(f_times_t_low))/(self.n_/2))*self.window_ # Equivalent of Eq.4 of the reference paper (SPEAKER RECOGNITION FROM RAW WAVEFORM WITH SINCNET). I just have expanded the sinc and simplified the terms. This way I avoid several useless computations. 143 | band_pass_center = 2*band.view(-1,1) 144 | band_pass_right= torch.flip(band_pass_left,dims=[1]) 145 | 146 | 147 | band_pass=torch.cat([band_pass_left,band_pass_center,band_pass_right],dim=1) 148 | 149 | 150 | band_pass = band_pass / (2*band[:,None]) 151 | 152 | 153 | self.filters = (band_pass).view( 154 | self.out_channels, 1, self.kernel_size) 155 | 156 | return F.conv1d(waveforms, self.filters, stride=self.stride, 157 | padding=self.padding, dilation=self.dilation, 158 | bias=None, groups=1) 159 | 160 | def impulse_response(self): 161 | self.n_ = self.n_.cpu() 162 | 163 | self.window_ = self.window_.cpu() 164 | 165 | low = self.min_low_hz + torch.abs(self.low_hz_) 166 | 167 | high = torch.clamp(low + self.min_band_hz + torch.abs(self.band_hz_),self.min_low_hz,self.sample_rate/2) 168 | band=(high-low)[:,0] 169 | 170 | f_times_t_low = torch.matmul(low, self.n_) 171 | f_times_t_high = torch.matmul(high, self.n_) 172 | 173 | band_pass_left=((torch.sin(f_times_t_high)-torch.sin(f_times_t_low))/(self.n_/2))*self.window_ # Equivalent of Eq.4 of the reference paper (SPEAKER RECOGNITION FROM RAW WAVEFORM WITH SINCNET). I just have expanded the sinc and simplified the terms. This way I avoid several useless computations. 174 | band_pass_center = 2*band.view(-1,1) 175 | band_pass_right= torch.flip(band_pass_left,dims=[1]) 176 | 177 | 178 | band_pass=torch.cat([band_pass_left,band_pass_center,band_pass_right],dim=1) 179 | 180 | 181 | band_pass = band_pass / (2*band[:,None]) 182 | 183 | self.filters = (band_pass).view( 184 | self.out_channels, 1, self.kernel_size) 185 | 186 | return self.filters 187 | 188 | 189 | 190 | class sinc_conv(nn.Module): 191 | 192 | def __init__(self, N_filt,Filt_dim,fs): 193 | super(sinc_conv,self).__init__() 194 | 195 | # Mel Initialization of the filterbanks 196 | low_freq_mel = 80 197 | high_freq_mel = (2595 * np.log10(1 + (fs / 2) / 700)) # Convert Hz to Mel 198 | mel_points = np.linspace(low_freq_mel, high_freq_mel, N_filt) # Equally spaced in Mel scale 199 | f_cos = (700 * (10**(mel_points / 2595) - 1)) # Convert Mel to Hz 200 | b1=np.roll(f_cos,1) 201 | b2=np.roll(f_cos,-1) 202 | b1[0]=30 203 | b2[-1]=(fs/2)-100 204 | 205 | self.freq_scale=fs*1.0 206 | self.filt_b1 = nn.Parameter(torch.from_numpy(b1/self.freq_scale)) 207 | self.filt_band = nn.Parameter(torch.from_numpy((b2-b1)/self.freq_scale)) 208 | 209 | 210 | self.N_filt=N_filt 211 | self.Filt_dim=Filt_dim 212 | self.fs=fs 213 | 214 | 215 | def forward(self, x): 216 | 217 | filters=Variable(torch.zeros((self.N_filt,self.Filt_dim))).cuda() 218 | N=self.Filt_dim 219 | t_right=Variable(torch.linspace(1, (N-1)/2, steps=int((N-1)/2))/self.fs).cuda() 220 | 221 | 222 | min_freq=50.0; 223 | min_band=50.0; 224 | 225 | filt_beg_freq=torch.abs(self.filt_b1)+min_freq/self.freq_scale 226 | filt_end_freq=filt_beg_freq+(torch.abs(self.filt_band)+min_band/self.freq_scale) 227 | 228 | n=torch.linspace(0, N, steps=N) 229 | 230 | # Filter window (hamming) 231 | window=0.54-0.46*torch.cos(2*math.pi*n/N); 232 | window=Variable(window.float().cuda()) 233 | 234 | 235 | for i in range(self.N_filt): 236 | 237 | low_pass1 = 2*filt_beg_freq[i].float()*sinc(filt_beg_freq[i].float()*self.freq_scale,t_right) 238 | low_pass2 = 2*filt_end_freq[i].float()*sinc(filt_end_freq[i].float()*self.freq_scale,t_right) 239 | band_pass=(low_pass2-low_pass1) 240 | 241 | band_pass=band_pass/torch.max(band_pass) 242 | 243 | filters[i,:]=band_pass.cuda()*window 244 | 245 | out=F.conv1d(x, filters.view(self.N_filt,1,self.Filt_dim)) 246 | 247 | return out 248 | 249 | 250 | def act_fun(act_type): 251 | 252 | if act_type=="relu": 253 | return nn.ReLU() 254 | 255 | if act_type=="tanh": 256 | return nn.Tanh() 257 | 258 | if act_type=="sigmoid": 259 | return nn.Sigmoid() 260 | 261 | if act_type=="leaky_relu": 262 | return nn.LeakyReLU(0.2) 263 | 264 | if act_type=="elu": 265 | return nn.ELU() 266 | 267 | if act_type=="softmax": 268 | return nn.LogSoftmax(dim=1) 269 | 270 | if act_type=="linear": 271 | return nn.LeakyReLU(1) # initializzed like this, but not used in forward! 272 | 273 | 274 | class LayerNorm(nn.Module): 275 | 276 | def __init__(self, features, eps=1e-6): 277 | super(LayerNorm,self).__init__() 278 | self.gamma = nn.Parameter(torch.ones(features)) 279 | self.beta = nn.Parameter(torch.zeros(features)) 280 | self.eps = eps 281 | 282 | def forward(self, x): 283 | mean = x.mean(-1, keepdim=True) 284 | std = x.std(-1, keepdim=True) 285 | return self.gamma * (x - mean) / (std + self.eps) + self.beta 286 | 287 | 288 | class MLP(nn.Module): 289 | def __init__(self, options): 290 | super(MLP, self).__init__() 291 | 292 | self.input_dim=int(options['input_dim']) 293 | self.fc_lay=options['fc_lay'] 294 | self.fc_drop=options['fc_drop'] 295 | self.fc_use_batchnorm=options['fc_use_batchnorm'] 296 | self.fc_use_laynorm=options['fc_use_laynorm'] 297 | self.fc_use_laynorm_inp=options['fc_use_laynorm_inp'] 298 | self.fc_use_batchnorm_inp=options['fc_use_batchnorm_inp'] 299 | self.fc_act=options['fc_act'] 300 | 301 | 302 | self.wx = nn.ModuleList([]) 303 | self.bn = nn.ModuleList([]) 304 | self.ln = nn.ModuleList([]) 305 | self.act = nn.ModuleList([]) 306 | self.drop = nn.ModuleList([]) 307 | 308 | 309 | 310 | # input layer normalization 311 | if self.fc_use_laynorm_inp: 312 | self.ln0=LayerNorm(self.input_dim) 313 | 314 | # input batch normalization 315 | if self.fc_use_batchnorm_inp: 316 | self.bn0=nn.BatchNorm1d([self.input_dim],momentum=0.05) 317 | 318 | 319 | self.N_fc_lay=len(self.fc_lay) 320 | 321 | current_input=self.input_dim 322 | 323 | # Initialization of hidden layers 324 | 325 | for i in range(self.N_fc_lay): 326 | 327 | # dropout 328 | self.drop.append(nn.Dropout(p=self.fc_drop[i])) 329 | 330 | # activation 331 | if self.fc_act is not None: 332 | self.act.append(act_fun(self.fc_act[i])) 333 | 334 | 335 | add_bias=True 336 | 337 | # layer norm initialization 338 | self.ln.append(LayerNorm(self.fc_lay[i])) 339 | self.bn.append(nn.BatchNorm1d(self.fc_lay[i],momentum=0.05)) 340 | 341 | if self.fc_use_laynorm[i] or self.fc_use_batchnorm[i]: 342 | add_bias=False 343 | 344 | 345 | # Linear operations 346 | self.wx.append(nn.Linear(current_input, self.fc_lay[i],bias=add_bias)) 347 | 348 | # weight initialization 349 | self.wx[i].weight = torch.nn.Parameter(torch.Tensor(self.fc_lay[i],current_input).uniform_(-np.sqrt(0.01/(current_input+self.fc_lay[i])),np.sqrt(0.01/(current_input+self.fc_lay[i])))) 350 | self.wx[i].bias = torch.nn.Parameter(torch.zeros(self.fc_lay[i])) 351 | 352 | current_input=self.fc_lay[i] 353 | 354 | 355 | def forward(self, x): 356 | 357 | # Applying Layer/Batch Norm 358 | if bool(self.fc_use_laynorm_inp): 359 | x=self.ln0((x)) 360 | 361 | if bool(self.fc_use_batchnorm_inp): 362 | x=self.bn0((x)) 363 | 364 | for i in range(self.N_fc_lay): 365 | 366 | if self.fc_act is not None and self.fc_act[i]!='linear': 367 | 368 | if self.fc_use_laynorm[i]: 369 | x = self.drop[i](self.act[i](self.ln[i](self.wx[i](x)))) 370 | 371 | if self.fc_use_batchnorm[i]: 372 | x = self.drop[i](self.act[i](self.bn[i](self.wx[i](x)))) 373 | 374 | if self.fc_use_batchnorm[i]==False and self.fc_use_laynorm[i]==False: 375 | x = self.drop[i](self.act[i](self.wx[i](x))) 376 | 377 | else: 378 | if self.fc_use_laynorm[i]: 379 | x = self.drop[i](self.ln[i](self.wx[i](x))) 380 | 381 | if self.fc_use_batchnorm[i]: 382 | x = self.drop[i](self.bn[i](self.wx[i](x))) 383 | 384 | if self.fc_use_batchnorm[i]==False and self.fc_use_laynorm[i]==False: 385 | x = self.drop[i](self.wx[i](x)) 386 | 387 | return x 388 | 389 | 390 | 391 | class SincNet(nn.Module): 392 | 393 | def __init__(self,options): 394 | super(SincNet,self).__init__() 395 | 396 | self.cnn_N_filt=options['cnn_N_filt'] 397 | self.cnn_len_filt=options['cnn_len_filt'] 398 | self.cnn_max_pool_len=options['cnn_max_pool_len'] 399 | 400 | 401 | self.cnn_act=options['cnn_act'] 402 | self.cnn_drop=options['cnn_drop'] 403 | 404 | self.cnn_use_laynorm=options['cnn_use_laynorm'] 405 | self.cnn_use_batchnorm=options['cnn_use_batchnorm'] 406 | self.cnn_use_laynorm_inp=options['cnn_use_laynorm_inp'] 407 | self.cnn_use_batchnorm_inp=options['cnn_use_batchnorm_inp'] 408 | 409 | self.input_dim=int(options['input_dim']) 410 | 411 | self.fs=options['fs'] 412 | 413 | self.N_cnn_lay=len(options['cnn_N_filt']) 414 | self.conv = nn.ModuleList([]) 415 | self.bn = nn.ModuleList([]) 416 | self.ln = nn.ModuleList([]) 417 | self.act = nn.ModuleList([]) 418 | self.drop = nn.ModuleList([]) 419 | 420 | 421 | if self.cnn_use_laynorm_inp: 422 | self.ln0=LayerNorm(self.input_dim) 423 | 424 | if self.cnn_use_batchnorm_inp: 425 | self.bn0=nn.BatchNorm1d([self.input_dim],momentum=0.05) 426 | 427 | current_input=self.input_dim 428 | 429 | for i in range(self.N_cnn_lay): 430 | 431 | N_filt=int(self.cnn_N_filt[i]) 432 | len_filt=int(self.cnn_len_filt[i]) 433 | 434 | # dropout 435 | self.drop.append(nn.Dropout(p=self.cnn_drop[i])) 436 | 437 | # activation 438 | self.act.append(act_fun(self.cnn_act[i])) 439 | 440 | # layer norm initialization 441 | self.ln.append(LayerNorm([N_filt,int((current_input-self.cnn_len_filt[i]+1)/self.cnn_max_pool_len[i])])) 442 | 443 | self.bn.append(nn.BatchNorm1d(N_filt,int((current_input-self.cnn_len_filt[i]+1)/self.cnn_max_pool_len[i]),momentum=0.05)) 444 | 445 | 446 | if i==0: 447 | self.conv.append(SincConv_fast(self.cnn_N_filt[0],self.cnn_len_filt[0],self.fs)) 448 | 449 | else: 450 | self.conv.append(nn.Conv1d(self.cnn_N_filt[i-1], self.cnn_N_filt[i], self.cnn_len_filt[i])) 451 | 452 | current_input=int((current_input-self.cnn_len_filt[i]+1)/self.cnn_max_pool_len[i]) 453 | 454 | 455 | self.out_dim=current_input*N_filt 456 | 457 | 458 | 459 | def forward(self, x): 460 | batch=x.shape[0] 461 | seq_len=x.shape[1] 462 | 463 | if bool(self.cnn_use_laynorm_inp): 464 | x=self.ln0((x)) 465 | 466 | if bool(self.cnn_use_batchnorm_inp): 467 | x=self.bn0((x)) 468 | 469 | x=x.view(batch,1,seq_len) 470 | 471 | 472 | for i in range(self.N_cnn_lay): 473 | 474 | if self.cnn_use_laynorm[i]: 475 | if i==0: 476 | x = self.drop[i](self.act[i](self.ln[i](F.max_pool1d(torch.abs(self.conv[i](x)), self.cnn_max_pool_len[i])))) 477 | else: 478 | x = self.drop[i](self.act[i](self.ln[i](F.max_pool1d(self.conv[i](x), self.cnn_max_pool_len[i])))) 479 | 480 | if self.cnn_use_batchnorm[i]: 481 | x = self.drop[i](self.act[i](self.bn[i](F.max_pool1d(self.conv[i](x), self.cnn_max_pool_len[i])))) 482 | 483 | if self.cnn_use_batchnorm[i]==False and self.cnn_use_laynorm[i]==False: 484 | x = self.drop[i](self.act[i](F.max_pool1d(self.conv[i](x), self.cnn_max_pool_len[i]))) 485 | 486 | 487 | x = x.view(batch,-1) 488 | 489 | return x 490 | 491 | 492 | 493 | 494 | -------------------------------------------------------------------------------- /workspace/models/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from base import BaseModel 4 | import torch 5 | import wandb 6 | # import leaf_audio_pytorch.frontend as frontend 7 | # from leaf_audio_pytorch.postprocessing import log_compression 8 | from models.modules.fusion import MFB 9 | from models.modules.sincnet.dnn_models import * 10 | # from leaf_audio_pytorch.convolution import GaborConv1D 11 | # from leaf_audio_pytorch.frontend import SquaredModulus 12 | # import leaf_audio_pytorch.initializers as initializers 13 | 14 | 15 | class SpectrogramClassifier(BaseModel): 16 | def __init__(self, 17 | n_classes, 18 | n_in_channels=1, 19 | non_linearity='LeakyReLU', 20 | dropout=0.3, 21 | latent_size=2048, 22 | cn_feature_n=[32, 64, 128, 256, 512], 23 | kernel_size=3, 24 | max_pool_kernel=(2,2), 25 | fc_layer_n=[1024, 512], 26 | use_leaf=False, 27 | leaf_sample_rate=48000, 28 | leaf_window_len=42, 29 | leaf_window_stride=40): 30 | super().__init__() 31 | 32 | # test_input_spectrogram = torch.zeros(3, n_in_channels, n_bins, n_frames) 33 | self.use_leaf = use_leaf 34 | 35 | if self.use_leaf: 36 | self.leaf = frontend.Leaf(n_filters=40, sample_rate=leaf_sample_rate, window_len=leaf_window_len, window_stride=leaf_window_stride) 37 | 38 | cn = [] 39 | 40 | for ilb, n_out in enumerate(cn_feature_n): 41 | if ilb == 0: 42 | cn.append(nn.Conv2d(n_in_channels, n_out, kernel_size=kernel_size, padding=kernel_size//2)) 43 | cn.append(nn.BatchNorm2d(n_out)) 44 | cn.append(getattr(nn, non_linearity)()) 45 | cn.append(nn.MaxPool2d(kernel_size=max_pool_kernel)) 46 | else: 47 | cn.append(nn.Conv2d( 48 | cn_feature_n[ilb-1], n_out, 49 | kernel_size=kernel_size, padding=kernel_size//2) 50 | ), 51 | cn.append(nn.BatchNorm2d(n_out)) 52 | cn.append(getattr(nn, non_linearity)()) 53 | cn.append(nn.MaxPool2d(kernel_size=max_pool_kernel)) 54 | 55 | cn.append(nn.AdaptiveAvgPool2d((1,latent_size//cn_feature_n[-1]))) 56 | 57 | self.cnn = nn.Sequential(*cn) 58 | 59 | # _, cn_channels, cn_bins, cn_frames = self.cnn(test_input_spectrogram).shape 60 | # self.fc_in = cn_channels * cn_frames * cn_bins 61 | self.fc_in = latent_size 62 | 63 | fc = [ 64 | nn.Flatten(), 65 | nn.Dropout(dropout) 66 | ] 67 | 68 | for il, n_out in enumerate(fc_layer_n): 69 | n_in = self.fc_in if il == 0 else fc_layer_n[il-1] 70 | fc.append(nn.Linear(n_in, n_out)) 71 | fc.append(getattr(nn, non_linearity)()) 72 | fc.append(nn.BatchNorm1d(n_out)) 73 | fc.append(nn.Dropout(dropout)) 74 | 75 | fc.append(nn.Linear(fc_layer_n[-1], n_classes)) 76 | 77 | self.tail = nn.Sequential(*fc) 78 | 79 | def forward(self, x): 80 | x = x['spec'] 81 | 82 | if self.use_leaf: 83 | x = x.unsqueeze(1) 84 | x = self.leaf(x) 85 | 86 | x = x.unsqueeze(1) 87 | x = self.cnn(x) 88 | x = self.tail(x) 89 | return x 90 | 91 | 92 | class WaveformClassifier(BaseModel): 93 | def __init__(self, 94 | input_length, 95 | num_classes=12, 96 | n_in_channels=1, 97 | non_linearity='LeakyReLU', 98 | dropout=0.3, 99 | latent_size=2048, 100 | cn_feature_n=[32, 64, 128, 256, 512], 101 | max_pool_kernel=8, 102 | kernel_0_size=7, 103 | kernel_size=7, 104 | fc_layer_n=[1024,512]): 105 | 106 | super().__init__() 107 | test_input_waveform = torch.zeros(3, n_in_channels, input_length) 108 | 109 | cn = [] 110 | for ilb, n_out in enumerate(cn_feature_n): 111 | if ilb == 0: 112 | cn.append(nn.Conv1d(n_in_channels, n_out, kernel_size=kernel_0_size, padding=kernel_size//2)) 113 | cn.append(nn.BatchNorm1d(n_out)) 114 | cn.append(getattr(nn, non_linearity)()) 115 | cn.append(nn.MaxPool1d(kernel_size=max_pool_kernel)) 116 | else: 117 | cn.append(nn.Conv1d( 118 | cn_feature_n[ilb-1], n_out, 119 | kernel_size=kernel_size, padding=kernel_size // 2) 120 | ), 121 | cn.append(nn.BatchNorm1d(n_out)) 122 | cn.append(getattr(nn, non_linearity)()) 123 | cn.append(nn.MaxPool1d(kernel_size=max_pool_kernel)) 124 | 125 | cn.append(nn.AdaptiveAvgPool1d(latent_size//cn_feature_n[-1])) 126 | 127 | self.frontend = nn.Sequential(*cn) 128 | 129 | # Build the FC head 130 | _, cn_features, cn_slices = self.frontend(test_input_waveform).shape 131 | 132 | self.fc_in = cn_features * cn_slices 133 | 134 | fc = [ 135 | nn.Flatten(), 136 | nn.Dropout(dropout) 137 | ] 138 | 139 | for il, n_out in enumerate(fc_layer_n): 140 | n_in = self.fc_in if il == 0 else fc_layer_n[il-1] 141 | fc.append(nn.Linear(n_in, n_out)) 142 | fc.append(getattr(nn, non_linearity)()) 143 | fc.append(nn.BatchNorm1d(n_out)) 144 | fc.append(nn.Dropout(dropout)) 145 | 146 | fc.append(nn.Linear(fc_layer_n[-1], num_classes)) 147 | 148 | self.tail = nn.Sequential(*fc) 149 | 150 | def forward(self, x): 151 | x = x['data'] 152 | x = x.unsqueeze(1) 153 | x = self.frontend(x) 154 | x = self.tail(x) 155 | 156 | return x 157 | 158 | 159 | class WaveformExtractor(nn.Module): 160 | def __init__(self, 161 | n_in_channels=1, 162 | non_linearity='LeakyReLU', 163 | latent_size=2048, 164 | parameterization='normal', 165 | kernel_0_size=251, 166 | cn_feature_n=[32, 64, 128, 256, 512], 167 | max_pool_kernel=6, 168 | kernel_size=7, 169 | layernorm_fusion=False): 170 | 171 | super(WaveformExtractor, self).__init__() 172 | # test_input_waveform = torch.zeros(3, n_in_channels, input_length) 173 | 174 | cn = [] 175 | for ilb, n_out in enumerate(cn_feature_n): 176 | if ilb == 0: 177 | if parameterization == 'normal': 178 | cn.append(nn.Conv1d(n_in_channels, n_out, kernel_size=kernel_0_size, padding=kernel_0_size//2)) 179 | elif parameterization == 'sinc': 180 | cn.append(SincConv_fast(n_out, kernel_size=kernel_0_size, sample_rate=48000, padding=kernel_0_size//2)) 181 | elif parameterization == 'leaf': 182 | cn.append(GaborConv1D(n_out*2, kernel_size=kernel_0_size, strides=1, padding=kernel_0_size//2, use_bias=False, 183 | input_shape=(None, None, 1), 184 | kernel_initializer=initializers.GaborInit, 185 | kernel_regularizer=None, 186 | name='complex_conv', 187 | trainable=True)) 188 | cn.append(SquaredModulus()) 189 | cn.append(nn.BatchNorm1d(n_out)) 190 | cn.append(getattr(nn, non_linearity)()) 191 | cn.append(nn.MaxPool1d(kernel_size=max_pool_kernel)) 192 | else: 193 | cn.append(nn.Conv1d( 194 | cn_feature_n[ilb-1], n_out, 195 | kernel_size=kernel_size, padding=kernel_size // 2) 196 | ), 197 | cn.append(nn.BatchNorm1d(n_out)) 198 | cn.append(getattr(nn, non_linearity)()) 199 | cn.append(nn.MaxPool1d(kernel_size=max_pool_kernel)) 200 | 201 | cn.append(nn.AdaptiveAvgPool1d(latent_size//cn_feature_n[-1])) 202 | 203 | if layernorm_fusion: 204 | cn.append(nn.Flatten()) 205 | cn.append(nn.LayerNorm(latent_size)) 206 | 207 | self.frontend = nn.Sequential(*cn) 208 | 209 | def forward(self, x): 210 | x = x.unsqueeze(1) 211 | x = self.frontend(x) 212 | return x 213 | 214 | 215 | class SpectrogramExtractor(nn.Module): 216 | def __init__(self, 217 | n_in_channels=1, 218 | non_linearity='LeakyReLU', 219 | latent_size=2048, 220 | cn_feature_n=[32, 64, 128, 256, 512], 221 | kernel_size=3, 222 | max_pool_kernel=(2,2), 223 | layernorm_fusion=False): 224 | super(SpectrogramExtractor, self).__init__() 225 | 226 | # test_input_spectrogram = torch.zeros(3, n_in_channels, n_bins, n_frames) 227 | 228 | cn = [] 229 | 230 | for ilb, n_out in enumerate(cn_feature_n): 231 | if ilb == 0: 232 | cn.append(nn.Conv2d(n_in_channels, n_out, kernel_size=kernel_size, padding=kernel_size//2)) 233 | cn.append(nn.BatchNorm2d(n_out)) 234 | cn.append(getattr(nn, non_linearity)()) 235 | cn.append(nn.MaxPool2d(kernel_size=max_pool_kernel)) 236 | else: 237 | cn.append(nn.Conv2d( 238 | cn_feature_n[ilb-1], n_out, 239 | kernel_size=kernel_size, padding=kernel_size//2) 240 | ), 241 | cn.append(nn.BatchNorm2d(n_out)) 242 | cn.append(getattr(nn, non_linearity)()) 243 | cn.append(nn.MaxPool2d(kernel_size=max_pool_kernel)) 244 | 245 | cn.append(nn.AdaptiveAvgPool2d((1,latent_size//cn_feature_n[-1]))) 246 | 247 | if layernorm_fusion: 248 | cn.append(nn.Flatten()) 249 | cn.append(nn.LayerNorm(latent_size)) 250 | 251 | self.cnn = nn.Sequential(*cn) 252 | 253 | def forward(self, x): 254 | x = x.unsqueeze(1) 255 | x = self.cnn(x) 256 | return x 257 | 258 | 259 | class MultiModalFusionClassifier(BaseModel): 260 | def init_weights(self, m): 261 | if type(m) == nn.Linear: 262 | m.weight.data.fill_(0.0) 263 | m.bias.data.fill_(0.0) 264 | 265 | def __init__(self, 266 | input_length, 267 | n_bins, 268 | n_frames, 269 | num_classes, 270 | fusion_method='sum', 271 | parameterization='normal', 272 | non_linearity='ReLU', 273 | dropout=0.3, 274 | fc_layer_n=[1024,512]): 275 | super().__init__() 276 | 277 | self.fusion_method = fusion_method 278 | 279 | self.wave = WaveformExtractor(parameterization=parameterization) 280 | self.spec = SpectrogramExtractor(1) 281 | 282 | test_input_waveform = torch.zeros(3, input_length) 283 | test_input_spectrogram = torch.zeros(3, n_bins, n_frames) 284 | 285 | # Build the FC head 286 | _, w_cn_features, w_cn_slices = self.wave(test_input_waveform).shape 287 | self.wave_fc_in = w_cn_features * w_cn_slices 288 | _, s_cn_channels, s_cn_bins, s_cn_frames = self.spec(test_input_spectrogram).shape 289 | self.spec_fc_in = s_cn_channels * s_cn_frames * s_cn_bins 290 | 291 | assert self.wave_fc_in == self.spec_fc_in 292 | 293 | self.flat = nn.Flatten() 294 | 295 | fc = [ 296 | nn.Dropout(dropout) 297 | ] 298 | 299 | if fusion_method == 'concat': 300 | self.wave_fc_in = self.wave_fc_in + self.spec_fc_in 301 | 302 | if fusion_method == 'mfb': 303 | self.mfb = MFB(self.wave_fc_in, self.spec_fc_in) 304 | 305 | if fusion_method == 'sum-attention-noinit': 306 | self.attention = nn.Sequential( 307 | nn.Linear(self.wave_fc_in*2, self.wave_fc_in*4), 308 | getattr(nn, non_linearity)(), 309 | nn.Linear(self.wave_fc_in*4, self.wave_fc_in//2), 310 | getattr(nn, non_linearity)(), 311 | nn.Linear(self.wave_fc_in//2, 2), 312 | nn.Softmax(dim=1) 313 | ) 314 | if fusion_method == 'sum-attention-init': 315 | self.attention = nn.Sequential( 316 | nn.Linear(self.wave_fc_in*2, self.wave_fc_in*4), 317 | getattr(nn, non_linearity)(), 318 | nn.Linear(self.wave_fc_in*4, self.wave_fc_in//2), 319 | getattr(nn, non_linearity)(), 320 | nn.Linear(self.wave_fc_in//2, 2), 321 | nn.Softmax(dim=1) 322 | ) 323 | self.init_weights(self.attention) 324 | 325 | for il, n_out in enumerate(fc_layer_n): 326 | n_in = self.wave_fc_in if il == 0 else fc_layer_n[il-1] 327 | fc.append(nn.Linear(n_in, n_out)) 328 | fc.append(getattr(nn, non_linearity)()) 329 | fc.append(nn.BatchNorm1d(n_out)) 330 | fc.append(nn.Dropout(dropout)) 331 | 332 | fc.append(nn.Linear(fc_layer_n[-1], num_classes)) 333 | 334 | self.tail = nn.Sequential(*fc) 335 | 336 | def forward(self, x): 337 | w = x['wave'] 338 | s = x['spec'] 339 | w = self.wave(w) 340 | s = self.spec(s) 341 | 342 | w_flat = self.flat(w) 343 | s_flat = self.flat(s) 344 | 345 | # Select fusion method 346 | if self.fusion_method == 'sum': 347 | combined_features = w_flat.add(s_flat) 348 | elif self.fusion_method == 'concat': 349 | combined_features = torch.cat((w_flat, s_flat), dim=1) 350 | elif self.fusion_method == 'mfb': 351 | combined_features, _ = self.mfb(w_flat.unsqueeze(1), s_flat.unsqueeze(1)) 352 | combined_features = combined_features.squeeze(1) 353 | elif self.fusion_method == 'sum-attention-noinit' or self.fusion_method == 'sum-attention-init': 354 | concat_features = torch.cat((w_flat, s_flat), dim=1) 355 | att = self.attention(concat_features) 356 | 357 | att_1, att_2 = torch.split(att, 1, dim=1) 358 | 359 | combined_features = (w_flat*att_1).add((s_flat*att_2)) 360 | 361 | res = self.tail(combined_features) 362 | 363 | return res 364 | 365 | 366 | class LeafClassifier(BaseModel): 367 | def __init__(self, 368 | n_classes, 369 | learn_pooling=False, 370 | learn_filters=False, 371 | n_in_channels=1, 372 | non_linearity='LeakyReLU', 373 | dropout=0.3, 374 | latent_size=1024, 375 | cn_feature_n=[32, 64, 128, 256, 512], 376 | kernel_size=3, 377 | max_pool_kernel=(2,2), 378 | fc_layer_n=[256]): 379 | super().__init__() 380 | 381 | # test_input_spectrogram = torch.zeros(3, n_in_channels, n_bins, n_frames) 382 | 383 | self.leaf = frontend.Leaf(n_filters=40, sample_rate=48000, window_len=42, window_stride=40, learn_pooling=learn_pooling, learn_filters=learn_filters, compression_fn=log_compression) 384 | 385 | cn = [] 386 | 387 | for ilb, n_out in enumerate(cn_feature_n): 388 | if ilb == 0: 389 | cn.append(nn.Conv2d(n_in_channels, n_out, kernel_size=kernel_size, padding=kernel_size//2)) 390 | cn.append(nn.BatchNorm2d(n_out)) 391 | cn.append(getattr(nn, non_linearity)()) 392 | cn.append(nn.MaxPool2d(kernel_size=max_pool_kernel)) 393 | else: 394 | cn.append(nn.Conv2d( 395 | cn_feature_n[ilb-1], n_out, 396 | kernel_size=kernel_size, padding=kernel_size//2) 397 | ), 398 | cn.append(nn.BatchNorm2d(n_out)) 399 | cn.append(getattr(nn, non_linearity)()) 400 | cn.append(nn.MaxPool2d(kernel_size=max_pool_kernel)) 401 | 402 | cn.append(nn.AdaptiveAvgPool2d((1,latent_size//cn_feature_n[-1]))) 403 | 404 | self.cnn = nn.Sequential(*cn) 405 | 406 | # _, cn_channels, cn_bins, cn_frames = self.cnn(test_input_spectrogram).shape 407 | # self.fc_in = cn_channels * cn_frames * cn_bins 408 | self.fc_in = latent_size 409 | 410 | fc = [ 411 | nn.Flatten(), 412 | nn.Dropout(dropout) 413 | ] 414 | 415 | for il, n_out in enumerate(fc_layer_n): 416 | n_in = self.fc_in if il == 0 else fc_layer_n[il-1] 417 | fc.append(nn.Linear(n_in, n_out)) 418 | fc.append(getattr(nn, non_linearity)()) 419 | fc.append(nn.BatchNorm1d(n_out)) 420 | fc.append(nn.Dropout(dropout)) 421 | 422 | fc.append(nn.Linear(fc_layer_n[-1], n_classes)) 423 | 424 | self.tail = nn.Sequential(*fc) 425 | 426 | def forward(self, x): 427 | x = x['data'] 428 | x = x.unsqueeze(1) 429 | x = self.leaf(x) 430 | x = x.unsqueeze(1) 431 | x = self.cnn(x) 432 | x = self.tail(x) 433 | return x 434 | 435 | 436 | class WaveformParameterizedClassifier(BaseModel): 437 | def __init__(self, 438 | input_length, 439 | num_classes=12, 440 | parameterization='normal', # 'normal', 'sinc', 'leaf' 441 | n_in_channels=1, 442 | non_linearity='LeakyReLU', 443 | dropout=0.3, 444 | latent_size=2048, 445 | cn_feature_n=[32, 64, 128, 256, 512], 446 | max_pool_kernel=8, 447 | kernel_0_size=7, 448 | kernel_size=7, 449 | fc_layer_n=[1024,512]): 450 | 451 | super().__init__() 452 | test_input_waveform = torch.zeros(3, n_in_channels, input_length) 453 | 454 | cn = [] 455 | for ilb, n_out in enumerate(cn_feature_n): 456 | if ilb == 0: 457 | if parameterization == 'normal': 458 | cn.append(nn.Conv1d(n_in_channels, n_out, kernel_size=kernel_0_size, padding=kernel_0_size//2)) 459 | elif parameterization == 'sinc': 460 | cn.append(SincConv_fast(n_out, kernel_size=kernel_0_size, sample_rate=48000, padding=kernel_0_size//2)) 461 | elif parameterization == 'leaf': 462 | cn.append(GaborConv1D(n_out*2, kernel_size=kernel_0_size, strides=1, padding=kernel_0_size//2, use_bias=False, 463 | input_shape=(None, None, 1), 464 | kernel_initializer=initializers.GaborInit, 465 | kernel_regularizer=None, 466 | name='complex_conv', 467 | trainable=True)) 468 | cn.append(SquaredModulus()) 469 | cn.append(nn.BatchNorm1d(n_out)) 470 | cn.append(getattr(nn, non_linearity)()) 471 | cn.append(nn.MaxPool1d(kernel_size=max_pool_kernel)) 472 | else: 473 | cn.append(nn.Conv1d( 474 | cn_feature_n[ilb-1], n_out, 475 | kernel_size=kernel_size, padding=kernel_size // 2) 476 | ), 477 | cn.append(nn.BatchNorm1d(n_out)) 478 | cn.append(getattr(nn, non_linearity)()) 479 | cn.append(nn.MaxPool1d(kernel_size=max_pool_kernel)) 480 | 481 | cn.append(nn.AdaptiveAvgPool1d(latent_size//cn_feature_n[-1])) 482 | 483 | self.frontend = nn.Sequential(*cn) 484 | 485 | # Build the FC head 486 | _, cn_features, cn_slices = self.frontend(test_input_waveform).shape 487 | 488 | self.fc_in = cn_features * cn_slices 489 | 490 | fc = [ 491 | nn.Flatten(), 492 | nn.Dropout(dropout) 493 | ] 494 | 495 | for il, n_out in enumerate(fc_layer_n): 496 | n_in = self.fc_in if il == 0 else fc_layer_n[il-1] 497 | fc.append(nn.Linear(n_in, n_out)) 498 | fc.append(getattr(nn, non_linearity)()) 499 | fc.append(nn.BatchNorm1d(n_out)) 500 | fc.append(nn.Dropout(dropout)) 501 | 502 | fc.append(nn.Linear(fc_layer_n[-1], num_classes)) 503 | 504 | self.tail = nn.Sequential(*fc) 505 | 506 | def forward(self, x): 507 | x = x['wave'] 508 | x = x.unsqueeze(1) 509 | x = self.frontend(x) 510 | x = self.tail(x) 511 | # print(self.frontend[0].low_hz_) 512 | 513 | return x 514 | 515 | 516 | class SmallMultiModalFusionClassifier(BaseModel): 517 | def init_weights(self, m): 518 | if type(m) == nn.Linear: 519 | m.weight.data.fill_(0.0) 520 | m.bias.data.fill_(0.0) 521 | 522 | def __init__(self, 523 | input_length, 524 | n_bins, 525 | n_frames, 526 | num_classes, 527 | fusion_method='sum', 528 | parameterization='normal', 529 | non_linearity='LeakyReLU', 530 | layernorm_fusion=False, 531 | dropout=0.3, 532 | fc_layer_n=[512,256], 533 | kernel_0_size=251): 534 | super().__init__() 535 | 536 | self.fusion_method = fusion_method 537 | self.layernorm_fusion = layernorm_fusion 538 | 539 | self.wave = WaveformExtractor( 540 | n_in_channels=1, 541 | non_linearity='LeakyReLU', 542 | latent_size=1024, 543 | parameterization=parameterization, 544 | kernel_0_size=kernel_0_size, 545 | cn_feature_n=[32, 64, 128, 256], 546 | max_pool_kernel=8, 547 | kernel_size=7, 548 | layernorm_fusion=layernorm_fusion) 549 | self.spec = SpectrogramExtractor( 550 | n_in_channels=1, 551 | non_linearity='LeakyReLU', 552 | latent_size=1024, 553 | cn_feature_n=[32, 64, 128, 256], 554 | kernel_size=3, 555 | max_pool_kernel=(2,2), 556 | layernorm_fusion=layernorm_fusion) 557 | 558 | test_input_waveform = torch.zeros(3, input_length) 559 | test_input_spectrogram = torch.zeros(3, n_bins, n_frames) 560 | 561 | # Build the FC head 562 | # print(self.wave(test_input_waveform).shape) 563 | _, w_cn_features, w_cn_slices = self.wave(test_input_waveform).shape 564 | self.wave_fc_in = w_cn_features * w_cn_slices 565 | _, s_cn_channels, s_cn_bins, s_cn_frames = self.spec(test_input_spectrogram).shape 566 | self.spec_fc_in = s_cn_channels * s_cn_frames * s_cn_bins 567 | 568 | assert self.wave_fc_in == self.spec_fc_in 569 | 570 | self.flat = nn.Flatten() 571 | 572 | fc = [ 573 | nn.Dropout(dropout) 574 | ] 575 | 576 | if fusion_method == 'concat': 577 | self.wave_fc_in = self.wave_fc_in + self.spec_fc_in 578 | 579 | if fusion_method == 'mfb': 580 | self.mfb = MFB(self.wave_fc_in, self.spec_fc_in, MFB_O=self.wave_fc_in, MFB_K=3) 581 | 582 | if fusion_method == 'sum-attention-noinit': 583 | self.attention = nn.Sequential( 584 | nn.Linear(self.wave_fc_in*2, self.wave_fc_in*4), 585 | getattr(nn, non_linearity)(), 586 | nn.Linear(self.wave_fc_in*4, self.wave_fc_in//2), 587 | getattr(nn, non_linearity)(), 588 | nn.Linear(self.wave_fc_in//2, 2), 589 | nn.Softmax(dim=1) 590 | ) 591 | if fusion_method == 'sum-attention-init': 592 | self.attention = nn.Sequential( 593 | nn.Linear(self.wave_fc_in*2, self.wave_fc_in*4), 594 | getattr(nn, non_linearity)(), 595 | nn.Linear(self.wave_fc_in*4, self.wave_fc_in//2), 596 | getattr(nn, non_linearity)(), 597 | nn.Linear(self.wave_fc_in//2, 2), 598 | nn.Softmax(dim=1) 599 | ) 600 | self.init_weights(self.attention) 601 | 602 | for il, n_out in enumerate(fc_layer_n): 603 | n_in = self.wave_fc_in if il == 0 else fc_layer_n[il-1] 604 | fc.append(nn.Linear(n_in, n_out)) 605 | fc.append(getattr(nn, non_linearity)()) 606 | fc.append(nn.BatchNorm1d(n_out)) 607 | fc.append(nn.Dropout(dropout)) 608 | 609 | fc.append(nn.Linear(fc_layer_n[-1], num_classes)) 610 | 611 | self.tail = nn.Sequential(*fc) 612 | 613 | def forward(self, x): 614 | 615 | w = x['wave'] 616 | s = x['spec'] 617 | w = self.wave(w) 618 | s = self.spec(s) 619 | 620 | # print(self.wave.frontend[0].low_hz_) 621 | if not self.layernorm_fusion: 622 | w_flat = self.flat(w) 623 | s_flat = self.flat(s) 624 | else: 625 | w_flat = w 626 | s_flat = s 627 | 628 | # Select fusion method 629 | if self.fusion_method == 'sum': 630 | combined_features = w_flat.add(s_flat) 631 | # combined_features = w_flat 632 | elif self.fusion_method == 'concat': 633 | w_flat = F.normalize(w_flat, p=2.0, dim=1, eps=1e-12) 634 | s_flat = F.normalize(s_flat, p=2.0, dim=1, eps=1e-12) 635 | 636 | combined_features = torch.cat((w_flat, s_flat), dim=1) 637 | elif self.fusion_method == 'mfb': 638 | combined_features, _ = self.mfb(w_flat.unsqueeze(1), s_flat.unsqueeze(1)) 639 | combined_features = combined_features.squeeze(1) 640 | elif self.fusion_method == 'sum-attention-noinit' or self.fusion_method == 'sum-attention-init': 641 | concat_features = torch.cat((w_flat, s_flat), dim=1) 642 | att = self.attention(concat_features) 643 | 644 | att_1, att_2 = torch.split(att, 1, dim=1) 645 | 646 | combined_features = (w_flat*att_1).add((s_flat*att_2)) 647 | 648 | res = self.tail(combined_features) 649 | 650 | return res 651 | 652 | 653 | class ScratchEnsemble(BaseModel): 654 | def __init__(self): 655 | super().__init__() 656 | self.spec_model_kwargs = { 657 | "n_classes": 10, 658 | "latent_size": 1024, 659 | "fc_layer_n": [512, 256], 660 | "cn_feature_n": [32, 64, 128, 256] 661 | } 662 | 663 | self.spec_model = SpectrogramClassifier(**self.spec_model_kwargs) 664 | 665 | self.wave_model_kwargs = { 666 | "input_length": 48000, 667 | "num_classes": 10, 668 | "kernel_0_size": 251, 669 | "parameterization": "sinc", 670 | "max_pool_kernel": 8, 671 | "latent_size": 1024, 672 | "cn_feature_n": [32, 64, 128, 256], 673 | "fc_layer_n": [512, 256] 674 | } 675 | 676 | self.wave_model = WaveformParameterizedClassifier(**self.wave_model_kwargs) 677 | 678 | def forward(self, x): 679 | self.wave = self.wave_model(x) 680 | 681 | self.spec = self.spec_model(x) 682 | 683 | output = torch.stack([self.wave, self.spec]).mean(dim=0) 684 | 685 | return output --------------------------------------------------------------------------------