├── ignite_trainer ├── version.py ├── README.md ├── __init__.py ├── _interfaces.py ├── _utils.py ├── _visdom.py └── _trainer.py ├── protocols ├── README.md ├── esc50 │ ├── adcnn5-esc50-cv1.json │ └── esresnet-esc50-cv1.json ├── us8k │ ├── adcnn5-us8k-cv1.json │ ├── lmcnet-us8k-cv1.json │ ├── esresnet-us8k-mono-cv1.json │ └── esresnet-us8k-stereo-cv1.json └── esc10 │ └── esresnet-esc10-cv1.json ├── main.py ├── requirements.txt ├── utils ├── __init__.py ├── lr_scheduler.py ├── transforms.py ├── datasets.py └── features.py ├── reproduced ├── README.md ├── TFNet │ └── README.md ├── lmcnet.py └── adcnn.py ├── model ├── attention.py └── esresnet.py └── README.md /ignite_trainer/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.2.5b4' 2 | -------------------------------------------------------------------------------- /protocols/README.md: -------------------------------------------------------------------------------- 1 | # Protocols 2 | 3 | Here are the JSON-files that describe configurations of experiments. 4 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.7 2 | 3 | import ignite_trainer as it 4 | 5 | if __name__ == '__main__': 6 | it.main() 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | librosa==0.7.2 2 | numpy==1.18.1 3 | pandas==1.0.3 4 | pytorch-ignite==0.3.0 5 | scikit-learn==0.22.1 6 | scipy==1.4.1 7 | termcolor==1.1.0 8 | torch==1.4.0 9 | torchvision==0.5.0 10 | tqdm==4.43.0 11 | visdom==0.1.8.9 12 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import datasets 2 | from . import features 3 | from . import lr_scheduler 4 | from . import transforms 5 | 6 | __all__ = [ 7 | 'datasets', 8 | 'features', 9 | 'lr_scheduler', 10 | 'transforms' 11 | ] 12 | -------------------------------------------------------------------------------- /reproduced/README.md: -------------------------------------------------------------------------------- 1 | # Reproduced 2 | 3 | Here are the models that were reproduced: 4 | 5 | 1. LMCNet model which is a part of the [TSCNN-DS model](https://www.mdpi.com/1424-8220/19/7/1733/pdf). 6 | 2. [TFNet model](https://arxiv.org/abs/1912.06808) 7 | 3. [ADCNN-5 model](https://arxiv.org/abs/1908.11219) (excluded from the paper) 8 | -------------------------------------------------------------------------------- /ignite_trainer/README.md: -------------------------------------------------------------------------------- 1 | # Ignite Trainer 2 | 3 | Ignite Trainer is a framework built on top of [PyTorch Ignite](https://github.com/pytorch/ignite) and [visdom](https://github.com/facebookresearch/visdom). 4 | It was developed to wrap training and logging of PyTorch models. 5 | The development is frozen due to the switch to [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning). 6 | -------------------------------------------------------------------------------- /ignite_trainer/__init__.py: -------------------------------------------------------------------------------- 1 | import os as _os 2 | import sys as _sys 3 | 4 | from ignite_trainer.version import __version__ 5 | from ._trainer import main, run 6 | from ._utils import load_class 7 | from ._interfaces import AbstractNet, AbstractTransform 8 | 9 | __all__ = [ 10 | '__version__', 11 | 'main', 'run', 'load_class', 12 | 'AbstractNet', 'AbstractTransform' 13 | ] 14 | 15 | _sys.path.extend([_os.getcwd()]) 16 | -------------------------------------------------------------------------------- /reproduced/TFNet/README.md: -------------------------------------------------------------------------------- 1 | ## Reproduced: TFNet 2 | 3 | The TFNet model's results were reproduced using 4 | [temporarily available source code (inactive)](https://github.com/WangHelin1997/TFNet-for-Environmental-Sound-Classification) 5 | provided by the authors of the following [paper](https://arxiv.org/abs/1912.06808). 6 | The original repository was forked and is now available [here](https://github.com/AndreyGuzhov/TFNet-for-Environmental-Sound-Classification). 7 | -------------------------------------------------------------------------------- /ignite_trainer/_interfaces.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import torch 3 | 4 | from typing import Tuple 5 | from typing import Union 6 | from typing import Callable 7 | from typing import Optional 8 | 9 | 10 | TensorPair = Tuple[torch.Tensor, torch.Tensor] 11 | TensorOrTwo = Union[torch.Tensor, TensorPair] 12 | 13 | 14 | class AbstractNet(abc.ABC, torch.nn.Module): 15 | 16 | @abc.abstractmethod 17 | def forward(self, x: torch.Tensor, y: Optional[torch.Tensor] = None) -> TensorOrTwo: 18 | pass 19 | 20 | @abc.abstractmethod 21 | def loss_fn(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 22 | pass 23 | 24 | @property 25 | @abc.abstractmethod 26 | def loss_fn_name(self) -> str: 27 | pass 28 | 29 | 30 | class AbstractTransform(abc.ABC, Callable[[torch.Tensor], torch.Tensor]): 31 | 32 | @abc.abstractmethod 33 | def __call__(self, x: torch.Tensor) -> torch.Tensor: 34 | pass 35 | 36 | def __repr__(self): 37 | return self.__class__.__name__ + '()' 38 | -------------------------------------------------------------------------------- /model/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from typing import Tuple 5 | 6 | 7 | class Attention2d(torch.nn.Module): 8 | 9 | def __init__(self, 10 | in_channels: int, 11 | out_channels: int, 12 | num_kernels: int, 13 | kernel_size: Tuple[int, int], 14 | padding_size: Tuple[int, int]): 15 | 16 | super(Attention2d, self).__init__() 17 | 18 | self.conv_depth = torch.nn.Conv2d( 19 | in_channels=in_channels, 20 | out_channels=in_channels * num_kernels, 21 | kernel_size=kernel_size, 22 | padding=padding_size, 23 | groups=in_channels 24 | ) 25 | self.conv_point = torch.nn.Conv2d( 26 | in_channels=in_channels * num_kernels, 27 | out_channels=out_channels, 28 | kernel_size=(1, 1) 29 | ) 30 | self.bn = torch.nn.BatchNorm2d(num_features=out_channels) 31 | self.activation = torch.nn.Sigmoid() 32 | 33 | def forward(self, x: torch.Tensor, size: torch.Size) -> torch.Tensor: 34 | x = F.adaptive_max_pool2d(x, size) 35 | x = self.conv_depth(x) 36 | x = self.conv_point(x) 37 | x = self.bn(x) 38 | x = self.activation(x) 39 | 40 | return x 41 | -------------------------------------------------------------------------------- /utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class WarmUpStepLR(torch.optim.lr_scheduler._LRScheduler): 5 | 6 | def __init__(self, 7 | optimizer: torch.optim.Optimizer, 8 | cold_epochs: int, 9 | warm_epochs: int, 10 | step_size: int, 11 | gamma: float = 0.1, 12 | last_epoch: int = -1): 13 | 14 | self.cold_epochs = cold_epochs 15 | self.warm_epochs = warm_epochs 16 | self.step_size = step_size 17 | self.gamma = gamma 18 | 19 | super(WarmUpStepLR, self).__init__(optimizer=optimizer, last_epoch=last_epoch) 20 | 21 | def get_lr(self): 22 | if self.last_epoch < self.cold_epochs: 23 | return [base_lr * 0.1 for base_lr in self.base_lrs] 24 | elif self.last_epoch < self.cold_epochs + self.warm_epochs: 25 | return [ 26 | base_lr * 0.1 + (1 + self.last_epoch - self.cold_epochs) * 0.9 * base_lr / self.warm_epochs 27 | for base_lr in self.base_lrs 28 | ] 29 | else: 30 | return [ 31 | base_lr * self.gamma ** ((self.last_epoch - self.cold_epochs - self.warm_epochs) // self.step_size) 32 | for base_lr in self.base_lrs 33 | ] 34 | 35 | 36 | class WarmUpExponentialLR(WarmUpStepLR): 37 | 38 | def __init__(self, 39 | optimizer: torch.optim.Optimizer, 40 | cold_epochs: int, 41 | warm_epochs: int, 42 | gamma: float = 0.1, 43 | last_epoch: int = -1): 44 | 45 | self.cold_epochs = cold_epochs 46 | self.warm_epochs = warm_epochs 47 | self.step_size = 1 48 | self.gamma = gamma 49 | 50 | super(WarmUpStepLR, self).__init__(optimizer=optimizer, last_epoch=last_epoch) -------------------------------------------------------------------------------- /protocols/esc50/adcnn5-esc50-cv1.json: -------------------------------------------------------------------------------- 1 | { 2 | "Visdom": { 3 | "host": null, 4 | "port": null, 5 | "env_path": null 6 | }, 7 | "Setup": { 8 | "name": "MFCC", 9 | "suffix": "CV1", 10 | "batch_train": 64, 11 | "batch_test": 64, 12 | "workers_train": 0, 13 | "workers_test": 0, 14 | "epochs": 500, 15 | "log_interval": 5, 16 | "saved_models_path": null 17 | }, 18 | "Model": { 19 | "class": "reproduced.adcnn.ADCNN5", 20 | "args": { 21 | "num_channels": 1, 22 | "n_fft": 1024, 23 | "hop_length": 512, 24 | "window": "blackmanharris", 25 | "num_classes": 50 26 | } 27 | }, 28 | "Optimizer": { 29 | "class": "torch.optim.Adam", 30 | "args": { 31 | "lr": 1e-2, 32 | "betas": [0.9, 0.999], 33 | "eps": 1e-7, 34 | "weight_decay": 1e-4 35 | } 36 | }, 37 | "Scheduler": { 38 | "class": "torch.optim.lr_scheduler.StepLR", 39 | "args": { 40 | "gamma": 0.1, 41 | "step_size": 100 42 | } 43 | }, 44 | "Dataset": { 45 | "class": "utils.datasets.ESC50", 46 | "args": { 47 | "root": "/path/to/ESC50", 48 | "sample_rate": 32000, 49 | "fold": 1, 50 | "training": {"key": "train", "yes": true, "no": false} 51 | } 52 | }, 53 | "Transforms": [ 54 | { 55 | "class": "utils.transforms.ToTensor1D", 56 | "args": {} 57 | }, 58 | { 59 | "class": "utils.transforms.RandomPadding", 60 | "args": {"out_len": 160000, "train": false} 61 | }, 62 | { 63 | "class": "utils.transforms.RandomCrop", 64 | "args": {"out_len": 160000, "train": false} 65 | } 66 | ], 67 | "Metrics": { 68 | "Performance": { 69 | "window_name": null, 70 | "x_label": "#Epochs", 71 | "y_label": "Accuracy", 72 | "width": 1890, 73 | "height": 416, 74 | "lines": [ 75 | { 76 | "line_label": "Val. Acc.", 77 | "class": "ignite.metrics.Accuracy", 78 | "args": {}, 79 | "is_checkpoint": true 80 | } 81 | ] 82 | } 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /protocols/us8k/adcnn5-us8k-cv1.json: -------------------------------------------------------------------------------- 1 | { 2 | "Visdom": { 3 | "host": null, 4 | "port": null, 5 | "env_path": null 6 | }, 7 | "Setup": { 8 | "name": "MFCC", 9 | "suffix": "CV1", 10 | "batch_train": 64, 11 | "batch_test": 64, 12 | "workers_train": 0, 13 | "workers_test": 0, 14 | "epochs": 500, 15 | "log_interval": 5, 16 | "saved_models_path": null 17 | }, 18 | "Model": { 19 | "class": "reproduced.adcnn.ADCNN5", 20 | "args": { 21 | "num_channels": 1, 22 | "n_fft": 1024, 23 | "hop_length": 512, 24 | "window": "hann", 25 | "num_classes": 10 26 | } 27 | }, 28 | "Optimizer": { 29 | "class": "torch.optim.Adam", 30 | "args": { 31 | "lr": 1e-2, 32 | "betas": [0.9, 0.999], 33 | "eps": 1e-7, 34 | "weight_decay": 1e-4 35 | } 36 | }, 37 | "Scheduler": { 38 | "class": "torch.optim.lr_scheduler.StepLR", 39 | "args": { 40 | "gamma": 0.1, 41 | "step_size": 100 42 | } 43 | }, 44 | "Dataset": { 45 | "class": "utils.datasets.UrbanSound8K", 46 | "args": { 47 | "root": "/path/to/UrbanSound8K", 48 | "sample_rate": 32000, 49 | "fold": 1, 50 | "random_split_seed": 42, 51 | "mono": true, 52 | "training": {"key": "train", "yes": true, "no": false} 53 | } 54 | }, 55 | "Transforms": [ 56 | { 57 | "class": "utils.transforms.ToTensor1D", 58 | "args": {} 59 | }, 60 | { 61 | "class": "utils.transforms.RandomPadding", 62 | "args": {"out_len": 128000, "train": false} 63 | }, 64 | { 65 | "class": "utils.transforms.RandomCrop", 66 | "args": {"out_len": 128000, "train": false} 67 | } 68 | ], 69 | "Metrics": { 70 | "Performance": { 71 | "window_name": null, 72 | "x_label": "#Epochs", 73 | "y_label": "Accuracy", 74 | "width": 1890, 75 | "height": 416, 76 | "lines": [ 77 | { 78 | "line_label": "Val. Acc.", 79 | "class": "ignite.metrics.Accuracy", 80 | "args": {}, 81 | "is_checkpoint": true 82 | } 83 | ] 84 | } 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /protocols/us8k/lmcnet-us8k-cv1.json: -------------------------------------------------------------------------------- 1 | { 2 | "Visdom": { 3 | "host": null, 4 | "port": null, 5 | "env_path": null 6 | }, 7 | "Setup": { 8 | "name": "LMC", 9 | "suffix": "CV1", 10 | "batch_train": 32, 11 | "batch_test": 32, 12 | "workers_train": 0, 13 | "workers_test": 0, 14 | "epochs": 300, 15 | "log_interval": 10, 16 | "saved_models_path": null 17 | }, 18 | "Model": { 19 | "class": "reproduced.lmcnet.LMCNet", 20 | "args": { 21 | "num_channels": 1, 22 | "num_classes": 10, 23 | "sample_rate": 22050, 24 | "norm": "inf", 25 | "n_fft": 8192, 26 | "hop_length": 2205, 27 | "win_length": 4410, 28 | "window": "hann", 29 | "n_mels": 60, 30 | "tuning": 0.0, 31 | "n_chroma": 7, 32 | "ctroct": 5.0, 33 | "octwidth": 2.0, 34 | "base_c": true, 35 | "freq": null, 36 | "fmin": 10.0, 37 | "fmax": null, 38 | "n_bands": 11, 39 | "quantile": 0.02, 40 | "linear": false 41 | } 42 | }, 43 | "Optimizer": { 44 | "class": "torch.optim.Adam", 45 | "args": { 46 | "lr": 1e-3, 47 | "betas": [0.9, 0.999], 48 | "eps": 1e-8, 49 | "weight_decay": 1e-3 50 | } 51 | }, 52 | "Dataset": { 53 | "class": "utils.datasets.UrbanSound8K", 54 | "args": { 55 | "root": "/path/to/UrbanSound8K", 56 | "sample_rate": 22050, 57 | "fold": 1, 58 | "random_split_seed": null, 59 | "mono": true, 60 | "training": {"key": "train", "yes": true, "no": false} 61 | } 62 | }, 63 | "Transforms": [ 64 | { 65 | "class": "utils.transforms.ToTensor1D", 66 | "args": {} 67 | }, 68 | { 69 | "class": "utils.transforms.RandomPadding", 70 | "args": {"out_len": 88200, "train": false} 71 | }, 72 | { 73 | "class": "utils.transforms.RandomCrop", 74 | "args": {"out_len": 88200, "train": false} 75 | } 76 | ], 77 | "Metrics": { 78 | "Performance": { 79 | "window_name": null, 80 | "x_label": "#Epochs", 81 | "y_label": "Accuracy", 82 | "width": 1890, 83 | "height": 416, 84 | "lines": [ 85 | { 86 | "line_label": "Val. Acc.", 87 | "class": "ignite.metrics.Accuracy", 88 | "args": {}, 89 | "is_checkpoint": true 90 | } 91 | ] 92 | } 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ESResNet 2 | ## Environmental Sound Classification Based on Visual Domain Models 3 | 4 | This repository contains implementation of the models described in the paper [arXiv:2004.07301](https://arxiv.org/abs/2004.07301) (submitted to ICPR 2020). 5 | 6 | ### Abstract 7 | Environmental Sound Classification (ESC) is an active research area in the audio domain and has seen a lot of progress in the past years. However, many of the existing approaches achieve high accuracy by relying on domain-specific features and architectures, making it harder to benefit from advances in other fields (e.g., the image domain). Additionally, some of the past successes have been attributed to a discrepancy of how results are evaluated (i.e., on unofficial splits of the UrbanSound8K (US8K) dataset), distorting the overall progression of the field. 8 | The contribution of this paper is twofold. First, we present a model that is inherently compatible with mono and stereo sound inputs. Our model is based on simple log-power Short-Time Fourier Transform (STFT) spectrograms and combines them with several well-known approaches from the image domain (i.e., ResNet, Siamese-like networks and attention). We investigate the influence of cross-domain pre-training, architectural changes, and evaluate our model on standard datasets. We find that our model out-performs all previously known approaches in a fair comparison by achieving accuracies of 97.0 % (ESC-10), 91.5 % (ESC-50) and 84.2 % / 85.4 % (US8K mono / stereo). 9 | Second, we provide a comprehensive overview of the actual state of the field, by differentiating several previously reported results on the US8K dataset between official or unofficial splits. For better reproducibility, our code (including any re-implementations) is made available. 10 | 11 | ### How to run the model 12 | 13 | The required Python version is >= 3.7. 14 | 15 | #### ESResNet 16 | 17 | ##### On the [ESC-10](https://github.com/karolpiczak/ESC-50) dataset 18 | python main.py --config protocols/esc10/esresnet-esc10-cv1.json --Dataset.args.root /path/to/ESC10 19 | 20 | ##### On the [ESC-50](https://github.com/karolpiczak/ESC-50) dataset 21 | python main.py --config protocols/esc50/esresnet-esc50-cv1.json --Dataset.args.root /path/to/ESC50 22 | 23 | ##### On the [UrbanSound8K](https://urbansounddataset.weebly.com/) dataset (stereo) 24 | python main.py --config protocols/us8k/esresnet-us8k-stereo-cv1.json --Dataset.args.root /path/to/UrbanSound8K 25 | 26 | #### Reproduced results 27 | 28 | ##### [LMCNet](https://www.mdpi.com/1424-8220/19/7/1733/pdf) on the [UrbanSound8K](https://urbansounddataset.weebly.com/) dataset 29 | python main.py --config protocols/us8k/lmcnet-us8k-cv1.json --Dataset.args.root /path/to/UrbanSound8K 30 | -------------------------------------------------------------------------------- /protocols/esc10/esresnet-esc10-cv1.json: -------------------------------------------------------------------------------- 1 | { 2 | "Visdom": { 3 | "host": null, 4 | "port": null, 5 | "env_path": null 6 | }, 7 | "Setup": { 8 | "name": "STFT", 9 | "suffix": "CV1", 10 | "batch_train": 16, 11 | "batch_test": 16, 12 | "workers_train": 4, 13 | "workers_test": 4, 14 | "epochs": 300, 15 | "log_interval": 10, 16 | "saved_models_path": null 17 | }, 18 | "Model": { 19 | "class": "model.esresnet.ESResNet", 20 | "args": { 21 | "n_fft": 2048, 22 | "hop_length": 561, 23 | "win_length": 1654, 24 | "window": "blackmanharris", 25 | "normalized": true, 26 | "onesided": true, 27 | "spec_height": -1, 28 | "spec_width": -1, 29 | "num_classes": 10, 30 | "pretrained": true, 31 | "lock_pretrained": false 32 | } 33 | }, 34 | "Optimizer": { 35 | "class": "torch.optim.Adam", 36 | "args": { 37 | "lr": 2.5e-4, 38 | "betas": [0.9, 0.999], 39 | "eps": 1e-8, 40 | "weight_decay": 5e-4 41 | } 42 | }, 43 | "Scheduler": { 44 | "class": "utils.lr_scheduler.WarmUpExponentialLR", 45 | "args": { 46 | "gamma": 0.985, 47 | "cold_epochs": 5, 48 | "warm_epochs": 10 49 | } 50 | }, 51 | "Dataset": { 52 | "class": "utils.datasets.ESC10", 53 | "args": { 54 | "root": "/path/to/ESC10", 55 | "sample_rate": 44100, 56 | "fold": 1, 57 | "training": {"key": "train", "yes": true, "no": false} 58 | } 59 | }, 60 | "Transforms": [ 61 | { 62 | "class": "utils.transforms.ToTensor1D", 63 | "args": {} 64 | }, 65 | { 66 | "class": "utils.transforms.RandomFlip", 67 | "args": {"p": 0.5}, 68 | "test": false 69 | }, 70 | { 71 | "class": "utils.transforms.RandomScale", 72 | "args": {"max_scale": 1.25}, 73 | "test": false 74 | }, 75 | { 76 | "class": "utils.transforms.RandomPadding", 77 | "args": {"out_len": 220500}, 78 | "test": false 79 | }, 80 | { 81 | "class": "utils.transforms.RandomCrop", 82 | "args": {"out_len": 220500}, 83 | "test": false 84 | }, 85 | { 86 | "class": "utils.transforms.RandomPadding", 87 | "args": {"out_len": 220500, "train": false}, 88 | "train": false 89 | }, 90 | { 91 | "class": "utils.transforms.RandomCrop", 92 | "args": {"out_len": 220500, "train": false}, 93 | "train": false 94 | } 95 | ], 96 | "Metrics": { 97 | "Performance": { 98 | "window_name": null, 99 | "x_label": "#Epochs", 100 | "y_label": "Accuracy", 101 | "width": 1890, 102 | "height": 416, 103 | "lines": [ 104 | { 105 | "line_label": "Val. Acc.", 106 | "class": "ignite.metrics.Accuracy", 107 | "args": {}, 108 | "is_checkpoint": true 109 | } 110 | ] 111 | } 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /protocols/esc50/esresnet-esc50-cv1.json: -------------------------------------------------------------------------------- 1 | { 2 | "Visdom": { 3 | "host": null, 4 | "port": null, 5 | "env_path": null 6 | }, 7 | "Setup": { 8 | "name": "STFT", 9 | "suffix": "CV1", 10 | "batch_train": 16, 11 | "batch_test": 16, 12 | "workers_train": 4, 13 | "workers_test": 4, 14 | "epochs": 300, 15 | "log_interval": 10, 16 | "saved_models_path": null 17 | }, 18 | "Model": { 19 | "class": "model.esresnet.ESResNet", 20 | "args": { 21 | "n_fft": 2048, 22 | "hop_length": 561, 23 | "win_length": 1654, 24 | "window": "blackmanharris", 25 | "normalized": true, 26 | "onesided": true, 27 | "spec_height": -1, 28 | "spec_width": -1, 29 | "num_classes": 50, 30 | "pretrained": true, 31 | "lock_pretrained": false 32 | } 33 | }, 34 | "Optimizer": { 35 | "class": "torch.optim.Adam", 36 | "args": { 37 | "lr": 2.5e-4, 38 | "betas": [0.9, 0.999], 39 | "eps": 1e-8, 40 | "weight_decay": 5e-4 41 | } 42 | }, 43 | "Scheduler": { 44 | "class": "utils.lr_scheduler.WarmUpExponentialLR", 45 | "args": { 46 | "gamma": 0.985, 47 | "cold_epochs": 5, 48 | "warm_epochs": 10 49 | } 50 | }, 51 | "Dataset": { 52 | "class": "utils.datasets.ESC50", 53 | "args": { 54 | "root": "/path/to/ESC50", 55 | "sample_rate": 44100, 56 | "fold": 1, 57 | "training": {"key": "train", "yes": true, "no": false} 58 | } 59 | }, 60 | "Transforms": [ 61 | { 62 | "class": "utils.transforms.ToTensor1D", 63 | "args": {} 64 | }, 65 | { 66 | "class": "utils.transforms.RandomFlip", 67 | "args": {"p": 0.5}, 68 | "test": false 69 | }, 70 | { 71 | "class": "utils.transforms.RandomScale", 72 | "args": {"max_scale": 1.25}, 73 | "test": false 74 | }, 75 | { 76 | "class": "utils.transforms.RandomPadding", 77 | "args": {"out_len": 220500}, 78 | "test": false 79 | }, 80 | { 81 | "class": "utils.transforms.RandomCrop", 82 | "args": {"out_len": 220500}, 83 | "test": false 84 | }, 85 | { 86 | "class": "utils.transforms.RandomPadding", 87 | "args": {"out_len": 220500, "train": false}, 88 | "train": false 89 | }, 90 | { 91 | "class": "utils.transforms.RandomCrop", 92 | "args": {"out_len": 220500, "train": false}, 93 | "train": false 94 | } 95 | ], 96 | "Metrics": { 97 | "Performance": { 98 | "window_name": null, 99 | "x_label": "#Epochs", 100 | "y_label": "Accuracy", 101 | "width": 1890, 102 | "height": 416, 103 | "lines": [ 104 | { 105 | "line_label": "Val. Acc.", 106 | "class": "ignite.metrics.Accuracy", 107 | "args": {}, 108 | "is_checkpoint": true 109 | } 110 | ] 111 | } 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /protocols/us8k/esresnet-us8k-mono-cv1.json: -------------------------------------------------------------------------------- 1 | { 2 | "Visdom": { 3 | "host": null, 4 | "port": null, 5 | "env_path": null 6 | }, 7 | "Setup": { 8 | "name": "STFT", 9 | "suffix": "CV1", 10 | "batch_train": 16, 11 | "batch_test": 16, 12 | "workers_train": 2, 13 | "workers_test": 2, 14 | "epochs": 300, 15 | "log_interval": 50, 16 | "saved_models_path": null 17 | }, 18 | "Model": { 19 | "class": "model.esresnet.ESResNet", 20 | "args": { 21 | "n_fft": 2048, 22 | "hop_length": 561, 23 | "win_length": 1654, 24 | "window": "blackmanharris", 25 | "normalized": true, 26 | "onesided": true, 27 | "spec_height": -1, 28 | "spec_width": -1, 29 | "num_classes": 10, 30 | "pretrained": true, 31 | "lock_pretrained": false 32 | } 33 | }, 34 | "Optimizer": { 35 | "class": "torch.optim.Adam", 36 | "args": { 37 | "lr": 2.5e-4, 38 | "betas": [0.9, 0.999], 39 | "eps": 1e-8, 40 | "weight_decay": 5e-4 41 | } 42 | }, 43 | "Scheduler": { 44 | "class": "utils.lr_scheduler.WarmUpExponentialLR", 45 | "args": { 46 | "gamma": 0.985, 47 | "cold_epochs": 5, 48 | "warm_epochs": 10 49 | } 50 | }, 51 | "Dataset": { 52 | "class": "utils.datasets.UrbanSound8K", 53 | "args": { 54 | "root": "/path/to/UrbanSound8K", 55 | "sample_rate": 44100, 56 | "fold": 1, 57 | "random_split_seed": null, 58 | "mono": true, 59 | "training": {"key": "train", "yes": true, "no": false} 60 | } 61 | }, 62 | "Transforms": [ 63 | { 64 | "class": "utils.transforms.ToTensor1D", 65 | "args": {} 66 | }, 67 | { 68 | "class": "utils.transforms.RandomFlip", 69 | "args": {"p": 0.5}, 70 | "test": false 71 | }, 72 | { 73 | "class": "utils.transforms.RandomScale", 74 | "args": {"max_scale": 1.25}, 75 | "test": false 76 | }, 77 | { 78 | "class": "utils.transforms.RandomPadding", 79 | "args": {"out_len": 176400}, 80 | "test": false 81 | }, 82 | { 83 | "class": "utils.transforms.RandomCrop", 84 | "args": {"out_len": 176400}, 85 | "test": false 86 | }, 87 | { 88 | "class": "utils.transforms.RandomCrop", 89 | "args": {"out_len": 176400, "train": false}, 90 | "train": false 91 | }, 92 | { 93 | "class": "utils.transforms.RandomPadding", 94 | "args": {"out_len": 176400, "train": false}, 95 | "train": false 96 | } 97 | ], 98 | "Metrics": { 99 | "Performance": { 100 | "window_name": null, 101 | "x_label": "#Epochs", 102 | "y_label": "Accuracy", 103 | "width": 1890, 104 | "height": 416, 105 | "lines": [ 106 | { 107 | "line_label": "Val. Acc.", 108 | "class": "ignite.metrics.Accuracy", 109 | "args": {}, 110 | "is_checkpoint": true 111 | } 112 | ] 113 | } 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /protocols/us8k/esresnet-us8k-stereo-cv1.json: -------------------------------------------------------------------------------- 1 | { 2 | "Visdom": { 3 | "host": null, 4 | "port": null, 5 | "env_path": null 6 | }, 7 | "Setup": { 8 | "name": "STFT", 9 | "suffix": "CV1", 10 | "batch_train": 16, 11 | "batch_test": 16, 12 | "workers_train": 2, 13 | "workers_test": 2, 14 | "epochs": 300, 15 | "log_interval": 50, 16 | "saved_models_path": null 17 | }, 18 | "Model": { 19 | "class": "model.esresnet.ESResNt", 20 | "args": { 21 | "n_fft": 2048, 22 | "hop_length": 561, 23 | "win_length": 1654, 24 | "window": "blackmanharris", 25 | "normalized": true, 26 | "onesided": true, 27 | "spec_height": -1, 28 | "spec_width": -1, 29 | "num_classes": 10, 30 | "pretrained": true, 31 | "lock_pretrained": false 32 | } 33 | }, 34 | "Optimizer": { 35 | "class": "torch.optim.Adam", 36 | "args": { 37 | "lr": 2.5e-4, 38 | "betas": [0.9, 0.999], 39 | "eps": 1e-8, 40 | "weight_decay": 5e-4 41 | } 42 | }, 43 | "Scheduler": { 44 | "class": "utils.lr_scheduler.WarmUpExponentialLR", 45 | "args": { 46 | "gamma": 0.985, 47 | "cold_epochs": 5, 48 | "warm_epochs": 10 49 | } 50 | }, 51 | "Dataset": { 52 | "class": "utils.datasets.UrbanSound8K", 53 | "args": { 54 | "root": "/path/to/UrbanSound8K", 55 | "sample_rate": 44100, 56 | "fold": 1, 57 | "random_split_seed": null, 58 | "mono": false, 59 | "training": {"key": "train", "yes": true, "no": false} 60 | } 61 | }, 62 | "Transforms": [ 63 | { 64 | "class": "utils.transforms.ToTensor1D", 65 | "args": {} 66 | }, 67 | { 68 | "class": "utils.transforms.RandomFlip", 69 | "args": {"p": 0.5}, 70 | "test": false 71 | }, 72 | { 73 | "class": "utils.transforms.RandomScale", 74 | "args": {"max_scale": 1.25}, 75 | "test": false 76 | }, 77 | { 78 | "class": "utils.transforms.RandomPadding", 79 | "args": {"out_len": 176400}, 80 | "test": false 81 | }, 82 | { 83 | "class": "utils.transforms.RandomCrop", 84 | "args": {"out_len": 176400}, 85 | "test": false 86 | }, 87 | { 88 | "class": "utils.transforms.RandomCrop", 89 | "args": {"out_len": 176400, "train": false}, 90 | "train": false 91 | }, 92 | { 93 | "class": "utils.transforms.RandomPadding", 94 | "args": {"out_len": 176400, "train": false}, 95 | "train": false 96 | } 97 | ], 98 | "Metrics": { 99 | "Performance": { 100 | "window_name": null, 101 | "x_label": "#Epochs", 102 | "y_label": "Accuracy", 103 | "width": 1890, 104 | "height": 416, 105 | "lines": [ 106 | { 107 | "line_label": "Val. Acc.", 108 | "class": "ignite.metrics.Accuracy", 109 | "args": {}, 110 | "is_checkpoint": true 111 | } 112 | ] 113 | } 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /utils/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torchvision as tv 5 | 6 | import ignite_trainer as it 7 | 8 | 9 | def scale(old_value, old_min, old_max, new_min, new_max): 10 | old_range = (old_max - old_min) 11 | new_range = (new_max - new_min) 12 | new_value = (((old_value - old_min) * new_range) / old_range) + new_min 13 | 14 | return new_value 15 | 16 | 17 | class ToTensor1D(tv.transforms.ToTensor): 18 | 19 | def __call__(self, tensor: np.ndarray): 20 | tensor_2d = super(ToTensor1D, self).__call__(tensor[..., np.newaxis]) 21 | 22 | return tensor_2d.squeeze_(0) 23 | 24 | 25 | class RandomFlip(it.AbstractTransform): 26 | 27 | def __init__(self, p: float = 0.5): 28 | super(RandomFlip, self).__init__() 29 | 30 | self.p = p 31 | 32 | def __call__(self, x: torch.Tensor) -> torch.Tensor: 33 | if x.dim() > 2: 34 | flip_mask = torch.rand(x.shape[0], device=x.device) <= self.p 35 | x[flip_mask] = x[flip_mask].flip(-1) 36 | else: 37 | if torch.rand(1) <= self.p: 38 | x = x.flip(0) 39 | 40 | return x 41 | 42 | 43 | class RandomScale(it.AbstractTransform): 44 | 45 | def __init__(self, max_scale: float = 1.25): 46 | super(RandomScale, self).__init__() 47 | 48 | self.max_scale = max_scale 49 | 50 | @staticmethod 51 | def random_scale(max_scale: float, signal: torch.Tensor) -> torch.Tensor: 52 | scaling = np.power(max_scale, np.random.uniform(-1, 1)) 53 | output_size = int(signal.shape[-1] * scaling) 54 | ref = torch.arange(output_size, device=signal.device, dtype=signal.dtype).div_(scaling) 55 | 56 | ref1 = ref.clone().type(torch.int64) 57 | ref2 = torch.min(ref1 + 1, torch.full_like(ref1, signal.shape[-1] - 1, dtype=torch.int64)) 58 | r = ref - ref1.type(ref.type()) 59 | scaled_signal = signal[..., ref1] * (1 - r) + signal[..., ref2] * r 60 | 61 | return scaled_signal 62 | 63 | def __call__(self, x: torch.Tensor) -> torch.Tensor: 64 | return self.random_scale(self.max_scale, x) 65 | 66 | 67 | class RandomCrop(it.AbstractTransform): 68 | 69 | def __init__(self, out_len: int = 44100, train: bool = True): 70 | super(RandomCrop, self).__init__() 71 | 72 | self.out_len = out_len 73 | self.train = train 74 | 75 | def random_crop(self, signal: torch.Tensor) -> torch.Tensor: 76 | if self.train: 77 | left = np.random.randint(0, signal.shape[-1] - self.out_len) 78 | else: 79 | left = int(round(0.5 * (signal.shape[-1] - self.out_len))) 80 | 81 | orig_std = signal.float().std() * 0.5 82 | output = signal[..., left:left + self.out_len] 83 | 84 | out_std = output.float().std() 85 | if out_std < orig_std: 86 | output = signal[..., :self.out_len] 87 | 88 | new_out_std = output.float().std() 89 | if orig_std > new_out_std > out_std: 90 | output = signal[..., -self.out_len:] 91 | 92 | return output 93 | 94 | def __call__(self, x: torch.Tensor) -> torch.Tensor: 95 | return self.random_crop(x) if x.shape[-1] > self.out_len else x 96 | 97 | 98 | class RandomPadding(it.AbstractTransform): 99 | 100 | def __init__(self, out_len: int = 88200, train: bool = True): 101 | super(RandomPadding, self).__init__() 102 | 103 | self.out_len = out_len 104 | self.train = train 105 | 106 | def random_pad(self, signal: torch.Tensor) -> torch.Tensor: 107 | if self.train: 108 | left = np.random.randint(0, self.out_len - signal.shape[-1]) 109 | else: 110 | left = int(round(0.5 * (self.out_len - signal.shape[-1]))) 111 | 112 | right = self.out_len - (left + signal.shape[-1]) 113 | 114 | pad_value_left = signal[..., 0].float().mean().to(signal.dtype) 115 | pad_value_right = signal[..., -1].float().mean().to(signal.dtype) 116 | output = torch.cat(( 117 | torch.zeros(signal.shape[:-1] + (left,), dtype=signal.dtype, device=signal.device).fill_(pad_value_left), 118 | signal, 119 | torch.zeros(signal.shape[:-1] + (right,), dtype=signal.dtype, device=signal.device).fill_(pad_value_right) 120 | ), dim=-1) 121 | 122 | return output 123 | 124 | def __call__(self, x: torch.Tensor) -> torch.Tensor: 125 | return self.random_pad(x) if x.shape[-1] < self.out_len else x 126 | -------------------------------------------------------------------------------- /reproduced/lmcnet.py: -------------------------------------------------------------------------------- 1 | import scipy.signal as sps 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | import ignite_trainer as it 7 | 8 | from utils import features 9 | 10 | from typing import Tuple 11 | from typing import Union 12 | from typing import Optional 13 | 14 | 15 | class LMCNet(it.AbstractNet): 16 | 17 | def __init__(self, 18 | num_channels: int = 1, 19 | num_classes: int = 10, 20 | sample_rate: int = 44100, 21 | norm: Union[str, float] = 'inf', 22 | n_fft: int = 2048, 23 | hop_length: int = 1024, 24 | win_length: int = 2048, 25 | window: str = 'hann', 26 | n_mels: int = 128, 27 | tuning: float = 0.0, 28 | n_chroma: int = 12, 29 | ctroct: float = 5.0, 30 | octwidth: float = 2.0, 31 | base_c: bool = True, 32 | freq: Optional[torch.Tensor] = None, 33 | fmin: float = 200.0, 34 | fmax: Optional[float] = None, 35 | n_bands: int = 6, 36 | quantile: float = 0.02, 37 | linear: bool = False): 38 | 39 | super(LMCNet, self).__init__() 40 | 41 | norm = float(norm) 42 | 43 | self.lmc = features.LMC( 44 | sample_rate=sample_rate, 45 | norm=norm, 46 | n_fft=n_fft, 47 | n_mels=n_mels, 48 | tuning=tuning, 49 | n_chroma=n_chroma, 50 | ctroct=ctroct, 51 | octwidth=octwidth, 52 | base_c=base_c, 53 | freq=freq, 54 | fmin=fmin, 55 | fmax=fmax, 56 | n_bands=n_bands, 57 | quantile=quantile, 58 | linear=linear 59 | ) 60 | 61 | self.n_fft = n_fft 62 | self.win_length = win_length 63 | self.hop_length = hop_length 64 | 65 | window_buf = sps.get_window(window, win_length, False) 66 | self.register_buffer('window', torch.from_numpy(window_buf).to(torch.get_default_dtype())) 67 | 68 | self.conv1 = torch.nn.Conv2d( 69 | in_channels=num_channels, 70 | out_channels=32, 71 | kernel_size=(3, 3), 72 | stride=(1, 1), 73 | padding=(1, 1) 74 | ) 75 | self.bn1 = torch.nn.BatchNorm2d(num_features=self.conv1.out_channels) 76 | self.activation1 = torch.nn.ReLU() 77 | 78 | self.conv2 = torch.nn.Conv2d( 79 | in_channels=self.conv1.out_channels, 80 | out_channels=self.conv1.out_channels, 81 | kernel_size=(3, 3), 82 | stride=(1, 1), 83 | padding=(1, 1) 84 | ) 85 | self.bn2 = torch.nn.BatchNorm2d(num_features=self.conv2.out_channels) 86 | self.activation2 = torch.nn.ReLU() 87 | self.pool2 = torch.nn.MaxPool2d(kernel_size=(2, 2), padding=(1, 1)) 88 | 89 | self.conv3 = torch.nn.Conv2d( 90 | in_channels=self.conv2.out_channels, 91 | out_channels=64, 92 | kernel_size=(3, 3), 93 | stride=(1, 1), 94 | padding=(1, 1) 95 | ) 96 | self.bn3 = torch.nn.BatchNorm2d(num_features=self.conv3.out_channels) 97 | self.activation3 = torch.nn.ReLU() 98 | 99 | self.conv4 = torch.nn.Conv2d( 100 | in_channels=self.conv3.out_channels, 101 | out_channels=64, 102 | kernel_size=(3, 3), 103 | stride=(1, 1), 104 | padding=(1, 1) 105 | ) 106 | self.bn4 = torch.nn.BatchNorm2d(num_features=self.conv4.out_channels) 107 | self.activation4 = torch.nn.ReLU() 108 | self.pool4 = torch.nn.MaxPool2d(kernel_size=(2, 2), padding=(1, 1)) 109 | 110 | self.fc1 = torch.nn.Linear(in_features=11 * 22 * self.conv4.out_channels, out_features=1024) 111 | self.activation5 = torch.nn.Sigmoid() 112 | 113 | self.fc2 = torch.nn.Linear(in_features=self.fc1.out_features, out_features=num_classes) 114 | 115 | def spectrogram(self, x: torch.Tensor) -> torch.Tensor: 116 | spectrogram = torch.stft( 117 | x.view(x.shape[0], -1), 118 | n_fft=self.n_fft, 119 | hop_length=self.hop_length, 120 | win_length=self.win_length, 121 | window=self.window, 122 | normalized=True 123 | ) 124 | spectrogram = spectrogram[..., 0] ** 2 + spectrogram[..., 1] ** 2 125 | spectrogram = spectrogram.view(x.shape[0], -1, *spectrogram.shape[1:]) 126 | spectrogram = torch.where(spectrogram == 0.0, spectrogram + 1e-10, spectrogram) 127 | 128 | return spectrogram 129 | 130 | def forward(self, 131 | x: torch.Tensor, 132 | y: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 133 | 134 | x = self.spectrogram(x) 135 | x = self.lmc(x) 136 | 137 | x = self.conv1(x) 138 | x = self.bn1(x) 139 | x = self.activation1(x) 140 | 141 | x = F.dropout2d(x, p=0.5, training=self.training) 142 | 143 | x = self.conv2(x) 144 | x = self.bn2(x) 145 | x = self.activation2(x) 146 | x = self.pool2(x) 147 | 148 | x = self.conv3(x) 149 | x = self.bn3(x) 150 | x = self.activation3(x) 151 | 152 | x = F.dropout2d(x, p=0.5, training=self.training) 153 | 154 | x = self.conv4(x) 155 | x = self.bn4(x) 156 | x = self.activation4(x) 157 | x = self.pool4(x) 158 | 159 | x = x.view(x.shape[0], -1) 160 | 161 | x = F.dropout(x, p=0.5, training=self.training) 162 | 163 | x = self.fc1(x) 164 | x = self.activation5(x) 165 | y_pred = self.fc2(x) 166 | 167 | loss = None 168 | if y is not None: 169 | loss = self.loss_fn(y_pred, y).mean() 170 | 171 | return y_pred if loss is None else (y_pred, loss) 172 | 173 | def loss_fn(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 174 | loss_pred = F.cross_entropy(y_pred, y) 175 | 176 | return loss_pred 177 | 178 | @property 179 | def loss_fn_name(self) -> str: 180 | return 'Cross Entropy' 181 | -------------------------------------------------------------------------------- /ignite_trainer/_utils.py: -------------------------------------------------------------------------------- 1 | import io 2 | import sys 3 | import json 4 | import tqdm 5 | import datetime 6 | import importlib 7 | import contextlib 8 | 9 | import numpy as np 10 | 11 | import torch 12 | 13 | from collections import OrderedDict 14 | 15 | from typing import Any 16 | from typing import Dict 17 | from typing import List 18 | from typing import Type 19 | from typing import Tuple 20 | from typing import Union 21 | from typing import Callable 22 | from typing import Optional 23 | 24 | 25 | @contextlib.contextmanager 26 | def tqdm_stdout(orig_stdout: Optional[io.TextIOBase] = None): 27 | 28 | class DummyFile(object): 29 | file = None 30 | 31 | def __init__(self, file): 32 | self.file = file 33 | 34 | def write(self, x): 35 | if len(x.rstrip()) > 0: 36 | tqdm.tqdm.write(x, file=self.file) 37 | 38 | def flush(self): 39 | return getattr(self.file, 'flush', lambda: None)() 40 | 41 | orig_out_err = sys.stdout, sys.stderr 42 | 43 | try: 44 | if orig_stdout is None: 45 | sys.stdout, sys.stderr = map(DummyFile, orig_out_err) 46 | yield orig_out_err[0] 47 | else: 48 | yield orig_stdout 49 | except Exception as exc: 50 | raise exc 51 | finally: 52 | sys.stdout, sys.stderr = orig_out_err 53 | 54 | 55 | def load_class(package_name: str, class_name: Optional[str] = None) -> Type: 56 | if class_name is None: 57 | package_name, class_name = package_name.rsplit('.', 1) 58 | 59 | importlib.invalidate_caches() 60 | 61 | package = importlib.import_module(package_name) 62 | cls = getattr(package, class_name) 63 | 64 | return cls 65 | 66 | 67 | def arg_selector(arg_cmd: Optional[Any], arg_conf: Optional[Any], arg_const: Any) -> Any: 68 | if arg_cmd is not None: 69 | return arg_cmd 70 | else: 71 | if arg_conf is not None: 72 | return arg_conf 73 | else: 74 | return arg_const 75 | 76 | 77 | def get_data_loaders(Dataset: Type, 78 | dataset_args: Dict[str, Any], 79 | batch_train: int = 64, 80 | batch_test: int = 1024, 81 | workers_train: int = 0, 82 | workers_test: int = 0, 83 | transforms_train: Optional[Callable[ 84 | [Union[np.ndarray, torch.Tensor]], Union[np.ndarray, torch.Tensor] 85 | ]] = None, 86 | transforms_test: Optional[Callable[ 87 | [Union[np.ndarray, torch.Tensor]], Union[np.ndarray, torch.Tensor] 88 | ]] = None) -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]: 89 | 90 | dataset_mode_train = {dataset_args['training']['key']: dataset_args['training']['yes']} 91 | dataset_mode_test = {dataset_args['training']['key']: dataset_args['training']['no']} 92 | 93 | dataset_args_train = {**{k: v for k, v in dataset_args.items() if k != 'training'}, **dataset_mode_train} 94 | dataset_args_test = {**{k: v for k, v in dataset_args.items() if k != 'training'}, **dataset_mode_test} 95 | 96 | train_loader = torch.utils.data.DataLoader( 97 | Dataset(**{**dataset_args_train, **{'transform': transforms_train}}), 98 | batch_size=batch_train, 99 | shuffle=True, 100 | num_workers=workers_train, 101 | pin_memory=True 102 | ) 103 | eval_loader = torch.utils.data.DataLoader( 104 | Dataset(**{**dataset_args_test, **{'transform': transforms_test}}), 105 | batch_size=batch_test, 106 | num_workers=workers_test, 107 | pin_memory=True 108 | ) 109 | 110 | return train_loader, eval_loader 111 | 112 | 113 | def build_summary_str(experiment_name: str, 114 | model_short_name: str, 115 | model_class: str, 116 | model_args: Dict[str, Any], 117 | optimizer_class: str, 118 | optimizer_args: Dict[str, Any], 119 | dataset_class: str, 120 | dataset_args: Dict[str, Any], 121 | transforms: List[Dict[str, Union[str, Dict[str, Any]]]], 122 | epochs: int, 123 | batch_train: int, 124 | log_interval: int, 125 | saved_models_path: str, 126 | scheduler_class: Optional[str] = None, 127 | scheduler_args: Optional[Dict[str, Any]] = None) -> str: 128 | 129 | setup_title = '{}-{}'.format(experiment_name, model_short_name) 130 | 131 | summary_window_text = '

' 132 | summary_window_text += ''.format(setup_title) 133 | 134 | summary_window_text += setup_title 135 | 136 | summary_window_text += '' 137 | summary_window_text += '

' 138 | summary_window_text += '
' 139 | summary_window_text += '' 164 | summary_window_text += '
' 165 | 166 | return summary_window_text 167 | -------------------------------------------------------------------------------- /ignite_trainer/_visdom.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import time 5 | import tqdm 6 | import socket 7 | import subprocess 8 | import numpy as np 9 | 10 | import visdom 11 | 12 | from typing import Tuple 13 | from typing import Optional 14 | 15 | 16 | def calc_ytick_range(vis: visdom.Visdom, window_name: str, env: Optional[str] = None) -> Tuple[float, float]: 17 | lower_bound, upper_bound = -1.0, 1.0 18 | 19 | stats = vis.get_window_data(win=window_name, env=env) 20 | 21 | if stats: 22 | stats = json.loads(stats) 23 | 24 | stats = [np.array(item['y']) for item in stats['content']['data']] 25 | stats = [item[item != np.array([None])].astype(np.float16) for item in stats] 26 | 27 | if stats: 28 | q25s = np.array([np.quantile(item, 0.25) for item in stats if len(item) > 0]) 29 | q75s = np.array([np.quantile(item, 0.75) for item in stats if len(item) > 0]) 30 | 31 | if q25s.shape == q75s.shape and len(q25s) > 0: 32 | iqrs = q75s - q25s 33 | 34 | lower_bounds = q25s - 1.5 * iqrs 35 | upper_bounds = q75s + 1.5 * iqrs 36 | 37 | stats_sanitized = list() 38 | idx = 0 39 | for item in stats: 40 | if len(item) > 0: 41 | item_sanitized = item[(item >= lower_bounds[idx]) & (item <= upper_bounds[idx])] 42 | stats_sanitized.append(item_sanitized) 43 | 44 | idx += 1 45 | 46 | stats_sanitized = np.array(stats_sanitized) 47 | 48 | q25_sanitized = np.array([np.quantile(item, 0.25) for item in stats_sanitized]) 49 | q75_sanitized = np.array([np.quantile(item, 0.75) for item in stats_sanitized]) 50 | 51 | iqr_sanitized = np.sum(q75_sanitized - q25_sanitized) 52 | lower_bound = np.min(q25_sanitized) - 1.5 * iqr_sanitized 53 | upper_bound = np.max(q75_sanitized) + 1.5 * iqr_sanitized 54 | 55 | return lower_bound, upper_bound 56 | 57 | 58 | def plot_line(vis: visdom.Visdom, 59 | window_name: str, 60 | env: Optional[str] = None, 61 | line_label: Optional[str] = None, 62 | x: Optional[np.ndarray] = None, 63 | y: Optional[np.ndarray] = None, 64 | x_label: Optional[str] = None, 65 | y_label: Optional[str] = None, 66 | width: int = 576, 67 | height: int = 416, 68 | draw_marker: bool = False) -> str: 69 | 70 | empty_call = not vis.win_exists(window_name) 71 | 72 | if empty_call and (x is not None or y is not None): 73 | return window_name 74 | 75 | if x is None: 76 | x = np.ones(1) 77 | empty_call = empty_call & True 78 | 79 | if y is None: 80 | y = np.full(1, np.nan) 81 | empty_call = empty_call & True 82 | 83 | if x.shape != y.shape: 84 | x = np.ones_like(y) 85 | 86 | opts = { 87 | 'showlegend': True, 88 | 'markers': draw_marker, 89 | 'markersize': 5, 90 | } 91 | 92 | if empty_call: 93 | opts['title'] = window_name 94 | opts['width'] = width 95 | opts['height'] = height 96 | 97 | window_name = vis.line( 98 | X=x, 99 | Y=y, 100 | win=window_name, 101 | env=env, 102 | update='append', 103 | name=line_label, 104 | opts=opts 105 | ) 106 | 107 | xtickmin, xtickmax = 0.0, np.max(x) * 1.05 108 | ytickmin, ytickmax = calc_ytick_range(vis, window_name, env) 109 | 110 | opts = { 111 | 'showlegend': True, 112 | 'xtickmin': xtickmin, 113 | 'xtickmax': xtickmax, 114 | 'ytickmin': ytickmin, 115 | 'ytickmax': ytickmax, 116 | 'xlabel': x_label, 117 | 'ylabel': y_label 118 | } 119 | 120 | window_name = vis.update_window_opts(win=window_name, opts=opts, env=env) 121 | 122 | return window_name 123 | 124 | 125 | # TODO: implement remove experiment callback 126 | 127 | 128 | def create_summary_window(vis: visdom.Visdom, 129 | visdom_env_name: str, 130 | experiment_name: str, 131 | summary: str) -> str: 132 | 133 | return vis.text( 134 | text=summary, 135 | win=experiment_name, 136 | env=visdom_env_name, 137 | opts={'title': 'Summary', 'width': 576, 'height': 416}, 138 | append=vis.win_exists(experiment_name, visdom_env_name) 139 | ) 140 | 141 | 142 | def connection_is_alive(host: str, port: int) -> bool: 143 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: 144 | try: 145 | sock.connect((host, port)) 146 | sock.shutdown(socket.SHUT_RDWR) 147 | 148 | return True 149 | except socket.error: 150 | return False 151 | 152 | 153 | def get_visdom_instance(host: str = 'localhost', 154 | port: int = 8097, 155 | env_name: str = 'main', 156 | env_path: str = 'visdom_env') -> Tuple[visdom.Visdom, Optional[int]]: 157 | 158 | vis_pid = None 159 | 160 | if not connection_is_alive(host, port): 161 | if any(host.strip('/').endswith(lh) for lh in ['127.0.0.1', 'localhost']): 162 | os.makedirs(env_path, exist_ok=True) 163 | 164 | tqdm.tqdm.write('Starting visdom on port {}'.format(port), end='') 165 | 166 | vis_args = [ 167 | sys.executable, 168 | '-m', 'visdom.server', 169 | '-port', str(port), 170 | '-env_path', os.path.join(os.getcwd(), env_path) 171 | ] 172 | vis_proc = subprocess.Popen(vis_args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 173 | time.sleep(2.0) 174 | 175 | vis_pid = vis_proc.pid 176 | tqdm.tqdm.write('PID -> {}'.format(vis_pid)) 177 | 178 | trials_left = 5 179 | while not connection_is_alive(host, port): 180 | time.sleep(1.0) 181 | 182 | tqdm.tqdm.write('Trying to connect ({} left)...'.format(trials_left)) 183 | 184 | trials_left -= 1 185 | if trials_left < 1: 186 | raise RuntimeError('Visdom server is not running. Please run "python -m visdom.server".') 187 | 188 | vis = visdom.Visdom( 189 | server='http://{}'.format(host), 190 | port=port, 191 | env=env_name 192 | ) 193 | 194 | return vis, vis_pid 195 | -------------------------------------------------------------------------------- /reproduced/adcnn.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | import ignite_trainer as it 7 | 8 | from utils import transforms, features 9 | 10 | from typing import Tuple 11 | from typing import Union 12 | from typing import Optional 13 | 14 | 15 | class Block(torch.nn.Module): 16 | 17 | def __init__(self, 18 | in_channels: int, 19 | out_channels: int, 20 | kernel_size: Tuple[int, int], 21 | pooling_size: Tuple[int, int]): 22 | 23 | super(Block, self).__init__() 24 | 25 | self.conv1 = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size) 26 | self.conv2 = torch.nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size) 27 | self.conv1x1 = torch.nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=(1, 1)) 28 | self.bn = torch.nn.BatchNorm2d(num_features=out_channels) 29 | self.activation = torch.nn.LeakyReLU() 30 | self.pooling = torch.nn.MaxPool2d(kernel_size=pooling_size) 31 | 32 | def forward(self, x: torch.Tensor) -> torch.Tensor: 33 | x = self.conv1(x) 34 | x = self.conv2(x) 35 | x = self.conv1x1(x) 36 | x = self.bn(x) 37 | x = self.activation(x) 38 | x = self.pooling(x) 39 | 40 | return x 41 | 42 | 43 | class Attention(torch.nn.Module): 44 | 45 | def __init__(self, 46 | in_channels: int, 47 | out_channels: int, 48 | kernel_size: Tuple[int, int], 49 | pooling_size: Tuple[int, int]): 50 | 51 | super(Attention, self).__init__() 52 | 53 | self.pool = torch.nn.MaxPool2d(kernel_size=pooling_size) 54 | self.conv_depth = torch.nn.Conv2d( 55 | in_channels=in_channels, 56 | out_channels=out_channels, 57 | kernel_size=kernel_size, 58 | groups=in_channels 59 | ) 60 | self.conv_point = torch.nn.Conv2d( 61 | in_channels=out_channels, 62 | out_channels=out_channels, 63 | kernel_size=(1, 1) 64 | ) 65 | self.bn = torch.nn.BatchNorm2d(num_features=out_channels) 66 | self.activation = torch.nn.ReLU() 67 | 68 | def forward(self, x: torch.Tensor) -> torch.Tensor: 69 | x = self.pool(x) 70 | x = self.conv_depth(x) 71 | x = self.conv_point(x) 72 | x = self.bn(x) 73 | x = self.activation(x) 74 | 75 | return x 76 | 77 | 78 | class DCNN5(it.AbstractNet): 79 | 80 | def __init__(self, 81 | num_channels: int = 1, 82 | sample_rate: int = 32000, 83 | n_fft: int = 256, 84 | hop_length: Optional[int] = None, 85 | window: Optional[str] = None, 86 | num_classes: int = 10): 87 | 88 | super(DCNN5, self).__init__() 89 | 90 | self.num_channels = num_channels 91 | self.num_classes = num_classes 92 | 93 | if hop_length is None: 94 | hop_length = int(math.floor(n_fft / 4)) 95 | 96 | if window is None: 97 | window = 'boxcar' 98 | 99 | self.log10_eps = 1e-18 100 | 101 | self.mfcc = features.MFCC( 102 | sample_rate=sample_rate, 103 | n_mfcc=128, 104 | n_fft=n_fft, 105 | hop_length=hop_length, 106 | window=window 107 | ) 108 | 109 | self.block1 = Block(self.num_channels, 32, (3, 1), (2, 1)) 110 | self.block2 = Block(32, 32, (1, 5), (1, 4)) 111 | self.block3 = Block(32, 64, (3, 1), (2, 1)) 112 | self.block4 = Block(64, 64, (1, 5), (1, 4)) 113 | self.block5 = Block(64, 128, (3, 5), (1, 1)) 114 | self.max_pool = torch.nn.MaxPool2d(kernel_size=(2, 4)) 115 | 116 | self.drop1 = torch.nn.Dropout(p=0.25) 117 | self.fc1 = torch.nn.Linear(in_features=128 * 12 * 2, out_features=256) 118 | self.fc2 = torch.nn.Linear(in_features=self.fc1.out_features, out_features=self.num_classes) 119 | 120 | self.activation = torch.nn.LeakyReLU() 121 | 122 | self.l2_lambda = 0.1 123 | 124 | def forward(self, 125 | x: torch.Tensor, 126 | y: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 127 | 128 | x = self.mfcc(x) 129 | x = transforms.scale( 130 | x, 131 | x.min(dim=-1, keepdim=True).values.min(dim=-2, keepdim=True).values.min(dim=-3, keepdim=True).values, 132 | x.max(dim=-1, keepdim=True).values.max(dim=-2, keepdim=True).values.max(dim=-3, keepdim=True).values, 133 | 0.0, 134 | 1.0 135 | ) 136 | 137 | x = self.block1(x) 138 | x = self.block2(x) 139 | x = self.block3(x) 140 | x = self.block4(x) 141 | x = self.max_pool(self.block5(x)) 142 | 143 | x = x.view(x.shape[0], -1) 144 | x = self.drop1(x) 145 | 146 | x = self.fc1(x) 147 | x = self.activation(x) 148 | 149 | y_pred = self.fc2(x) 150 | 151 | loss = None 152 | if y is not None: 153 | loss = self.loss_fn(y_pred, y).sum() 154 | 155 | return y_pred if loss is None else (y_pred, loss) 156 | 157 | def loss_fn(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 158 | loss_pred = F.cross_entropy(y_pred, y) 159 | 160 | loss_l2 = 0.0 161 | loss_l2_params = list(self.fc1.parameters()) 162 | for p in loss_l2_params: 163 | loss_l2 = p.norm(2) + loss_l2 164 | 165 | loss_pred = loss_pred + self.l2_lambda * loss_l2 166 | 167 | return loss_pred 168 | 169 | @property 170 | def loss_fn_name(self) -> str: 171 | return 'Cross Entropy' 172 | 173 | 174 | class ADCNN5(DCNN5): 175 | 176 | def __init__(self, 177 | num_channels: int = 1, 178 | n_fft: int = 1024, 179 | hop_length: Optional[int] = None, 180 | window: Optional[str] = None, 181 | num_classes: int = 10): 182 | 183 | super(ADCNN5, self).__init__( 184 | num_channels=num_channels, 185 | n_fft=n_fft, 186 | hop_length=hop_length, 187 | window=window, 188 | num_classes=num_classes 189 | ) 190 | 191 | self.attn1 = Attention(self.num_channels, 32, (3, 1), (2, 1)) 192 | self.attn2 = Attention(32, 32, (1, 3), (1, 4)) 193 | self.attn3 = Attention(32, 64, (3, 1), (2, 1)) 194 | self.attn4 = Attention(64, 64, (1, 3), (1, 4)) 195 | self.attn5 = Attention(64, 128, (3, 3), (2, 4)) 196 | self.attn5.pool = torch.nn.Identity() 197 | self.attn5 = torch.nn.Sequential( 198 | self.attn5, 199 | torch.nn.AdaptiveMaxPool2d(output_size=(12, 2)) 200 | ) 201 | 202 | def forward(self, 203 | x: torch.Tensor, 204 | y: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 205 | 206 | x = self.mfcc(x) 207 | x = transforms.scale( 208 | x, 209 | x.min(dim=-1, keepdim=True).values.min(dim=-2, keepdim=True).values.min(dim=-3, keepdim=True).values, 210 | x.max(dim=-1, keepdim=True).values.max(dim=-2, keepdim=True).values.max(dim=-3, keepdim=True).values, 211 | 0.0, 212 | 1.0 213 | ) 214 | 215 | x = self.attn1(x) * self.block1(x) 216 | x = self.attn2(x) * self.block2(x) 217 | x = self.attn3(x) * self.block3(x) 218 | x = self.attn4(x) * self.block4(x) 219 | x = self.attn5(x) * self.max_pool(self.block5(x)) 220 | 221 | x = x.view(x.shape[0], -1) 222 | x = self.drop1(x) 223 | 224 | x = self.fc1(x) 225 | x = self.activation(x) 226 | 227 | y_pred = self.fc2(x) 228 | 229 | loss = None 230 | if y is not None: 231 | loss = self.loss_fn(y_pred, y).sum() 232 | 233 | return y_pred if loss is None else (y_pred, loss) 234 | -------------------------------------------------------------------------------- /utils/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | import multiprocessing as mp 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import sklearn.model_selection as skms 8 | 9 | import tqdm 10 | import librosa 11 | 12 | import torch.utils.data as td 13 | 14 | from utils import transforms 15 | 16 | from typing import Tuple 17 | from typing import Optional 18 | 19 | 20 | class ESC50(td.Dataset): 21 | 22 | def __init__(self, 23 | root: str, 24 | sample_rate: int = 22050, 25 | train: bool = True, 26 | fold: Optional[int] = None, 27 | transform=None, 28 | target_transform=None): 29 | 30 | super(ESC50, self).__init__() 31 | 32 | self.sample_rate = sample_rate 33 | 34 | meta = self.load_meta(os.path.join(root, 'meta', 'esc50.csv')) 35 | 36 | if fold is None: 37 | fold = 5 38 | 39 | self.folds_to_load = set(meta['fold']) 40 | 41 | if fold not in self.folds_to_load: 42 | raise ValueError(f'fold {fold} does not exist') 43 | 44 | self.train = train 45 | self.transform = transform 46 | 47 | if self.train: 48 | self.folds_to_load -= {fold} 49 | else: 50 | self.folds_to_load -= self.folds_to_load - {fold} 51 | 52 | self.data = dict() 53 | self.load_data(meta, os.path.join(root, 'audio')) 54 | self.indices = list(self.data.keys()) 55 | 56 | self.target_transform = target_transform 57 | 58 | @staticmethod 59 | def load_meta(path_to_csv: str) -> pd.DataFrame: 60 | meta = pd.read_csv(path_to_csv) 61 | 62 | return meta 63 | 64 | @staticmethod 65 | def _load_worker(idx: int, filename: str, sample_rate: Optional[int] = None) -> Tuple[int, int, np.ndarray]: 66 | wav, sample_rate = librosa.load(filename, sr=sample_rate, mono=True) 67 | 68 | if wav.ndim == 1: 69 | wav = wav[:, np.newaxis] 70 | 71 | if np.abs(wav.max()) > 1.0: 72 | wav = transforms.scale(wav, wav.min(), wav.max(), -1.0, 1.0) 73 | 74 | wav = wav.T * 32768.0 75 | 76 | return idx, sample_rate, wav.astype(np.float32) 77 | 78 | def load_data(self, meta: pd.DataFrame, base_path: str): 79 | items_to_load = dict() 80 | 81 | for idx, row in meta.iterrows(): 82 | if row['fold'] in self.folds_to_load: 83 | items_to_load[idx] = os.path.join(base_path, row['filename']), self.sample_rate 84 | 85 | items_to_load = [(idx, path, sample_rate) for idx, (path, sample_rate) in items_to_load.items()] 86 | 87 | warnings.filterwarnings('ignore') 88 | with mp.Pool(processes=mp.cpu_count()) as pool: 89 | chunksize = int(np.ceil(len(items_to_load) / pool._processes)) or 1 90 | tqdm.tqdm.write(f'Loading {self.__class__.__name__} (train={self.train})') 91 | for idx, sample_rate, wav in pool.starmap( 92 | func=self._load_worker, 93 | iterable=items_to_load, 94 | chunksize=chunksize 95 | ): 96 | row = meta.loc[idx] 97 | 98 | self.data[idx] = { 99 | 'audio': wav, 100 | 'sample_rate': sample_rate, 101 | 'target': row['target'], 102 | 'fold': row['fold'], 103 | 'esc10': row['esc10'] 104 | } 105 | 106 | def __getitem__(self, index: int) -> Tuple[np.ndarray, int]: 107 | if not (0 <= index < len(self)): 108 | raise IndexError 109 | 110 | audio: np.ndarray = self.data[self.indices[index]]['audio'] 111 | target: int = self.data[self.indices[index]]['target'] 112 | 113 | if self.transform is not None: 114 | audio = self.transform(audio) 115 | if self.target_transform is not None: 116 | target = self.target_transform(target) 117 | 118 | return audio, target 119 | 120 | def __len__(self) -> int: 121 | return len(self.indices) 122 | 123 | 124 | class ESC10(ESC50): 125 | 126 | def __init__(self, 127 | root: str, 128 | sample_rate: int = 22050, 129 | train: bool = True, 130 | fold: Optional[int] = None, 131 | transform=None, 132 | target_transform=None): 133 | 134 | super(ESC10, self).__init__( 135 | root=root, 136 | sample_rate=sample_rate, 137 | train=train, 138 | fold=fold, 139 | transform=transform, 140 | target_transform=target_transform 141 | ) 142 | 143 | self.classes = { 144 | old_target: new_target 145 | for new_target, old_target 146 | in enumerate({item['target'] for item in self.data.values()}) 147 | } 148 | 149 | @staticmethod 150 | def load_meta(path_to_csv: str) -> pd.DataFrame: 151 | meta = ESC50.load_meta(path_to_csv) 152 | meta.drop(index=meta[~meta['esc10']].index, inplace=True) 153 | 154 | return meta 155 | 156 | def __getitem__(self, index: int) -> Tuple[np.ndarray, int]: 157 | audio, target = super(ESC10, self).__getitem__(index) 158 | 159 | target = self.classes[target] 160 | 161 | return audio, target 162 | 163 | 164 | class UrbanSound8K(td.Dataset): 165 | 166 | def __init__(self, 167 | root: str, 168 | sample_rate: int = 22050, 169 | train: bool = True, 170 | fold: Optional[int] = None, 171 | random_split_seed: Optional[int] = None, 172 | mono: bool = False, 173 | transform=None, 174 | target_transform=None): 175 | 176 | super(UrbanSound8K, self).__init__() 177 | 178 | self.root = root 179 | self.sample_rate = sample_rate 180 | self.train = train 181 | 182 | if fold is None: 183 | fold = 1 184 | 185 | if not (1 <= fold <= 10): 186 | raise ValueError(f'Expected fold in range [1, 10], got {fold}') 187 | 188 | self.fold = fold 189 | self.folds_to_load = set(range(1, 11)) 190 | 191 | if self.fold not in self.folds_to_load: 192 | raise ValueError(f'fold {fold} does not exist') 193 | 194 | if self.train: 195 | # if in training mode, keep all but test fold 196 | self.folds_to_load -= {self.fold} 197 | else: 198 | # if in evaluation mode, keep the test samples only 199 | self.folds_to_load -= self.folds_to_load - {self.fold} 200 | 201 | self.random_split_seed = random_split_seed 202 | self.mono = mono 203 | 204 | self.transform = transform 205 | self.target_transform = target_transform 206 | 207 | self.data = dict() 208 | self.indices = dict() 209 | self.load_data() 210 | 211 | @staticmethod 212 | def _load_worker(fn: str, path_to_file: str, sample_rate: int, mono: bool = False) -> Tuple[str, int, np.ndarray]: 213 | wav, sample_rate = librosa.load(path_to_file, sr=sample_rate, mono=mono) 214 | 215 | if wav.ndim == 1: 216 | wav = wav[np.newaxis, :] 217 | 218 | if not mono: 219 | wav = np.concatenate((wav, wav), axis=0) 220 | 221 | wav = wav.T 222 | wav = wav[:sample_rate * 4] 223 | 224 | if np.abs(wav.max()) > 1.0: 225 | wav = transforms.scale(wav, wav.min(), wav.max(), -1.0, 1.0) 226 | 227 | wav = transforms.scale(wav, wav.min(), wav.max(), -32768.0, 32767.0).T 228 | 229 | return fn, sample_rate, wav.astype(np.float32) 230 | 231 | def load_data(self): 232 | # read metadata 233 | meta = pd.read_csv( 234 | os.path.join(self.root, 'metadata', 'UrbanSound8K.csv'), 235 | sep=',', 236 | index_col='slice_file_name' 237 | ) 238 | 239 | for row_idx, (fn, row) in enumerate(meta.iterrows()): 240 | path = os.path.join(self.root, 'audio', 'fold{}'.format(row['fold']), fn) 241 | self.data[fn] = path, self.sample_rate, self.mono 242 | 243 | # by default, the official split from the metadata is used 244 | files_to_load = list() 245 | # if the random seed is not None, the random split is used 246 | if self.random_split_seed is not None: 247 | # given an integer random seed 248 | skf = skms.StratifiedKFold(n_splits=10, shuffle=True, random_state=self.random_split_seed) 249 | 250 | # split the US8K samples into 10 folds 251 | for fold_idx, (train_ids, test_ids) in enumerate(skf.split( 252 | np.zeros(len(meta)), meta['classID'].values.astype(int) 253 | ), 1): 254 | # if this is the fold we want to load, add the corresponding files to the list 255 | if fold_idx == self.fold: 256 | ids = train_ids if self.train else test_ids 257 | filenames = meta.iloc[ids].index 258 | files_to_load.extend(filenames) 259 | break 260 | else: 261 | # if the random seed is None, use the official split 262 | for fn, row in meta.iterrows(): 263 | if int(row['fold']) in self.folds_to_load: 264 | files_to_load.append(fn) 265 | 266 | self.data = {fn: vals for fn, vals in self.data.items() if fn in files_to_load} 267 | self.indices = {idx: fn for idx, fn in enumerate(self.data)} 268 | 269 | warnings.filterwarnings('ignore') 270 | with mp.Pool(processes=mp.cpu_count()) as pool: 271 | chunksize = int(np.ceil(len(meta) / pool._processes)) or 1 272 | 273 | tqdm.tqdm.write(f'Loading {self.__class__.__name__} (train={self.train})') 274 | 275 | for fn, sample_rate, wav in pool.starmap( 276 | func=self._load_worker, 277 | iterable=[(fn, path, sr, mono) for fn, (path, sr, mono) in self.data.items()], 278 | chunksize=chunksize 279 | ): 280 | self.data[fn] = { 281 | 'audio': wav, 282 | 'sample_rate': sample_rate, 283 | 'target': meta.loc[fn, 'classID'] 284 | } 285 | 286 | def __getitem__(self, index: int) -> Tuple[np.ndarray, int]: 287 | if not (0 <= index < len(self)): 288 | raise IndexError 289 | 290 | audio: np.ndarray = self.data[self.indices[index]]['audio'] 291 | target: int = self.data[self.indices[index]]['target'] 292 | 293 | if self.transform is not None: 294 | audio = self.transform(audio) 295 | if self.target_transform is not None: 296 | target = self.target_transform(target) 297 | 298 | return audio, target 299 | 300 | def __len__(self) -> int: 301 | return len(self.data) 302 | -------------------------------------------------------------------------------- /utils/features.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.fft as spf 3 | import scipy.signal as sps 4 | 5 | import librosa 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from utils import transforms 11 | 12 | from typing import Optional 13 | 14 | 15 | def fft_frequencies(sample_rate: int = 22050, n_fft: int = 2048) -> torch.Tensor: 16 | return torch.linspace(0, sample_rate * 0.5, int(1 + n_fft // 2)) 17 | 18 | 19 | def power_to_db(spectrogram: torch.Tensor, ref: float = 1.0, amin: float = 1e-10, top_db: float = 80.0) -> torch.Tensor: 20 | log_spec = 10.0 * torch.log10(torch.max(torch.full_like(spectrogram, amin), spectrogram)) 21 | log_spec -= 10.0 * torch.log10(torch.full_like(spectrogram, max(amin, ref))) 22 | 23 | log_spec = torch.max( 24 | log_spec, 25 | log_spec.max(dim=-1, keepdim=True).values.max(dim=-2, keepdim=True).values - top_db 26 | ) 27 | 28 | return log_spec 29 | 30 | 31 | class MFCC(torch.nn.Module): 32 | 33 | def __init__(self, 34 | sample_rate: int = 22050, 35 | n_mfcc: int = 128, 36 | n_fft: int = 1024, 37 | hop_length: int = 512, 38 | window: str = 'hann'): 39 | 40 | super(MFCC, self).__init__() 41 | 42 | mel_filterbank = librosa.filters.mel( 43 | sr=sample_rate, 44 | n_fft=n_fft, 45 | n_mels=n_mfcc 46 | ) 47 | mel_filterbank = torch.from_numpy(mel_filterbank).to(torch.get_default_dtype()) 48 | self.register_buffer('mel', mel_filterbank) 49 | 50 | dct_buf = spf.dct(np.eye(n_mfcc), type=2, norm='ortho').T 51 | dct_buf = torch.from_numpy(dct_buf).to(torch.get_default_dtype()) 52 | self.register_buffer('dct_mat', dct_buf) 53 | 54 | window_buffer: torch.Tensor = torch.from_numpy( 55 | sps.get_window(window=window, Nx=n_fft, fftbins=True) 56 | ).to(torch.get_default_dtype()) 57 | self.register_buffer('window', window_buffer) 58 | 59 | self.sample_rate = sample_rate 60 | self.n_fft = n_fft 61 | self.n_mfcc = n_mfcc 62 | self.hop_length = hop_length 63 | 64 | def dct2(self, x): 65 | x_dct = self.dct_mat.view(1, *self.dct_mat.shape) @ x 66 | 67 | return x_dct 68 | 69 | def forward(self, x: torch.Tensor) -> torch.Tensor: 70 | spec = torch.stft( 71 | x.view(-1, x.shape[-1]), 72 | n_fft=self.n_fft, 73 | hop_length=self.hop_length, 74 | win_length=self.n_fft, 75 | window=self.window, 76 | normalized=True 77 | ) 78 | 79 | power_spec = spec[..., 0] ** 2 + spec[..., 1] ** 2 80 | log_power_spec = 10 * torch.log10(power_spec.add(1e-18)) 81 | 82 | mel_spec = self.mel.view(1, *self.mel.shape) @ log_power_spec 83 | mfcc = self.dct2(mel_spec) 84 | mfcc = mfcc.view(x.shape[0], 1, *mfcc.shape[-2:]) 85 | 86 | return mfcc 87 | 88 | 89 | class Chroma(torch.nn.Module): 90 | 91 | def __init__(self, 92 | sample_rate: int = 22050, 93 | norm: float = float('inf'), 94 | n_fft: int = 2048, 95 | tuning: float = 0.0, 96 | n_chroma: int = 12, 97 | ctroct: float = 5.0, 98 | octwidth: float = 2.0, 99 | base_c: bool = True): 100 | 101 | super(Chroma, self).__init__() 102 | 103 | chroma_fb_buf = librosa.filters.chroma( 104 | sr=sample_rate, 105 | n_fft=n_fft, 106 | n_chroma=n_chroma, 107 | tuning=tuning, 108 | ctroct=ctroct, 109 | octwidth=octwidth, 110 | norm=norm, 111 | base_c=base_c 112 | ) 113 | self.register_buffer('chroma_fb', torch.from_numpy(chroma_fb_buf).to(torch.get_default_dtype())) 114 | 115 | self.norm = norm 116 | 117 | def forward(self, spectrogram: torch.Tensor) -> torch.Tensor: 118 | chroma = self.chroma_fb @ spectrogram 119 | chroma = chroma / torch.norm(chroma, p=self.norm, dim=-2, keepdim=True) 120 | 121 | return chroma 122 | 123 | 124 | class Tonnetz(Chroma): 125 | 126 | def __init__(self, 127 | sample_rate: int = 22050, 128 | norm: float = float('inf'), 129 | n_fft: int = 2048, 130 | tuning: float = 0.0, 131 | n_chroma: int = 12, 132 | ctroct: float = 5.0, 133 | octwidth: float = 2.0, 134 | base_c: bool = True): 135 | 136 | super(Tonnetz, self).__init__( 137 | sample_rate=sample_rate, 138 | norm=norm, 139 | n_fft=n_fft, 140 | tuning=tuning, 141 | n_chroma=n_chroma, 142 | ctroct=ctroct, 143 | octwidth=octwidth, 144 | base_c=base_c 145 | ) 146 | 147 | # Generate Transformation matrix 148 | dim_map = np.linspace(0, 12, n_chroma, endpoint=False) 149 | 150 | scale = np.asarray([7. / 6, 7. / 6, 151 | 3. / 2, 3. / 2, 152 | 2. / 3, 2. / 3]) 153 | 154 | V = np.multiply.outer(scale, dim_map) 155 | 156 | # Even rows compute sin() 157 | V[::2] -= 0.5 158 | 159 | R = np.array([1, 1, # Fifths 160 | 1, 1, # Minor 161 | 0.5, 0.5]) # Major 162 | 163 | phi_buf = R[:, np.newaxis] * np.cos(np.pi * V) 164 | 165 | self.register_buffer('phi', torch.from_numpy(phi_buf).to(torch.get_default_dtype())) 166 | 167 | def forward(self, spectrogram: torch.Tensor) -> torch.Tensor: 168 | chroma = super(Tonnetz, self).forward(spectrogram) 169 | chroma = chroma / torch.norm(chroma, p=1, dim=-2, keepdim=True) 170 | tonnetz = self.phi @ chroma 171 | 172 | return tonnetz 173 | 174 | 175 | class SpectralContrast(torch.nn.Module): 176 | 177 | def __init__(self, 178 | sample_rate: int = 22050, 179 | n_fft: int = 2048, 180 | freq: Optional[torch.Tensor] = None, 181 | fmin: float = 200.0, 182 | n_bands: int = 6, 183 | quantile: float = 0.02, 184 | linear: bool = False): 185 | 186 | super(SpectralContrast, self).__init__() 187 | 188 | # Compute the center frequencies of each bin 189 | if freq is None: 190 | freq = fft_frequencies(sample_rate=sample_rate, n_fft=n_fft) 191 | 192 | self.register_buffer('freq', freq) 193 | 194 | if n_bands < 1 or not isinstance(n_bands, int): 195 | raise ValueError('n_bands must be a positive integer') 196 | 197 | self.n_bands = n_bands 198 | 199 | if not 0.0 < quantile < 1.0: 200 | raise ValueError('quantile must lie in the range (0, 1)') 201 | 202 | self.quantile = quantile 203 | 204 | if fmin <= 0: 205 | raise ValueError('fmin must be a positive number') 206 | 207 | octa_buf = torch.zeros(n_bands + 2) 208 | octa_buf[1:] = fmin * (2.0 ** torch.arange(0, n_bands + 1, dtype=torch.float32)) 209 | 210 | if torch.any(octa_buf[:-1] >= 0.5 * sample_rate): 211 | raise ValueError('Frequency band exceeds Nyquist. Reduce either fmin or n_bands.') 212 | 213 | self.register_buffer('octa', octa_buf) 214 | 215 | self.linear = linear 216 | 217 | def forward(self, spectrogram: torch.Tensor) -> torch.Tensor: 218 | valley = torch.zeros( 219 | *spectrogram.shape[:-2], self.n_bands + 1, spectrogram.shape[-1], 220 | dtype=spectrogram.dtype, 221 | device=spectrogram.device 222 | ) 223 | peak = torch.zeros_like(valley) 224 | 225 | for k, (f_low, f_high) in enumerate(zip(self.octa[:-1], self.octa[1:])): 226 | current_band: torch.Tensor = (self.freq >= f_low) & (self.freq <= f_high) 227 | 228 | idx = torch.nonzero(torch.flatten(current_band)) 229 | 230 | if k > 0: 231 | current_band[idx[0] - 1] = True 232 | 233 | if k == self.n_bands: 234 | current_band[idx[-1] + 1:] = True 235 | 236 | sub_band = spectrogram[..., current_band, :] 237 | 238 | if k < self.n_bands: 239 | sub_band = sub_band[..., :-1, :] 240 | 241 | # Always take at least one bin from each side 242 | idx = np.rint(self.quantile * torch.sum(current_band).item()) 243 | idx = int(np.maximum(idx, 1)) 244 | 245 | sortedr, _ = torch.sort(sub_band, dim=-2) 246 | 247 | valley[..., k, :] = torch.mean(sortedr[..., :idx, :], dim=-2) 248 | peak[..., k, :] = torch.mean(sortedr[..., -idx:, :], dim=-2) 249 | 250 | if self.linear: 251 | return peak - valley 252 | else: 253 | return power_to_db(peak) - power_to_db(valley) 254 | 255 | 256 | class Melspectrogram(torch.nn.Module): 257 | 258 | def __init__(self, 259 | sample_rate: int = 22050, 260 | n_fft: int = 2048, 261 | n_mels: int = 128, 262 | fmin: float = 0.0, 263 | fmax: Optional[float] = None): 264 | 265 | super(Melspectrogram, self).__init__() 266 | 267 | mel_fb_buf = librosa.filters.mel( 268 | sr=sample_rate, 269 | n_fft=n_fft, 270 | n_mels=n_mels, 271 | fmin=fmin, 272 | fmax=fmax 273 | ) 274 | self.register_buffer('mel_fb', torch.from_numpy(mel_fb_buf).to(torch.get_default_dtype())) 275 | 276 | def forward(self, spectrogram: torch.Tensor) -> torch.Tensor: 277 | lm = self.mel_fb @ spectrogram 278 | lm = power_to_db(lm) 279 | 280 | lm = transforms.scale( 281 | lm, 282 | lm.min(dim=-1, keepdim=True).values.min(dim=-2, keepdim=True).values, 283 | lm.max(dim=-1, keepdim=True).values.max(dim=-2, keepdim=True).values, 284 | -1.0, 285 | 1.0 286 | ) 287 | 288 | return lm 289 | 290 | 291 | class CST(torch.nn.Module): 292 | 293 | def __init__(self, 294 | sample_rate: int = 22050, 295 | norm: float = float('inf'), 296 | n_fft: int = 2048, 297 | tuning: float = 0.0, 298 | n_chroma: int = 12, 299 | ctroct: float = 5.0, 300 | octwidth: float = 2.0, 301 | base_c: bool = True, 302 | freq: Optional[torch.Tensor] = None, 303 | fmin: float = 200.0, 304 | n_bands: int = 6, 305 | quantile: float = 0.02, 306 | linear: bool = False): 307 | 308 | super(CST, self).__init__() 309 | 310 | self.chroma = Chroma( 311 | sample_rate=sample_rate, 312 | norm=norm, 313 | n_fft=n_fft, 314 | tuning=tuning, 315 | n_chroma=n_chroma, 316 | ctroct=ctroct, 317 | octwidth=octwidth, 318 | base_c=base_c 319 | ) 320 | self.spectral_contrast = SpectralContrast( 321 | sample_rate=sample_rate, 322 | n_fft=n_fft, 323 | freq=freq, 324 | fmin=fmin, 325 | n_bands=n_bands, 326 | quantile=quantile, 327 | linear=linear 328 | ) 329 | self.tonnetz = Tonnetz( 330 | sample_rate=sample_rate, 331 | norm=norm, 332 | n_fft=n_fft, 333 | tuning=tuning, 334 | n_chroma=n_chroma, 335 | ctroct=ctroct, 336 | octwidth=octwidth, 337 | base_c=base_c 338 | ) 339 | 340 | def forward(self, spectrogram: torch.Tensor) -> torch.Tensor: 341 | chroma = self.chroma(spectrogram) 342 | spectral_contrast = self.spectral_contrast(spectrogram) 343 | tonnetz = self.tonnetz(spectrogram) 344 | 345 | chroma = transforms.scale( 346 | chroma, 347 | chroma.min(dim=-1, keepdim=True).values.min(dim=-2, keepdim=True).values, 348 | chroma.max(dim=-1, keepdim=True).values.max(dim=-2, keepdim=True).values, 349 | -1.0, 350 | 1.0 351 | ) 352 | spectral_contrast = transforms.scale( 353 | spectral_contrast, 354 | spectral_contrast.min(dim=-1, keepdim=True).values.min(dim=-2, keepdim=True).values, 355 | spectral_contrast.max(dim=-1, keepdim=True).values.max(dim=-2, keepdim=True).values, 356 | -1.0, 357 | 1.0 358 | ) 359 | tonnetz = transforms.scale( 360 | tonnetz, 361 | tonnetz.min(dim=-1, keepdim=True).values.min(dim=-2, keepdim=True).values, 362 | tonnetz.max(dim=-1, keepdim=True).values.max(dim=-2, keepdim=True).values, 363 | -1.0, 364 | 1.0 365 | ) 366 | 367 | cst = torch.cat(( 368 | tonnetz, 369 | spectral_contrast, 370 | chroma 371 | ), dim=-2) 372 | 373 | return cst 374 | 375 | 376 | class LMC(torch.nn.Module): 377 | 378 | def __init__(self, 379 | sample_rate: int = 22050, 380 | norm: float = float('inf'), 381 | n_fft: int = 2048, 382 | n_mels: int = 128, 383 | tuning: float = 0.0, 384 | n_chroma: int = 12, 385 | ctroct: float = 5.0, 386 | octwidth: float = 2.0, 387 | base_c: bool = True, 388 | freq: Optional[torch.Tensor] = None, 389 | fmin: float = 200.0, 390 | fmax: Optional[float] = None, 391 | n_bands: int = 6, 392 | quantile: float = 0.02, 393 | linear: bool = False): 394 | 395 | super(LMC, self).__init__() 396 | 397 | self.lm = Melspectrogram( 398 | sample_rate=sample_rate, 399 | n_fft=n_fft, 400 | n_mels=n_mels, 401 | fmin=fmin, 402 | fmax=fmax 403 | ) 404 | 405 | self.cst = CST( 406 | sample_rate=sample_rate, 407 | norm=norm, 408 | n_fft=n_fft, 409 | tuning=tuning, 410 | n_chroma=n_chroma, 411 | ctroct=ctroct, 412 | octwidth=octwidth, 413 | base_c=base_c, 414 | freq=freq, 415 | fmin=fmin, 416 | n_bands=n_bands, 417 | quantile=quantile, 418 | linear=linear 419 | ) 420 | 421 | def forward(self, spectrogram: torch.Tensor) -> torch.Tensor: 422 | lm = self.lm(spectrogram) 423 | cst = self.cst(spectrogram) 424 | 425 | lmc = torch.cat(( 426 | cst, 427 | lm 428 | ), dim=-2) 429 | 430 | return lmc 431 | -------------------------------------------------------------------------------- /model/esresnet.py: -------------------------------------------------------------------------------- 1 | import termcolor 2 | 3 | import numpy as np 4 | import scipy.signal as sps 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | import torchvision as tv 10 | 11 | import ignite_trainer as it 12 | 13 | from model import attention 14 | 15 | from typing import Tuple 16 | from typing import Union 17 | from typing import Optional 18 | from typing import Sequence 19 | 20 | 21 | def conv3x3(in_planes, out_planes, stride=1): 22 | """3x3 convolution with padding""" 23 | return torch.nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 24 | 25 | 26 | def conv1x1(in_planes, out_planes, stride=1): 27 | """1x1 convolution""" 28 | return torch.nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 29 | 30 | 31 | class BasicBlock(torch.nn.Module): 32 | 33 | expansion = 1 34 | 35 | def __init__(self, inplanes, planes, stride=1, downsample=None): 36 | super(BasicBlock, self).__init__() 37 | self.conv1 = conv3x3(inplanes, planes, stride) 38 | self.bn1 = torch.nn.BatchNorm2d(planes) 39 | self.relu = torch.nn.ReLU() 40 | self.conv2 = conv3x3(planes, planes) 41 | self.bn2 = torch.nn.BatchNorm2d(planes) 42 | self.downsample = downsample 43 | self.stride = stride 44 | 45 | def forward(self, x): 46 | identity = x 47 | 48 | out = self.conv1(x) 49 | out = self.bn1(out) 50 | out = self.relu(out) 51 | 52 | out = self.conv2(out) 53 | out = self.bn2(out) 54 | 55 | if self.downsample is not None: 56 | identity = self.downsample(x) 57 | 58 | out += identity 59 | out = self.relu(out) 60 | 61 | return out 62 | 63 | 64 | class Bottleneck(torch.nn.Module): 65 | 66 | expansion = 4 67 | 68 | def __init__(self, inplanes, planes, stride=1, downsample=None): 69 | super(Bottleneck, self).__init__() 70 | self.conv1 = conv1x1(inplanes, planes) 71 | self.bn1 = torch.nn.BatchNorm2d(planes) 72 | self.conv2 = conv3x3(planes, planes, stride) 73 | self.bn2 = torch.nn.BatchNorm2d(planes) 74 | self.conv3 = conv1x1(planes, planes * self.expansion) 75 | self.bn3 = torch.nn.BatchNorm2d(planes * self.expansion) 76 | self.relu = torch.nn.ReLU() 77 | self.downsample = downsample 78 | self.stride = stride 79 | 80 | def forward(self, x): 81 | identity = x 82 | 83 | out = self.conv1(x) 84 | out = self.bn1(out) 85 | out = self.relu(out) 86 | 87 | out = self.conv2(out) 88 | out = self.bn2(out) 89 | out = self.relu(out) 90 | 91 | out = self.conv3(out) 92 | out = self.bn3(out) 93 | 94 | if self.downsample is not None: 95 | identity = self.downsample(x) 96 | 97 | out += identity 98 | out = self.relu(out) 99 | 100 | return out 101 | 102 | 103 | class ResNet(it.AbstractNet): 104 | 105 | def __init__(self, 106 | block: Union[BasicBlock, Bottleneck], 107 | layers: Sequence[int], 108 | num_channels: int = 3, 109 | num_classes: int = 1000): 110 | 111 | super(ResNet, self).__init__() 112 | 113 | self.inplanes = 64 114 | 115 | self.conv1 = torch.nn.Conv2d(num_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) 116 | self.bn1 = torch.nn.BatchNorm2d(64) 117 | self.relu = torch.nn.ReLU() 118 | self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 119 | self.layer1 = self._make_layer(block, 64, layers[0]) 120 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 121 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 122 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 123 | self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1)) 124 | self.fc = torch.nn.Linear(512 * block.expansion, num_classes) 125 | 126 | def _make_layer(self, block, planes, blocks, stride=1): 127 | downsample = None 128 | if stride != 1 or self.inplanes != planes * block.expansion: 129 | downsample = torch.nn.Sequential( 130 | conv1x1(self.inplanes, planes * block.expansion, stride), 131 | torch.nn.BatchNorm2d(planes * block.expansion) 132 | ) 133 | 134 | layers = list() 135 | layers.append(block(self.inplanes, planes, stride, downsample)) 136 | self.inplanes = planes * block.expansion 137 | for _ in range(1, blocks): 138 | layers.append(block(self.inplanes, planes)) 139 | 140 | return torch.nn.Sequential(*layers) 141 | 142 | def forward(self, 143 | x: torch.Tensor, 144 | y: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 145 | 146 | x = self.conv1(x) 147 | x = self.bn1(x) 148 | x = self.relu(x) 149 | x = self.maxpool(x) 150 | 151 | x = self.layer1(x) 152 | x = self.layer2(x) 153 | x = self.layer3(x) 154 | x = self.layer4(x) 155 | 156 | x = self.avgpool(x) 157 | x = x.view(x.size(0), -1) 158 | 159 | y_pred = self.fc(x) 160 | 161 | loss = None 162 | if y is not None: 163 | loss = self.loss_fn(y_pred, y).sum() 164 | 165 | return y_pred if loss is None else (y_pred, loss) 166 | 167 | def loss_fn(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 168 | loss_pred = F.cross_entropy(y_pred, y) 169 | 170 | return loss_pred 171 | 172 | @property 173 | def loss_fn_name(self) -> str: 174 | return 'Cross Entropy' 175 | 176 | 177 | class ResNet50(ResNet): 178 | 179 | def __init__(self, num_channels: int = 3, num_classes: int = 1000): 180 | super(ResNet50, self).__init__( 181 | block=Bottleneck, 182 | layers=[3, 4, 6, 3], 183 | num_channels=num_channels, 184 | num_classes=num_classes 185 | ) 186 | 187 | 188 | class ResNetWithAttention(ResNet): 189 | 190 | def __init__(self, 191 | block: Union[BasicBlock, Bottleneck], 192 | layers: Sequence[int], 193 | num_channels: int = 3, 194 | num_classes: int = 1000): 195 | 196 | super(ResNetWithAttention, self).__init__( 197 | block=block, 198 | layers=layers, 199 | num_channels=num_channels, 200 | num_classes=num_classes 201 | ) 202 | 203 | self.att1 = attention.Attention2d( 204 | in_channels=64, 205 | out_channels=64 * block.expansion, 206 | num_kernels=1, 207 | kernel_size=(3, 1), 208 | padding_size=(1, 0) 209 | ) 210 | self.att2 = attention.Attention2d( 211 | in_channels=64 * block.expansion, 212 | out_channels=128 * block.expansion, 213 | num_kernels=1, 214 | kernel_size=(1, 5), 215 | padding_size=(0, 2) 216 | ) 217 | self.att3 = attention.Attention2d( 218 | in_channels=128 * block.expansion, 219 | out_channels=256 * block.expansion, 220 | num_kernels=1, 221 | kernel_size=(3, 1), 222 | padding_size=(1, 0) 223 | ) 224 | self.att4 = attention.Attention2d( 225 | in_channels=256 * block.expansion, 226 | out_channels=512 * block.expansion, 227 | num_kernels=1, 228 | kernel_size=(1, 5), 229 | padding_size=(0, 2) 230 | ) 231 | self.att5 = attention.Attention2d( 232 | in_channels=512 * block.expansion, 233 | out_channels=512 * block.expansion, 234 | num_kernels=1, 235 | kernel_size=(3, 5), 236 | padding_size=(1, 2) 237 | ) 238 | 239 | def forward(self, 240 | x: torch.Tensor, 241 | y: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 242 | 243 | x = self.conv1(x) 244 | x = self.bn1(x) 245 | x = self.relu(x) 246 | x = self.maxpool(x) 247 | 248 | x_att = x.clone() 249 | x = self.layer1(x) 250 | x_att = self.att1(x_att, x.shape[-2:]) 251 | x = x * x_att 252 | 253 | x_att = x.clone() 254 | x = self.layer2(x) 255 | x_att = self.att2(x_att, x.shape[-2:]) 256 | x = x * x_att 257 | 258 | x_att = x.clone() 259 | x = self.layer3(x) 260 | x_att = self.att3(x_att, x.shape[-2:]) 261 | x = x * x_att 262 | 263 | x_att = x.clone() 264 | x = self.layer4(x) 265 | x_att = self.att4(x_att, x.shape[-2:]) 266 | x = x * x_att 267 | 268 | x_att = x.clone() 269 | x = self.avgpool(x) 270 | x_att = self.att5(x_att, x.shape[-2:]) 271 | x = x * x_att 272 | x = x.view(x.size(0), -1) 273 | 274 | y_pred = self.fc(x) 275 | 276 | loss = None 277 | if y is not None: 278 | loss = self.loss_fn(y_pred, y).sum() 279 | 280 | return y_pred if loss is None else (y_pred, loss) 281 | 282 | 283 | class ResNet50WithAttention(ResNetWithAttention): 284 | 285 | def __init__(self, num_channels: int = 3, num_classes: int = 1000): 286 | super(ResNet50WithAttention, self).__init__( 287 | block=Bottleneck, 288 | layers=[3, 4, 6, 3], 289 | num_channels=num_channels, 290 | num_classes=num_classes 291 | ) 292 | 293 | 294 | class _ESResNet(ResNet): 295 | 296 | def __init__(self, 297 | block: Union[BasicBlock, Bottleneck], 298 | layers: Sequence[int], 299 | n_fft: int = 256, 300 | hop_length: Optional[int] = None, 301 | win_length: Optional[int] = None, 302 | window: Optional[str] = None, 303 | normalized: bool = False, 304 | onesided: bool = True, 305 | spec_height: int = 224, 306 | spec_width: int = 224, 307 | num_classes: int = 1000, 308 | pretrained: Union[bool, str] = False, 309 | lock_pretrained: Optional[bool] = None): 310 | 311 | super(_ESResNet, self).__init__( 312 | block=block, 313 | layers=layers, 314 | num_channels=3, 315 | num_classes=num_classes 316 | ) 317 | 318 | self.num_classes = num_classes 319 | 320 | self.fc = torch.nn.Identity() 321 | self.classifier = torch.nn.Linear( 322 | in_features=512 * block.expansion, 323 | out_features=self.num_classes 324 | ) 325 | 326 | if hop_length is None: 327 | hop_length = int(np.floor(n_fft / 4)) 328 | 329 | if win_length is None: 330 | win_length = n_fft 331 | 332 | if window is None: 333 | window = 'boxcar' 334 | 335 | self.n_fft = n_fft 336 | self.win_length = win_length 337 | self.hop_length = hop_length 338 | 339 | self.normalized = normalized 340 | self.onesided = onesided 341 | 342 | self.spec_height = spec_height 343 | self.spec_width = spec_width 344 | 345 | self.pretrained = pretrained 346 | if pretrained: 347 | err_msg = self.load_pretrained() 348 | 349 | unlocked_weights = list() 350 | 351 | for name, p in self.named_parameters(): 352 | if lock_pretrained and name not in err_msg: 353 | p.requires_grad_(False) 354 | else: 355 | unlocked_weights.append(name) 356 | 357 | print(f'Following weights are unlocked: {unlocked_weights}') 358 | 359 | window_buffer: torch.Tensor = torch.from_numpy( 360 | sps.get_window(window=window, Nx=win_length, fftbins=True) 361 | ).to(torch.get_default_dtype()) 362 | self.register_buffer('window', window_buffer) 363 | 364 | self.log10_eps = 1e-18 365 | 366 | def load_pretrained(self) -> str: 367 | if isinstance(self.pretrained, bool): 368 | state_dict = self.loading_func(pretrained=True).state_dict() 369 | else: 370 | state_dict = torch.load(self.pretrained, map_location='cpu') 371 | 372 | err_msg = '' 373 | try: 374 | self.load_state_dict(state_dict=state_dict, strict=True) 375 | except RuntimeError as ex: 376 | err_msg += f'While loading some errors occurred.\n{ex}' 377 | print(termcolor.colored(err_msg, 'red')) 378 | 379 | return err_msg 380 | 381 | def forward(self, 382 | x: torch.Tensor, 383 | y: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 384 | 385 | pow_spec = self.spectrogram(x) 386 | x_db = torch.log10(pow_spec).mul(10.0) 387 | 388 | outputs = list() 389 | for ch_idx in range(x_db.shape[1]): 390 | ch = x_db[:, ch_idx] 391 | out = super(_ESResNet, self).forward(ch) 392 | outputs.append(out) 393 | 394 | outputs = torch.stack(outputs, dim=-1).sum(dim=-1) 395 | y_pred = self.classifier(outputs) 396 | 397 | loss = None 398 | if y is not None: 399 | loss = self.loss_fn(y_pred, y).mean() 400 | 401 | return y_pred if loss is None else (y_pred, loss) 402 | 403 | def spectrogram(self, x: torch.Tensor) -> torch.Tensor: 404 | spec = torch.stft( 405 | x.view(-1, x.shape[-1]), 406 | n_fft=self.n_fft, 407 | hop_length=self.hop_length, 408 | win_length=self.win_length, 409 | window=self.window, 410 | pad_mode='reflect', 411 | normalized=self.normalized, 412 | onesided=True 413 | ) 414 | 415 | if not self.onesided: 416 | spec = torch.cat((torch.flip(spec, dims=(-3,)), spec), dim=-3) 417 | 418 | spec_height_3_bands = spec.shape[-3] // 3 419 | spec_height_single_band = 3 * spec_height_3_bands 420 | spec = spec[:, :spec_height_single_band] 421 | 422 | spec = spec.reshape(x.shape[0], -1, spec.shape[-3] // 3, *spec.shape[-2:]) 423 | 424 | spec_height = spec.shape[-3] if self.spec_height < 1 else self.spec_height 425 | spec_width = spec.shape[-2] if self.spec_width < 1 else self.spec_width 426 | 427 | pow_spec = spec[..., 0] ** 2 + spec[..., 1] ** 2 428 | 429 | if spec_height != pow_spec.shape[-2] or spec_width != pow_spec.shape[-1]: 430 | pow_spec = F.interpolate( 431 | pow_spec, 432 | size=(spec_height, spec_width), 433 | mode='bilinear', 434 | align_corners=True 435 | ) 436 | 437 | pow_spec = torch.where(pow_spec > 0.0, pow_spec, torch.full_like(pow_spec, self.log10_eps)) 438 | 439 | pow_spec = pow_spec.view(x.shape[0], -1, 3, *pow_spec.shape[-2:]) 440 | 441 | return pow_spec 442 | 443 | 444 | class ESResNet(_ESResNet): 445 | 446 | loading_func = staticmethod(tv.models.resnet50) 447 | 448 | def __init__(self, 449 | n_fft: int = 256, 450 | hop_length: Optional[int] = None, 451 | win_length: Optional[int] = None, 452 | window: Optional[str] = None, 453 | normalized: bool = False, 454 | onesided: bool = True, 455 | spec_height: int = 224, 456 | spec_width: int = 224, 457 | num_classes: int = 1000, 458 | pretrained: bool = False, 459 | lock_pretrained: Optional[bool] = None): 460 | 461 | super(ESResNet, self).__init__( 462 | block=Bottleneck, 463 | layers=[3, 4, 6, 3], 464 | n_fft=n_fft, 465 | hop_length=hop_length, 466 | win_length=win_length, 467 | window=window, 468 | normalized=normalized, 469 | onesided=onesided, 470 | spec_height=spec_height, 471 | spec_width=spec_width, 472 | num_classes=num_classes, 473 | pretrained=pretrained, 474 | lock_pretrained=lock_pretrained 475 | ) 476 | 477 | 478 | class ESResNetAttention(_ESResNet, ResNetWithAttention): 479 | 480 | loading_func = staticmethod(tv.models.resnet50) 481 | 482 | def __init__(self, 483 | n_fft: int = 256, 484 | hop_length: Optional[int] = None, 485 | win_length: Optional[int] = None, 486 | window: Optional[str] = None, 487 | normalized: bool = False, 488 | onesided: bool = True, 489 | spec_height: int = 224, 490 | spec_width: int = 224, 491 | num_classes: int = 1000, 492 | pretrained: bool = False, 493 | lock_pretrained: Optional[bool] = None): 494 | 495 | super(ESResNetAttention, self).__init__( 496 | block=Bottleneck, 497 | layers=[3, 4, 6, 3], 498 | n_fft=n_fft, 499 | hop_length=hop_length, 500 | win_length=win_length, 501 | window=window, 502 | normalized=normalized, 503 | onesided=onesided, 504 | spec_height=spec_height, 505 | spec_width=spec_width, 506 | num_classes=num_classes, 507 | pretrained=pretrained, 508 | lock_pretrained=lock_pretrained 509 | ) 510 | 511 | def forward(self, 512 | x: torch.Tensor, 513 | y: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 514 | 515 | pow_spec = self.spectrogram(x) 516 | x_db = torch.log10(pow_spec).mul(10.0) 517 | 518 | outputs = list() 519 | for ch_idx in range(x_db.shape[1]): 520 | ch = x_db[:, ch_idx] 521 | out = super(_ESResNet, self).forward(ch) 522 | outputs.append(out) 523 | 524 | outputs = torch.stack(outputs, dim=-1).sum(dim=-1) 525 | y_pred = self.classifier(outputs) 526 | 527 | loss = None 528 | if y is not None: 529 | loss = self.loss_fn(y_pred, y).mean() 530 | 531 | return y_pred if loss is None else (y_pred, loss) 532 | -------------------------------------------------------------------------------- /ignite_trainer/_trainer.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import glob 4 | import json 5 | import time 6 | import tqdm 7 | import signal 8 | import argparse 9 | import numpy as np 10 | 11 | import torch 12 | import torch.utils.data 13 | import torch.nn.functional 14 | 15 | import torchvision as tv 16 | 17 | import ignite.engine as ieng 18 | import ignite.metrics as imet 19 | import ignite.handlers as ihan 20 | 21 | from typing import Any 22 | from typing import Dict 23 | from typing import List 24 | from typing import Type 25 | from typing import Union 26 | from typing import Optional 27 | 28 | from termcolor import colored 29 | 30 | from collections import defaultdict 31 | from collections.abc import Iterable 32 | 33 | from ignite_trainer import _utils 34 | from ignite_trainer import _visdom 35 | from ignite_trainer import _interfaces 36 | 37 | VISDOM_HOST = 'localhost' 38 | VISDOM_PORT = 8097 39 | VISDOM_ENV_PATH = 'visdom_env' 40 | BATCH_TRAIN = 128 41 | BATCH_TEST = 1024 42 | WORKERS_TRAIN = 0 43 | WORKERS_TEST = 0 44 | EPOCHS = 100 45 | LOG_INTERVAL = 50 46 | SAVED_MODELS_PATH = os.path.join(os.path.expanduser('~'), 'saved_models') 47 | 48 | 49 | def run(experiment_name: str, 50 | visdom_host: str, 51 | visdom_port: int, 52 | visdom_env_path: str, 53 | model_class: str, 54 | model_args: Dict[str, Any], 55 | optimizer_class: str, 56 | optimizer_args: Dict[str, Any], 57 | dataset_class: str, 58 | dataset_args: Dict[str, Any], 59 | batch_train: int, 60 | batch_test: int, 61 | workers_train: int, 62 | workers_test: int, 63 | transforms: List[Dict[str, Union[str, Dict[str, Any]]]], 64 | epochs: int, 65 | log_interval: int, 66 | saved_models_path: str, 67 | performance_metrics: Optional = None, 68 | scheduler_class: Optional[str] = None, 69 | scheduler_args: Optional[Dict[str, Any]] = None, 70 | model_suffix: Optional[str] = None, 71 | setup_suffix: Optional[str] = None, 72 | orig_stdout: Optional[io.TextIOBase] = None): 73 | 74 | with _utils.tqdm_stdout(orig_stdout) as orig_stdout: 75 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 76 | 77 | transforms_train = list() 78 | transforms_test = list() 79 | 80 | for idx, transform in enumerate(transforms): 81 | use_train = transform.get('train', True) 82 | use_test = transform.get('test', True) 83 | 84 | transform = _utils.load_class(transform['class'])(**transform['args']) 85 | 86 | if use_train: 87 | transforms_train.append(transform) 88 | if use_test: 89 | transforms_test.append(transform) 90 | 91 | transforms[idx]['train'] = use_train 92 | transforms[idx]['test'] = use_test 93 | 94 | transforms_train = tv.transforms.Compose(transforms_train) 95 | transforms_test = tv.transforms.Compose(transforms_test) 96 | 97 | Dataset: Type = _utils.load_class(dataset_class) 98 | 99 | train_loader, eval_loader = _utils.get_data_loaders( 100 | Dataset, 101 | dataset_args, 102 | batch_train, 103 | batch_test, 104 | workers_train, 105 | workers_test, 106 | transforms_train, 107 | transforms_test 108 | ) 109 | 110 | Network: Type = _utils.load_class(model_class) 111 | model: _interfaces.AbstractNet = Network(**model_args) 112 | model = model.to(device) 113 | 114 | Optimizer: Type = _utils.load_class(optimizer_class) 115 | optimizer: torch.optim.Optimizer = Optimizer(model.parameters(), **optimizer_args) 116 | 117 | if scheduler_class is not None: 118 | Scheduler: Type = _utils.load_class(scheduler_class) 119 | 120 | if scheduler_args is None: 121 | scheduler_args = dict() 122 | 123 | scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = Scheduler(optimizer, **scheduler_args) 124 | else: 125 | scheduler = None 126 | 127 | model_short_name = ''.join([c for c in Network.__name__ if c == c.upper()]) 128 | model_name = '{}{}'.format( 129 | model_short_name, 130 | '-{}'.format(model_suffix) if model_suffix is not None else '' 131 | ) 132 | visdom_env_name = '{}_{}_{}{}'.format( 133 | Dataset.__name__, 134 | experiment_name, 135 | model_name, 136 | '-{}'.format(setup_suffix) if setup_suffix is not None else '' 137 | ) 138 | 139 | vis, vis_pid = _visdom.get_visdom_instance(visdom_host, visdom_port, visdom_env_name, visdom_env_path) 140 | 141 | prog_bar_epochs = tqdm.tqdm(total=epochs, desc='Epochs', file=orig_stdout, dynamic_ncols=True, unit='epoch') 142 | prog_bar_iters = tqdm.tqdm(desc='Batches', file=orig_stdout, dynamic_ncols=True) 143 | 144 | tqdm.tqdm.write(f'\n{repr(model)}\n') 145 | tqdm.tqdm.write('Total number of parameters: {:.2f}M'.format(sum(p.numel() for p in model.parameters()) / 1e6)) 146 | 147 | def training_step(_: ieng.Engine, batch: _interfaces.TensorPair) -> torch.Tensor: 148 | model.train() 149 | 150 | optimizer.zero_grad() 151 | 152 | x, y = batch 153 | 154 | x = x.to(device) 155 | y = y.to(device) 156 | 157 | _, loss = model(x, y) 158 | 159 | loss.backward(retain_graph=False) 160 | optimizer.step(None) 161 | 162 | return loss.item() 163 | 164 | def eval_step(_: ieng.Engine, batch: _interfaces.TensorPair) -> _interfaces.TensorPair: 165 | model.eval() 166 | 167 | with torch.no_grad(): 168 | x, y = batch 169 | 170 | x = x.to(device) 171 | y = y.to(device) 172 | 173 | y_pred = model(x) 174 | 175 | return y_pred, y 176 | 177 | trainer = ieng.Engine(training_step) 178 | validator_train = ieng.Engine(eval_step) 179 | validator_eval = ieng.Engine(eval_step) 180 | 181 | # placeholder for summary window 182 | vis.text( 183 | text='', 184 | win=experiment_name, 185 | env=visdom_env_name, 186 | opts={'title': 'Summary', 'width': 940, 'height': 416}, 187 | append=vis.win_exists(experiment_name, visdom_env_name) 188 | ) 189 | 190 | default_metrics = { 191 | "Loss": { 192 | "window_name": None, 193 | "x_label": "#Epochs", 194 | "y_label": model.loss_fn_name, 195 | "width": 940, 196 | "height": 416, 197 | "lines": [ 198 | { 199 | "line_label": "SMA", 200 | "object": imet.RunningAverage(output_transform=lambda x: x), 201 | "test": False, 202 | "update_rate": "iteration" 203 | }, 204 | { 205 | "line_label": "Val.", 206 | "object": imet.Loss(model.loss_fn) 207 | } 208 | ] 209 | } 210 | } 211 | 212 | performance_metrics = {**default_metrics, **performance_metrics} 213 | checkpoint_metrics = list() 214 | 215 | for scope_name, scope in performance_metrics.items(): 216 | scope['window_name'] = scope.get('window_name', scope_name) or scope_name 217 | 218 | for line in scope['lines']: 219 | if 'object' not in line: 220 | line['object']: imet.Metric = _utils.load_class(line['class'])(**line['args']) 221 | 222 | line['metric_label'] = '{}: {}'.format(scope['window_name'], line['line_label']) 223 | 224 | line['update_rate'] = line.get('update_rate', 'epoch') 225 | line_suffixes = list() 226 | if line['update_rate'] == 'iteration': 227 | line['object'].attach(trainer, line['metric_label']) 228 | line['train'] = False 229 | line['test'] = False 230 | 231 | line_suffixes.append(' Train.') 232 | 233 | if line.get('train', True): 234 | line['object'].attach(validator_train, line['metric_label']) 235 | line_suffixes.append(' Train.') 236 | if line.get('test', True): 237 | line['object'].attach(validator_eval, line['metric_label']) 238 | line_suffixes.append(' Eval.') 239 | 240 | if line.get('is_checkpoint', False): 241 | checkpoint_metrics.append(line['metric_label']) 242 | 243 | for line_suffix in line_suffixes: 244 | _visdom.plot_line( 245 | vis=vis, 246 | window_name=scope['window_name'], 247 | env=visdom_env_name, 248 | line_label=line['line_label'] + line_suffix, 249 | x_label=scope['x_label'], 250 | y_label=scope['y_label'], 251 | width=scope['width'], 252 | height=scope['height'], 253 | draw_marker=(line['update_rate'] == 'epoch') 254 | ) 255 | 256 | if checkpoint_metrics: 257 | score_name = 'performance' 258 | 259 | def get_score(engine: ieng.Engine) -> float: 260 | current_mode = getattr(engine.state.dataloader.iterable.dataset, dataset_args['training']['key']) 261 | val_mode = dataset_args['training']['no'] 262 | 263 | score = 0.0 264 | if current_mode == val_mode: 265 | for metric_name in checkpoint_metrics: 266 | try: 267 | score += engine.state.metrics[metric_name] 268 | except KeyError: 269 | pass 270 | 271 | return score 272 | 273 | model_saver = ihan.ModelCheckpoint( 274 | os.path.join(saved_models_path, visdom_env_name), 275 | filename_prefix=visdom_env_name, 276 | score_name=score_name, 277 | score_function=get_score, 278 | n_saved=3, 279 | save_as_state_dict=True, 280 | require_empty=False, 281 | create_dir=True 282 | ) 283 | 284 | validator_eval.add_event_handler(ieng.Events.EPOCH_COMPLETED, model_saver, {model_name: model}) 285 | 286 | @trainer.on(ieng.Events.EPOCH_STARTED) 287 | def reset_progress_iterations(engine: ieng.Engine): 288 | prog_bar_iters.clear() 289 | prog_bar_iters.n = 0 290 | prog_bar_iters.last_print_n = 0 291 | prog_bar_iters.start_t = time.time() 292 | prog_bar_iters.last_print_t = time.time() 293 | prog_bar_iters.total = len(engine.state.dataloader) 294 | 295 | @trainer.on(ieng.Events.ITERATION_COMPLETED) 296 | def log_training(engine: ieng.Engine): 297 | prog_bar_iters.update(1) 298 | 299 | num_iter = (engine.state.iteration - 1) % len(train_loader) + 1 300 | 301 | early_stop = np.isnan(engine.state.output) or np.isinf(engine.state.output) 302 | 303 | if num_iter % log_interval == 0 or num_iter == len(train_loader) or early_stop: 304 | tqdm.tqdm.write( 305 | 'Epoch[{}] Iteration[{}/{}] Loss: {:.4f}'.format( 306 | engine.state.epoch, num_iter, len(train_loader), engine.state.output 307 | ) 308 | ) 309 | 310 | x_pos = engine.state.epoch + num_iter / len(train_loader) - 1 311 | for scope_name, scope in performance_metrics.items(): 312 | for line in scope['lines']: 313 | if line['update_rate'] == 'iteration': 314 | line_label = '{} Train.'.format(line['line_label']) 315 | line_value = engine.state.metrics[line['metric_label']] 316 | 317 | if engine.state.epoch > 1: 318 | _visdom.plot_line( 319 | vis=vis, 320 | window_name=scope['window_name'], 321 | env=visdom_env_name, 322 | line_label=line_label, 323 | x_label=scope['x_label'], 324 | y_label=scope['y_label'], 325 | x=np.full(1, x_pos), 326 | y=np.full(1, line_value) 327 | ) 328 | 329 | if early_stop: 330 | tqdm.tqdm.write(colored('Early stopping due to invalid loss value.', 'red')) 331 | trainer.terminate() 332 | 333 | def log_validation(engine: ieng.Engine, 334 | train: bool = True): 335 | 336 | if train: 337 | run_type = 'Train.' 338 | data_loader = train_loader 339 | validator = validator_train 340 | else: 341 | run_type = 'Eval.' 342 | data_loader = eval_loader 343 | validator = validator_eval 344 | 345 | prog_bar_validation = tqdm.tqdm( 346 | data_loader, 347 | desc=f'Validation {run_type}', 348 | file=orig_stdout, 349 | dynamic_ncols=True, 350 | leave=False 351 | ) 352 | validator.run(prog_bar_validation) 353 | prog_bar_validation.clear() 354 | prog_bar_validation.close() 355 | 356 | tqdm_info = [ 357 | 'Epoch: {}'.format(engine.state.epoch) 358 | ] 359 | for scope_name, scope in performance_metrics.items(): 360 | for line in scope['lines']: 361 | if line['update_rate'] == 'epoch': 362 | try: 363 | line_label = '{} {}'.format(line['line_label'], run_type) 364 | line_value = validator.state.metrics[line['metric_label']] 365 | 366 | _visdom.plot_line( 367 | vis=vis, 368 | window_name=scope['window_name'], 369 | env=visdom_env_name, 370 | line_label=line_label, 371 | x_label=scope['x_label'], 372 | y_label=scope['y_label'], 373 | x=np.full(1, engine.state.epoch), 374 | y=np.full(1, line_value), 375 | draw_marker=True 376 | ) 377 | 378 | tqdm_info.append('{}: {:.4f}'.format(line_label, line_value)) 379 | except KeyError: 380 | pass 381 | 382 | tqdm.tqdm.write('{} results - {}'.format(run_type, '; '.join(tqdm_info))) 383 | 384 | @trainer.on(ieng.Events.EPOCH_COMPLETED) 385 | def log_validation_train(engine: ieng.Engine): 386 | log_validation(engine, True) 387 | 388 | @trainer.on(ieng.Events.EPOCH_COMPLETED) 389 | def log_validation_eval(engine: ieng.Engine): 390 | log_validation(engine, False) 391 | 392 | if engine.state.epoch == 1: 393 | summary = _utils.build_summary_str( 394 | experiment_name=experiment_name, 395 | model_short_name=model_name, 396 | model_class=model_class, 397 | model_args=model_args, 398 | optimizer_class=optimizer_class, 399 | optimizer_args=optimizer_args, 400 | dataset_class=dataset_class, 401 | dataset_args=dataset_args, 402 | transforms=transforms, 403 | epochs=epochs, 404 | batch_train=batch_train, 405 | log_interval=log_interval, 406 | saved_models_path=saved_models_path, 407 | scheduler_class=scheduler_class, 408 | scheduler_args=scheduler_args 409 | ) 410 | _visdom.create_summary_window( 411 | vis=vis, 412 | visdom_env_name=visdom_env_name, 413 | experiment_name=experiment_name, 414 | summary=summary 415 | ) 416 | 417 | vis.save([visdom_env_name]) 418 | 419 | prog_bar_epochs.update(1) 420 | 421 | if scheduler is not None: 422 | scheduler.step(engine.state.epoch) 423 | 424 | trainer.run(train_loader, max_epochs=epochs) 425 | 426 | if vis_pid is not None: 427 | tqdm.tqdm.write('Stopping visdom') 428 | os.kill(vis_pid, signal.SIGTERM) 429 | 430 | del vis 431 | del train_loader 432 | del eval_loader 433 | 434 | prog_bar_iters.clear() 435 | prog_bar_iters.close() 436 | 437 | prog_bar_epochs.clear() 438 | prog_bar_epochs.close() 439 | 440 | tqdm.tqdm.write('\n') 441 | 442 | 443 | def main(): 444 | with _utils.tqdm_stdout() as orig_stdout: 445 | parser = argparse.ArgumentParser() 446 | 447 | parser.add_argument('-c', '--config', type=str, required=True) 448 | parser.add_argument('-H', '--visdom-host', type=str, required=False) 449 | parser.add_argument('-P', '--visdom-port', type=int, required=False) 450 | parser.add_argument('-E', '--visdom-env-path', type=str, required=False) 451 | parser.add_argument('-b', '--batch-train', type=int, required=False) 452 | parser.add_argument('-B', '--batch-test', type=int, required=False) 453 | parser.add_argument('-w', '--workers-train', type=int, required=False) 454 | parser.add_argument('-W', '--workers-test', type=int, required=False) 455 | parser.add_argument('-e', '--epochs', type=int, required=False) 456 | parser.add_argument('-L', '--log-interval', type=int, required=False) 457 | parser.add_argument('-M', '--saved-models-path', type=str, required=False) 458 | parser.add_argument('-R', '--random-seed', type=int, required=False) 459 | parser.add_argument('-s', '--suffix', type=str, required=False) 460 | 461 | args, unknown_args = parser.parse_known_args() 462 | 463 | if args.batch_test is None: 464 | args.batch_test = args.batch_train 465 | 466 | if args.random_seed is not None: 467 | args.suffix = '{}r-{}'.format( 468 | '{}_'.format(args.suffix) if args.suffix is not None else '', 469 | args.random_seed 470 | ) 471 | 472 | np.random.seed(args.random_seed) 473 | torch.random.manual_seed(args.random_seed) 474 | if torch.cuda.is_available(): 475 | torch.cuda.manual_seed(args.random_seed) 476 | torch.backends.cudnn.deterministic = True 477 | torch.backends.cudnn.benchmark = False 478 | 479 | configs_found = list(sorted(glob.glob(os.path.expanduser(args.config)))) 480 | prog_bar_exps = tqdm.tqdm( 481 | configs_found, 482 | desc='Experiments', 483 | unit='setup', 484 | file=orig_stdout, 485 | dynamic_ncols=True 486 | ) 487 | 488 | for config_path in prog_bar_exps: 489 | config = json.load(open(config_path)) 490 | 491 | if unknown_args: 492 | tqdm.tqdm.write('\nParsing additional arguments...') 493 | 494 | args_not_found = list() 495 | for arg in unknown_args: 496 | if arg.startswith('--'): 497 | keys = arg.strip('-').split('.') 498 | 499 | section = config 500 | found = True 501 | for key in keys: 502 | if key in section: 503 | section = section[key] 504 | else: 505 | found = False 506 | break 507 | 508 | if found: 509 | override_parser = argparse.ArgumentParser() 510 | 511 | section_nargs = None 512 | section_type = type(section) if section is not None else str 513 | 514 | if section_type is bool: 515 | if section_type is bool: 516 | def infer_bool(x: str) -> bool: 517 | return x.lower() not in ('0', 'false', 'no') 518 | 519 | section_type = infer_bool 520 | 521 | if isinstance(section, Iterable) and section_type is not str: 522 | section_nargs = '+' 523 | section_type = {type(value) for value in section} 524 | 525 | if len(section_type) == 1: 526 | section_type = section_type.pop() 527 | else: 528 | section_type = str 529 | 530 | override_parser.add_argument(arg, nargs=section_nargs, type=section_type) 531 | overridden_args, _ = override_parser.parse_known_args(unknown_args) 532 | overridden_args = vars(overridden_args) 533 | 534 | overridden_key = arg.strip('-') 535 | overriding_value = overridden_args[overridden_key] 536 | 537 | section = config 538 | old_value = None 539 | for i, key in enumerate(keys, 1): 540 | if i == len(keys): 541 | old_value = section[key] 542 | section[key] = overriding_value 543 | else: 544 | section = section[key] 545 | 546 | tqdm.tqdm.write( 547 | colored(f'Overriding "{overridden_key}": {old_value} -> {overriding_value}', 'magenta') 548 | ) 549 | else: 550 | args_not_found.append(arg) 551 | 552 | if args_not_found: 553 | tqdm.tqdm.write( 554 | colored( 555 | '\nThere are unrecognized arguments to override: {}'.format( 556 | ', '.join(args_not_found) 557 | ), 558 | 'red' 559 | ) 560 | ) 561 | 562 | config = defaultdict(None, config) 563 | 564 | experiment_name = config['Setup']['name'] 565 | 566 | visdom_host = _utils.arg_selector( 567 | args.visdom_host, config['Visdom']['host'], VISDOM_HOST 568 | ) 569 | visdom_port = int(_utils.arg_selector( 570 | args.visdom_port, config['Visdom']['port'], VISDOM_PORT 571 | )) 572 | visdom_env_path = _utils.arg_selector( 573 | args.visdom_env_path, config['Visdom']['env_path'], VISDOM_ENV_PATH 574 | ) 575 | batch_train = int(_utils.arg_selector( 576 | args.batch_train, config['Setup']['batch_train'], BATCH_TRAIN 577 | )) 578 | batch_test = int(_utils.arg_selector( 579 | args.batch_test, config['Setup']['batch_test'], BATCH_TEST 580 | )) 581 | workers_train = _utils.arg_selector( 582 | args.workers_train, config['Setup']['workers_train'], WORKERS_TRAIN 583 | ) 584 | workers_test = _utils.arg_selector( 585 | args.workers_test, config['Setup']['workers_test'], WORKERS_TEST 586 | ) 587 | epochs = _utils.arg_selector( 588 | args.epochs, config['Setup']['epochs'], EPOCHS 589 | ) 590 | log_interval = _utils.arg_selector( 591 | args.log_interval, config['Setup']['log_interval'], LOG_INTERVAL 592 | ) 593 | saved_models_path = _utils.arg_selector( 594 | args.saved_models_path, config['Setup']['saved_models_path'], SAVED_MODELS_PATH 595 | ) 596 | 597 | model_class = config['Model']['class'] 598 | model_args = config['Model']['args'] 599 | 600 | optimizer_class = config['Optimizer']['class'] 601 | optimizer_args = config['Optimizer']['args'] 602 | 603 | if 'Scheduler' in config: 604 | scheduler_class = config['Scheduler']['class'] 605 | scheduler_args = config['Scheduler']['args'] 606 | else: 607 | scheduler_class = None 608 | scheduler_args = None 609 | 610 | dataset_class = config['Dataset']['class'] 611 | dataset_args = config['Dataset']['args'] 612 | 613 | transforms = config['Transforms'] 614 | performance_metrics = config['Metrics'] 615 | 616 | tqdm.tqdm.write(f'\nStarting experiment "{experiment_name}"\n') 617 | 618 | run( 619 | experiment_name=experiment_name, 620 | visdom_host=visdom_host, 621 | visdom_port=visdom_port, 622 | visdom_env_path=visdom_env_path, 623 | model_class=model_class, 624 | model_args=model_args, 625 | optimizer_class=optimizer_class, 626 | optimizer_args=optimizer_args, 627 | dataset_class=dataset_class, 628 | dataset_args=dataset_args, 629 | batch_train=batch_train, 630 | batch_test=batch_test, 631 | workers_train=workers_train, 632 | workers_test=workers_test, 633 | transforms=transforms, 634 | epochs=epochs, 635 | log_interval=log_interval, 636 | saved_models_path=saved_models_path, 637 | performance_metrics=performance_metrics, 638 | scheduler_class=scheduler_class, 639 | scheduler_args=scheduler_args, 640 | model_suffix=config['Setup']['suffix'], 641 | setup_suffix=args.suffix, 642 | orig_stdout=orig_stdout 643 | ) 644 | 645 | prog_bar_exps.close() 646 | 647 | tqdm.tqdm.write('\n') 648 | --------------------------------------------------------------------------------