├── ignite_trainer
├── version.py
├── README.md
├── __init__.py
├── _interfaces.py
├── _utils.py
├── _visdom.py
└── _trainer.py
├── protocols
├── README.md
├── esc50
│ ├── adcnn5-esc50-cv1.json
│ └── esresnet-esc50-cv1.json
├── us8k
│ ├── adcnn5-us8k-cv1.json
│ ├── lmcnet-us8k-cv1.json
│ ├── esresnet-us8k-mono-cv1.json
│ └── esresnet-us8k-stereo-cv1.json
└── esc10
│ └── esresnet-esc10-cv1.json
├── main.py
├── requirements.txt
├── utils
├── __init__.py
├── lr_scheduler.py
├── transforms.py
├── datasets.py
└── features.py
├── reproduced
├── README.md
├── TFNet
│ └── README.md
├── lmcnet.py
└── adcnn.py
├── model
├── attention.py
└── esresnet.py
└── README.md
/ignite_trainer/version.py:
--------------------------------------------------------------------------------
1 | __version__ = '0.2.5b4'
2 |
--------------------------------------------------------------------------------
/protocols/README.md:
--------------------------------------------------------------------------------
1 | # Protocols
2 |
3 | Here are the JSON-files that describe configurations of experiments.
4 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3.7
2 |
3 | import ignite_trainer as it
4 |
5 | if __name__ == '__main__':
6 | it.main()
7 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | librosa==0.7.2
2 | numpy==1.18.1
3 | pandas==1.0.3
4 | pytorch-ignite==0.3.0
5 | scikit-learn==0.22.1
6 | scipy==1.4.1
7 | termcolor==1.1.0
8 | torch==1.4.0
9 | torchvision==0.5.0
10 | tqdm==4.43.0
11 | visdom==0.1.8.9
12 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from . import datasets
2 | from . import features
3 | from . import lr_scheduler
4 | from . import transforms
5 |
6 | __all__ = [
7 | 'datasets',
8 | 'features',
9 | 'lr_scheduler',
10 | 'transforms'
11 | ]
12 |
--------------------------------------------------------------------------------
/reproduced/README.md:
--------------------------------------------------------------------------------
1 | # Reproduced
2 |
3 | Here are the models that were reproduced:
4 |
5 | 1. LMCNet model which is a part of the [TSCNN-DS model](https://www.mdpi.com/1424-8220/19/7/1733/pdf).
6 | 2. [TFNet model](https://arxiv.org/abs/1912.06808)
7 | 3. [ADCNN-5 model](https://arxiv.org/abs/1908.11219) (excluded from the paper)
8 |
--------------------------------------------------------------------------------
/ignite_trainer/README.md:
--------------------------------------------------------------------------------
1 | # Ignite Trainer
2 |
3 | Ignite Trainer is a framework built on top of [PyTorch Ignite](https://github.com/pytorch/ignite) and [visdom](https://github.com/facebookresearch/visdom).
4 | It was developed to wrap training and logging of PyTorch models.
5 | The development is frozen due to the switch to [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning).
6 |
--------------------------------------------------------------------------------
/ignite_trainer/__init__.py:
--------------------------------------------------------------------------------
1 | import os as _os
2 | import sys as _sys
3 |
4 | from ignite_trainer.version import __version__
5 | from ._trainer import main, run
6 | from ._utils import load_class
7 | from ._interfaces import AbstractNet, AbstractTransform
8 |
9 | __all__ = [
10 | '__version__',
11 | 'main', 'run', 'load_class',
12 | 'AbstractNet', 'AbstractTransform'
13 | ]
14 |
15 | _sys.path.extend([_os.getcwd()])
16 |
--------------------------------------------------------------------------------
/reproduced/TFNet/README.md:
--------------------------------------------------------------------------------
1 | ## Reproduced: TFNet
2 |
3 | The TFNet model's results were reproduced using
4 | [temporarily available source code (inactive)](https://github.com/WangHelin1997/TFNet-for-Environmental-Sound-Classification)
5 | provided by the authors of the following [paper](https://arxiv.org/abs/1912.06808).
6 | The original repository was forked and is now available [here](https://github.com/AndreyGuzhov/TFNet-for-Environmental-Sound-Classification).
7 |
--------------------------------------------------------------------------------
/ignite_trainer/_interfaces.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import torch
3 |
4 | from typing import Tuple
5 | from typing import Union
6 | from typing import Callable
7 | from typing import Optional
8 |
9 |
10 | TensorPair = Tuple[torch.Tensor, torch.Tensor]
11 | TensorOrTwo = Union[torch.Tensor, TensorPair]
12 |
13 |
14 | class AbstractNet(abc.ABC, torch.nn.Module):
15 |
16 | @abc.abstractmethod
17 | def forward(self, x: torch.Tensor, y: Optional[torch.Tensor] = None) -> TensorOrTwo:
18 | pass
19 |
20 | @abc.abstractmethod
21 | def loss_fn(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
22 | pass
23 |
24 | @property
25 | @abc.abstractmethod
26 | def loss_fn_name(self) -> str:
27 | pass
28 |
29 |
30 | class AbstractTransform(abc.ABC, Callable[[torch.Tensor], torch.Tensor]):
31 |
32 | @abc.abstractmethod
33 | def __call__(self, x: torch.Tensor) -> torch.Tensor:
34 | pass
35 |
36 | def __repr__(self):
37 | return self.__class__.__name__ + '()'
38 |
--------------------------------------------------------------------------------
/model/attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from typing import Tuple
5 |
6 |
7 | class Attention2d(torch.nn.Module):
8 |
9 | def __init__(self,
10 | in_channels: int,
11 | out_channels: int,
12 | num_kernels: int,
13 | kernel_size: Tuple[int, int],
14 | padding_size: Tuple[int, int]):
15 |
16 | super(Attention2d, self).__init__()
17 |
18 | self.conv_depth = torch.nn.Conv2d(
19 | in_channels=in_channels,
20 | out_channels=in_channels * num_kernels,
21 | kernel_size=kernel_size,
22 | padding=padding_size,
23 | groups=in_channels
24 | )
25 | self.conv_point = torch.nn.Conv2d(
26 | in_channels=in_channels * num_kernels,
27 | out_channels=out_channels,
28 | kernel_size=(1, 1)
29 | )
30 | self.bn = torch.nn.BatchNorm2d(num_features=out_channels)
31 | self.activation = torch.nn.Sigmoid()
32 |
33 | def forward(self, x: torch.Tensor, size: torch.Size) -> torch.Tensor:
34 | x = F.adaptive_max_pool2d(x, size)
35 | x = self.conv_depth(x)
36 | x = self.conv_point(x)
37 | x = self.bn(x)
38 | x = self.activation(x)
39 |
40 | return x
41 |
--------------------------------------------------------------------------------
/utils/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class WarmUpStepLR(torch.optim.lr_scheduler._LRScheduler):
5 |
6 | def __init__(self,
7 | optimizer: torch.optim.Optimizer,
8 | cold_epochs: int,
9 | warm_epochs: int,
10 | step_size: int,
11 | gamma: float = 0.1,
12 | last_epoch: int = -1):
13 |
14 | self.cold_epochs = cold_epochs
15 | self.warm_epochs = warm_epochs
16 | self.step_size = step_size
17 | self.gamma = gamma
18 |
19 | super(WarmUpStepLR, self).__init__(optimizer=optimizer, last_epoch=last_epoch)
20 |
21 | def get_lr(self):
22 | if self.last_epoch < self.cold_epochs:
23 | return [base_lr * 0.1 for base_lr in self.base_lrs]
24 | elif self.last_epoch < self.cold_epochs + self.warm_epochs:
25 | return [
26 | base_lr * 0.1 + (1 + self.last_epoch - self.cold_epochs) * 0.9 * base_lr / self.warm_epochs
27 | for base_lr in self.base_lrs
28 | ]
29 | else:
30 | return [
31 | base_lr * self.gamma ** ((self.last_epoch - self.cold_epochs - self.warm_epochs) // self.step_size)
32 | for base_lr in self.base_lrs
33 | ]
34 |
35 |
36 | class WarmUpExponentialLR(WarmUpStepLR):
37 |
38 | def __init__(self,
39 | optimizer: torch.optim.Optimizer,
40 | cold_epochs: int,
41 | warm_epochs: int,
42 | gamma: float = 0.1,
43 | last_epoch: int = -1):
44 |
45 | self.cold_epochs = cold_epochs
46 | self.warm_epochs = warm_epochs
47 | self.step_size = 1
48 | self.gamma = gamma
49 |
50 | super(WarmUpStepLR, self).__init__(optimizer=optimizer, last_epoch=last_epoch)
--------------------------------------------------------------------------------
/protocols/esc50/adcnn5-esc50-cv1.json:
--------------------------------------------------------------------------------
1 | {
2 | "Visdom": {
3 | "host": null,
4 | "port": null,
5 | "env_path": null
6 | },
7 | "Setup": {
8 | "name": "MFCC",
9 | "suffix": "CV1",
10 | "batch_train": 64,
11 | "batch_test": 64,
12 | "workers_train": 0,
13 | "workers_test": 0,
14 | "epochs": 500,
15 | "log_interval": 5,
16 | "saved_models_path": null
17 | },
18 | "Model": {
19 | "class": "reproduced.adcnn.ADCNN5",
20 | "args": {
21 | "num_channels": 1,
22 | "n_fft": 1024,
23 | "hop_length": 512,
24 | "window": "blackmanharris",
25 | "num_classes": 50
26 | }
27 | },
28 | "Optimizer": {
29 | "class": "torch.optim.Adam",
30 | "args": {
31 | "lr": 1e-2,
32 | "betas": [0.9, 0.999],
33 | "eps": 1e-7,
34 | "weight_decay": 1e-4
35 | }
36 | },
37 | "Scheduler": {
38 | "class": "torch.optim.lr_scheduler.StepLR",
39 | "args": {
40 | "gamma": 0.1,
41 | "step_size": 100
42 | }
43 | },
44 | "Dataset": {
45 | "class": "utils.datasets.ESC50",
46 | "args": {
47 | "root": "/path/to/ESC50",
48 | "sample_rate": 32000,
49 | "fold": 1,
50 | "training": {"key": "train", "yes": true, "no": false}
51 | }
52 | },
53 | "Transforms": [
54 | {
55 | "class": "utils.transforms.ToTensor1D",
56 | "args": {}
57 | },
58 | {
59 | "class": "utils.transforms.RandomPadding",
60 | "args": {"out_len": 160000, "train": false}
61 | },
62 | {
63 | "class": "utils.transforms.RandomCrop",
64 | "args": {"out_len": 160000, "train": false}
65 | }
66 | ],
67 | "Metrics": {
68 | "Performance": {
69 | "window_name": null,
70 | "x_label": "#Epochs",
71 | "y_label": "Accuracy",
72 | "width": 1890,
73 | "height": 416,
74 | "lines": [
75 | {
76 | "line_label": "Val. Acc.",
77 | "class": "ignite.metrics.Accuracy",
78 | "args": {},
79 | "is_checkpoint": true
80 | }
81 | ]
82 | }
83 | }
84 | }
85 |
--------------------------------------------------------------------------------
/protocols/us8k/adcnn5-us8k-cv1.json:
--------------------------------------------------------------------------------
1 | {
2 | "Visdom": {
3 | "host": null,
4 | "port": null,
5 | "env_path": null
6 | },
7 | "Setup": {
8 | "name": "MFCC",
9 | "suffix": "CV1",
10 | "batch_train": 64,
11 | "batch_test": 64,
12 | "workers_train": 0,
13 | "workers_test": 0,
14 | "epochs": 500,
15 | "log_interval": 5,
16 | "saved_models_path": null
17 | },
18 | "Model": {
19 | "class": "reproduced.adcnn.ADCNN5",
20 | "args": {
21 | "num_channels": 1,
22 | "n_fft": 1024,
23 | "hop_length": 512,
24 | "window": "hann",
25 | "num_classes": 10
26 | }
27 | },
28 | "Optimizer": {
29 | "class": "torch.optim.Adam",
30 | "args": {
31 | "lr": 1e-2,
32 | "betas": [0.9, 0.999],
33 | "eps": 1e-7,
34 | "weight_decay": 1e-4
35 | }
36 | },
37 | "Scheduler": {
38 | "class": "torch.optim.lr_scheduler.StepLR",
39 | "args": {
40 | "gamma": 0.1,
41 | "step_size": 100
42 | }
43 | },
44 | "Dataset": {
45 | "class": "utils.datasets.UrbanSound8K",
46 | "args": {
47 | "root": "/path/to/UrbanSound8K",
48 | "sample_rate": 32000,
49 | "fold": 1,
50 | "random_split_seed": 42,
51 | "mono": true,
52 | "training": {"key": "train", "yes": true, "no": false}
53 | }
54 | },
55 | "Transforms": [
56 | {
57 | "class": "utils.transforms.ToTensor1D",
58 | "args": {}
59 | },
60 | {
61 | "class": "utils.transforms.RandomPadding",
62 | "args": {"out_len": 128000, "train": false}
63 | },
64 | {
65 | "class": "utils.transforms.RandomCrop",
66 | "args": {"out_len": 128000, "train": false}
67 | }
68 | ],
69 | "Metrics": {
70 | "Performance": {
71 | "window_name": null,
72 | "x_label": "#Epochs",
73 | "y_label": "Accuracy",
74 | "width": 1890,
75 | "height": 416,
76 | "lines": [
77 | {
78 | "line_label": "Val. Acc.",
79 | "class": "ignite.metrics.Accuracy",
80 | "args": {},
81 | "is_checkpoint": true
82 | }
83 | ]
84 | }
85 | }
86 | }
87 |
--------------------------------------------------------------------------------
/protocols/us8k/lmcnet-us8k-cv1.json:
--------------------------------------------------------------------------------
1 | {
2 | "Visdom": {
3 | "host": null,
4 | "port": null,
5 | "env_path": null
6 | },
7 | "Setup": {
8 | "name": "LMC",
9 | "suffix": "CV1",
10 | "batch_train": 32,
11 | "batch_test": 32,
12 | "workers_train": 0,
13 | "workers_test": 0,
14 | "epochs": 300,
15 | "log_interval": 10,
16 | "saved_models_path": null
17 | },
18 | "Model": {
19 | "class": "reproduced.lmcnet.LMCNet",
20 | "args": {
21 | "num_channels": 1,
22 | "num_classes": 10,
23 | "sample_rate": 22050,
24 | "norm": "inf",
25 | "n_fft": 8192,
26 | "hop_length": 2205,
27 | "win_length": 4410,
28 | "window": "hann",
29 | "n_mels": 60,
30 | "tuning": 0.0,
31 | "n_chroma": 7,
32 | "ctroct": 5.0,
33 | "octwidth": 2.0,
34 | "base_c": true,
35 | "freq": null,
36 | "fmin": 10.0,
37 | "fmax": null,
38 | "n_bands": 11,
39 | "quantile": 0.02,
40 | "linear": false
41 | }
42 | },
43 | "Optimizer": {
44 | "class": "torch.optim.Adam",
45 | "args": {
46 | "lr": 1e-3,
47 | "betas": [0.9, 0.999],
48 | "eps": 1e-8,
49 | "weight_decay": 1e-3
50 | }
51 | },
52 | "Dataset": {
53 | "class": "utils.datasets.UrbanSound8K",
54 | "args": {
55 | "root": "/path/to/UrbanSound8K",
56 | "sample_rate": 22050,
57 | "fold": 1,
58 | "random_split_seed": null,
59 | "mono": true,
60 | "training": {"key": "train", "yes": true, "no": false}
61 | }
62 | },
63 | "Transforms": [
64 | {
65 | "class": "utils.transforms.ToTensor1D",
66 | "args": {}
67 | },
68 | {
69 | "class": "utils.transforms.RandomPadding",
70 | "args": {"out_len": 88200, "train": false}
71 | },
72 | {
73 | "class": "utils.transforms.RandomCrop",
74 | "args": {"out_len": 88200, "train": false}
75 | }
76 | ],
77 | "Metrics": {
78 | "Performance": {
79 | "window_name": null,
80 | "x_label": "#Epochs",
81 | "y_label": "Accuracy",
82 | "width": 1890,
83 | "height": 416,
84 | "lines": [
85 | {
86 | "line_label": "Val. Acc.",
87 | "class": "ignite.metrics.Accuracy",
88 | "args": {},
89 | "is_checkpoint": true
90 | }
91 | ]
92 | }
93 | }
94 | }
95 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ESResNet
2 | ## Environmental Sound Classification Based on Visual Domain Models
3 |
4 | This repository contains implementation of the models described in the paper [arXiv:2004.07301](https://arxiv.org/abs/2004.07301) (submitted to ICPR 2020).
5 |
6 | ### Abstract
7 | Environmental Sound Classification (ESC) is an active research area in the audio domain and has seen a lot of progress in the past years. However, many of the existing approaches achieve high accuracy by relying on domain-specific features and architectures, making it harder to benefit from advances in other fields (e.g., the image domain). Additionally, some of the past successes have been attributed to a discrepancy of how results are evaluated (i.e., on unofficial splits of the UrbanSound8K (US8K) dataset), distorting the overall progression of the field.
8 | The contribution of this paper is twofold. First, we present a model that is inherently compatible with mono and stereo sound inputs. Our model is based on simple log-power Short-Time Fourier Transform (STFT) spectrograms and combines them with several well-known approaches from the image domain (i.e., ResNet, Siamese-like networks and attention). We investigate the influence of cross-domain pre-training, architectural changes, and evaluate our model on standard datasets. We find that our model out-performs all previously known approaches in a fair comparison by achieving accuracies of 97.0 % (ESC-10), 91.5 % (ESC-50) and 84.2 % / 85.4 % (US8K mono / stereo).
9 | Second, we provide a comprehensive overview of the actual state of the field, by differentiating several previously reported results on the US8K dataset between official or unofficial splits. For better reproducibility, our code (including any re-implementations) is made available.
10 |
11 | ### How to run the model
12 |
13 | The required Python version is >= 3.7.
14 |
15 | #### ESResNet
16 |
17 | ##### On the [ESC-10](https://github.com/karolpiczak/ESC-50) dataset
18 | python main.py --config protocols/esc10/esresnet-esc10-cv1.json --Dataset.args.root /path/to/ESC10
19 |
20 | ##### On the [ESC-50](https://github.com/karolpiczak/ESC-50) dataset
21 | python main.py --config protocols/esc50/esresnet-esc50-cv1.json --Dataset.args.root /path/to/ESC50
22 |
23 | ##### On the [UrbanSound8K](https://urbansounddataset.weebly.com/) dataset (stereo)
24 | python main.py --config protocols/us8k/esresnet-us8k-stereo-cv1.json --Dataset.args.root /path/to/UrbanSound8K
25 |
26 | #### Reproduced results
27 |
28 | ##### [LMCNet](https://www.mdpi.com/1424-8220/19/7/1733/pdf) on the [UrbanSound8K](https://urbansounddataset.weebly.com/) dataset
29 | python main.py --config protocols/us8k/lmcnet-us8k-cv1.json --Dataset.args.root /path/to/UrbanSound8K
30 |
--------------------------------------------------------------------------------
/protocols/esc10/esresnet-esc10-cv1.json:
--------------------------------------------------------------------------------
1 | {
2 | "Visdom": {
3 | "host": null,
4 | "port": null,
5 | "env_path": null
6 | },
7 | "Setup": {
8 | "name": "STFT",
9 | "suffix": "CV1",
10 | "batch_train": 16,
11 | "batch_test": 16,
12 | "workers_train": 4,
13 | "workers_test": 4,
14 | "epochs": 300,
15 | "log_interval": 10,
16 | "saved_models_path": null
17 | },
18 | "Model": {
19 | "class": "model.esresnet.ESResNet",
20 | "args": {
21 | "n_fft": 2048,
22 | "hop_length": 561,
23 | "win_length": 1654,
24 | "window": "blackmanharris",
25 | "normalized": true,
26 | "onesided": true,
27 | "spec_height": -1,
28 | "spec_width": -1,
29 | "num_classes": 10,
30 | "pretrained": true,
31 | "lock_pretrained": false
32 | }
33 | },
34 | "Optimizer": {
35 | "class": "torch.optim.Adam",
36 | "args": {
37 | "lr": 2.5e-4,
38 | "betas": [0.9, 0.999],
39 | "eps": 1e-8,
40 | "weight_decay": 5e-4
41 | }
42 | },
43 | "Scheduler": {
44 | "class": "utils.lr_scheduler.WarmUpExponentialLR",
45 | "args": {
46 | "gamma": 0.985,
47 | "cold_epochs": 5,
48 | "warm_epochs": 10
49 | }
50 | },
51 | "Dataset": {
52 | "class": "utils.datasets.ESC10",
53 | "args": {
54 | "root": "/path/to/ESC10",
55 | "sample_rate": 44100,
56 | "fold": 1,
57 | "training": {"key": "train", "yes": true, "no": false}
58 | }
59 | },
60 | "Transforms": [
61 | {
62 | "class": "utils.transforms.ToTensor1D",
63 | "args": {}
64 | },
65 | {
66 | "class": "utils.transforms.RandomFlip",
67 | "args": {"p": 0.5},
68 | "test": false
69 | },
70 | {
71 | "class": "utils.transforms.RandomScale",
72 | "args": {"max_scale": 1.25},
73 | "test": false
74 | },
75 | {
76 | "class": "utils.transforms.RandomPadding",
77 | "args": {"out_len": 220500},
78 | "test": false
79 | },
80 | {
81 | "class": "utils.transforms.RandomCrop",
82 | "args": {"out_len": 220500},
83 | "test": false
84 | },
85 | {
86 | "class": "utils.transforms.RandomPadding",
87 | "args": {"out_len": 220500, "train": false},
88 | "train": false
89 | },
90 | {
91 | "class": "utils.transforms.RandomCrop",
92 | "args": {"out_len": 220500, "train": false},
93 | "train": false
94 | }
95 | ],
96 | "Metrics": {
97 | "Performance": {
98 | "window_name": null,
99 | "x_label": "#Epochs",
100 | "y_label": "Accuracy",
101 | "width": 1890,
102 | "height": 416,
103 | "lines": [
104 | {
105 | "line_label": "Val. Acc.",
106 | "class": "ignite.metrics.Accuracy",
107 | "args": {},
108 | "is_checkpoint": true
109 | }
110 | ]
111 | }
112 | }
113 | }
114 |
--------------------------------------------------------------------------------
/protocols/esc50/esresnet-esc50-cv1.json:
--------------------------------------------------------------------------------
1 | {
2 | "Visdom": {
3 | "host": null,
4 | "port": null,
5 | "env_path": null
6 | },
7 | "Setup": {
8 | "name": "STFT",
9 | "suffix": "CV1",
10 | "batch_train": 16,
11 | "batch_test": 16,
12 | "workers_train": 4,
13 | "workers_test": 4,
14 | "epochs": 300,
15 | "log_interval": 10,
16 | "saved_models_path": null
17 | },
18 | "Model": {
19 | "class": "model.esresnet.ESResNet",
20 | "args": {
21 | "n_fft": 2048,
22 | "hop_length": 561,
23 | "win_length": 1654,
24 | "window": "blackmanharris",
25 | "normalized": true,
26 | "onesided": true,
27 | "spec_height": -1,
28 | "spec_width": -1,
29 | "num_classes": 50,
30 | "pretrained": true,
31 | "lock_pretrained": false
32 | }
33 | },
34 | "Optimizer": {
35 | "class": "torch.optim.Adam",
36 | "args": {
37 | "lr": 2.5e-4,
38 | "betas": [0.9, 0.999],
39 | "eps": 1e-8,
40 | "weight_decay": 5e-4
41 | }
42 | },
43 | "Scheduler": {
44 | "class": "utils.lr_scheduler.WarmUpExponentialLR",
45 | "args": {
46 | "gamma": 0.985,
47 | "cold_epochs": 5,
48 | "warm_epochs": 10
49 | }
50 | },
51 | "Dataset": {
52 | "class": "utils.datasets.ESC50",
53 | "args": {
54 | "root": "/path/to/ESC50",
55 | "sample_rate": 44100,
56 | "fold": 1,
57 | "training": {"key": "train", "yes": true, "no": false}
58 | }
59 | },
60 | "Transforms": [
61 | {
62 | "class": "utils.transforms.ToTensor1D",
63 | "args": {}
64 | },
65 | {
66 | "class": "utils.transforms.RandomFlip",
67 | "args": {"p": 0.5},
68 | "test": false
69 | },
70 | {
71 | "class": "utils.transforms.RandomScale",
72 | "args": {"max_scale": 1.25},
73 | "test": false
74 | },
75 | {
76 | "class": "utils.transforms.RandomPadding",
77 | "args": {"out_len": 220500},
78 | "test": false
79 | },
80 | {
81 | "class": "utils.transforms.RandomCrop",
82 | "args": {"out_len": 220500},
83 | "test": false
84 | },
85 | {
86 | "class": "utils.transforms.RandomPadding",
87 | "args": {"out_len": 220500, "train": false},
88 | "train": false
89 | },
90 | {
91 | "class": "utils.transforms.RandomCrop",
92 | "args": {"out_len": 220500, "train": false},
93 | "train": false
94 | }
95 | ],
96 | "Metrics": {
97 | "Performance": {
98 | "window_name": null,
99 | "x_label": "#Epochs",
100 | "y_label": "Accuracy",
101 | "width": 1890,
102 | "height": 416,
103 | "lines": [
104 | {
105 | "line_label": "Val. Acc.",
106 | "class": "ignite.metrics.Accuracy",
107 | "args": {},
108 | "is_checkpoint": true
109 | }
110 | ]
111 | }
112 | }
113 | }
114 |
--------------------------------------------------------------------------------
/protocols/us8k/esresnet-us8k-mono-cv1.json:
--------------------------------------------------------------------------------
1 | {
2 | "Visdom": {
3 | "host": null,
4 | "port": null,
5 | "env_path": null
6 | },
7 | "Setup": {
8 | "name": "STFT",
9 | "suffix": "CV1",
10 | "batch_train": 16,
11 | "batch_test": 16,
12 | "workers_train": 2,
13 | "workers_test": 2,
14 | "epochs": 300,
15 | "log_interval": 50,
16 | "saved_models_path": null
17 | },
18 | "Model": {
19 | "class": "model.esresnet.ESResNet",
20 | "args": {
21 | "n_fft": 2048,
22 | "hop_length": 561,
23 | "win_length": 1654,
24 | "window": "blackmanharris",
25 | "normalized": true,
26 | "onesided": true,
27 | "spec_height": -1,
28 | "spec_width": -1,
29 | "num_classes": 10,
30 | "pretrained": true,
31 | "lock_pretrained": false
32 | }
33 | },
34 | "Optimizer": {
35 | "class": "torch.optim.Adam",
36 | "args": {
37 | "lr": 2.5e-4,
38 | "betas": [0.9, 0.999],
39 | "eps": 1e-8,
40 | "weight_decay": 5e-4
41 | }
42 | },
43 | "Scheduler": {
44 | "class": "utils.lr_scheduler.WarmUpExponentialLR",
45 | "args": {
46 | "gamma": 0.985,
47 | "cold_epochs": 5,
48 | "warm_epochs": 10
49 | }
50 | },
51 | "Dataset": {
52 | "class": "utils.datasets.UrbanSound8K",
53 | "args": {
54 | "root": "/path/to/UrbanSound8K",
55 | "sample_rate": 44100,
56 | "fold": 1,
57 | "random_split_seed": null,
58 | "mono": true,
59 | "training": {"key": "train", "yes": true, "no": false}
60 | }
61 | },
62 | "Transforms": [
63 | {
64 | "class": "utils.transforms.ToTensor1D",
65 | "args": {}
66 | },
67 | {
68 | "class": "utils.transforms.RandomFlip",
69 | "args": {"p": 0.5},
70 | "test": false
71 | },
72 | {
73 | "class": "utils.transforms.RandomScale",
74 | "args": {"max_scale": 1.25},
75 | "test": false
76 | },
77 | {
78 | "class": "utils.transforms.RandomPadding",
79 | "args": {"out_len": 176400},
80 | "test": false
81 | },
82 | {
83 | "class": "utils.transforms.RandomCrop",
84 | "args": {"out_len": 176400},
85 | "test": false
86 | },
87 | {
88 | "class": "utils.transforms.RandomCrop",
89 | "args": {"out_len": 176400, "train": false},
90 | "train": false
91 | },
92 | {
93 | "class": "utils.transforms.RandomPadding",
94 | "args": {"out_len": 176400, "train": false},
95 | "train": false
96 | }
97 | ],
98 | "Metrics": {
99 | "Performance": {
100 | "window_name": null,
101 | "x_label": "#Epochs",
102 | "y_label": "Accuracy",
103 | "width": 1890,
104 | "height": 416,
105 | "lines": [
106 | {
107 | "line_label": "Val. Acc.",
108 | "class": "ignite.metrics.Accuracy",
109 | "args": {},
110 | "is_checkpoint": true
111 | }
112 | ]
113 | }
114 | }
115 | }
116 |
--------------------------------------------------------------------------------
/protocols/us8k/esresnet-us8k-stereo-cv1.json:
--------------------------------------------------------------------------------
1 | {
2 | "Visdom": {
3 | "host": null,
4 | "port": null,
5 | "env_path": null
6 | },
7 | "Setup": {
8 | "name": "STFT",
9 | "suffix": "CV1",
10 | "batch_train": 16,
11 | "batch_test": 16,
12 | "workers_train": 2,
13 | "workers_test": 2,
14 | "epochs": 300,
15 | "log_interval": 50,
16 | "saved_models_path": null
17 | },
18 | "Model": {
19 | "class": "model.esresnet.ESResNt",
20 | "args": {
21 | "n_fft": 2048,
22 | "hop_length": 561,
23 | "win_length": 1654,
24 | "window": "blackmanharris",
25 | "normalized": true,
26 | "onesided": true,
27 | "spec_height": -1,
28 | "spec_width": -1,
29 | "num_classes": 10,
30 | "pretrained": true,
31 | "lock_pretrained": false
32 | }
33 | },
34 | "Optimizer": {
35 | "class": "torch.optim.Adam",
36 | "args": {
37 | "lr": 2.5e-4,
38 | "betas": [0.9, 0.999],
39 | "eps": 1e-8,
40 | "weight_decay": 5e-4
41 | }
42 | },
43 | "Scheduler": {
44 | "class": "utils.lr_scheduler.WarmUpExponentialLR",
45 | "args": {
46 | "gamma": 0.985,
47 | "cold_epochs": 5,
48 | "warm_epochs": 10
49 | }
50 | },
51 | "Dataset": {
52 | "class": "utils.datasets.UrbanSound8K",
53 | "args": {
54 | "root": "/path/to/UrbanSound8K",
55 | "sample_rate": 44100,
56 | "fold": 1,
57 | "random_split_seed": null,
58 | "mono": false,
59 | "training": {"key": "train", "yes": true, "no": false}
60 | }
61 | },
62 | "Transforms": [
63 | {
64 | "class": "utils.transforms.ToTensor1D",
65 | "args": {}
66 | },
67 | {
68 | "class": "utils.transforms.RandomFlip",
69 | "args": {"p": 0.5},
70 | "test": false
71 | },
72 | {
73 | "class": "utils.transforms.RandomScale",
74 | "args": {"max_scale": 1.25},
75 | "test": false
76 | },
77 | {
78 | "class": "utils.transforms.RandomPadding",
79 | "args": {"out_len": 176400},
80 | "test": false
81 | },
82 | {
83 | "class": "utils.transforms.RandomCrop",
84 | "args": {"out_len": 176400},
85 | "test": false
86 | },
87 | {
88 | "class": "utils.transforms.RandomCrop",
89 | "args": {"out_len": 176400, "train": false},
90 | "train": false
91 | },
92 | {
93 | "class": "utils.transforms.RandomPadding",
94 | "args": {"out_len": 176400, "train": false},
95 | "train": false
96 | }
97 | ],
98 | "Metrics": {
99 | "Performance": {
100 | "window_name": null,
101 | "x_label": "#Epochs",
102 | "y_label": "Accuracy",
103 | "width": 1890,
104 | "height": 416,
105 | "lines": [
106 | {
107 | "line_label": "Val. Acc.",
108 | "class": "ignite.metrics.Accuracy",
109 | "args": {},
110 | "is_checkpoint": true
111 | }
112 | ]
113 | }
114 | }
115 | }
116 |
--------------------------------------------------------------------------------
/utils/transforms.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 | import torchvision as tv
5 |
6 | import ignite_trainer as it
7 |
8 |
9 | def scale(old_value, old_min, old_max, new_min, new_max):
10 | old_range = (old_max - old_min)
11 | new_range = (new_max - new_min)
12 | new_value = (((old_value - old_min) * new_range) / old_range) + new_min
13 |
14 | return new_value
15 |
16 |
17 | class ToTensor1D(tv.transforms.ToTensor):
18 |
19 | def __call__(self, tensor: np.ndarray):
20 | tensor_2d = super(ToTensor1D, self).__call__(tensor[..., np.newaxis])
21 |
22 | return tensor_2d.squeeze_(0)
23 |
24 |
25 | class RandomFlip(it.AbstractTransform):
26 |
27 | def __init__(self, p: float = 0.5):
28 | super(RandomFlip, self).__init__()
29 |
30 | self.p = p
31 |
32 | def __call__(self, x: torch.Tensor) -> torch.Tensor:
33 | if x.dim() > 2:
34 | flip_mask = torch.rand(x.shape[0], device=x.device) <= self.p
35 | x[flip_mask] = x[flip_mask].flip(-1)
36 | else:
37 | if torch.rand(1) <= self.p:
38 | x = x.flip(0)
39 |
40 | return x
41 |
42 |
43 | class RandomScale(it.AbstractTransform):
44 |
45 | def __init__(self, max_scale: float = 1.25):
46 | super(RandomScale, self).__init__()
47 |
48 | self.max_scale = max_scale
49 |
50 | @staticmethod
51 | def random_scale(max_scale: float, signal: torch.Tensor) -> torch.Tensor:
52 | scaling = np.power(max_scale, np.random.uniform(-1, 1))
53 | output_size = int(signal.shape[-1] * scaling)
54 | ref = torch.arange(output_size, device=signal.device, dtype=signal.dtype).div_(scaling)
55 |
56 | ref1 = ref.clone().type(torch.int64)
57 | ref2 = torch.min(ref1 + 1, torch.full_like(ref1, signal.shape[-1] - 1, dtype=torch.int64))
58 | r = ref - ref1.type(ref.type())
59 | scaled_signal = signal[..., ref1] * (1 - r) + signal[..., ref2] * r
60 |
61 | return scaled_signal
62 |
63 | def __call__(self, x: torch.Tensor) -> torch.Tensor:
64 | return self.random_scale(self.max_scale, x)
65 |
66 |
67 | class RandomCrop(it.AbstractTransform):
68 |
69 | def __init__(self, out_len: int = 44100, train: bool = True):
70 | super(RandomCrop, self).__init__()
71 |
72 | self.out_len = out_len
73 | self.train = train
74 |
75 | def random_crop(self, signal: torch.Tensor) -> torch.Tensor:
76 | if self.train:
77 | left = np.random.randint(0, signal.shape[-1] - self.out_len)
78 | else:
79 | left = int(round(0.5 * (signal.shape[-1] - self.out_len)))
80 |
81 | orig_std = signal.float().std() * 0.5
82 | output = signal[..., left:left + self.out_len]
83 |
84 | out_std = output.float().std()
85 | if out_std < orig_std:
86 | output = signal[..., :self.out_len]
87 |
88 | new_out_std = output.float().std()
89 | if orig_std > new_out_std > out_std:
90 | output = signal[..., -self.out_len:]
91 |
92 | return output
93 |
94 | def __call__(self, x: torch.Tensor) -> torch.Tensor:
95 | return self.random_crop(x) if x.shape[-1] > self.out_len else x
96 |
97 |
98 | class RandomPadding(it.AbstractTransform):
99 |
100 | def __init__(self, out_len: int = 88200, train: bool = True):
101 | super(RandomPadding, self).__init__()
102 |
103 | self.out_len = out_len
104 | self.train = train
105 |
106 | def random_pad(self, signal: torch.Tensor) -> torch.Tensor:
107 | if self.train:
108 | left = np.random.randint(0, self.out_len - signal.shape[-1])
109 | else:
110 | left = int(round(0.5 * (self.out_len - signal.shape[-1])))
111 |
112 | right = self.out_len - (left + signal.shape[-1])
113 |
114 | pad_value_left = signal[..., 0].float().mean().to(signal.dtype)
115 | pad_value_right = signal[..., -1].float().mean().to(signal.dtype)
116 | output = torch.cat((
117 | torch.zeros(signal.shape[:-1] + (left,), dtype=signal.dtype, device=signal.device).fill_(pad_value_left),
118 | signal,
119 | torch.zeros(signal.shape[:-1] + (right,), dtype=signal.dtype, device=signal.device).fill_(pad_value_right)
120 | ), dim=-1)
121 |
122 | return output
123 |
124 | def __call__(self, x: torch.Tensor) -> torch.Tensor:
125 | return self.random_pad(x) if x.shape[-1] < self.out_len else x
126 |
--------------------------------------------------------------------------------
/reproduced/lmcnet.py:
--------------------------------------------------------------------------------
1 | import scipy.signal as sps
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | import ignite_trainer as it
7 |
8 | from utils import features
9 |
10 | from typing import Tuple
11 | from typing import Union
12 | from typing import Optional
13 |
14 |
15 | class LMCNet(it.AbstractNet):
16 |
17 | def __init__(self,
18 | num_channels: int = 1,
19 | num_classes: int = 10,
20 | sample_rate: int = 44100,
21 | norm: Union[str, float] = 'inf',
22 | n_fft: int = 2048,
23 | hop_length: int = 1024,
24 | win_length: int = 2048,
25 | window: str = 'hann',
26 | n_mels: int = 128,
27 | tuning: float = 0.0,
28 | n_chroma: int = 12,
29 | ctroct: float = 5.0,
30 | octwidth: float = 2.0,
31 | base_c: bool = True,
32 | freq: Optional[torch.Tensor] = None,
33 | fmin: float = 200.0,
34 | fmax: Optional[float] = None,
35 | n_bands: int = 6,
36 | quantile: float = 0.02,
37 | linear: bool = False):
38 |
39 | super(LMCNet, self).__init__()
40 |
41 | norm = float(norm)
42 |
43 | self.lmc = features.LMC(
44 | sample_rate=sample_rate,
45 | norm=norm,
46 | n_fft=n_fft,
47 | n_mels=n_mels,
48 | tuning=tuning,
49 | n_chroma=n_chroma,
50 | ctroct=ctroct,
51 | octwidth=octwidth,
52 | base_c=base_c,
53 | freq=freq,
54 | fmin=fmin,
55 | fmax=fmax,
56 | n_bands=n_bands,
57 | quantile=quantile,
58 | linear=linear
59 | )
60 |
61 | self.n_fft = n_fft
62 | self.win_length = win_length
63 | self.hop_length = hop_length
64 |
65 | window_buf = sps.get_window(window, win_length, False)
66 | self.register_buffer('window', torch.from_numpy(window_buf).to(torch.get_default_dtype()))
67 |
68 | self.conv1 = torch.nn.Conv2d(
69 | in_channels=num_channels,
70 | out_channels=32,
71 | kernel_size=(3, 3),
72 | stride=(1, 1),
73 | padding=(1, 1)
74 | )
75 | self.bn1 = torch.nn.BatchNorm2d(num_features=self.conv1.out_channels)
76 | self.activation1 = torch.nn.ReLU()
77 |
78 | self.conv2 = torch.nn.Conv2d(
79 | in_channels=self.conv1.out_channels,
80 | out_channels=self.conv1.out_channels,
81 | kernel_size=(3, 3),
82 | stride=(1, 1),
83 | padding=(1, 1)
84 | )
85 | self.bn2 = torch.nn.BatchNorm2d(num_features=self.conv2.out_channels)
86 | self.activation2 = torch.nn.ReLU()
87 | self.pool2 = torch.nn.MaxPool2d(kernel_size=(2, 2), padding=(1, 1))
88 |
89 | self.conv3 = torch.nn.Conv2d(
90 | in_channels=self.conv2.out_channels,
91 | out_channels=64,
92 | kernel_size=(3, 3),
93 | stride=(1, 1),
94 | padding=(1, 1)
95 | )
96 | self.bn3 = torch.nn.BatchNorm2d(num_features=self.conv3.out_channels)
97 | self.activation3 = torch.nn.ReLU()
98 |
99 | self.conv4 = torch.nn.Conv2d(
100 | in_channels=self.conv3.out_channels,
101 | out_channels=64,
102 | kernel_size=(3, 3),
103 | stride=(1, 1),
104 | padding=(1, 1)
105 | )
106 | self.bn4 = torch.nn.BatchNorm2d(num_features=self.conv4.out_channels)
107 | self.activation4 = torch.nn.ReLU()
108 | self.pool4 = torch.nn.MaxPool2d(kernel_size=(2, 2), padding=(1, 1))
109 |
110 | self.fc1 = torch.nn.Linear(in_features=11 * 22 * self.conv4.out_channels, out_features=1024)
111 | self.activation5 = torch.nn.Sigmoid()
112 |
113 | self.fc2 = torch.nn.Linear(in_features=self.fc1.out_features, out_features=num_classes)
114 |
115 | def spectrogram(self, x: torch.Tensor) -> torch.Tensor:
116 | spectrogram = torch.stft(
117 | x.view(x.shape[0], -1),
118 | n_fft=self.n_fft,
119 | hop_length=self.hop_length,
120 | win_length=self.win_length,
121 | window=self.window,
122 | normalized=True
123 | )
124 | spectrogram = spectrogram[..., 0] ** 2 + spectrogram[..., 1] ** 2
125 | spectrogram = spectrogram.view(x.shape[0], -1, *spectrogram.shape[1:])
126 | spectrogram = torch.where(spectrogram == 0.0, spectrogram + 1e-10, spectrogram)
127 |
128 | return spectrogram
129 |
130 | def forward(self,
131 | x: torch.Tensor,
132 | y: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
133 |
134 | x = self.spectrogram(x)
135 | x = self.lmc(x)
136 |
137 | x = self.conv1(x)
138 | x = self.bn1(x)
139 | x = self.activation1(x)
140 |
141 | x = F.dropout2d(x, p=0.5, training=self.training)
142 |
143 | x = self.conv2(x)
144 | x = self.bn2(x)
145 | x = self.activation2(x)
146 | x = self.pool2(x)
147 |
148 | x = self.conv3(x)
149 | x = self.bn3(x)
150 | x = self.activation3(x)
151 |
152 | x = F.dropout2d(x, p=0.5, training=self.training)
153 |
154 | x = self.conv4(x)
155 | x = self.bn4(x)
156 | x = self.activation4(x)
157 | x = self.pool4(x)
158 |
159 | x = x.view(x.shape[0], -1)
160 |
161 | x = F.dropout(x, p=0.5, training=self.training)
162 |
163 | x = self.fc1(x)
164 | x = self.activation5(x)
165 | y_pred = self.fc2(x)
166 |
167 | loss = None
168 | if y is not None:
169 | loss = self.loss_fn(y_pred, y).mean()
170 |
171 | return y_pred if loss is None else (y_pred, loss)
172 |
173 | def loss_fn(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
174 | loss_pred = F.cross_entropy(y_pred, y)
175 |
176 | return loss_pred
177 |
178 | @property
179 | def loss_fn_name(self) -> str:
180 | return 'Cross Entropy'
181 |
--------------------------------------------------------------------------------
/ignite_trainer/_utils.py:
--------------------------------------------------------------------------------
1 | import io
2 | import sys
3 | import json
4 | import tqdm
5 | import datetime
6 | import importlib
7 | import contextlib
8 |
9 | import numpy as np
10 |
11 | import torch
12 |
13 | from collections import OrderedDict
14 |
15 | from typing import Any
16 | from typing import Dict
17 | from typing import List
18 | from typing import Type
19 | from typing import Tuple
20 | from typing import Union
21 | from typing import Callable
22 | from typing import Optional
23 |
24 |
25 | @contextlib.contextmanager
26 | def tqdm_stdout(orig_stdout: Optional[io.TextIOBase] = None):
27 |
28 | class DummyFile(object):
29 | file = None
30 |
31 | def __init__(self, file):
32 | self.file = file
33 |
34 | def write(self, x):
35 | if len(x.rstrip()) > 0:
36 | tqdm.tqdm.write(x, file=self.file)
37 |
38 | def flush(self):
39 | return getattr(self.file, 'flush', lambda: None)()
40 |
41 | orig_out_err = sys.stdout, sys.stderr
42 |
43 | try:
44 | if orig_stdout is None:
45 | sys.stdout, sys.stderr = map(DummyFile, orig_out_err)
46 | yield orig_out_err[0]
47 | else:
48 | yield orig_stdout
49 | except Exception as exc:
50 | raise exc
51 | finally:
52 | sys.stdout, sys.stderr = orig_out_err
53 |
54 |
55 | def load_class(package_name: str, class_name: Optional[str] = None) -> Type:
56 | if class_name is None:
57 | package_name, class_name = package_name.rsplit('.', 1)
58 |
59 | importlib.invalidate_caches()
60 |
61 | package = importlib.import_module(package_name)
62 | cls = getattr(package, class_name)
63 |
64 | return cls
65 |
66 |
67 | def arg_selector(arg_cmd: Optional[Any], arg_conf: Optional[Any], arg_const: Any) -> Any:
68 | if arg_cmd is not None:
69 | return arg_cmd
70 | else:
71 | if arg_conf is not None:
72 | return arg_conf
73 | else:
74 | return arg_const
75 |
76 |
77 | def get_data_loaders(Dataset: Type,
78 | dataset_args: Dict[str, Any],
79 | batch_train: int = 64,
80 | batch_test: int = 1024,
81 | workers_train: int = 0,
82 | workers_test: int = 0,
83 | transforms_train: Optional[Callable[
84 | [Union[np.ndarray, torch.Tensor]], Union[np.ndarray, torch.Tensor]
85 | ]] = None,
86 | transforms_test: Optional[Callable[
87 | [Union[np.ndarray, torch.Tensor]], Union[np.ndarray, torch.Tensor]
88 | ]] = None) -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]:
89 |
90 | dataset_mode_train = {dataset_args['training']['key']: dataset_args['training']['yes']}
91 | dataset_mode_test = {dataset_args['training']['key']: dataset_args['training']['no']}
92 |
93 | dataset_args_train = {**{k: v for k, v in dataset_args.items() if k != 'training'}, **dataset_mode_train}
94 | dataset_args_test = {**{k: v for k, v in dataset_args.items() if k != 'training'}, **dataset_mode_test}
95 |
96 | train_loader = torch.utils.data.DataLoader(
97 | Dataset(**{**dataset_args_train, **{'transform': transforms_train}}),
98 | batch_size=batch_train,
99 | shuffle=True,
100 | num_workers=workers_train,
101 | pin_memory=True
102 | )
103 | eval_loader = torch.utils.data.DataLoader(
104 | Dataset(**{**dataset_args_test, **{'transform': transforms_test}}),
105 | batch_size=batch_test,
106 | num_workers=workers_test,
107 | pin_memory=True
108 | )
109 |
110 | return train_loader, eval_loader
111 |
112 |
113 | def build_summary_str(experiment_name: str,
114 | model_short_name: str,
115 | model_class: str,
116 | model_args: Dict[str, Any],
117 | optimizer_class: str,
118 | optimizer_args: Dict[str, Any],
119 | dataset_class: str,
120 | dataset_args: Dict[str, Any],
121 | transforms: List[Dict[str, Union[str, Dict[str, Any]]]],
122 | epochs: int,
123 | batch_train: int,
124 | log_interval: int,
125 | saved_models_path: str,
126 | scheduler_class: Optional[str] = None,
127 | scheduler_args: Optional[Dict[str, Any]] = None) -> str:
128 |
129 | setup_title = '{}-{}'.format(experiment_name, model_short_name)
130 |
131 | summary_window_text = '
'
138 | summary_window_text += ''
139 | summary_window_text += '
'.format(setup_title)
140 |
141 | summary = OrderedDict({
142 | 'Date started': datetime.datetime.now().strftime('%Y-%m-%d @ %H:%M:%S'),
143 | 'Model': OrderedDict({model_class: model_args}),
144 | 'Setup': OrderedDict({
145 | 'epochs': epochs,
146 | 'batch': batch_train,
147 | 'log_interval': log_interval,
148 | 'saved_models_path': saved_models_path
149 | }),
150 | 'Optimizer': OrderedDict({optimizer_class: optimizer_args}),
151 | 'Dataset': OrderedDict({dataset_class: dataset_args}),
152 | 'Transforms': OrderedDict({
153 | 'Training': OrderedDict({tr['class']: tr['args'] for tr in transforms if tr['train']}),
154 | 'Validation': OrderedDict({tr['class']: tr['args'] for tr in transforms if tr['test']})
155 | })
156 | })
157 | if scheduler_class is not None:
158 | summary['Scheduler'] = {scheduler_class: scheduler_args}
159 | summary_window_text += '{}'.format(
160 | json.dumps(summary, indent=2)
161 | )
162 |
163 | summary_window_text += ''
164 | summary_window_text += '
'
165 |
166 | return summary_window_text
167 |
--------------------------------------------------------------------------------
/ignite_trainer/_visdom.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import json
4 | import time
5 | import tqdm
6 | import socket
7 | import subprocess
8 | import numpy as np
9 |
10 | import visdom
11 |
12 | from typing import Tuple
13 | from typing import Optional
14 |
15 |
16 | def calc_ytick_range(vis: visdom.Visdom, window_name: str, env: Optional[str] = None) -> Tuple[float, float]:
17 | lower_bound, upper_bound = -1.0, 1.0
18 |
19 | stats = vis.get_window_data(win=window_name, env=env)
20 |
21 | if stats:
22 | stats = json.loads(stats)
23 |
24 | stats = [np.array(item['y']) for item in stats['content']['data']]
25 | stats = [item[item != np.array([None])].astype(np.float16) for item in stats]
26 |
27 | if stats:
28 | q25s = np.array([np.quantile(item, 0.25) for item in stats if len(item) > 0])
29 | q75s = np.array([np.quantile(item, 0.75) for item in stats if len(item) > 0])
30 |
31 | if q25s.shape == q75s.shape and len(q25s) > 0:
32 | iqrs = q75s - q25s
33 |
34 | lower_bounds = q25s - 1.5 * iqrs
35 | upper_bounds = q75s + 1.5 * iqrs
36 |
37 | stats_sanitized = list()
38 | idx = 0
39 | for item in stats:
40 | if len(item) > 0:
41 | item_sanitized = item[(item >= lower_bounds[idx]) & (item <= upper_bounds[idx])]
42 | stats_sanitized.append(item_sanitized)
43 |
44 | idx += 1
45 |
46 | stats_sanitized = np.array(stats_sanitized)
47 |
48 | q25_sanitized = np.array([np.quantile(item, 0.25) for item in stats_sanitized])
49 | q75_sanitized = np.array([np.quantile(item, 0.75) for item in stats_sanitized])
50 |
51 | iqr_sanitized = np.sum(q75_sanitized - q25_sanitized)
52 | lower_bound = np.min(q25_sanitized) - 1.5 * iqr_sanitized
53 | upper_bound = np.max(q75_sanitized) + 1.5 * iqr_sanitized
54 |
55 | return lower_bound, upper_bound
56 |
57 |
58 | def plot_line(vis: visdom.Visdom,
59 | window_name: str,
60 | env: Optional[str] = None,
61 | line_label: Optional[str] = None,
62 | x: Optional[np.ndarray] = None,
63 | y: Optional[np.ndarray] = None,
64 | x_label: Optional[str] = None,
65 | y_label: Optional[str] = None,
66 | width: int = 576,
67 | height: int = 416,
68 | draw_marker: bool = False) -> str:
69 |
70 | empty_call = not vis.win_exists(window_name)
71 |
72 | if empty_call and (x is not None or y is not None):
73 | return window_name
74 |
75 | if x is None:
76 | x = np.ones(1)
77 | empty_call = empty_call & True
78 |
79 | if y is None:
80 | y = np.full(1, np.nan)
81 | empty_call = empty_call & True
82 |
83 | if x.shape != y.shape:
84 | x = np.ones_like(y)
85 |
86 | opts = {
87 | 'showlegend': True,
88 | 'markers': draw_marker,
89 | 'markersize': 5,
90 | }
91 |
92 | if empty_call:
93 | opts['title'] = window_name
94 | opts['width'] = width
95 | opts['height'] = height
96 |
97 | window_name = vis.line(
98 | X=x,
99 | Y=y,
100 | win=window_name,
101 | env=env,
102 | update='append',
103 | name=line_label,
104 | opts=opts
105 | )
106 |
107 | xtickmin, xtickmax = 0.0, np.max(x) * 1.05
108 | ytickmin, ytickmax = calc_ytick_range(vis, window_name, env)
109 |
110 | opts = {
111 | 'showlegend': True,
112 | 'xtickmin': xtickmin,
113 | 'xtickmax': xtickmax,
114 | 'ytickmin': ytickmin,
115 | 'ytickmax': ytickmax,
116 | 'xlabel': x_label,
117 | 'ylabel': y_label
118 | }
119 |
120 | window_name = vis.update_window_opts(win=window_name, opts=opts, env=env)
121 |
122 | return window_name
123 |
124 |
125 | # TODO: implement remove experiment callback
126 |
127 |
128 | def create_summary_window(vis: visdom.Visdom,
129 | visdom_env_name: str,
130 | experiment_name: str,
131 | summary: str) -> str:
132 |
133 | return vis.text(
134 | text=summary,
135 | win=experiment_name,
136 | env=visdom_env_name,
137 | opts={'title': 'Summary', 'width': 576, 'height': 416},
138 | append=vis.win_exists(experiment_name, visdom_env_name)
139 | )
140 |
141 |
142 | def connection_is_alive(host: str, port: int) -> bool:
143 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
144 | try:
145 | sock.connect((host, port))
146 | sock.shutdown(socket.SHUT_RDWR)
147 |
148 | return True
149 | except socket.error:
150 | return False
151 |
152 |
153 | def get_visdom_instance(host: str = 'localhost',
154 | port: int = 8097,
155 | env_name: str = 'main',
156 | env_path: str = 'visdom_env') -> Tuple[visdom.Visdom, Optional[int]]:
157 |
158 | vis_pid = None
159 |
160 | if not connection_is_alive(host, port):
161 | if any(host.strip('/').endswith(lh) for lh in ['127.0.0.1', 'localhost']):
162 | os.makedirs(env_path, exist_ok=True)
163 |
164 | tqdm.tqdm.write('Starting visdom on port {}'.format(port), end='')
165 |
166 | vis_args = [
167 | sys.executable,
168 | '-m', 'visdom.server',
169 | '-port', str(port),
170 | '-env_path', os.path.join(os.getcwd(), env_path)
171 | ]
172 | vis_proc = subprocess.Popen(vis_args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
173 | time.sleep(2.0)
174 |
175 | vis_pid = vis_proc.pid
176 | tqdm.tqdm.write('PID -> {}'.format(vis_pid))
177 |
178 | trials_left = 5
179 | while not connection_is_alive(host, port):
180 | time.sleep(1.0)
181 |
182 | tqdm.tqdm.write('Trying to connect ({} left)...'.format(trials_left))
183 |
184 | trials_left -= 1
185 | if trials_left < 1:
186 | raise RuntimeError('Visdom server is not running. Please run "python -m visdom.server".')
187 |
188 | vis = visdom.Visdom(
189 | server='http://{}'.format(host),
190 | port=port,
191 | env=env_name
192 | )
193 |
194 | return vis, vis_pid
195 |
--------------------------------------------------------------------------------
/reproduced/adcnn.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | import ignite_trainer as it
7 |
8 | from utils import transforms, features
9 |
10 | from typing import Tuple
11 | from typing import Union
12 | from typing import Optional
13 |
14 |
15 | class Block(torch.nn.Module):
16 |
17 | def __init__(self,
18 | in_channels: int,
19 | out_channels: int,
20 | kernel_size: Tuple[int, int],
21 | pooling_size: Tuple[int, int]):
22 |
23 | super(Block, self).__init__()
24 |
25 | self.conv1 = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size)
26 | self.conv2 = torch.nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size)
27 | self.conv1x1 = torch.nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=(1, 1))
28 | self.bn = torch.nn.BatchNorm2d(num_features=out_channels)
29 | self.activation = torch.nn.LeakyReLU()
30 | self.pooling = torch.nn.MaxPool2d(kernel_size=pooling_size)
31 |
32 | def forward(self, x: torch.Tensor) -> torch.Tensor:
33 | x = self.conv1(x)
34 | x = self.conv2(x)
35 | x = self.conv1x1(x)
36 | x = self.bn(x)
37 | x = self.activation(x)
38 | x = self.pooling(x)
39 |
40 | return x
41 |
42 |
43 | class Attention(torch.nn.Module):
44 |
45 | def __init__(self,
46 | in_channels: int,
47 | out_channels: int,
48 | kernel_size: Tuple[int, int],
49 | pooling_size: Tuple[int, int]):
50 |
51 | super(Attention, self).__init__()
52 |
53 | self.pool = torch.nn.MaxPool2d(kernel_size=pooling_size)
54 | self.conv_depth = torch.nn.Conv2d(
55 | in_channels=in_channels,
56 | out_channels=out_channels,
57 | kernel_size=kernel_size,
58 | groups=in_channels
59 | )
60 | self.conv_point = torch.nn.Conv2d(
61 | in_channels=out_channels,
62 | out_channels=out_channels,
63 | kernel_size=(1, 1)
64 | )
65 | self.bn = torch.nn.BatchNorm2d(num_features=out_channels)
66 | self.activation = torch.nn.ReLU()
67 |
68 | def forward(self, x: torch.Tensor) -> torch.Tensor:
69 | x = self.pool(x)
70 | x = self.conv_depth(x)
71 | x = self.conv_point(x)
72 | x = self.bn(x)
73 | x = self.activation(x)
74 |
75 | return x
76 |
77 |
78 | class DCNN5(it.AbstractNet):
79 |
80 | def __init__(self,
81 | num_channels: int = 1,
82 | sample_rate: int = 32000,
83 | n_fft: int = 256,
84 | hop_length: Optional[int] = None,
85 | window: Optional[str] = None,
86 | num_classes: int = 10):
87 |
88 | super(DCNN5, self).__init__()
89 |
90 | self.num_channels = num_channels
91 | self.num_classes = num_classes
92 |
93 | if hop_length is None:
94 | hop_length = int(math.floor(n_fft / 4))
95 |
96 | if window is None:
97 | window = 'boxcar'
98 |
99 | self.log10_eps = 1e-18
100 |
101 | self.mfcc = features.MFCC(
102 | sample_rate=sample_rate,
103 | n_mfcc=128,
104 | n_fft=n_fft,
105 | hop_length=hop_length,
106 | window=window
107 | )
108 |
109 | self.block1 = Block(self.num_channels, 32, (3, 1), (2, 1))
110 | self.block2 = Block(32, 32, (1, 5), (1, 4))
111 | self.block3 = Block(32, 64, (3, 1), (2, 1))
112 | self.block4 = Block(64, 64, (1, 5), (1, 4))
113 | self.block5 = Block(64, 128, (3, 5), (1, 1))
114 | self.max_pool = torch.nn.MaxPool2d(kernel_size=(2, 4))
115 |
116 | self.drop1 = torch.nn.Dropout(p=0.25)
117 | self.fc1 = torch.nn.Linear(in_features=128 * 12 * 2, out_features=256)
118 | self.fc2 = torch.nn.Linear(in_features=self.fc1.out_features, out_features=self.num_classes)
119 |
120 | self.activation = torch.nn.LeakyReLU()
121 |
122 | self.l2_lambda = 0.1
123 |
124 | def forward(self,
125 | x: torch.Tensor,
126 | y: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
127 |
128 | x = self.mfcc(x)
129 | x = transforms.scale(
130 | x,
131 | x.min(dim=-1, keepdim=True).values.min(dim=-2, keepdim=True).values.min(dim=-3, keepdim=True).values,
132 | x.max(dim=-1, keepdim=True).values.max(dim=-2, keepdim=True).values.max(dim=-3, keepdim=True).values,
133 | 0.0,
134 | 1.0
135 | )
136 |
137 | x = self.block1(x)
138 | x = self.block2(x)
139 | x = self.block3(x)
140 | x = self.block4(x)
141 | x = self.max_pool(self.block5(x))
142 |
143 | x = x.view(x.shape[0], -1)
144 | x = self.drop1(x)
145 |
146 | x = self.fc1(x)
147 | x = self.activation(x)
148 |
149 | y_pred = self.fc2(x)
150 |
151 | loss = None
152 | if y is not None:
153 | loss = self.loss_fn(y_pred, y).sum()
154 |
155 | return y_pred if loss is None else (y_pred, loss)
156 |
157 | def loss_fn(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
158 | loss_pred = F.cross_entropy(y_pred, y)
159 |
160 | loss_l2 = 0.0
161 | loss_l2_params = list(self.fc1.parameters())
162 | for p in loss_l2_params:
163 | loss_l2 = p.norm(2) + loss_l2
164 |
165 | loss_pred = loss_pred + self.l2_lambda * loss_l2
166 |
167 | return loss_pred
168 |
169 | @property
170 | def loss_fn_name(self) -> str:
171 | return 'Cross Entropy'
172 |
173 |
174 | class ADCNN5(DCNN5):
175 |
176 | def __init__(self,
177 | num_channels: int = 1,
178 | n_fft: int = 1024,
179 | hop_length: Optional[int] = None,
180 | window: Optional[str] = None,
181 | num_classes: int = 10):
182 |
183 | super(ADCNN5, self).__init__(
184 | num_channels=num_channels,
185 | n_fft=n_fft,
186 | hop_length=hop_length,
187 | window=window,
188 | num_classes=num_classes
189 | )
190 |
191 | self.attn1 = Attention(self.num_channels, 32, (3, 1), (2, 1))
192 | self.attn2 = Attention(32, 32, (1, 3), (1, 4))
193 | self.attn3 = Attention(32, 64, (3, 1), (2, 1))
194 | self.attn4 = Attention(64, 64, (1, 3), (1, 4))
195 | self.attn5 = Attention(64, 128, (3, 3), (2, 4))
196 | self.attn5.pool = torch.nn.Identity()
197 | self.attn5 = torch.nn.Sequential(
198 | self.attn5,
199 | torch.nn.AdaptiveMaxPool2d(output_size=(12, 2))
200 | )
201 |
202 | def forward(self,
203 | x: torch.Tensor,
204 | y: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
205 |
206 | x = self.mfcc(x)
207 | x = transforms.scale(
208 | x,
209 | x.min(dim=-1, keepdim=True).values.min(dim=-2, keepdim=True).values.min(dim=-3, keepdim=True).values,
210 | x.max(dim=-1, keepdim=True).values.max(dim=-2, keepdim=True).values.max(dim=-3, keepdim=True).values,
211 | 0.0,
212 | 1.0
213 | )
214 |
215 | x = self.attn1(x) * self.block1(x)
216 | x = self.attn2(x) * self.block2(x)
217 | x = self.attn3(x) * self.block3(x)
218 | x = self.attn4(x) * self.block4(x)
219 | x = self.attn5(x) * self.max_pool(self.block5(x))
220 |
221 | x = x.view(x.shape[0], -1)
222 | x = self.drop1(x)
223 |
224 | x = self.fc1(x)
225 | x = self.activation(x)
226 |
227 | y_pred = self.fc2(x)
228 |
229 | loss = None
230 | if y is not None:
231 | loss = self.loss_fn(y_pred, y).sum()
232 |
233 | return y_pred if loss is None else (y_pred, loss)
234 |
--------------------------------------------------------------------------------
/utils/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import warnings
3 | import multiprocessing as mp
4 |
5 | import numpy as np
6 | import pandas as pd
7 | import sklearn.model_selection as skms
8 |
9 | import tqdm
10 | import librosa
11 |
12 | import torch.utils.data as td
13 |
14 | from utils import transforms
15 |
16 | from typing import Tuple
17 | from typing import Optional
18 |
19 |
20 | class ESC50(td.Dataset):
21 |
22 | def __init__(self,
23 | root: str,
24 | sample_rate: int = 22050,
25 | train: bool = True,
26 | fold: Optional[int] = None,
27 | transform=None,
28 | target_transform=None):
29 |
30 | super(ESC50, self).__init__()
31 |
32 | self.sample_rate = sample_rate
33 |
34 | meta = self.load_meta(os.path.join(root, 'meta', 'esc50.csv'))
35 |
36 | if fold is None:
37 | fold = 5
38 |
39 | self.folds_to_load = set(meta['fold'])
40 |
41 | if fold not in self.folds_to_load:
42 | raise ValueError(f'fold {fold} does not exist')
43 |
44 | self.train = train
45 | self.transform = transform
46 |
47 | if self.train:
48 | self.folds_to_load -= {fold}
49 | else:
50 | self.folds_to_load -= self.folds_to_load - {fold}
51 |
52 | self.data = dict()
53 | self.load_data(meta, os.path.join(root, 'audio'))
54 | self.indices = list(self.data.keys())
55 |
56 | self.target_transform = target_transform
57 |
58 | @staticmethod
59 | def load_meta(path_to_csv: str) -> pd.DataFrame:
60 | meta = pd.read_csv(path_to_csv)
61 |
62 | return meta
63 |
64 | @staticmethod
65 | def _load_worker(idx: int, filename: str, sample_rate: Optional[int] = None) -> Tuple[int, int, np.ndarray]:
66 | wav, sample_rate = librosa.load(filename, sr=sample_rate, mono=True)
67 |
68 | if wav.ndim == 1:
69 | wav = wav[:, np.newaxis]
70 |
71 | if np.abs(wav.max()) > 1.0:
72 | wav = transforms.scale(wav, wav.min(), wav.max(), -1.0, 1.0)
73 |
74 | wav = wav.T * 32768.0
75 |
76 | return idx, sample_rate, wav.astype(np.float32)
77 |
78 | def load_data(self, meta: pd.DataFrame, base_path: str):
79 | items_to_load = dict()
80 |
81 | for idx, row in meta.iterrows():
82 | if row['fold'] in self.folds_to_load:
83 | items_to_load[idx] = os.path.join(base_path, row['filename']), self.sample_rate
84 |
85 | items_to_load = [(idx, path, sample_rate) for idx, (path, sample_rate) in items_to_load.items()]
86 |
87 | warnings.filterwarnings('ignore')
88 | with mp.Pool(processes=mp.cpu_count()) as pool:
89 | chunksize = int(np.ceil(len(items_to_load) / pool._processes)) or 1
90 | tqdm.tqdm.write(f'Loading {self.__class__.__name__} (train={self.train})')
91 | for idx, sample_rate, wav in pool.starmap(
92 | func=self._load_worker,
93 | iterable=items_to_load,
94 | chunksize=chunksize
95 | ):
96 | row = meta.loc[idx]
97 |
98 | self.data[idx] = {
99 | 'audio': wav,
100 | 'sample_rate': sample_rate,
101 | 'target': row['target'],
102 | 'fold': row['fold'],
103 | 'esc10': row['esc10']
104 | }
105 |
106 | def __getitem__(self, index: int) -> Tuple[np.ndarray, int]:
107 | if not (0 <= index < len(self)):
108 | raise IndexError
109 |
110 | audio: np.ndarray = self.data[self.indices[index]]['audio']
111 | target: int = self.data[self.indices[index]]['target']
112 |
113 | if self.transform is not None:
114 | audio = self.transform(audio)
115 | if self.target_transform is not None:
116 | target = self.target_transform(target)
117 |
118 | return audio, target
119 |
120 | def __len__(self) -> int:
121 | return len(self.indices)
122 |
123 |
124 | class ESC10(ESC50):
125 |
126 | def __init__(self,
127 | root: str,
128 | sample_rate: int = 22050,
129 | train: bool = True,
130 | fold: Optional[int] = None,
131 | transform=None,
132 | target_transform=None):
133 |
134 | super(ESC10, self).__init__(
135 | root=root,
136 | sample_rate=sample_rate,
137 | train=train,
138 | fold=fold,
139 | transform=transform,
140 | target_transform=target_transform
141 | )
142 |
143 | self.classes = {
144 | old_target: new_target
145 | for new_target, old_target
146 | in enumerate({item['target'] for item in self.data.values()})
147 | }
148 |
149 | @staticmethod
150 | def load_meta(path_to_csv: str) -> pd.DataFrame:
151 | meta = ESC50.load_meta(path_to_csv)
152 | meta.drop(index=meta[~meta['esc10']].index, inplace=True)
153 |
154 | return meta
155 |
156 | def __getitem__(self, index: int) -> Tuple[np.ndarray, int]:
157 | audio, target = super(ESC10, self).__getitem__(index)
158 |
159 | target = self.classes[target]
160 |
161 | return audio, target
162 |
163 |
164 | class UrbanSound8K(td.Dataset):
165 |
166 | def __init__(self,
167 | root: str,
168 | sample_rate: int = 22050,
169 | train: bool = True,
170 | fold: Optional[int] = None,
171 | random_split_seed: Optional[int] = None,
172 | mono: bool = False,
173 | transform=None,
174 | target_transform=None):
175 |
176 | super(UrbanSound8K, self).__init__()
177 |
178 | self.root = root
179 | self.sample_rate = sample_rate
180 | self.train = train
181 |
182 | if fold is None:
183 | fold = 1
184 |
185 | if not (1 <= fold <= 10):
186 | raise ValueError(f'Expected fold in range [1, 10], got {fold}')
187 |
188 | self.fold = fold
189 | self.folds_to_load = set(range(1, 11))
190 |
191 | if self.fold not in self.folds_to_load:
192 | raise ValueError(f'fold {fold} does not exist')
193 |
194 | if self.train:
195 | # if in training mode, keep all but test fold
196 | self.folds_to_load -= {self.fold}
197 | else:
198 | # if in evaluation mode, keep the test samples only
199 | self.folds_to_load -= self.folds_to_load - {self.fold}
200 |
201 | self.random_split_seed = random_split_seed
202 | self.mono = mono
203 |
204 | self.transform = transform
205 | self.target_transform = target_transform
206 |
207 | self.data = dict()
208 | self.indices = dict()
209 | self.load_data()
210 |
211 | @staticmethod
212 | def _load_worker(fn: str, path_to_file: str, sample_rate: int, mono: bool = False) -> Tuple[str, int, np.ndarray]:
213 | wav, sample_rate = librosa.load(path_to_file, sr=sample_rate, mono=mono)
214 |
215 | if wav.ndim == 1:
216 | wav = wav[np.newaxis, :]
217 |
218 | if not mono:
219 | wav = np.concatenate((wav, wav), axis=0)
220 |
221 | wav = wav.T
222 | wav = wav[:sample_rate * 4]
223 |
224 | if np.abs(wav.max()) > 1.0:
225 | wav = transforms.scale(wav, wav.min(), wav.max(), -1.0, 1.0)
226 |
227 | wav = transforms.scale(wav, wav.min(), wav.max(), -32768.0, 32767.0).T
228 |
229 | return fn, sample_rate, wav.astype(np.float32)
230 |
231 | def load_data(self):
232 | # read metadata
233 | meta = pd.read_csv(
234 | os.path.join(self.root, 'metadata', 'UrbanSound8K.csv'),
235 | sep=',',
236 | index_col='slice_file_name'
237 | )
238 |
239 | for row_idx, (fn, row) in enumerate(meta.iterrows()):
240 | path = os.path.join(self.root, 'audio', 'fold{}'.format(row['fold']), fn)
241 | self.data[fn] = path, self.sample_rate, self.mono
242 |
243 | # by default, the official split from the metadata is used
244 | files_to_load = list()
245 | # if the random seed is not None, the random split is used
246 | if self.random_split_seed is not None:
247 | # given an integer random seed
248 | skf = skms.StratifiedKFold(n_splits=10, shuffle=True, random_state=self.random_split_seed)
249 |
250 | # split the US8K samples into 10 folds
251 | for fold_idx, (train_ids, test_ids) in enumerate(skf.split(
252 | np.zeros(len(meta)), meta['classID'].values.astype(int)
253 | ), 1):
254 | # if this is the fold we want to load, add the corresponding files to the list
255 | if fold_idx == self.fold:
256 | ids = train_ids if self.train else test_ids
257 | filenames = meta.iloc[ids].index
258 | files_to_load.extend(filenames)
259 | break
260 | else:
261 | # if the random seed is None, use the official split
262 | for fn, row in meta.iterrows():
263 | if int(row['fold']) in self.folds_to_load:
264 | files_to_load.append(fn)
265 |
266 | self.data = {fn: vals for fn, vals in self.data.items() if fn in files_to_load}
267 | self.indices = {idx: fn for idx, fn in enumerate(self.data)}
268 |
269 | warnings.filterwarnings('ignore')
270 | with mp.Pool(processes=mp.cpu_count()) as pool:
271 | chunksize = int(np.ceil(len(meta) / pool._processes)) or 1
272 |
273 | tqdm.tqdm.write(f'Loading {self.__class__.__name__} (train={self.train})')
274 |
275 | for fn, sample_rate, wav in pool.starmap(
276 | func=self._load_worker,
277 | iterable=[(fn, path, sr, mono) for fn, (path, sr, mono) in self.data.items()],
278 | chunksize=chunksize
279 | ):
280 | self.data[fn] = {
281 | 'audio': wav,
282 | 'sample_rate': sample_rate,
283 | 'target': meta.loc[fn, 'classID']
284 | }
285 |
286 | def __getitem__(self, index: int) -> Tuple[np.ndarray, int]:
287 | if not (0 <= index < len(self)):
288 | raise IndexError
289 |
290 | audio: np.ndarray = self.data[self.indices[index]]['audio']
291 | target: int = self.data[self.indices[index]]['target']
292 |
293 | if self.transform is not None:
294 | audio = self.transform(audio)
295 | if self.target_transform is not None:
296 | target = self.target_transform(target)
297 |
298 | return audio, target
299 |
300 | def __len__(self) -> int:
301 | return len(self.data)
302 |
--------------------------------------------------------------------------------
/utils/features.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.fft as spf
3 | import scipy.signal as sps
4 |
5 | import librosa
6 |
7 | import torch
8 | import torch.nn.functional as F
9 |
10 | from utils import transforms
11 |
12 | from typing import Optional
13 |
14 |
15 | def fft_frequencies(sample_rate: int = 22050, n_fft: int = 2048) -> torch.Tensor:
16 | return torch.linspace(0, sample_rate * 0.5, int(1 + n_fft // 2))
17 |
18 |
19 | def power_to_db(spectrogram: torch.Tensor, ref: float = 1.0, amin: float = 1e-10, top_db: float = 80.0) -> torch.Tensor:
20 | log_spec = 10.0 * torch.log10(torch.max(torch.full_like(spectrogram, amin), spectrogram))
21 | log_spec -= 10.0 * torch.log10(torch.full_like(spectrogram, max(amin, ref)))
22 |
23 | log_spec = torch.max(
24 | log_spec,
25 | log_spec.max(dim=-1, keepdim=True).values.max(dim=-2, keepdim=True).values - top_db
26 | )
27 |
28 | return log_spec
29 |
30 |
31 | class MFCC(torch.nn.Module):
32 |
33 | def __init__(self,
34 | sample_rate: int = 22050,
35 | n_mfcc: int = 128,
36 | n_fft: int = 1024,
37 | hop_length: int = 512,
38 | window: str = 'hann'):
39 |
40 | super(MFCC, self).__init__()
41 |
42 | mel_filterbank = librosa.filters.mel(
43 | sr=sample_rate,
44 | n_fft=n_fft,
45 | n_mels=n_mfcc
46 | )
47 | mel_filterbank = torch.from_numpy(mel_filterbank).to(torch.get_default_dtype())
48 | self.register_buffer('mel', mel_filterbank)
49 |
50 | dct_buf = spf.dct(np.eye(n_mfcc), type=2, norm='ortho').T
51 | dct_buf = torch.from_numpy(dct_buf).to(torch.get_default_dtype())
52 | self.register_buffer('dct_mat', dct_buf)
53 |
54 | window_buffer: torch.Tensor = torch.from_numpy(
55 | sps.get_window(window=window, Nx=n_fft, fftbins=True)
56 | ).to(torch.get_default_dtype())
57 | self.register_buffer('window', window_buffer)
58 |
59 | self.sample_rate = sample_rate
60 | self.n_fft = n_fft
61 | self.n_mfcc = n_mfcc
62 | self.hop_length = hop_length
63 |
64 | def dct2(self, x):
65 | x_dct = self.dct_mat.view(1, *self.dct_mat.shape) @ x
66 |
67 | return x_dct
68 |
69 | def forward(self, x: torch.Tensor) -> torch.Tensor:
70 | spec = torch.stft(
71 | x.view(-1, x.shape[-1]),
72 | n_fft=self.n_fft,
73 | hop_length=self.hop_length,
74 | win_length=self.n_fft,
75 | window=self.window,
76 | normalized=True
77 | )
78 |
79 | power_spec = spec[..., 0] ** 2 + spec[..., 1] ** 2
80 | log_power_spec = 10 * torch.log10(power_spec.add(1e-18))
81 |
82 | mel_spec = self.mel.view(1, *self.mel.shape) @ log_power_spec
83 | mfcc = self.dct2(mel_spec)
84 | mfcc = mfcc.view(x.shape[0], 1, *mfcc.shape[-2:])
85 |
86 | return mfcc
87 |
88 |
89 | class Chroma(torch.nn.Module):
90 |
91 | def __init__(self,
92 | sample_rate: int = 22050,
93 | norm: float = float('inf'),
94 | n_fft: int = 2048,
95 | tuning: float = 0.0,
96 | n_chroma: int = 12,
97 | ctroct: float = 5.0,
98 | octwidth: float = 2.0,
99 | base_c: bool = True):
100 |
101 | super(Chroma, self).__init__()
102 |
103 | chroma_fb_buf = librosa.filters.chroma(
104 | sr=sample_rate,
105 | n_fft=n_fft,
106 | n_chroma=n_chroma,
107 | tuning=tuning,
108 | ctroct=ctroct,
109 | octwidth=octwidth,
110 | norm=norm,
111 | base_c=base_c
112 | )
113 | self.register_buffer('chroma_fb', torch.from_numpy(chroma_fb_buf).to(torch.get_default_dtype()))
114 |
115 | self.norm = norm
116 |
117 | def forward(self, spectrogram: torch.Tensor) -> torch.Tensor:
118 | chroma = self.chroma_fb @ spectrogram
119 | chroma = chroma / torch.norm(chroma, p=self.norm, dim=-2, keepdim=True)
120 |
121 | return chroma
122 |
123 |
124 | class Tonnetz(Chroma):
125 |
126 | def __init__(self,
127 | sample_rate: int = 22050,
128 | norm: float = float('inf'),
129 | n_fft: int = 2048,
130 | tuning: float = 0.0,
131 | n_chroma: int = 12,
132 | ctroct: float = 5.0,
133 | octwidth: float = 2.0,
134 | base_c: bool = True):
135 |
136 | super(Tonnetz, self).__init__(
137 | sample_rate=sample_rate,
138 | norm=norm,
139 | n_fft=n_fft,
140 | tuning=tuning,
141 | n_chroma=n_chroma,
142 | ctroct=ctroct,
143 | octwidth=octwidth,
144 | base_c=base_c
145 | )
146 |
147 | # Generate Transformation matrix
148 | dim_map = np.linspace(0, 12, n_chroma, endpoint=False)
149 |
150 | scale = np.asarray([7. / 6, 7. / 6,
151 | 3. / 2, 3. / 2,
152 | 2. / 3, 2. / 3])
153 |
154 | V = np.multiply.outer(scale, dim_map)
155 |
156 | # Even rows compute sin()
157 | V[::2] -= 0.5
158 |
159 | R = np.array([1, 1, # Fifths
160 | 1, 1, # Minor
161 | 0.5, 0.5]) # Major
162 |
163 | phi_buf = R[:, np.newaxis] * np.cos(np.pi * V)
164 |
165 | self.register_buffer('phi', torch.from_numpy(phi_buf).to(torch.get_default_dtype()))
166 |
167 | def forward(self, spectrogram: torch.Tensor) -> torch.Tensor:
168 | chroma = super(Tonnetz, self).forward(spectrogram)
169 | chroma = chroma / torch.norm(chroma, p=1, dim=-2, keepdim=True)
170 | tonnetz = self.phi @ chroma
171 |
172 | return tonnetz
173 |
174 |
175 | class SpectralContrast(torch.nn.Module):
176 |
177 | def __init__(self,
178 | sample_rate: int = 22050,
179 | n_fft: int = 2048,
180 | freq: Optional[torch.Tensor] = None,
181 | fmin: float = 200.0,
182 | n_bands: int = 6,
183 | quantile: float = 0.02,
184 | linear: bool = False):
185 |
186 | super(SpectralContrast, self).__init__()
187 |
188 | # Compute the center frequencies of each bin
189 | if freq is None:
190 | freq = fft_frequencies(sample_rate=sample_rate, n_fft=n_fft)
191 |
192 | self.register_buffer('freq', freq)
193 |
194 | if n_bands < 1 or not isinstance(n_bands, int):
195 | raise ValueError('n_bands must be a positive integer')
196 |
197 | self.n_bands = n_bands
198 |
199 | if not 0.0 < quantile < 1.0:
200 | raise ValueError('quantile must lie in the range (0, 1)')
201 |
202 | self.quantile = quantile
203 |
204 | if fmin <= 0:
205 | raise ValueError('fmin must be a positive number')
206 |
207 | octa_buf = torch.zeros(n_bands + 2)
208 | octa_buf[1:] = fmin * (2.0 ** torch.arange(0, n_bands + 1, dtype=torch.float32))
209 |
210 | if torch.any(octa_buf[:-1] >= 0.5 * sample_rate):
211 | raise ValueError('Frequency band exceeds Nyquist. Reduce either fmin or n_bands.')
212 |
213 | self.register_buffer('octa', octa_buf)
214 |
215 | self.linear = linear
216 |
217 | def forward(self, spectrogram: torch.Tensor) -> torch.Tensor:
218 | valley = torch.zeros(
219 | *spectrogram.shape[:-2], self.n_bands + 1, spectrogram.shape[-1],
220 | dtype=spectrogram.dtype,
221 | device=spectrogram.device
222 | )
223 | peak = torch.zeros_like(valley)
224 |
225 | for k, (f_low, f_high) in enumerate(zip(self.octa[:-1], self.octa[1:])):
226 | current_band: torch.Tensor = (self.freq >= f_low) & (self.freq <= f_high)
227 |
228 | idx = torch.nonzero(torch.flatten(current_band))
229 |
230 | if k > 0:
231 | current_band[idx[0] - 1] = True
232 |
233 | if k == self.n_bands:
234 | current_band[idx[-1] + 1:] = True
235 |
236 | sub_band = spectrogram[..., current_band, :]
237 |
238 | if k < self.n_bands:
239 | sub_band = sub_band[..., :-1, :]
240 |
241 | # Always take at least one bin from each side
242 | idx = np.rint(self.quantile * torch.sum(current_band).item())
243 | idx = int(np.maximum(idx, 1))
244 |
245 | sortedr, _ = torch.sort(sub_band, dim=-2)
246 |
247 | valley[..., k, :] = torch.mean(sortedr[..., :idx, :], dim=-2)
248 | peak[..., k, :] = torch.mean(sortedr[..., -idx:, :], dim=-2)
249 |
250 | if self.linear:
251 | return peak - valley
252 | else:
253 | return power_to_db(peak) - power_to_db(valley)
254 |
255 |
256 | class Melspectrogram(torch.nn.Module):
257 |
258 | def __init__(self,
259 | sample_rate: int = 22050,
260 | n_fft: int = 2048,
261 | n_mels: int = 128,
262 | fmin: float = 0.0,
263 | fmax: Optional[float] = None):
264 |
265 | super(Melspectrogram, self).__init__()
266 |
267 | mel_fb_buf = librosa.filters.mel(
268 | sr=sample_rate,
269 | n_fft=n_fft,
270 | n_mels=n_mels,
271 | fmin=fmin,
272 | fmax=fmax
273 | )
274 | self.register_buffer('mel_fb', torch.from_numpy(mel_fb_buf).to(torch.get_default_dtype()))
275 |
276 | def forward(self, spectrogram: torch.Tensor) -> torch.Tensor:
277 | lm = self.mel_fb @ spectrogram
278 | lm = power_to_db(lm)
279 |
280 | lm = transforms.scale(
281 | lm,
282 | lm.min(dim=-1, keepdim=True).values.min(dim=-2, keepdim=True).values,
283 | lm.max(dim=-1, keepdim=True).values.max(dim=-2, keepdim=True).values,
284 | -1.0,
285 | 1.0
286 | )
287 |
288 | return lm
289 |
290 |
291 | class CST(torch.nn.Module):
292 |
293 | def __init__(self,
294 | sample_rate: int = 22050,
295 | norm: float = float('inf'),
296 | n_fft: int = 2048,
297 | tuning: float = 0.0,
298 | n_chroma: int = 12,
299 | ctroct: float = 5.0,
300 | octwidth: float = 2.0,
301 | base_c: bool = True,
302 | freq: Optional[torch.Tensor] = None,
303 | fmin: float = 200.0,
304 | n_bands: int = 6,
305 | quantile: float = 0.02,
306 | linear: bool = False):
307 |
308 | super(CST, self).__init__()
309 |
310 | self.chroma = Chroma(
311 | sample_rate=sample_rate,
312 | norm=norm,
313 | n_fft=n_fft,
314 | tuning=tuning,
315 | n_chroma=n_chroma,
316 | ctroct=ctroct,
317 | octwidth=octwidth,
318 | base_c=base_c
319 | )
320 | self.spectral_contrast = SpectralContrast(
321 | sample_rate=sample_rate,
322 | n_fft=n_fft,
323 | freq=freq,
324 | fmin=fmin,
325 | n_bands=n_bands,
326 | quantile=quantile,
327 | linear=linear
328 | )
329 | self.tonnetz = Tonnetz(
330 | sample_rate=sample_rate,
331 | norm=norm,
332 | n_fft=n_fft,
333 | tuning=tuning,
334 | n_chroma=n_chroma,
335 | ctroct=ctroct,
336 | octwidth=octwidth,
337 | base_c=base_c
338 | )
339 |
340 | def forward(self, spectrogram: torch.Tensor) -> torch.Tensor:
341 | chroma = self.chroma(spectrogram)
342 | spectral_contrast = self.spectral_contrast(spectrogram)
343 | tonnetz = self.tonnetz(spectrogram)
344 |
345 | chroma = transforms.scale(
346 | chroma,
347 | chroma.min(dim=-1, keepdim=True).values.min(dim=-2, keepdim=True).values,
348 | chroma.max(dim=-1, keepdim=True).values.max(dim=-2, keepdim=True).values,
349 | -1.0,
350 | 1.0
351 | )
352 | spectral_contrast = transforms.scale(
353 | spectral_contrast,
354 | spectral_contrast.min(dim=-1, keepdim=True).values.min(dim=-2, keepdim=True).values,
355 | spectral_contrast.max(dim=-1, keepdim=True).values.max(dim=-2, keepdim=True).values,
356 | -1.0,
357 | 1.0
358 | )
359 | tonnetz = transforms.scale(
360 | tonnetz,
361 | tonnetz.min(dim=-1, keepdim=True).values.min(dim=-2, keepdim=True).values,
362 | tonnetz.max(dim=-1, keepdim=True).values.max(dim=-2, keepdim=True).values,
363 | -1.0,
364 | 1.0
365 | )
366 |
367 | cst = torch.cat((
368 | tonnetz,
369 | spectral_contrast,
370 | chroma
371 | ), dim=-2)
372 |
373 | return cst
374 |
375 |
376 | class LMC(torch.nn.Module):
377 |
378 | def __init__(self,
379 | sample_rate: int = 22050,
380 | norm: float = float('inf'),
381 | n_fft: int = 2048,
382 | n_mels: int = 128,
383 | tuning: float = 0.0,
384 | n_chroma: int = 12,
385 | ctroct: float = 5.0,
386 | octwidth: float = 2.0,
387 | base_c: bool = True,
388 | freq: Optional[torch.Tensor] = None,
389 | fmin: float = 200.0,
390 | fmax: Optional[float] = None,
391 | n_bands: int = 6,
392 | quantile: float = 0.02,
393 | linear: bool = False):
394 |
395 | super(LMC, self).__init__()
396 |
397 | self.lm = Melspectrogram(
398 | sample_rate=sample_rate,
399 | n_fft=n_fft,
400 | n_mels=n_mels,
401 | fmin=fmin,
402 | fmax=fmax
403 | )
404 |
405 | self.cst = CST(
406 | sample_rate=sample_rate,
407 | norm=norm,
408 | n_fft=n_fft,
409 | tuning=tuning,
410 | n_chroma=n_chroma,
411 | ctroct=ctroct,
412 | octwidth=octwidth,
413 | base_c=base_c,
414 | freq=freq,
415 | fmin=fmin,
416 | n_bands=n_bands,
417 | quantile=quantile,
418 | linear=linear
419 | )
420 |
421 | def forward(self, spectrogram: torch.Tensor) -> torch.Tensor:
422 | lm = self.lm(spectrogram)
423 | cst = self.cst(spectrogram)
424 |
425 | lmc = torch.cat((
426 | cst,
427 | lm
428 | ), dim=-2)
429 |
430 | return lmc
431 |
--------------------------------------------------------------------------------
/model/esresnet.py:
--------------------------------------------------------------------------------
1 | import termcolor
2 |
3 | import numpy as np
4 | import scipy.signal as sps
5 |
6 | import torch
7 | import torch.nn.functional as F
8 |
9 | import torchvision as tv
10 |
11 | import ignite_trainer as it
12 |
13 | from model import attention
14 |
15 | from typing import Tuple
16 | from typing import Union
17 | from typing import Optional
18 | from typing import Sequence
19 |
20 |
21 | def conv3x3(in_planes, out_planes, stride=1):
22 | """3x3 convolution with padding"""
23 | return torch.nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
24 |
25 |
26 | def conv1x1(in_planes, out_planes, stride=1):
27 | """1x1 convolution"""
28 | return torch.nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
29 |
30 |
31 | class BasicBlock(torch.nn.Module):
32 |
33 | expansion = 1
34 |
35 | def __init__(self, inplanes, planes, stride=1, downsample=None):
36 | super(BasicBlock, self).__init__()
37 | self.conv1 = conv3x3(inplanes, planes, stride)
38 | self.bn1 = torch.nn.BatchNorm2d(planes)
39 | self.relu = torch.nn.ReLU()
40 | self.conv2 = conv3x3(planes, planes)
41 | self.bn2 = torch.nn.BatchNorm2d(planes)
42 | self.downsample = downsample
43 | self.stride = stride
44 |
45 | def forward(self, x):
46 | identity = x
47 |
48 | out = self.conv1(x)
49 | out = self.bn1(out)
50 | out = self.relu(out)
51 |
52 | out = self.conv2(out)
53 | out = self.bn2(out)
54 |
55 | if self.downsample is not None:
56 | identity = self.downsample(x)
57 |
58 | out += identity
59 | out = self.relu(out)
60 |
61 | return out
62 |
63 |
64 | class Bottleneck(torch.nn.Module):
65 |
66 | expansion = 4
67 |
68 | def __init__(self, inplanes, planes, stride=1, downsample=None):
69 | super(Bottleneck, self).__init__()
70 | self.conv1 = conv1x1(inplanes, planes)
71 | self.bn1 = torch.nn.BatchNorm2d(planes)
72 | self.conv2 = conv3x3(planes, planes, stride)
73 | self.bn2 = torch.nn.BatchNorm2d(planes)
74 | self.conv3 = conv1x1(planes, planes * self.expansion)
75 | self.bn3 = torch.nn.BatchNorm2d(planes * self.expansion)
76 | self.relu = torch.nn.ReLU()
77 | self.downsample = downsample
78 | self.stride = stride
79 |
80 | def forward(self, x):
81 | identity = x
82 |
83 | out = self.conv1(x)
84 | out = self.bn1(out)
85 | out = self.relu(out)
86 |
87 | out = self.conv2(out)
88 | out = self.bn2(out)
89 | out = self.relu(out)
90 |
91 | out = self.conv3(out)
92 | out = self.bn3(out)
93 |
94 | if self.downsample is not None:
95 | identity = self.downsample(x)
96 |
97 | out += identity
98 | out = self.relu(out)
99 |
100 | return out
101 |
102 |
103 | class ResNet(it.AbstractNet):
104 |
105 | def __init__(self,
106 | block: Union[BasicBlock, Bottleneck],
107 | layers: Sequence[int],
108 | num_channels: int = 3,
109 | num_classes: int = 1000):
110 |
111 | super(ResNet, self).__init__()
112 |
113 | self.inplanes = 64
114 |
115 | self.conv1 = torch.nn.Conv2d(num_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
116 | self.bn1 = torch.nn.BatchNorm2d(64)
117 | self.relu = torch.nn.ReLU()
118 | self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
119 | self.layer1 = self._make_layer(block, 64, layers[0])
120 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
121 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
122 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
123 | self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
124 | self.fc = torch.nn.Linear(512 * block.expansion, num_classes)
125 |
126 | def _make_layer(self, block, planes, blocks, stride=1):
127 | downsample = None
128 | if stride != 1 or self.inplanes != planes * block.expansion:
129 | downsample = torch.nn.Sequential(
130 | conv1x1(self.inplanes, planes * block.expansion, stride),
131 | torch.nn.BatchNorm2d(planes * block.expansion)
132 | )
133 |
134 | layers = list()
135 | layers.append(block(self.inplanes, planes, stride, downsample))
136 | self.inplanes = planes * block.expansion
137 | for _ in range(1, blocks):
138 | layers.append(block(self.inplanes, planes))
139 |
140 | return torch.nn.Sequential(*layers)
141 |
142 | def forward(self,
143 | x: torch.Tensor,
144 | y: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
145 |
146 | x = self.conv1(x)
147 | x = self.bn1(x)
148 | x = self.relu(x)
149 | x = self.maxpool(x)
150 |
151 | x = self.layer1(x)
152 | x = self.layer2(x)
153 | x = self.layer3(x)
154 | x = self.layer4(x)
155 |
156 | x = self.avgpool(x)
157 | x = x.view(x.size(0), -1)
158 |
159 | y_pred = self.fc(x)
160 |
161 | loss = None
162 | if y is not None:
163 | loss = self.loss_fn(y_pred, y).sum()
164 |
165 | return y_pred if loss is None else (y_pred, loss)
166 |
167 | def loss_fn(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
168 | loss_pred = F.cross_entropy(y_pred, y)
169 |
170 | return loss_pred
171 |
172 | @property
173 | def loss_fn_name(self) -> str:
174 | return 'Cross Entropy'
175 |
176 |
177 | class ResNet50(ResNet):
178 |
179 | def __init__(self, num_channels: int = 3, num_classes: int = 1000):
180 | super(ResNet50, self).__init__(
181 | block=Bottleneck,
182 | layers=[3, 4, 6, 3],
183 | num_channels=num_channels,
184 | num_classes=num_classes
185 | )
186 |
187 |
188 | class ResNetWithAttention(ResNet):
189 |
190 | def __init__(self,
191 | block: Union[BasicBlock, Bottleneck],
192 | layers: Sequence[int],
193 | num_channels: int = 3,
194 | num_classes: int = 1000):
195 |
196 | super(ResNetWithAttention, self).__init__(
197 | block=block,
198 | layers=layers,
199 | num_channels=num_channels,
200 | num_classes=num_classes
201 | )
202 |
203 | self.att1 = attention.Attention2d(
204 | in_channels=64,
205 | out_channels=64 * block.expansion,
206 | num_kernels=1,
207 | kernel_size=(3, 1),
208 | padding_size=(1, 0)
209 | )
210 | self.att2 = attention.Attention2d(
211 | in_channels=64 * block.expansion,
212 | out_channels=128 * block.expansion,
213 | num_kernels=1,
214 | kernel_size=(1, 5),
215 | padding_size=(0, 2)
216 | )
217 | self.att3 = attention.Attention2d(
218 | in_channels=128 * block.expansion,
219 | out_channels=256 * block.expansion,
220 | num_kernels=1,
221 | kernel_size=(3, 1),
222 | padding_size=(1, 0)
223 | )
224 | self.att4 = attention.Attention2d(
225 | in_channels=256 * block.expansion,
226 | out_channels=512 * block.expansion,
227 | num_kernels=1,
228 | kernel_size=(1, 5),
229 | padding_size=(0, 2)
230 | )
231 | self.att5 = attention.Attention2d(
232 | in_channels=512 * block.expansion,
233 | out_channels=512 * block.expansion,
234 | num_kernels=1,
235 | kernel_size=(3, 5),
236 | padding_size=(1, 2)
237 | )
238 |
239 | def forward(self,
240 | x: torch.Tensor,
241 | y: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
242 |
243 | x = self.conv1(x)
244 | x = self.bn1(x)
245 | x = self.relu(x)
246 | x = self.maxpool(x)
247 |
248 | x_att = x.clone()
249 | x = self.layer1(x)
250 | x_att = self.att1(x_att, x.shape[-2:])
251 | x = x * x_att
252 |
253 | x_att = x.clone()
254 | x = self.layer2(x)
255 | x_att = self.att2(x_att, x.shape[-2:])
256 | x = x * x_att
257 |
258 | x_att = x.clone()
259 | x = self.layer3(x)
260 | x_att = self.att3(x_att, x.shape[-2:])
261 | x = x * x_att
262 |
263 | x_att = x.clone()
264 | x = self.layer4(x)
265 | x_att = self.att4(x_att, x.shape[-2:])
266 | x = x * x_att
267 |
268 | x_att = x.clone()
269 | x = self.avgpool(x)
270 | x_att = self.att5(x_att, x.shape[-2:])
271 | x = x * x_att
272 | x = x.view(x.size(0), -1)
273 |
274 | y_pred = self.fc(x)
275 |
276 | loss = None
277 | if y is not None:
278 | loss = self.loss_fn(y_pred, y).sum()
279 |
280 | return y_pred if loss is None else (y_pred, loss)
281 |
282 |
283 | class ResNet50WithAttention(ResNetWithAttention):
284 |
285 | def __init__(self, num_channels: int = 3, num_classes: int = 1000):
286 | super(ResNet50WithAttention, self).__init__(
287 | block=Bottleneck,
288 | layers=[3, 4, 6, 3],
289 | num_channels=num_channels,
290 | num_classes=num_classes
291 | )
292 |
293 |
294 | class _ESResNet(ResNet):
295 |
296 | def __init__(self,
297 | block: Union[BasicBlock, Bottleneck],
298 | layers: Sequence[int],
299 | n_fft: int = 256,
300 | hop_length: Optional[int] = None,
301 | win_length: Optional[int] = None,
302 | window: Optional[str] = None,
303 | normalized: bool = False,
304 | onesided: bool = True,
305 | spec_height: int = 224,
306 | spec_width: int = 224,
307 | num_classes: int = 1000,
308 | pretrained: Union[bool, str] = False,
309 | lock_pretrained: Optional[bool] = None):
310 |
311 | super(_ESResNet, self).__init__(
312 | block=block,
313 | layers=layers,
314 | num_channels=3,
315 | num_classes=num_classes
316 | )
317 |
318 | self.num_classes = num_classes
319 |
320 | self.fc = torch.nn.Identity()
321 | self.classifier = torch.nn.Linear(
322 | in_features=512 * block.expansion,
323 | out_features=self.num_classes
324 | )
325 |
326 | if hop_length is None:
327 | hop_length = int(np.floor(n_fft / 4))
328 |
329 | if win_length is None:
330 | win_length = n_fft
331 |
332 | if window is None:
333 | window = 'boxcar'
334 |
335 | self.n_fft = n_fft
336 | self.win_length = win_length
337 | self.hop_length = hop_length
338 |
339 | self.normalized = normalized
340 | self.onesided = onesided
341 |
342 | self.spec_height = spec_height
343 | self.spec_width = spec_width
344 |
345 | self.pretrained = pretrained
346 | if pretrained:
347 | err_msg = self.load_pretrained()
348 |
349 | unlocked_weights = list()
350 |
351 | for name, p in self.named_parameters():
352 | if lock_pretrained and name not in err_msg:
353 | p.requires_grad_(False)
354 | else:
355 | unlocked_weights.append(name)
356 |
357 | print(f'Following weights are unlocked: {unlocked_weights}')
358 |
359 | window_buffer: torch.Tensor = torch.from_numpy(
360 | sps.get_window(window=window, Nx=win_length, fftbins=True)
361 | ).to(torch.get_default_dtype())
362 | self.register_buffer('window', window_buffer)
363 |
364 | self.log10_eps = 1e-18
365 |
366 | def load_pretrained(self) -> str:
367 | if isinstance(self.pretrained, bool):
368 | state_dict = self.loading_func(pretrained=True).state_dict()
369 | else:
370 | state_dict = torch.load(self.pretrained, map_location='cpu')
371 |
372 | err_msg = ''
373 | try:
374 | self.load_state_dict(state_dict=state_dict, strict=True)
375 | except RuntimeError as ex:
376 | err_msg += f'While loading some errors occurred.\n{ex}'
377 | print(termcolor.colored(err_msg, 'red'))
378 |
379 | return err_msg
380 |
381 | def forward(self,
382 | x: torch.Tensor,
383 | y: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
384 |
385 | pow_spec = self.spectrogram(x)
386 | x_db = torch.log10(pow_spec).mul(10.0)
387 |
388 | outputs = list()
389 | for ch_idx in range(x_db.shape[1]):
390 | ch = x_db[:, ch_idx]
391 | out = super(_ESResNet, self).forward(ch)
392 | outputs.append(out)
393 |
394 | outputs = torch.stack(outputs, dim=-1).sum(dim=-1)
395 | y_pred = self.classifier(outputs)
396 |
397 | loss = None
398 | if y is not None:
399 | loss = self.loss_fn(y_pred, y).mean()
400 |
401 | return y_pred if loss is None else (y_pred, loss)
402 |
403 | def spectrogram(self, x: torch.Tensor) -> torch.Tensor:
404 | spec = torch.stft(
405 | x.view(-1, x.shape[-1]),
406 | n_fft=self.n_fft,
407 | hop_length=self.hop_length,
408 | win_length=self.win_length,
409 | window=self.window,
410 | pad_mode='reflect',
411 | normalized=self.normalized,
412 | onesided=True
413 | )
414 |
415 | if not self.onesided:
416 | spec = torch.cat((torch.flip(spec, dims=(-3,)), spec), dim=-3)
417 |
418 | spec_height_3_bands = spec.shape[-3] // 3
419 | spec_height_single_band = 3 * spec_height_3_bands
420 | spec = spec[:, :spec_height_single_band]
421 |
422 | spec = spec.reshape(x.shape[0], -1, spec.shape[-3] // 3, *spec.shape[-2:])
423 |
424 | spec_height = spec.shape[-3] if self.spec_height < 1 else self.spec_height
425 | spec_width = spec.shape[-2] if self.spec_width < 1 else self.spec_width
426 |
427 | pow_spec = spec[..., 0] ** 2 + spec[..., 1] ** 2
428 |
429 | if spec_height != pow_spec.shape[-2] or spec_width != pow_spec.shape[-1]:
430 | pow_spec = F.interpolate(
431 | pow_spec,
432 | size=(spec_height, spec_width),
433 | mode='bilinear',
434 | align_corners=True
435 | )
436 |
437 | pow_spec = torch.where(pow_spec > 0.0, pow_spec, torch.full_like(pow_spec, self.log10_eps))
438 |
439 | pow_spec = pow_spec.view(x.shape[0], -1, 3, *pow_spec.shape[-2:])
440 |
441 | return pow_spec
442 |
443 |
444 | class ESResNet(_ESResNet):
445 |
446 | loading_func = staticmethod(tv.models.resnet50)
447 |
448 | def __init__(self,
449 | n_fft: int = 256,
450 | hop_length: Optional[int] = None,
451 | win_length: Optional[int] = None,
452 | window: Optional[str] = None,
453 | normalized: bool = False,
454 | onesided: bool = True,
455 | spec_height: int = 224,
456 | spec_width: int = 224,
457 | num_classes: int = 1000,
458 | pretrained: bool = False,
459 | lock_pretrained: Optional[bool] = None):
460 |
461 | super(ESResNet, self).__init__(
462 | block=Bottleneck,
463 | layers=[3, 4, 6, 3],
464 | n_fft=n_fft,
465 | hop_length=hop_length,
466 | win_length=win_length,
467 | window=window,
468 | normalized=normalized,
469 | onesided=onesided,
470 | spec_height=spec_height,
471 | spec_width=spec_width,
472 | num_classes=num_classes,
473 | pretrained=pretrained,
474 | lock_pretrained=lock_pretrained
475 | )
476 |
477 |
478 | class ESResNetAttention(_ESResNet, ResNetWithAttention):
479 |
480 | loading_func = staticmethod(tv.models.resnet50)
481 |
482 | def __init__(self,
483 | n_fft: int = 256,
484 | hop_length: Optional[int] = None,
485 | win_length: Optional[int] = None,
486 | window: Optional[str] = None,
487 | normalized: bool = False,
488 | onesided: bool = True,
489 | spec_height: int = 224,
490 | spec_width: int = 224,
491 | num_classes: int = 1000,
492 | pretrained: bool = False,
493 | lock_pretrained: Optional[bool] = None):
494 |
495 | super(ESResNetAttention, self).__init__(
496 | block=Bottleneck,
497 | layers=[3, 4, 6, 3],
498 | n_fft=n_fft,
499 | hop_length=hop_length,
500 | win_length=win_length,
501 | window=window,
502 | normalized=normalized,
503 | onesided=onesided,
504 | spec_height=spec_height,
505 | spec_width=spec_width,
506 | num_classes=num_classes,
507 | pretrained=pretrained,
508 | lock_pretrained=lock_pretrained
509 | )
510 |
511 | def forward(self,
512 | x: torch.Tensor,
513 | y: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
514 |
515 | pow_spec = self.spectrogram(x)
516 | x_db = torch.log10(pow_spec).mul(10.0)
517 |
518 | outputs = list()
519 | for ch_idx in range(x_db.shape[1]):
520 | ch = x_db[:, ch_idx]
521 | out = super(_ESResNet, self).forward(ch)
522 | outputs.append(out)
523 |
524 | outputs = torch.stack(outputs, dim=-1).sum(dim=-1)
525 | y_pred = self.classifier(outputs)
526 |
527 | loss = None
528 | if y is not None:
529 | loss = self.loss_fn(y_pred, y).mean()
530 |
531 | return y_pred if loss is None else (y_pred, loss)
532 |
--------------------------------------------------------------------------------
/ignite_trainer/_trainer.py:
--------------------------------------------------------------------------------
1 | import io
2 | import os
3 | import glob
4 | import json
5 | import time
6 | import tqdm
7 | import signal
8 | import argparse
9 | import numpy as np
10 |
11 | import torch
12 | import torch.utils.data
13 | import torch.nn.functional
14 |
15 | import torchvision as tv
16 |
17 | import ignite.engine as ieng
18 | import ignite.metrics as imet
19 | import ignite.handlers as ihan
20 |
21 | from typing import Any
22 | from typing import Dict
23 | from typing import List
24 | from typing import Type
25 | from typing import Union
26 | from typing import Optional
27 |
28 | from termcolor import colored
29 |
30 | from collections import defaultdict
31 | from collections.abc import Iterable
32 |
33 | from ignite_trainer import _utils
34 | from ignite_trainer import _visdom
35 | from ignite_trainer import _interfaces
36 |
37 | VISDOM_HOST = 'localhost'
38 | VISDOM_PORT = 8097
39 | VISDOM_ENV_PATH = 'visdom_env'
40 | BATCH_TRAIN = 128
41 | BATCH_TEST = 1024
42 | WORKERS_TRAIN = 0
43 | WORKERS_TEST = 0
44 | EPOCHS = 100
45 | LOG_INTERVAL = 50
46 | SAVED_MODELS_PATH = os.path.join(os.path.expanduser('~'), 'saved_models')
47 |
48 |
49 | def run(experiment_name: str,
50 | visdom_host: str,
51 | visdom_port: int,
52 | visdom_env_path: str,
53 | model_class: str,
54 | model_args: Dict[str, Any],
55 | optimizer_class: str,
56 | optimizer_args: Dict[str, Any],
57 | dataset_class: str,
58 | dataset_args: Dict[str, Any],
59 | batch_train: int,
60 | batch_test: int,
61 | workers_train: int,
62 | workers_test: int,
63 | transforms: List[Dict[str, Union[str, Dict[str, Any]]]],
64 | epochs: int,
65 | log_interval: int,
66 | saved_models_path: str,
67 | performance_metrics: Optional = None,
68 | scheduler_class: Optional[str] = None,
69 | scheduler_args: Optional[Dict[str, Any]] = None,
70 | model_suffix: Optional[str] = None,
71 | setup_suffix: Optional[str] = None,
72 | orig_stdout: Optional[io.TextIOBase] = None):
73 |
74 | with _utils.tqdm_stdout(orig_stdout) as orig_stdout:
75 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
76 |
77 | transforms_train = list()
78 | transforms_test = list()
79 |
80 | for idx, transform in enumerate(transforms):
81 | use_train = transform.get('train', True)
82 | use_test = transform.get('test', True)
83 |
84 | transform = _utils.load_class(transform['class'])(**transform['args'])
85 |
86 | if use_train:
87 | transforms_train.append(transform)
88 | if use_test:
89 | transforms_test.append(transform)
90 |
91 | transforms[idx]['train'] = use_train
92 | transforms[idx]['test'] = use_test
93 |
94 | transforms_train = tv.transforms.Compose(transforms_train)
95 | transforms_test = tv.transforms.Compose(transforms_test)
96 |
97 | Dataset: Type = _utils.load_class(dataset_class)
98 |
99 | train_loader, eval_loader = _utils.get_data_loaders(
100 | Dataset,
101 | dataset_args,
102 | batch_train,
103 | batch_test,
104 | workers_train,
105 | workers_test,
106 | transforms_train,
107 | transforms_test
108 | )
109 |
110 | Network: Type = _utils.load_class(model_class)
111 | model: _interfaces.AbstractNet = Network(**model_args)
112 | model = model.to(device)
113 |
114 | Optimizer: Type = _utils.load_class(optimizer_class)
115 | optimizer: torch.optim.Optimizer = Optimizer(model.parameters(), **optimizer_args)
116 |
117 | if scheduler_class is not None:
118 | Scheduler: Type = _utils.load_class(scheduler_class)
119 |
120 | if scheduler_args is None:
121 | scheduler_args = dict()
122 |
123 | scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = Scheduler(optimizer, **scheduler_args)
124 | else:
125 | scheduler = None
126 |
127 | model_short_name = ''.join([c for c in Network.__name__ if c == c.upper()])
128 | model_name = '{}{}'.format(
129 | model_short_name,
130 | '-{}'.format(model_suffix) if model_suffix is not None else ''
131 | )
132 | visdom_env_name = '{}_{}_{}{}'.format(
133 | Dataset.__name__,
134 | experiment_name,
135 | model_name,
136 | '-{}'.format(setup_suffix) if setup_suffix is not None else ''
137 | )
138 |
139 | vis, vis_pid = _visdom.get_visdom_instance(visdom_host, visdom_port, visdom_env_name, visdom_env_path)
140 |
141 | prog_bar_epochs = tqdm.tqdm(total=epochs, desc='Epochs', file=orig_stdout, dynamic_ncols=True, unit='epoch')
142 | prog_bar_iters = tqdm.tqdm(desc='Batches', file=orig_stdout, dynamic_ncols=True)
143 |
144 | tqdm.tqdm.write(f'\n{repr(model)}\n')
145 | tqdm.tqdm.write('Total number of parameters: {:.2f}M'.format(sum(p.numel() for p in model.parameters()) / 1e6))
146 |
147 | def training_step(_: ieng.Engine, batch: _interfaces.TensorPair) -> torch.Tensor:
148 | model.train()
149 |
150 | optimizer.zero_grad()
151 |
152 | x, y = batch
153 |
154 | x = x.to(device)
155 | y = y.to(device)
156 |
157 | _, loss = model(x, y)
158 |
159 | loss.backward(retain_graph=False)
160 | optimizer.step(None)
161 |
162 | return loss.item()
163 |
164 | def eval_step(_: ieng.Engine, batch: _interfaces.TensorPair) -> _interfaces.TensorPair:
165 | model.eval()
166 |
167 | with torch.no_grad():
168 | x, y = batch
169 |
170 | x = x.to(device)
171 | y = y.to(device)
172 |
173 | y_pred = model(x)
174 |
175 | return y_pred, y
176 |
177 | trainer = ieng.Engine(training_step)
178 | validator_train = ieng.Engine(eval_step)
179 | validator_eval = ieng.Engine(eval_step)
180 |
181 | # placeholder for summary window
182 | vis.text(
183 | text='',
184 | win=experiment_name,
185 | env=visdom_env_name,
186 | opts={'title': 'Summary', 'width': 940, 'height': 416},
187 | append=vis.win_exists(experiment_name, visdom_env_name)
188 | )
189 |
190 | default_metrics = {
191 | "Loss": {
192 | "window_name": None,
193 | "x_label": "#Epochs",
194 | "y_label": model.loss_fn_name,
195 | "width": 940,
196 | "height": 416,
197 | "lines": [
198 | {
199 | "line_label": "SMA",
200 | "object": imet.RunningAverage(output_transform=lambda x: x),
201 | "test": False,
202 | "update_rate": "iteration"
203 | },
204 | {
205 | "line_label": "Val.",
206 | "object": imet.Loss(model.loss_fn)
207 | }
208 | ]
209 | }
210 | }
211 |
212 | performance_metrics = {**default_metrics, **performance_metrics}
213 | checkpoint_metrics = list()
214 |
215 | for scope_name, scope in performance_metrics.items():
216 | scope['window_name'] = scope.get('window_name', scope_name) or scope_name
217 |
218 | for line in scope['lines']:
219 | if 'object' not in line:
220 | line['object']: imet.Metric = _utils.load_class(line['class'])(**line['args'])
221 |
222 | line['metric_label'] = '{}: {}'.format(scope['window_name'], line['line_label'])
223 |
224 | line['update_rate'] = line.get('update_rate', 'epoch')
225 | line_suffixes = list()
226 | if line['update_rate'] == 'iteration':
227 | line['object'].attach(trainer, line['metric_label'])
228 | line['train'] = False
229 | line['test'] = False
230 |
231 | line_suffixes.append(' Train.')
232 |
233 | if line.get('train', True):
234 | line['object'].attach(validator_train, line['metric_label'])
235 | line_suffixes.append(' Train.')
236 | if line.get('test', True):
237 | line['object'].attach(validator_eval, line['metric_label'])
238 | line_suffixes.append(' Eval.')
239 |
240 | if line.get('is_checkpoint', False):
241 | checkpoint_metrics.append(line['metric_label'])
242 |
243 | for line_suffix in line_suffixes:
244 | _visdom.plot_line(
245 | vis=vis,
246 | window_name=scope['window_name'],
247 | env=visdom_env_name,
248 | line_label=line['line_label'] + line_suffix,
249 | x_label=scope['x_label'],
250 | y_label=scope['y_label'],
251 | width=scope['width'],
252 | height=scope['height'],
253 | draw_marker=(line['update_rate'] == 'epoch')
254 | )
255 |
256 | if checkpoint_metrics:
257 | score_name = 'performance'
258 |
259 | def get_score(engine: ieng.Engine) -> float:
260 | current_mode = getattr(engine.state.dataloader.iterable.dataset, dataset_args['training']['key'])
261 | val_mode = dataset_args['training']['no']
262 |
263 | score = 0.0
264 | if current_mode == val_mode:
265 | for metric_name in checkpoint_metrics:
266 | try:
267 | score += engine.state.metrics[metric_name]
268 | except KeyError:
269 | pass
270 |
271 | return score
272 |
273 | model_saver = ihan.ModelCheckpoint(
274 | os.path.join(saved_models_path, visdom_env_name),
275 | filename_prefix=visdom_env_name,
276 | score_name=score_name,
277 | score_function=get_score,
278 | n_saved=3,
279 | save_as_state_dict=True,
280 | require_empty=False,
281 | create_dir=True
282 | )
283 |
284 | validator_eval.add_event_handler(ieng.Events.EPOCH_COMPLETED, model_saver, {model_name: model})
285 |
286 | @trainer.on(ieng.Events.EPOCH_STARTED)
287 | def reset_progress_iterations(engine: ieng.Engine):
288 | prog_bar_iters.clear()
289 | prog_bar_iters.n = 0
290 | prog_bar_iters.last_print_n = 0
291 | prog_bar_iters.start_t = time.time()
292 | prog_bar_iters.last_print_t = time.time()
293 | prog_bar_iters.total = len(engine.state.dataloader)
294 |
295 | @trainer.on(ieng.Events.ITERATION_COMPLETED)
296 | def log_training(engine: ieng.Engine):
297 | prog_bar_iters.update(1)
298 |
299 | num_iter = (engine.state.iteration - 1) % len(train_loader) + 1
300 |
301 | early_stop = np.isnan(engine.state.output) or np.isinf(engine.state.output)
302 |
303 | if num_iter % log_interval == 0 or num_iter == len(train_loader) or early_stop:
304 | tqdm.tqdm.write(
305 | 'Epoch[{}] Iteration[{}/{}] Loss: {:.4f}'.format(
306 | engine.state.epoch, num_iter, len(train_loader), engine.state.output
307 | )
308 | )
309 |
310 | x_pos = engine.state.epoch + num_iter / len(train_loader) - 1
311 | for scope_name, scope in performance_metrics.items():
312 | for line in scope['lines']:
313 | if line['update_rate'] == 'iteration':
314 | line_label = '{} Train.'.format(line['line_label'])
315 | line_value = engine.state.metrics[line['metric_label']]
316 |
317 | if engine.state.epoch > 1:
318 | _visdom.plot_line(
319 | vis=vis,
320 | window_name=scope['window_name'],
321 | env=visdom_env_name,
322 | line_label=line_label,
323 | x_label=scope['x_label'],
324 | y_label=scope['y_label'],
325 | x=np.full(1, x_pos),
326 | y=np.full(1, line_value)
327 | )
328 |
329 | if early_stop:
330 | tqdm.tqdm.write(colored('Early stopping due to invalid loss value.', 'red'))
331 | trainer.terminate()
332 |
333 | def log_validation(engine: ieng.Engine,
334 | train: bool = True):
335 |
336 | if train:
337 | run_type = 'Train.'
338 | data_loader = train_loader
339 | validator = validator_train
340 | else:
341 | run_type = 'Eval.'
342 | data_loader = eval_loader
343 | validator = validator_eval
344 |
345 | prog_bar_validation = tqdm.tqdm(
346 | data_loader,
347 | desc=f'Validation {run_type}',
348 | file=orig_stdout,
349 | dynamic_ncols=True,
350 | leave=False
351 | )
352 | validator.run(prog_bar_validation)
353 | prog_bar_validation.clear()
354 | prog_bar_validation.close()
355 |
356 | tqdm_info = [
357 | 'Epoch: {}'.format(engine.state.epoch)
358 | ]
359 | for scope_name, scope in performance_metrics.items():
360 | for line in scope['lines']:
361 | if line['update_rate'] == 'epoch':
362 | try:
363 | line_label = '{} {}'.format(line['line_label'], run_type)
364 | line_value = validator.state.metrics[line['metric_label']]
365 |
366 | _visdom.plot_line(
367 | vis=vis,
368 | window_name=scope['window_name'],
369 | env=visdom_env_name,
370 | line_label=line_label,
371 | x_label=scope['x_label'],
372 | y_label=scope['y_label'],
373 | x=np.full(1, engine.state.epoch),
374 | y=np.full(1, line_value),
375 | draw_marker=True
376 | )
377 |
378 | tqdm_info.append('{}: {:.4f}'.format(line_label, line_value))
379 | except KeyError:
380 | pass
381 |
382 | tqdm.tqdm.write('{} results - {}'.format(run_type, '; '.join(tqdm_info)))
383 |
384 | @trainer.on(ieng.Events.EPOCH_COMPLETED)
385 | def log_validation_train(engine: ieng.Engine):
386 | log_validation(engine, True)
387 |
388 | @trainer.on(ieng.Events.EPOCH_COMPLETED)
389 | def log_validation_eval(engine: ieng.Engine):
390 | log_validation(engine, False)
391 |
392 | if engine.state.epoch == 1:
393 | summary = _utils.build_summary_str(
394 | experiment_name=experiment_name,
395 | model_short_name=model_name,
396 | model_class=model_class,
397 | model_args=model_args,
398 | optimizer_class=optimizer_class,
399 | optimizer_args=optimizer_args,
400 | dataset_class=dataset_class,
401 | dataset_args=dataset_args,
402 | transforms=transforms,
403 | epochs=epochs,
404 | batch_train=batch_train,
405 | log_interval=log_interval,
406 | saved_models_path=saved_models_path,
407 | scheduler_class=scheduler_class,
408 | scheduler_args=scheduler_args
409 | )
410 | _visdom.create_summary_window(
411 | vis=vis,
412 | visdom_env_name=visdom_env_name,
413 | experiment_name=experiment_name,
414 | summary=summary
415 | )
416 |
417 | vis.save([visdom_env_name])
418 |
419 | prog_bar_epochs.update(1)
420 |
421 | if scheduler is not None:
422 | scheduler.step(engine.state.epoch)
423 |
424 | trainer.run(train_loader, max_epochs=epochs)
425 |
426 | if vis_pid is not None:
427 | tqdm.tqdm.write('Stopping visdom')
428 | os.kill(vis_pid, signal.SIGTERM)
429 |
430 | del vis
431 | del train_loader
432 | del eval_loader
433 |
434 | prog_bar_iters.clear()
435 | prog_bar_iters.close()
436 |
437 | prog_bar_epochs.clear()
438 | prog_bar_epochs.close()
439 |
440 | tqdm.tqdm.write('\n')
441 |
442 |
443 | def main():
444 | with _utils.tqdm_stdout() as orig_stdout:
445 | parser = argparse.ArgumentParser()
446 |
447 | parser.add_argument('-c', '--config', type=str, required=True)
448 | parser.add_argument('-H', '--visdom-host', type=str, required=False)
449 | parser.add_argument('-P', '--visdom-port', type=int, required=False)
450 | parser.add_argument('-E', '--visdom-env-path', type=str, required=False)
451 | parser.add_argument('-b', '--batch-train', type=int, required=False)
452 | parser.add_argument('-B', '--batch-test', type=int, required=False)
453 | parser.add_argument('-w', '--workers-train', type=int, required=False)
454 | parser.add_argument('-W', '--workers-test', type=int, required=False)
455 | parser.add_argument('-e', '--epochs', type=int, required=False)
456 | parser.add_argument('-L', '--log-interval', type=int, required=False)
457 | parser.add_argument('-M', '--saved-models-path', type=str, required=False)
458 | parser.add_argument('-R', '--random-seed', type=int, required=False)
459 | parser.add_argument('-s', '--suffix', type=str, required=False)
460 |
461 | args, unknown_args = parser.parse_known_args()
462 |
463 | if args.batch_test is None:
464 | args.batch_test = args.batch_train
465 |
466 | if args.random_seed is not None:
467 | args.suffix = '{}r-{}'.format(
468 | '{}_'.format(args.suffix) if args.suffix is not None else '',
469 | args.random_seed
470 | )
471 |
472 | np.random.seed(args.random_seed)
473 | torch.random.manual_seed(args.random_seed)
474 | if torch.cuda.is_available():
475 | torch.cuda.manual_seed(args.random_seed)
476 | torch.backends.cudnn.deterministic = True
477 | torch.backends.cudnn.benchmark = False
478 |
479 | configs_found = list(sorted(glob.glob(os.path.expanduser(args.config))))
480 | prog_bar_exps = tqdm.tqdm(
481 | configs_found,
482 | desc='Experiments',
483 | unit='setup',
484 | file=orig_stdout,
485 | dynamic_ncols=True
486 | )
487 |
488 | for config_path in prog_bar_exps:
489 | config = json.load(open(config_path))
490 |
491 | if unknown_args:
492 | tqdm.tqdm.write('\nParsing additional arguments...')
493 |
494 | args_not_found = list()
495 | for arg in unknown_args:
496 | if arg.startswith('--'):
497 | keys = arg.strip('-').split('.')
498 |
499 | section = config
500 | found = True
501 | for key in keys:
502 | if key in section:
503 | section = section[key]
504 | else:
505 | found = False
506 | break
507 |
508 | if found:
509 | override_parser = argparse.ArgumentParser()
510 |
511 | section_nargs = None
512 | section_type = type(section) if section is not None else str
513 |
514 | if section_type is bool:
515 | if section_type is bool:
516 | def infer_bool(x: str) -> bool:
517 | return x.lower() not in ('0', 'false', 'no')
518 |
519 | section_type = infer_bool
520 |
521 | if isinstance(section, Iterable) and section_type is not str:
522 | section_nargs = '+'
523 | section_type = {type(value) for value in section}
524 |
525 | if len(section_type) == 1:
526 | section_type = section_type.pop()
527 | else:
528 | section_type = str
529 |
530 | override_parser.add_argument(arg, nargs=section_nargs, type=section_type)
531 | overridden_args, _ = override_parser.parse_known_args(unknown_args)
532 | overridden_args = vars(overridden_args)
533 |
534 | overridden_key = arg.strip('-')
535 | overriding_value = overridden_args[overridden_key]
536 |
537 | section = config
538 | old_value = None
539 | for i, key in enumerate(keys, 1):
540 | if i == len(keys):
541 | old_value = section[key]
542 | section[key] = overriding_value
543 | else:
544 | section = section[key]
545 |
546 | tqdm.tqdm.write(
547 | colored(f'Overriding "{overridden_key}": {old_value} -> {overriding_value}', 'magenta')
548 | )
549 | else:
550 | args_not_found.append(arg)
551 |
552 | if args_not_found:
553 | tqdm.tqdm.write(
554 | colored(
555 | '\nThere are unrecognized arguments to override: {}'.format(
556 | ', '.join(args_not_found)
557 | ),
558 | 'red'
559 | )
560 | )
561 |
562 | config = defaultdict(None, config)
563 |
564 | experiment_name = config['Setup']['name']
565 |
566 | visdom_host = _utils.arg_selector(
567 | args.visdom_host, config['Visdom']['host'], VISDOM_HOST
568 | )
569 | visdom_port = int(_utils.arg_selector(
570 | args.visdom_port, config['Visdom']['port'], VISDOM_PORT
571 | ))
572 | visdom_env_path = _utils.arg_selector(
573 | args.visdom_env_path, config['Visdom']['env_path'], VISDOM_ENV_PATH
574 | )
575 | batch_train = int(_utils.arg_selector(
576 | args.batch_train, config['Setup']['batch_train'], BATCH_TRAIN
577 | ))
578 | batch_test = int(_utils.arg_selector(
579 | args.batch_test, config['Setup']['batch_test'], BATCH_TEST
580 | ))
581 | workers_train = _utils.arg_selector(
582 | args.workers_train, config['Setup']['workers_train'], WORKERS_TRAIN
583 | )
584 | workers_test = _utils.arg_selector(
585 | args.workers_test, config['Setup']['workers_test'], WORKERS_TEST
586 | )
587 | epochs = _utils.arg_selector(
588 | args.epochs, config['Setup']['epochs'], EPOCHS
589 | )
590 | log_interval = _utils.arg_selector(
591 | args.log_interval, config['Setup']['log_interval'], LOG_INTERVAL
592 | )
593 | saved_models_path = _utils.arg_selector(
594 | args.saved_models_path, config['Setup']['saved_models_path'], SAVED_MODELS_PATH
595 | )
596 |
597 | model_class = config['Model']['class']
598 | model_args = config['Model']['args']
599 |
600 | optimizer_class = config['Optimizer']['class']
601 | optimizer_args = config['Optimizer']['args']
602 |
603 | if 'Scheduler' in config:
604 | scheduler_class = config['Scheduler']['class']
605 | scheduler_args = config['Scheduler']['args']
606 | else:
607 | scheduler_class = None
608 | scheduler_args = None
609 |
610 | dataset_class = config['Dataset']['class']
611 | dataset_args = config['Dataset']['args']
612 |
613 | transforms = config['Transforms']
614 | performance_metrics = config['Metrics']
615 |
616 | tqdm.tqdm.write(f'\nStarting experiment "{experiment_name}"\n')
617 |
618 | run(
619 | experiment_name=experiment_name,
620 | visdom_host=visdom_host,
621 | visdom_port=visdom_port,
622 | visdom_env_path=visdom_env_path,
623 | model_class=model_class,
624 | model_args=model_args,
625 | optimizer_class=optimizer_class,
626 | optimizer_args=optimizer_args,
627 | dataset_class=dataset_class,
628 | dataset_args=dataset_args,
629 | batch_train=batch_train,
630 | batch_test=batch_test,
631 | workers_train=workers_train,
632 | workers_test=workers_test,
633 | transforms=transforms,
634 | epochs=epochs,
635 | log_interval=log_interval,
636 | saved_models_path=saved_models_path,
637 | performance_metrics=performance_metrics,
638 | scheduler_class=scheduler_class,
639 | scheduler_args=scheduler_args,
640 | model_suffix=config['Setup']['suffix'],
641 | setup_suffix=args.suffix,
642 | orig_stdout=orig_stdout
643 | )
644 |
645 | prog_bar_exps.close()
646 |
647 | tqdm.tqdm.write('\n')
648 |
--------------------------------------------------------------------------------