├── tests
├── __init__.py
├── optim
│ ├── __init__.py
│ ├── test_optimizers.py
│ └── test_schedulers.py
├── utils
│ ├── __init__.py
│ ├── test_configs.py
│ └── test_devices.py
├── configs
│ ├── __init__.py
│ ├── invalid_config.yml
│ ├── valid_config.yml
│ └── test_model_trainers_config.py
├── events
│ ├── __init__.py
│ ├── handlers
│ │ ├── __init__.py
│ │ ├── test_early_stopping.py
│ │ ├── test_visdom.py
│ │ ├── test_console.py
│ │ └── test_checkpoints.py
│ └── publishers
│ │ └── __init__.py
├── loggers
│ ├── __init__.py
│ └── visdom
│ │ ├── __init__.py
│ │ └── test_visdom.py
├── metrics
│ ├── __init__.py
│ └── test_metrics.py
├── models
│ └── __init__.py
├── nn
│ ├── utils
│ │ ├── __init__.py
│ │ └── test_gradients.py
│ ├── __init__.py
│ ├── test_functional.py
│ ├── test_apex.py
│ └── test_criterions.py
├── parsers
│ ├── __init__.py
│ ├── test_yaml.yml
│ └── yaml_parser_test.py
├── scheduler
│ └── __init__.py
├── training
│ ├── __init__.py
│ └── test_trainers.py
├── functionals
│ ├── __init__.py
│ ├── distributed
│ │ ├── __init__.py
│ │ ├── run.sh
│ │ ├── config.yml
│ │ ├── mnist_trainer.py
│ │ └── main.py
│ └── models.py
├── constants.py
└── base_test.py
├── kerosene
├── __init__.py
├── nn
│ ├── __init__.py
│ ├── utils
│ │ ├── __init__.py
│ │ └── gradients.py
│ ├── functional.py
│ ├── apex.py
│ └── criterions.py
├── configs
│ ├── __init__.py
│ ├── exceptions.py
│ ├── parsers.py
│ └── configs.py
├── loggers
│ ├── __init__.py
│ └── visdom
│ │ ├── __init__.py
│ │ ├── data.py
│ │ ├── config.py
│ │ ├── visdom.py
│ │ └── plots.py
├── metrics
│ ├── __init__.py
│ ├── gauges.py
│ └── metrics.py
├── models
│ ├── __init__.py
│ └── models.py
├── optim
│ ├── __init__.py
│ ├── schedulers.py
│ └── optimizers.py
├── utils
│ ├── __init__.py
│ ├── constants.py
│ ├── configs.py
│ ├── devices.py
│ ├── files.py
│ └── tensors.py
├── events
│ ├── handlers
│ │ ├── __init__.py
│ │ ├── early_stopping.py
│ │ ├── base_handler.py
│ │ ├── checkpoints.py
│ │ ├── base_monitor_watcher.py
│ │ ├── console.py
│ │ └── visdom.py
│ ├── publishers
│ │ ├── __init__.py
│ │ └── base_publisher.py
│ ├── exceptions.py
│ └── __init__.py
└── training
│ ├── __init__.py
│ └── events.py
├── examples
├── mnist
│ ├── __init__.py
│ ├── requirements.txt
│ ├── config.yml
│ ├── models.py
│ ├── main.py
│ └── main_hyperopt.py
└── config_template.yml
├── icons
└── oil.png
├── deploy.sh
├── deploy.bat
├── .travis.yml
├── requirements.txt
├── .github
└── ISSUE_TEMPLATE
│ ├── other-issues.md
│ ├── documentation-issue.md
│ ├── feature_request.md
│ └── bug---performance-report.md
├── LICENSE
├── setup.py
├── .gitignore
└── README.md
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/kerosene/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/kerosene/nn/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/optim/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/examples/mnist/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/kerosene/configs/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/kerosene/loggers/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/kerosene/metrics/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/kerosene/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/kerosene/optim/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/kerosene/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/configs/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/events/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/loggers/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/metrics/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/nn/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/parsers/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/scheduler/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/training/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/kerosene/nn/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/events/handlers/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/functionals/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/loggers/visdom/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/optim/test_optimizers.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/optim/test_schedulers.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/training/test_trainers.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/kerosene/events/handlers/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/kerosene/events/publishers/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/events/publishers/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/loggers/visdom/test_visdom.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/functionals/distributed/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/events/handlers/test_early_stopping.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/examples/mnist/requirements.txt:
--------------------------------------------------------------------------------
1 | torch-kerosene
2 | hyperopt
3 |
--------------------------------------------------------------------------------
/icons/oil.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/banctilrobitaille/kerosene/HEAD/icons/oil.png
--------------------------------------------------------------------------------
/deploy.sh:
--------------------------------------------------------------------------------
1 | python3 setup.py sdist
2 | python3 -m twine upload dist/torch-kerosene-0.2.4.tar.gz
--------------------------------------------------------------------------------
/kerosene/configs/exceptions.py:
--------------------------------------------------------------------------------
1 | class InvalidConfigurationError(Exception):
2 | pass
3 |
--------------------------------------------------------------------------------
/deploy.bat:
--------------------------------------------------------------------------------
1 | python setup.py sdist
2 | python -m twine upload dist/torch-kerosene-0.2.4.tar.gz
3 |
4 |
--------------------------------------------------------------------------------
/tests/parsers/test_yaml.yml:
--------------------------------------------------------------------------------
1 | weights: !torch/tensor [1, 1, 1, 1, 1, 1]
2 | tuple: !python/tuple [0.1,0.2]
--------------------------------------------------------------------------------
/tests/functionals/distributed/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | CUDA_VISIBLE_DEVICES=0,1 python3 -m torch.distributed.launch --nproc_per_node=2 main.py --amp-opt-level=O1
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: python
2 | python:
3 | - "3.6"
4 | before_install:
5 | - pip install --upgrade pip
6 | install:
7 | - pip install --progress-bar off -r requirements.txt
8 | script:
9 | - pytest
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | PyYAML>=5.1
2 | PyHamcrest>=1.9.0
3 | numpy>=1.16.1
4 | pyparsing>=2.3.1
5 | pytest>=4.3.0
6 | torch>=1.1
7 | torchfile>=0.1.0
8 | torchvision>=0.2.1
9 | visdom>=0.1.8.8
10 | pytorch-ignite>= 0.2.0
11 | mockito >= 1.1.1
12 | mock >= 3.0.5
13 | beautifultable
14 | crayons
15 |
--------------------------------------------------------------------------------
/kerosene/nn/functional.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def js_div(inputs):
6 | js_divs = []
7 | for dist in range(inputs.size(1)):
8 | js_divs.append(
9 | F.kl_div(torch.mean(inputs, dim=1, keepdim=True).log(), inputs[:, dist].unsqueeze(0), reduction='sum'))
10 |
11 | return torch.tensor(js_divs).mean()
12 |
--------------------------------------------------------------------------------
/tests/configs/invalid_config.yml:
--------------------------------------------------------------------------------
1 | training:
2 | nb_epochs: 10
3 | batch_size_train: 10
4 | batch_size_valid: 10
5 |
6 | models:
7 | SimpleNet:
8 | type: SimpleNet
9 | scheduler:
10 | type: ReduceLROnPlateau
11 | params:
12 | mode: 'min'
13 | factor: 0.1
14 | patience: 3
15 | criterion:
16 | type: CrossEntropyLoss
17 | metric:
18 | type: Accuracy
19 |
--------------------------------------------------------------------------------
/kerosene/events/exceptions.py:
--------------------------------------------------------------------------------
1 | class UnsupportedEventException(Exception):
2 |
3 | def __init__(self, supported_events, unsupported_event):
4 | super(UnsupportedEventException, self).__init__(
5 | "Unsupported event provided ({}). Only {} are permitted".format(str(unsupported_event),
6 | [str(event) for event in
7 | supported_events]))
8 |
--------------------------------------------------------------------------------
/examples/config_template.yml:
--------------------------------------------------------------------------------
1 | training:
2 | nb_epochs: 250
3 | batch_size: 1
4 |
5 | models:
6 | MyModel:
7 | type: resnet
8 | params:
9 | group_size: 32
10 | optimizer:
11 | type: "SGD"
12 | params:
13 | learning_rate: 0.001
14 | momentum: 0.9
15 | weigth_decay: 0.1
16 | scheduler:
17 | type: "ReduceOnPlateu"
18 | params:
19 | mode: 'min'
20 | factor: 0.1
21 | patience: 3
22 | criterion:
23 | type: "BCELoss"
24 | params:
--------------------------------------------------------------------------------
/tests/events/handlers/test_visdom.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from mockito import mock, spy
4 |
5 | from kerosene.events.handlers.visdom import PlotLosses
6 | from kerosene.loggers.visdom.visdom import VisdomLogger
7 | from tests.base_test import BaseTestEnvironment
8 |
9 |
10 | class VisdomTest(BaseTestEnvironment):
11 |
12 | def setUp(self):
13 | self._visdom_logger_mock = mock(VisdomLogger)
14 | self._plot_losses_handler = spy(PlotLosses)
15 |
16 | def test_should_plot_losses(self):
17 | pass
18 |
--------------------------------------------------------------------------------
/tests/constants.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | DELTA = 0.0001
4 | ZERO = torch.tensor([0.0])
5 | ONE = torch.tensor([1.0])
6 | TARGET_CLASS_0 = torch.tensor([0])
7 | TARGET_CLASS_1 = torch.tensor([1])
8 | MODEL_PREDICTION_CLASS_0 = torch.tensor([[1.0, 0.0]])
9 | MODEL_PREDICTION_CLASS_1 = torch.tensor([[0.0, 1.0]])
10 | MINIMUM_BINARY_CROSS_ENTROPY_LOSS = torch.tensor(0.3133)
11 | MAXIMUM_BINARY_CROSS_ENTROPY_LOSS = torch.tensor(1.3133)
12 | AVERAGED_BINARY_CROSS_ENTROPY_LOSS = (MINIMUM_BINARY_CROSS_ENTROPY_LOSS + MAXIMUM_BINARY_CROSS_ENTROPY_LOSS) / 2
13 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/other-issues.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Other issues
3 | about: Use this template for any other non-support related issues
4 | title: "[ISSUE]"
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | This template is for miscellaneous issues not covered by the other issue categories.
11 |
12 | **Describe the issue**
13 | A clear and concise description of what the problem is.
14 |
15 | **Expectations**
16 | A clear and concise description of what you expect by submitting this issue.
17 |
18 | **Additional context**
19 | Add any other context or screenshots about the issue here.
20 |
--------------------------------------------------------------------------------
/examples/mnist/config.yml:
--------------------------------------------------------------------------------
1 | training:
2 | nb_epochs: 1
3 | batch_size_train: 10
4 | batch_size_valid: 10
5 |
6 | models:
7 | SimpleNet:
8 | type: SimpleNet
9 | optimizer:
10 | type: SGD
11 | params:
12 | lr: 0.001
13 | momentum: 0.9
14 | weight_decay: 0.001
15 | scheduler:
16 | type: ReduceLROnPlateau
17 | params:
18 | mode: 'min'
19 | factor: 0.1
20 | patience: 3
21 | criterion:
22 | CrossEntropy:
23 | type: CrossEntropyLoss
24 |
25 |
26 |
27 | visdom:
28 | server: localhost
29 | port: 8097
30 | env: "Kerosene-MNIST"
--------------------------------------------------------------------------------
/tests/parsers/yaml_parser_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import torch
4 | from hamcrest import *
5 |
6 | from kerosene.configs.parsers import CustomYamlParser
7 |
8 |
9 | class YamlParserTest(unittest.TestCase):
10 |
11 | def setUp(self) -> None:
12 | self._path = "tests/parsers/test_yaml.yml"
13 | self._parser = CustomYamlParser()
14 |
15 | def test_should_parse_file_and_create_tensor(self):
16 | with open(self._path) as file:
17 | config = self._parser.safe_load(file)
18 |
19 | assert_that(config["weights"], instance_of(torch.Tensor))
20 | assert_that(config["tuple"], instance_of(tuple))
21 |
--------------------------------------------------------------------------------
/tests/functionals/distributed/config.yml:
--------------------------------------------------------------------------------
1 | training:
2 | nb_epochs: 100
3 | batch_size: 32
4 |
5 | models:
6 | SimpleNet:
7 | type: SimpleNet
8 | params:
9 | optimizer:
10 | type: FusedSGD
11 | params:
12 | lr: 0.01
13 | momentum: 0.5
14 | weight_decay: 0
15 | scheduler:
16 | type: ReduceLROnPlateau
17 | params:
18 | mode: 'min'
19 | factor: 0.1
20 | patience: 3
21 | criterion:
22 | type: CrossEntropyLoss
23 | params:
24 | metric:
25 | type: Accuracy
26 | params:
27 |
28 | visdom:
29 | server: "http://10.0.3.9"
30 | port: 8097
31 | env: "Kerosene-MNIST"
--------------------------------------------------------------------------------
/kerosene/training/__init__.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 |
4 | class BaseStatus(Enum):
5 |
6 | def __str__(self):
7 | return self.value
8 |
9 | def __eq__(self, other):
10 | if isinstance(other, str):
11 | return self.value == other
12 | elif isinstance(other, BaseStatus):
13 | return self.value == other.value
14 |
15 | def __hash__(self):
16 | return hash(self.value)
17 |
18 |
19 | class Status(BaseStatus):
20 | INITIALIZING = "Initializing"
21 | INITIALIZED = "Initialized"
22 | READY = "Ready"
23 | TRAINING = "Training"
24 | VALIDATING = "Validating"
25 | TESTING = "Testing"
26 | FINALIZED = "Finalized"
27 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/documentation-issue.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Documentation issue
3 | about: Create a report to help us improve documentation
4 | title: "[DOCUMENTATION]"
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Describe the documentation issue**
11 | A clear and concise description of what the documentation issue is.
12 |
13 | **System information**
14 | - Framework version: [e.g. 0.1.0]
15 | - Python version: [e.g. Python 3.7]
16 | - Doc Link:
17 |
18 | **Expected**
19 | A clear and concise description of what you expected to happen.
20 |
21 | **Screenshots**
22 | If applicable, add screenshots to help explain your problem.
23 |
24 | **Describe the documentation issue**
25 | Add any other context about the problem here.
26 |
--------------------------------------------------------------------------------
/examples/mnist/models.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | from torch.nn import Module, Linear, Dropout2d, Conv2d
3 |
4 |
5 | class SimpleConvNet(Module):
6 | def __init__(self):
7 | super().__init__()
8 | self.conv1 = Conv2d(1, 10, kernel_size=5)
9 | self.conv2 = Conv2d(10, 20, kernel_size=5)
10 | self.conv2_drop = Dropout2d()
11 | self.fc1 = Linear(320, 50)
12 | self.fc2 = Linear(50, 10)
13 |
14 | def forward(self, x):
15 | x = F.relu(F.max_pool2d(self.conv1(x), 2))
16 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
17 | x = x.view(-1, 320)
18 | x = F.relu(self.fc1(x))
19 | x = F.dropout(x)
20 | x = self.fc2(x)
21 | return x
22 |
--------------------------------------------------------------------------------
/tests/nn/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2019 SAMITorch Authors. All Rights Reserved.
3 | # #
4 | # Licensed under the MIT License;
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | # #
8 | # https://opensource.org/licenses/MIT
9 | # #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 |
--------------------------------------------------------------------------------
/tests/functionals/models.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | from torch.nn import Module, Linear, Dropout2d, Conv2d
3 |
4 |
5 | class SimpleNet(Module):
6 | def __init__(self):
7 | super().__init__()
8 | self.conv1 = Conv2d(1, 10, kernel_size=5)
9 | self.conv2 = Conv2d(10, 20, kernel_size=5)
10 | self.conv2_drop = Dropout2d()
11 | self.fc1 = Linear(320, 50)
12 | self.fc2 = Linear(50, 10)
13 |
14 | def forward(self, x):
15 | x = F.relu(F.max_pool2d(self.conv1(x), 2))
16 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
17 | x = x.view(-1, 320)
18 | x = F.relu(self.fc1(x))
19 | x = F.dropout(x, training=self.training)
20 | x = self.fc2(x)
21 | return x
--------------------------------------------------------------------------------
/kerosene/utils/constants.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2019 Kerosene Authors. All Rights Reserved.
3 | #
4 | # Licensed under the MIT License;
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://opensource.org/licenses/MIT
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | EPSILON = 1e-15
17 | CHECKPOINT_EXT = ".tar"
18 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: "[FEATURE]"
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Is your feature request related to a problem? Please describe.**
11 | A clear and concise description of what the problem is.
12 |
13 | **Describe the solution you'd like**
14 | A clear and concise description of what you want to happen.
15 |
16 | **Describe alternatives you've considered**
17 | A clear and concise description of any alternative solutions or features you've considered.
18 |
19 | **Will this change the current api? How?**
20 | A clear and concise description of how this feature request could change the current api.
21 |
22 | **Additional context**
23 | Add any other context or screenshots about the feature request here.
24 |
--------------------------------------------------------------------------------
/kerosene/loggers/visdom/__init__.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 |
4 | class PlotType(Enum):
5 | LINE_PLOT = "Line Plot"
6 | IMAGES_PLOT = "Images Plot"
7 | IMAGE_PLOT = "Image Plot"
8 | PIE_PLOT = "Pie Plot"
9 | TEXT_PLOT = "Text Plot"
10 | HISTOGRAM_PLOT = "Histogram Plot"
11 | SCATTER_PLOT = "Scatter Plot"
12 | STEM_PLOT = "Stem Plot"
13 | HEATMAP_PLOT = "Heatmap Plot"
14 | BAR_PLOT = "Bar Plot"
15 | BOX_PLOT = "Box Plot"
16 | SURFACE_PLOT = "Surface Plot"
17 | CONTOUR_PLOT = "Contour Plot"
18 | QUIVER_PLOT = "Quiver Plot"
19 | MESH_PLOT = "Mesh Plot"
20 | MATPLOTLIB_PLOT = "Matplotlib Plot"
21 |
22 | def __str__(self):
23 | return self.value
24 |
25 |
26 | class PlotFrequency(Enum):
27 | EVERY_STEP = "Step"
28 | EVERY_EPOCH = "Epoch"
29 |
30 | def __str__(self):
31 | return self.value
32 |
--------------------------------------------------------------------------------
/kerosene/models/models.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2019 Kerosene Authors. All Rights Reserved.
3 | #
4 | # Licensed under the MIT License;
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://opensource.org/licenses/MIT
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | from abc import ABC, abstractmethod
17 |
18 |
19 | class ModelFactory(ABC):
20 | @abstractmethod
21 | def create(self, model_type, params):
22 | raise NotImplementedError
23 |
--------------------------------------------------------------------------------
/kerosene/events/handlers/early_stopping.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from kerosene.events import MonitorMode, TemporalEvent
4 | from kerosene.events.handlers.base_monitor_watcher import MonitorWatcher, MonitorPatienceExceeded
5 | from kerosene.training.trainers import Trainer
6 |
7 |
8 | class EarlyStopping(MonitorWatcher):
9 | LOGGER = logging.getLogger("EarlyStopping")
10 |
11 | def __init__(self, monitor_fn, mode: MonitorMode, min_delta=0.01, patience=3):
12 | super(EarlyStopping, self).__init__(monitor_fn, mode, min_delta, patience)
13 |
14 | def __call__(self, temporal_event: TemporalEvent, monitors, trainer: Trainer):
15 | for model_trainer in trainer.model_trainers:
16 | try:
17 | value = self._monitor_fn(model_trainer)
18 | self.watch(model_trainer.name, value)
19 | except MonitorPatienceExceeded as e:
20 | model_trainer.finalize()
21 |
--------------------------------------------------------------------------------
/tests/utils/test_configs.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from hamcrest import *
4 |
5 | from kerosene.configs.configs import TrainerConfiguration, DatasetConfiguration
6 | from kerosene.utils.configs import configs_to_html
7 |
8 |
9 | class ConfigTests(unittest.TestCase):
10 |
11 | def setUp(self) -> None:
12 | self._config_1 = TrainerConfiguration({"num_epochs": 30, "batch_size": 32})
13 | self._config_2 = DatasetConfiguration({"path": "/home/data", "validation_split": 0.2})
14 |
15 | def test_should_produce_html_from_configs(self):
16 | config_html = configs_to_html([self._config_1, self._config_2])
17 | assert_that(config_html[0],
18 | is_('
Training Configuration
\n num_epochs: 30
\nbatch_size: 32
'))
19 | assert_that(config_html[1],
20 | is_('Dataset Configuration
\n path: /home/data
\nvalidation_split: 0.2
'))
21 |
--------------------------------------------------------------------------------
/tests/nn/test_functional.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import torch
4 | from hamcrest import assert_that, equal_to
5 |
6 | import kerosene.nn.functional as F
7 |
8 |
9 | class TestJensenShannonDivergence(unittest.TestCase):
10 | PROB_DIST1, PROB_DIST2, PROB_DIST3 = [1 / 2, 1 / 2, 0], [0, 1 / 10, 9 / 10], [1 / 3, 1 / 3, 1 / 3]
11 |
12 | def test_should_compute_jensen_shannon_divergence(self):
13 | prob_distributions = torch.tensor([[self.PROB_DIST1, self.PROB_DIST2, self.PROB_DIST3]])
14 | expected_results = torch.tensor([0.378889])
15 |
16 | assert_that(F.js_div(prob_distributions), equal_to(expected_results))
17 |
18 | def test_should_compute_jensen_shannon_divergence_of_same_distribution(self):
19 | prob_distributions = torch.tensor([[self.PROB_DIST1, self.PROB_DIST1, self.PROB_DIST1]])
20 | expected_results = torch.tensor([0.0])
21 |
22 | assert_that(F.js_div(prob_distributions), equal_to(expected_results))
23 |
--------------------------------------------------------------------------------
/kerosene/utils/configs.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2019 Kerosene Authors. All Rights Reserved.
3 | #
4 | # Licensed under the MIT License;
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://opensource.org/licenses/MIT
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | from typing import List
17 |
18 | from kerosene.configs.configs import HtmlConfiguration
19 |
20 |
21 | def configs_to_html(configs: List[HtmlConfiguration]):
22 | return list(map(lambda config: config.to_html(), configs))
23 |
--------------------------------------------------------------------------------
/tests/nn/utils/test_gradients.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from hamcrest import *
4 |
5 | from kerosene.nn.utils.gradients import GradientClippingStrategyFactory, GradientClippingStrategyType, \
6 | GradientClippingStrategy
7 |
8 |
9 | class TestGradientClippingStrategyFactory(unittest.TestCase):
10 | def setUp(self) -> None:
11 | pass
12 |
13 | def test_should_return_a_gradient_norm_clipping_strategy(self):
14 | strategy = GradientClippingStrategyFactory().create(GradientClippingStrategyType.Norm,
15 | {"max_norm": 1.0, "norm_type": 2})
16 |
17 | assert_that(strategy, instance_of(GradientClippingStrategy))
18 |
19 | def test_should_return_a_gradient_value_clipping_strategy(self):
20 | strategy = GradientClippingStrategyFactory().create(GradientClippingStrategyType.Value, {"clip_value": 1.0})
21 |
22 | assert_that(strategy, instance_of(GradientClippingStrategy))
23 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug---performance-report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug / Performance report
3 | about: Create a report to help us improve
4 | title: "[BUG]"
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Describe the bug**
11 | A clear and concise description of what the bug is.
12 |
13 | **To Reproduce**
14 | Steps to reproduce the behavior:
15 | 1. Go to '...'
16 | 2. Execute this script '...'
17 | 3. ...
18 |
19 | **Expected behavior**
20 | A clear and concise description of what you expected to happen.
21 |
22 | **Screenshots**
23 | If applicable, add screenshots to help explain your problem.
24 |
25 | **Desktop (please complete the following information):**
26 | - OS platform and distribution: [e.g. Linux Ubuntu 18.04.2 LTS]
27 | - Python version: [e.g. Python 3.7]
28 | - CUDA/cuDNN version: [e.g. CUDA 10.0 / cuDNN 7.3.1]
29 | - GPU model and memory: [e.g. NVIDIA GeForce RTX 2080 8 GB]
30 |
31 | **Additional context**
32 | Add any other context about the problem here.
33 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Benoit Anctil-Robitaille
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 |
3 | from setuptools import find_packages, setup
4 |
5 | # The text of the README file
6 | README_CONTENT = (pathlib.Path(__file__).parent / "README.md").read_text()
7 |
8 | setup(
9 | name='torch-kerosene',
10 | version='0.2.4',
11 | description='Deep Learning framework for fast and clean research development with Pytorch',
12 | long_description=README_CONTENT,
13 | long_description_content_type='text/markdown',
14 | author='Benoit Anctil-Robitaille',
15 | author_email='benoit.anctil-robitaille.1@ens.etsmtl.ca',
16 | license='MIT',
17 | classifiers=[
18 | "License :: OSI Approved :: MIT License",
19 | "Programming Language :: Python :: 3",
20 | "Programming Language :: Python :: 3.7"],
21 | packages=find_packages(exclude=("tests",)),
22 | install_requires=['numpy>=1.16.1',
23 | 'visdom>=0.1.8.8',
24 | 'pytorch-ignite>= 0.2.0',
25 | 'torch>=1.1',
26 | 'torchvision>=0.2.1',
27 | 'PyYAML',
28 | 'crayons',
29 | 'beautifultable']
30 | )
31 |
--------------------------------------------------------------------------------
/kerosene/events/publishers/base_publisher.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2019 Kerosene Authors. All Rights Reserved.
3 | #
4 | # Licensed under the MIT License;
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://opensource.org/licenses/MIT
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | from abc import ABC, abstractmethod
17 |
18 | from kerosene.events import BaseEvent, TemporalEvent
19 |
20 |
21 | class EventPublisher(ABC):
22 | def __init__(self):
23 | self._event_handlers = {}
24 |
25 | @property
26 | @abstractmethod
27 | def sender(self):
28 | raise NotImplementedError()
29 |
30 | @abstractmethod
31 | def with_event_handler(self, handler, event: BaseEvent):
32 | raise NotImplementedError()
33 |
34 | def fire(self, temporal_event: TemporalEvent, monitors: dict = None):
35 | if temporal_event.event in self._event_handlers.keys():
36 | for handler in self._event_handlers[temporal_event.event]:
37 | handler(temporal_event, monitors, self.sender)
38 |
--------------------------------------------------------------------------------
/tests/base_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from ignite.metrics import Accuracy, Recall
4 | from mockito import mock, spy
5 | from torch import nn
6 | from torch.optim import Optimizer, lr_scheduler
7 | from torch.utils.data import DataLoader
8 |
9 | from kerosene.nn.utils.gradients import GradientClippingStrategy
10 | from kerosene.training.trainers import ModelTrainer
11 |
12 |
13 | class BaseTestEnvironment(unittest.TestCase):
14 | MODEL_NAME = "Harry Potter"
15 | TRAINER_NAME = "Drago Malfoy"
16 |
17 | def setUp(self) -> None:
18 | self._model_mock = mock(nn.Module)
19 | self._criterion_mock = spy(nn.CrossEntropyLoss())
20 | self._optimizer_mock = mock(Optimizer)
21 | self._scheduler_mock = mock(lr_scheduler)
22 | self._accuracy_computer_mock = spy(Accuracy())
23 | self._recall_computer_mock = spy(Recall())
24 | self._gradient_clipping_strategy = mock(GradientClippingStrategy)
25 |
26 | self._training_data_loader_mock = mock(DataLoader)
27 | self._valid_data_loader_mock = mock(DataLoader)
28 | self._test_data_loader_mock = mock(DataLoader)
29 |
30 | self._model_trainer = ModelTrainer(self.MODEL_NAME, self._model_mock, self._criterion_mock,
31 | self._optimizer_mock, self._scheduler_mock,
32 | {"Accuracy": self._accuracy_computer_mock,
33 | "Recall": self._recall_computer_mock}, self._gradient_clipping_strategy)
34 |
--------------------------------------------------------------------------------
/tests/utils/test_devices.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import torch
4 | from hamcrest import *
5 |
6 | from kerosene.utils.devices import on_multiple_gpus, on_single_gpu, get_devices
7 |
8 |
9 | class DeviceTests(unittest.TestCase):
10 |
11 | def setUp(self) -> None:
12 | self._multiple_gpus_devices = [torch.device("cuda:0"), torch.device("cuda:1")]
13 | self._single_gpu_device = [torch.device("cuda:0")]
14 | self._single_cpu_device = [torch.device("cpu")]
15 |
16 | def test_on_single_gpu_should_return_true_with_single_GPU_device(self):
17 | assert_that(on_single_gpu(self._single_gpu_device), is_(True))
18 |
19 | def test_on_single_gpu_should_return_false_with_multiple_GPU_devices(self):
20 | assert_that(on_single_gpu(self._multiple_gpus_devices), is_(False))
21 |
22 | def test_on_multiple_gpu_should_return_false_with_single_GPU_device(self):
23 | assert_that(on_multiple_gpus(self._single_gpu_device), is_(False))
24 |
25 | def test_on_multiple_gpu_should_return_true_with_multiple_GPU_devices(self):
26 | assert_that(on_multiple_gpus(self._multiple_gpus_devices), is_(True))
27 |
28 | def test_get_devices_should_return_at_least_one_cuda_enabled_device(self):
29 | if torch.cuda.is_available():
30 | assert_that(get_devices(), has_item(torch.device("cuda:0")))
31 | assert_that(get_devices(), has_length(torch.cuda.device_count()))
32 | assert_that(get_devices(), is_(
33 | [torch.device("cuda:{}".format(id)) for id in range(int(torch.cuda.device_count()))]))
34 | else:
35 | assert_that(get_devices(), is_([torch.device("cpu")]))
36 |
--------------------------------------------------------------------------------
/tests/configs/valid_config.yml:
--------------------------------------------------------------------------------
1 | training:
2 | nb_epochs: 10
3 | batch_size_train: 10
4 | batch_size_valid: 10
5 |
6 | models:
7 | SimpleNet:
8 | type: SimpleNet
9 | optimizer:
10 | type: SGD
11 | params:
12 | lr: 0.01
13 | momentum: 0.5
14 | weight_decay: 0
15 | scheduler:
16 | type: ReduceLROnPlateau
17 | params:
18 | mode: 'min'
19 | factor: 0.1
20 | patience: 3
21 | criterion:
22 | cycle:
23 | type: "L1Loss"
24 | gan:
25 | type: "MSELoss"
26 | metrics:
27 | Dice:
28 | type: Dice
29 | params:
30 | num_classes: 4
31 | reduction: !!null
32 | ignore_index: 0
33 | average: !!null
34 | weight: !!null
35 | Accuracy:
36 | type: Accuracy
37 | gradients:
38 | type: 'norm'
39 | params:
40 | max_norm: 1.0
41 | SimpleNet2:
42 | type: SimpleNet
43 | optimizer:
44 | type: SGD
45 | params:
46 | lr: 0.01
47 | momentum: 0.5
48 | weight_decay: 0
49 | scheduler:
50 | type: ReduceLROnPlateau
51 | params:
52 | mode: 'min'
53 | factor: 0.1
54 | patience: 3
55 | criterion:
56 | cycle:
57 | type: "L1Loss"
58 | gan:
59 | type: "MSELoss"
60 | metrics:
61 | Dice:
62 | type: Dice
63 | params:
64 | num_classes: 4
65 | reduction: !!null
66 | ignore_index: 0
67 | average: !!null
68 | weight: !!null
69 | Accuracy:
70 | type: Accuracy
71 | gradients:
72 | type: 'norm'
73 | params:
74 | max_norm: 1.0
75 |
--------------------------------------------------------------------------------
/kerosene/metrics/gauges.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2019 Kerosene Authors. All Rights Reserved.
3 | #
4 | # Licensed under the MIT License;
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://opensource.org/licenses/MIT
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | from abc import ABC, abstractmethod
17 |
18 |
19 | class Gauge(ABC):
20 |
21 | @abstractmethod
22 | def update(self, **kwargs):
23 | raise NotImplementedError()
24 |
25 | @abstractmethod
26 | def compute(self):
27 | raise NotImplementedError()
28 |
29 | @abstractmethod
30 | def reset(self):
31 | raise NotImplementedError()
32 |
33 |
34 | class AverageGauge(Gauge):
35 |
36 | def __init__(self):
37 | self._value = 0.0
38 | self._sum = 0.0
39 | self._count = 0
40 |
41 | @property
42 | def count(self):
43 | return self._count
44 |
45 | def has_been_updated(self):
46 | return True if self._count > 0 else False
47 |
48 | def update(self, value, n=1):
49 | self._value = value
50 | self._sum += value * n
51 | self._count += n
52 |
53 | def compute(self):
54 | return self._sum / self._count if self._count != 0 else 0.0
55 |
56 | def reset(self):
57 | self._value = 0.0
58 | self._sum = 0.0
59 | self._count = 0
60 |
--------------------------------------------------------------------------------
/kerosene/utils/devices.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2019 Kerosene Authors. All Rights Reserved.
3 | #
4 | # Licensed under the MIT License;
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://opensource.org/licenses/MIT
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | from typing import List
17 |
18 | import torch
19 |
20 |
21 | def on_cpu(device: torch.device):
22 | return str(device) == "cpu"
23 |
24 |
25 | def on_gpu(device: torch.device):
26 | return device.type == "cuda"
27 |
28 |
29 | def on_gpus(devices: List[torch.device]):
30 | return all([device.type == "cuda" for device in devices])
31 |
32 |
33 | def on_single_device(devices: List[torch.device]):
34 | return len(devices) == 1
35 |
36 |
37 | def on_multiple_devices(devices: List[torch.device]):
38 | return len(devices) > 1
39 |
40 |
41 | def on_single_gpu(devices: List[torch.device]):
42 | return on_single_device(devices) and on_gpus(devices)
43 |
44 |
45 | def on_multiple_gpus(devices: List[torch.device]):
46 | return on_multiple_devices(devices) and on_gpus(devices)
47 |
48 |
49 | def get_devices():
50 | return [torch.device("cuda:{}".format(device_id)) for device_id in
51 | range(torch.cuda.device_count())] if torch.cuda.is_available() else [torch.device("cpu")]
52 |
53 |
54 | def num_gpus():
55 | return len(list(([device.type == "cuda" for device in get_devices()])))
56 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
106 | idea
107 | .idea
108 | /data/*
109 | /tests/functionals/distributed/files/
110 | files
--------------------------------------------------------------------------------
/kerosene/events/handlers/base_handler.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2019 Kerosene Authors. All Rights Reserved.
3 | #
4 | # Licensed under the MIT License;
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://opensource.org/licenses/MIT
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | from abc import ABC, abstractmethod
17 | from typing import List, Union
18 |
19 | from kerosene.events import TemporalEvent, BaseEvent
20 | from kerosene.events.exceptions import UnsupportedEventException
21 |
22 |
23 | class EventHandler(ABC):
24 |
25 | def __init__(self, supported_events: List[Union[BaseEvent, TemporalEvent]] = None, every=1):
26 | self._every = every
27 | self._supported_events = supported_events
28 |
29 | @property
30 | def every(self):
31 | return self._every
32 |
33 | @property
34 | def supported_events(self):
35 | return self._supported_events
36 |
37 | def should_handle(self, event: TemporalEvent):
38 | if (self.supported_events is not None) and (event not in self.supported_events):
39 | raise UnsupportedEventException(self.supported_events, event)
40 |
41 | if iter == 0 and self._every != 1:
42 | return False
43 | else:
44 | return event.iteration % self._every == 0
45 |
46 | @abstractmethod
47 | def __call__(self, temporal_event: TemporalEvent, monitors, sender):
48 | raise NotImplementedError()
49 |
--------------------------------------------------------------------------------
/kerosene/loggers/visdom/data.py:
--------------------------------------------------------------------------------
1 | from typing import Union, List
2 |
3 |
4 | class VisdomData(object):
5 |
6 | def __init__(self, source_name, variable_name: Union[List[str], str], plot_type, plot_frequency, x, y,
7 | params: dict = None):
8 | self._source_name = source_name
9 | self._variable_name = variable_name
10 | self._plot_type = plot_type
11 | self._plot_frequency = plot_frequency
12 | self._x = x
13 | self._y = y
14 | self._params = params
15 |
16 | @property
17 | def id(self):
18 | return hash(self)
19 |
20 | @property
21 | def source_name(self):
22 | return self._source_name
23 |
24 | @property
25 | def variable_name(self):
26 | return self._variable_name
27 |
28 | @property
29 | def plot_type(self):
30 | return self._plot_type
31 |
32 | @property
33 | def plot_frequency(self):
34 | return self._plot_frequency
35 |
36 | @property
37 | def x(self):
38 | return self._x
39 |
40 | @property
41 | def y(self):
42 | return self._y
43 |
44 | @property
45 | def params(self):
46 | return self._params
47 |
48 | def __hash__(self):
49 | return hash(self._source_name + self._variable_name + str(self._plot_frequency) + str(self._plot_type))
50 |
51 | def __str__(self):
52 | return "Source: {} Variable: {} Plot Type: {} Frequency: {} x: {} y: {}".format(self.source_name,
53 | self.variable_name,
54 | str(self.plot_type),
55 | str(self.plot_frequency),
56 | self.x, self.y)
57 |
--------------------------------------------------------------------------------
/kerosene/utils/files.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2019 Kerosene Authors. All Rights Reserved.
3 | #
4 | # Licensed under the MIT License;
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://opensource.org/licenses/MIT
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | import os
17 | from glob import glob
18 | from typing import List, Tuple
19 |
20 |
21 | def split_filename(filepath: str) -> Tuple[str, str, str]:
22 | """
23 | Split a filepath into the directory, base, and extension
24 |
25 | Args:
26 | filepath (str): The base file path.
27 |
28 | Returns:
29 | Tuple: The complete file path, base path and file extension.
30 | """
31 | path = os.path.dirname(filepath)
32 | filename = os.path.basename(filepath)
33 | base, ext = os.path.splitext(filename)
34 | if ext == '.gz':
35 | base, ext2 = os.path.splitext(base)
36 | ext = ext2 + ext
37 | return path, base, ext
38 |
39 |
40 | def extract_file_paths(path: str, ext='*.nii*') -> List[str]:
41 | """
42 | Grab all `ext` files in a directory and sort them for consistency.
43 |
44 | Args:
45 | path (str): File path.
46 | ext (str): File's extension to grab.
47 |
48 | Returns:
49 | list: A list of string containing every file paths.
50 | """
51 | file_paths = sorted(glob(os.path.join(path, ext)))
52 | return file_paths
53 |
54 |
55 | def should_create_dir(path, dir_name):
56 | return not os.path.exists(os.path.join(path, dir_name))
57 |
--------------------------------------------------------------------------------
/kerosene/loggers/visdom/config.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2019 Kerosene Authors. All Rights Reserved.
3 | #
4 | # Licensed under the MIT License;
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://opensource.org/licenses/MIT
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | from kerosene.configs.parsers import YamlConfigurationParser
17 |
18 |
19 | class VisdomConfiguration(object):
20 | def __init__(self, port, server, env, filename, offline=False):
21 | self._port = port
22 | self._server = server
23 | self._env = env
24 | self._filename = filename
25 | self._offline = offline
26 |
27 | @property
28 | def port(self):
29 | return self._port
30 |
31 | @property
32 | def server(self):
33 | return self._server
34 |
35 | @property
36 | def env(self):
37 | return self._env
38 |
39 | @property
40 | def filename(self):
41 | return self._filename
42 |
43 | @property
44 | def offline(self):
45 | return self._offline
46 |
47 | @classmethod
48 | def from_dict(cls, config_dict):
49 | return cls(config_dict.get("port", 8097), config_dict.get("server", "http://localhost"),
50 | config_dict.get("env", "main"), config_dict.get("filename", None),
51 | config_dict.get("offline", False))
52 |
53 | @classmethod
54 | def from_yml(cls, yml_file, yml_tag="visdom"):
55 | config = YamlConfigurationParser.parse_section(yml_file, yml_tag)
56 | return VisdomConfiguration.from_dict(config)
57 |
--------------------------------------------------------------------------------
/kerosene/utils/tensors.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2019 Kerosene Authors. All Rights Reserved.
3 | #
4 | # Licensed under the MIT License;
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://opensource.org/licenses/MIT
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | import torch
17 |
18 |
19 | def to_onehot(indices: torch.Tensor, num_classes: int) -> torch.Tensor:
20 | """
21 | Convert a tensor of indices of any shape `(N, ...)` to a tensor of one-hot indicators of shape
22 | `(N, num_classes, ...)`.
23 |
24 | Args:
25 | indices (:obj:`torch.Tensor`): Tensor of indices.
26 | num_classes (int): The number of classes of the problem.
27 |
28 | Returns:
29 | :obj:`torch.Tensor`: The one-hot representation of the input tensor.
30 | """
31 | onehot = torch.zeros(indices.shape[0], num_classes, *indices.shape[1:], device=indices.device)
32 | return onehot.scatter_(1, indices.unsqueeze(1), 1)
33 |
34 |
35 | def flatten(tensor: torch.Tensor) -> torch.Tensor:
36 | """
37 | Flattens a given tensor such that the channel axis is first.
38 | The shapes are transformed as follows: (N, C, D, H, W) -> (C, N * D * H * W)
39 |
40 | Args:
41 | tensor (:obj:`torch.Tensor`): Tensor to flatten.
42 |
43 | Returns:
44 | :obj:`torch.Tensor`: The flattened, 1-D tensor.
45 | """
46 | C = tensor.size(1)
47 | # new axis order
48 | axis_order = (1, 0) + tuple(range(2, tensor.dim()))
49 | # Transpose: (N, C, D, H, W) -> (C, N, D, H, W)
50 | transposed = tensor.permute(axis_order).contiguous()
51 | # Flatten: (C, N, D, H, W) -> (C, N * D * H * W)
52 | return transposed.view(C, -1)
53 |
--------------------------------------------------------------------------------
/examples/mnist/main.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torchvision
4 | from torch.utils.data import DataLoader
5 | from torchvision.transforms import Compose, ToTensor, Normalize
6 |
7 | from examples.mnist.models import SimpleConvNet
8 | from kerosene.configs.configs import RunConfiguration
9 | from kerosene.configs.parsers import YamlConfigurationParser
10 | from kerosene.events.handlers.console import PrintTrainingStatus
11 | from kerosene.events.handlers.visdom import PlotMonitors, PlotAvgGradientPerLayer
12 | from kerosene.loggers.visdom.config import VisdomConfiguration
13 | from kerosene.loggers.visdom.visdom import VisdomLogger
14 | from kerosene.training.events import Event
15 | from kerosene.training.trainers import ModelTrainerFactory, SimpleTrainer
16 |
17 | if __name__ == "__main__":
18 | logging.basicConfig(level=logging.INFO)
19 | CONFIG_FILE_PATH = "config.yml"
20 |
21 | model_trainer_config, training_config = YamlConfigurationParser.parse(CONFIG_FILE_PATH)
22 |
23 | train_loader = DataLoader(torchvision.datasets.MNIST('./files/', train=True, download=True, transform=Compose(
24 | [ToTensor(), Normalize((0.1307,), (0.3081,))])), batch_size=training_config.batch_size_train, shuffle=True)
25 |
26 | test_loader = DataLoader(torchvision.datasets.MNIST('./files/', train=False, download=True, transform=Compose(
27 | [ToTensor(), Normalize((0.1307,), (0.3081,))])), batch_size=training_config.batch_size_valid, shuffle=True)
28 |
29 | visdom_logger = VisdomLogger(VisdomConfiguration.from_yml(CONFIG_FILE_PATH))
30 |
31 | # Initialize the model trainers
32 | model_trainer = ModelTrainerFactory(model=SimpleConvNet()).create(model_trainer_config)
33 |
34 | # Train with the training strategy
35 | SimpleTrainer("MNIST Trainer", train_loader, test_loader, None, model_trainer, RunConfiguration(use_amp=False)) \
36 | .with_event_handler(PlotMonitors(every=500, visdom_logger=visdom_logger), Event.ON_BATCH_END) \
37 | .with_event_handler(PlotMonitors(visdom_logger=visdom_logger), Event.ON_EPOCH_END) \
38 | .with_event_handler(PlotAvgGradientPerLayer(every=500, visdom_logger=visdom_logger), Event.ON_TRAIN_BATCH_END) \
39 | .with_event_handler(PrintTrainingStatus(every=100), Event.ON_BATCH_END) \
40 | .train(training_config.nb_epochs)
41 |
--------------------------------------------------------------------------------
/kerosene/optim/schedulers.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2019 Kerosene Authors. All Rights Reserved.
3 | #
4 | # Licensed under the MIT License;
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://opensource.org/licenses/MIT
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | from enum import Enum
17 | from typing import Union
18 |
19 | from torch.optim import Optimizer
20 | from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR, MultiStepLR, ExponentialLR, CosineAnnealingLR, \
21 | _LRScheduler
22 |
23 |
24 | class SchedulerType(Enum):
25 | ReduceLROnPlateau = "ReduceLROnPlateau"
26 | StepLR = "StepLR"
27 | MultiStepLR = "MultiStepLR"
28 | ExponentialLR = "ExponentialLR"
29 | CosineAnnealingLR = "CosineAnnealingLR"
30 | CyclicLR = "CyclicLR"
31 | CosineAnnealingWarmRestarts = "CosineAnnealingWarmRestarts"
32 |
33 |
34 | class SchedulerFactory(object):
35 |
36 | def __init__(self):
37 | self._schedulers = {
38 | "ReduceLROnPlateau": ReduceLROnPlateau,
39 | "StepLR": StepLR,
40 | "MultiStepLR": MultiStepLR,
41 | "ExponentialLR": ExponentialLR,
42 | "CosineAnnealingLR": CosineAnnealingLR,
43 | }
44 |
45 | def create(self, scheduler_type: Union[str, SchedulerType], optimizer: Optimizer, params):
46 | return self._schedulers[str(scheduler_type)](optimizer=optimizer, **params) if params is not None else \
47 | self._schedulers[str(scheduler_type)](optimizer=optimizer)
48 |
49 | def register(self, function: str, creator: _LRScheduler):
50 | """
51 | Add a new activation layer.
52 | Args:
53 | function (str): Activation layer name.
54 | creator (:obj:`torch.nn.Module`): A torch module object wrapping the new custom optimizer function.
55 | """
56 | self._schedulers[function] = creator
57 |
--------------------------------------------------------------------------------
/examples/mnist/main_hyperopt.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import math
3 |
4 | import torchvision
5 | from hyperopt import fmin, tpe, hp, STATUS_OK
6 | from torch.utils.data import DataLoader
7 | from torchvision.transforms import Compose, ToTensor, Normalize
8 |
9 | from examples.mnist.models import SimpleConvNet
10 | from kerosene.configs.configs import RunConfiguration
11 | from kerosene.configs.parsers import YamlConfigurationParser
12 | from kerosene.events import Monitor, Phase
13 | from kerosene.events.handlers.console import PrintTrainingStatus
14 | from kerosene.training.events import Event
15 | from kerosene.training.trainers import ModelTrainerFactory, SimpleTrainer
16 |
17 |
18 | def objective(hyper_params):
19 | # Update the trainer with the new hyper-parameters
20 | model_config.update(hyper_params)
21 |
22 | # Create the model trainer
23 | model_trainer = ModelTrainerFactory(model=SimpleConvNet()).create(model_config)
24 |
25 | # Train with the training strategy
26 | monitor = SimpleTrainer("MNIST Trainer", train_loader, valid_loader, None, model_trainer,
27 | RunConfiguration(use_amp=False)) \
28 | .with_event_handler(PrintTrainingStatus(every=100), Event.ON_BATCH_END) \
29 | .train(training_config.nb_epochs)
30 |
31 | return {'loss': monitor["SimpleNet"][Phase.VALIDATION][Monitor.LOSS]["CrossEntropy"], 'status': STATUS_OK}
32 |
33 |
34 | if __name__ == "__main__":
35 | logging.basicConfig(level=logging.INFO)
36 | CONFIG_FILE_PATH = "config.yml"
37 |
38 | model_config, training_config = YamlConfigurationParser.parse(CONFIG_FILE_PATH)
39 |
40 | train_loader = DataLoader(torchvision.datasets.MNIST('./files/', train=True, download=True, transform=Compose(
41 | [ToTensor(), Normalize((0.1307,), (0.3081,))])), batch_size=training_config.batch_size_train, shuffle=True)
42 |
43 | valid_loader = DataLoader(torchvision.datasets.MNIST('./files/', train=False, download=True, transform=Compose(
44 | [ToTensor(), Normalize((0.1307,), (0.3081,))])), batch_size=training_config.batch_size_valid, shuffle=True)
45 |
46 | search_space = {
47 | 'SimpleNet': {'optimizer': {'params': {'lr': hp.loguniform('lr', math.log(0.0005), math.log(0.01))}}}
48 | }
49 |
50 | best = fmin(objective, space=search_space, algo=tpe.suggest, max_evals=2)
51 |
52 | print("The best hyper-parameters are: {}".format(best))
53 |
--------------------------------------------------------------------------------
/kerosene/events/handlers/checkpoints.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from typing import Callable, List, Union
4 |
5 | import torch
6 |
7 | from kerosene.events import MonitorMode, TemporalEvent, Phase, Monitor
8 | from kerosene.events.handlers.base_monitor_watcher import MonitorWatcher, MonitorPatienceExceeded
9 | from kerosene.training.events import Event
10 | from kerosene.training.trainers import Trainer
11 | from kerosene.utils.constants import CHECKPOINT_EXT
12 | from kerosene.utils.files import should_create_dir
13 |
14 |
15 | class Checkpoint(MonitorWatcher):
16 | LOGGER = logging.getLogger("Checkpoint")
17 | SUPPORTED_EVENTS = [Event.ON_EPOCH_END]
18 |
19 | def __init__(self, path: str, model_name: Union[str, None], monitor_names: Union[List[str], str], delta: float,
20 | mode: MonitorMode, reduction: Callable = torch.mean):
21 | super(Checkpoint, self).__init__(mode, delta, patience=0, supported_events=self.SUPPORTED_EVENTS)
22 | self._path = path
23 | self._model_name = model_name
24 | self._reduction = reduction
25 | self._monitors_names = monitor_names if isinstance(monitor_names, list) else [monitor_names]
26 |
27 | def __call__(self, event: TemporalEvent, monitors: dict, trainer: Trainer):
28 | if self.should_handle(event):
29 | try:
30 | model_monitors = {**monitors[self._model_name][Phase.VALIDATION][Monitor.METRICS],
31 | **monitors[self._model_name][Phase.VALIDATION][Monitor.LOSS]}
32 | values = torch.cat([model_monitors[monitor_name] for monitor_name in self._monitors_names])
33 |
34 | self.watch(self._model_name, self._reduction(values))
35 | except MonitorPatienceExceeded as e:
36 | self._save(trainer.epoch, self._model_name, trainer.model_trainers[self._model_name].model_state,
37 | trainer.model_trainers[self._model_name].optimizer_state)
38 | except KeyError as e:
39 | self.LOGGER.warning("Invalid model or monitor name: {}".format(e))
40 |
41 | def _save(self, epoch_num, model_name, model_state, optimizer_states):
42 | if should_create_dir(self._path, model_name):
43 | os.makedirs(os.path.join(self._path, model_name))
44 | torch.save({"epoch_num": epoch_num,
45 | "model_state_dict": model_state,
46 | "optimizer_state_dict": optimizer_states},
47 | os.path.join(self._path, model_name, model_name + CHECKPOINT_EXT))
48 |
--------------------------------------------------------------------------------
/kerosene/loggers/visdom/visdom.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from abc import ABC, abstractmethod
3 | from typing import Union, List, Dict
4 |
5 | from visdom import Visdom
6 |
7 | from kerosene.loggers.visdom.config import VisdomConfiguration
8 | from kerosene.loggers.visdom.data import VisdomData
9 | from kerosene.loggers.visdom.plots import VisdomPlotFactory, VisdomPlot
10 |
11 |
12 | class BaseVisdomLogger(ABC):
13 |
14 | def __init__(self, visdom_config: VisdomConfiguration):
15 | self._visdom_config = visdom_config
16 |
17 | @abstractmethod
18 | def __call__(self, visdom_data: Union[List[VisdomData], VisdomData]):
19 | raise NotImplementedError()
20 |
21 | @property
22 | def visdom_config(self):
23 | return self._visdom_config
24 |
25 |
26 | class VisdomLogger(BaseVisdomLogger):
27 | LOGGER = logging.getLogger("VisdomLogger")
28 |
29 | def __init__(self, visdom_config: VisdomConfiguration, plot_factory: VisdomPlotFactory = VisdomPlotFactory()):
30 | super().__init__(visdom_config)
31 | self._plots: Dict[int, VisdomPlot] = {}
32 | self._plot_factory = plot_factory
33 | self._visdom_config = visdom_config
34 |
35 | self._visdom = Visdom(server=visdom_config.server, port=visdom_config.port, env=visdom_config.env,
36 | log_to_filename=visdom_config.filename, offline=visdom_config.offline)
37 |
38 | @property
39 | def plots(self):
40 | return self._plots
41 |
42 | def save(self):
43 | self._visdom.save(envs=[self._visdom_config.env])
44 |
45 | def __call__(self, visdom_data: Union[List[VisdomData], VisdomData]):
46 | visdom_data = [visdom_data] if not hasattr(visdom_data, '__iter__') else visdom_data
47 |
48 | for visdom_datum in visdom_data:
49 | if visdom_datum.id not in self._plots.keys():
50 | self._plots[visdom_datum.id] = self._plot_factory.create_plot(self._visdom, visdom_datum.plot_type)
51 |
52 | self._plots[visdom_datum.id].update(visdom_datum)
53 |
54 |
55 | class DummyVisdomLogger(BaseVisdomLogger):
56 | LOGGER = logging.getLogger("DummyVisdomLogger")
57 |
58 | def __init__(self, visdom_config: VisdomConfiguration = None):
59 | super().__init__(visdom_config)
60 |
61 | def __call__(self, visdom_data: Union[List[VisdomData], VisdomData]):
62 | if not isinstance(visdom_data, list):
63 | visdom_data = [visdom_data]
64 |
65 | self.LOGGER.debug("\n".join(list(map(lambda visdom_datum: str(visdom_datum), visdom_data))))
66 |
--------------------------------------------------------------------------------
/tests/nn/test_apex.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import torch
4 | from hamcrest import assert_that, equal_to
5 |
6 | from kerosene.nn.apex import ApexLoss
7 |
8 |
9 | class TestApexLoss(unittest.TestCase):
10 | VALID_LOSS_VALUE = torch.tensor([1.0, 2.0, 3.0])
11 | INVALID_LOSS_VALUE = "Ford Fiesta"
12 |
13 | VALID_SCALAR = 2.0
14 | INVALID_SCALAR = "Ford Focus"
15 |
16 | VALID_LOSS_ID = 0
17 | INVALID_LOSS_ID = "Ford F150"
18 |
19 | def test_should_add_losses(self):
20 | expected_result = ApexLoss(self.VALID_LOSS_ID, torch.tensor([2.0, 4.0, 6.0]), None)
21 |
22 | apex_loss1 = ApexLoss(self.VALID_LOSS_ID, self.VALID_LOSS_VALUE, None)
23 | apex_loss2 = ApexLoss(self.VALID_LOSS_ID, self.VALID_LOSS_VALUE, None)
24 |
25 | assert_that(apex_loss1 + apex_loss2, equal_to(expected_result))
26 |
27 | def test_should_multiply_by_a_scalar(self):
28 | expected_result = ApexLoss(self.VALID_LOSS_ID, self.VALID_SCALAR * self.VALID_LOSS_VALUE, None)
29 |
30 | loss = ApexLoss(self.VALID_LOSS_ID, self.VALID_LOSS_VALUE, None)
31 |
32 | assert_that(self.VALID_SCALAR * loss, equal_to(expected_result))
33 | assert_that(loss * self.VALID_SCALAR, equal_to(expected_result))
34 |
35 | # noinspection PyTypeChecker
36 | def test_should_divide_by_a_scalar(self):
37 | left_div_expected_result = ApexLoss(self.VALID_LOSS_ID, self.VALID_LOSS_VALUE / self.VALID_SCALAR,
38 | None)
39 | right__div_expected_result = ApexLoss(self.VALID_LOSS_ID, self.VALID_SCALAR / self.VALID_LOSS_VALUE,
40 | None)
41 |
42 | loss = ApexLoss(self.VALID_LOSS_ID, self.VALID_LOSS_VALUE, None)
43 |
44 | assert_that(loss / self.VALID_SCALAR, equal_to(left_div_expected_result))
45 | assert_that(self.VALID_SCALAR / loss, equal_to(right__div_expected_result))
46 |
47 | def test_should_compute_the_mean(self):
48 | expected_result = ApexLoss(self.VALID_LOSS_ID, torch.tensor([2.0]), None)
49 |
50 | loss = ApexLoss(self.VALID_LOSS_ID, self.VALID_LOSS_VALUE, None)
51 |
52 | assert_that(loss.mean(), equal_to(expected_result))
53 |
54 | # noinspection PyTypeChecker
55 | def test_should_substract_a_scalar(self):
56 | left_sub_expected_result = ApexLoss(self.VALID_LOSS_ID, self.VALID_LOSS_VALUE - self.VALID_SCALAR, None)
57 | right_sub_expected_result = ApexLoss(self.VALID_LOSS_ID, self.VALID_SCALAR - self.VALID_LOSS_VALUE, None)
58 |
59 | loss = ApexLoss(self.VALID_LOSS_ID, self.VALID_LOSS_VALUE, None)
60 |
61 | assert_that(loss - self.VALID_SCALAR, equal_to(left_sub_expected_result))
62 | assert_that(self.VALID_SCALAR - loss, equal_to(right_sub_expected_result))
63 |
--------------------------------------------------------------------------------
/tests/functionals/distributed/mnist_trainer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2019 Pierre-Luc Delisle. All Rights Reserved.
3 | #
4 | # Licensed under the MIT License;
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://opensource.org/licenses/MIT
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 | from typing import List
18 |
19 | import torch
20 | from torch.utils.data import DataLoader
21 |
22 | from kerosene.configs.configs import RunConfiguration
23 | from kerosene.training.trainers import ModelTrainer
24 | from kerosene.training.trainers import Trainer
25 | from kerosene.utils.devices import on_single_device
26 |
27 |
28 | class MNISTTrainer(Trainer):
29 |
30 | def __init__(self, training_config, model_trainers: List[ModelTrainer],
31 | train_data_loader: DataLoader, valid_data_loader: DataLoader, run_config: RunConfiguration):
32 | super(MNISTTrainer, self).__init__("MNISTTrainer", train_data_loader, valid_data_loader, model_trainers,
33 | run_config)
34 |
35 | self._training_config = training_config
36 |
37 | def train_step(self, inputs, target):
38 | model = self._model_trainers[0]
39 |
40 | pred = model.forward(inputs)
41 | model.compute_train_metric(pred, target)
42 | loss = model.compute_and_update_train_loss(pred, target)
43 |
44 | model.zero_grad()
45 | loss.backward()
46 |
47 | if not on_single_device(self._run_config.devices):
48 | self.average_gradients(model)
49 |
50 | model.step()
51 |
52 | def validate_step(self, inputs, target):
53 | model = self._model_trainers[0]
54 |
55 | pred = model.forward(inputs)
56 | model.compute_valid_metric(pred, target)
57 | model.compute_and_update_valid_loss(pred, target)
58 |
59 | def scheduler_step(self):
60 | self._model_trainers[0].scheduler_step()
61 |
62 | @staticmethod
63 | def average_gradients(model):
64 | size = float(torch.distributed.get_world_size())
65 | for param in model.parameters():
66 | torch.distributed.all_reduce(param.grad.data, op=torch.distributed.ReduceOp.SUM)
67 | param.grad.data /= size
68 |
69 | def on_epoch_begin(self):
70 | pass
71 |
72 | def on_epoch_end(self):
73 | pass
74 |
--------------------------------------------------------------------------------
/kerosene/events/handlers/base_monitor_watcher.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 | from typing import Dict, List, Union
3 |
4 | from kerosene.events import MonitorMode, BaseEvent, TemporalEvent
5 | from kerosene.events.handlers.base_handler import EventHandler
6 |
7 |
8 | class MonitorPatienceExceeded(Exception):
9 | pass
10 |
11 |
12 | class MonitorInspection(object):
13 | def __init__(self, value=0, inspection_num=0):
14 | self._value = value
15 | self._inspection_num = inspection_num
16 |
17 | @property
18 | def value(self):
19 | return self._value
20 |
21 | @value.setter
22 | def value(self, new_value):
23 | self._value = new_value
24 |
25 | @property
26 | def inspection_num(self):
27 | return self._inspection_num
28 |
29 | def add_inspection(self):
30 | self._inspection_num = self._inspection_num + 1
31 |
32 | def reset_inspection_num(self):
33 | self._inspection_num = 0
34 | return self
35 |
36 | def with_value(self, value):
37 | self._value = value
38 | return self
39 |
40 |
41 | class MonitorWatcher(EventHandler, ABC):
42 | def __init__(self, mode: MonitorMode, min_delta, patience, every=1,
43 | supported_events: List[Union[BaseEvent, TemporalEvent]] = None):
44 | super().__init__(supported_events, every)
45 |
46 | self._mode = mode
47 | self._min_delta = min_delta
48 | self._patience = patience
49 | self._monitor_values: Dict[str, MonitorInspection] = {}
50 |
51 | @property
52 | def mode(self):
53 | return self._mode
54 |
55 | @property
56 | def min_delta(self):
57 | return self._min_delta
58 |
59 | @property
60 | def patience(self):
61 | return self._patience
62 |
63 | @property
64 | def monitor_values(self):
65 | return self._monitor_values
66 |
67 | def watch(self, source_name, current_monitor_value):
68 | if source_name not in self._monitor_values.keys():
69 | self._monitor_values[source_name] = MonitorInspection(value=current_monitor_value)
70 | else:
71 | if self._mode is MonitorMode.MIN:
72 | delta = self._monitor_values[source_name].value - current_monitor_value
73 | else:
74 | delta = current_monitor_value - self._monitor_values[source_name].value
75 |
76 | if delta <= self._min_delta:
77 | self._monitor_values[source_name].with_value(current_monitor_value).reset_inspection_num()
78 | else:
79 | self._monitor_values[source_name].add_inspection()
80 | if self._monitor_values[source_name].inspection_num >= self._patience:
81 | raise MonitorPatienceExceeded()
82 |
--------------------------------------------------------------------------------
/kerosene/optim/optimizers.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2019 Kerosene Authors. All Rights Reserved.
3 | #
4 | # Licensed under the MIT License;
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://opensource.org/licenses/MIT
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | from enum import Enum
17 | from typing import Union
18 |
19 | from torch import optim
20 | from torch.optim import Optimizer
21 |
22 | try:
23 | from apex.optimizers import FusedSGD, FusedAdam
24 |
25 | APEX_AVAILABLE = True
26 | except ModuleNotFoundError:
27 | APEX_AVAILABLE = False
28 | except ImportError as e:
29 | APEX_AVAILABLE = False
30 | print("Unable to import apex optimizers please upgrade your apex version".format(e))
31 |
32 |
33 | class OptimizerType(Enum):
34 | FusedSGD = "FusedSGD"
35 | FusedAdam = "FusedAdam"
36 | SGD = "SGD"
37 | Adam = "Adam"
38 | Adagrad = "Adagrad"
39 | Adadelta = "Adadelta"
40 | SparseAdam = "SparseAdam"
41 | Adamax = "Adamax"
42 | Rprop = "Rprop"
43 | RMSprop = "RMSprop"
44 | ASGD = "ASGD"
45 |
46 | def __str__(self):
47 | return self.value
48 |
49 |
50 | class OptimizerFactory(object):
51 |
52 | def __init__(self):
53 | self._optimizers = {
54 | "Adam": optim.Adam,
55 | "SGD": optim.SGD,
56 | "RMSprop": optim.RMSprop,
57 | "Adagrad": optim.Adagrad,
58 | "Adadelta": optim.Adadelta,
59 | "SparseAdam": optim.SparseAdam,
60 | "Adamax": optim.Adamax,
61 | "Rprop": optim.Rprop,
62 | "ASGD": optim.ASGD
63 | }
64 |
65 | if APEX_AVAILABLE:
66 | self.register("FusedSGD", FusedSGD)
67 | self.register("FusedAdam", FusedAdam)
68 |
69 | def create(self, optimizer_type: Union[str, OptimizerType], model_params, params):
70 | return self._optimizers[str(optimizer_type)](model_params, **params) if params is not None else \
71 | self._optimizers[str(optimizer_type)](model_params)
72 |
73 | def register(self, function: str, creator: Optimizer):
74 | """
75 | Add a new activation layer.
76 | Args:
77 | function (str): Activation layer name.
78 | creator (:obj:`torch.nn.Module`): A torch module object wrapping the new custom optimizer function.
79 | """
80 | self._optimizers[function] = creator
81 |
--------------------------------------------------------------------------------
/kerosene/nn/utils/gradients.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2019 Kerosene Authors. All Rights Reserved.
3 | #
4 | # Licensed under the MIT License;
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://opensource.org/licenses/MIT
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 | import abc
18 | from enum import Enum
19 | from typing import Union, Callable, Iterable
20 |
21 | import torch
22 | from torch.nn.utils import clip_grad_norm_, clip_grad_value_
23 |
24 |
25 | class GradientClippingStrategyType(Enum):
26 | Value = "value"
27 | Norm = "norm"
28 |
29 | def __str__(self):
30 | return self.value
31 |
32 |
33 | class GradientClippingStrategy(metaclass=abc.ABCMeta):
34 |
35 | @abc.abstractmethod
36 | def clip(self, model_parameters: Union[Iterable[torch.Tensor], torch.Tensor]):
37 | raise NotImplementedError()
38 |
39 |
40 | class GradientNormClipping(GradientClippingStrategy):
41 |
42 | def __init__(self, max_norm: float, norm_type: Union[int, float] = 2):
43 | """"
44 | Gradient norm clipping strategy.
45 |
46 | Args:
47 | max_norm (float or int): max norm of the gradients
48 | norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm.
49 | """
50 | self._max_norm = max_norm
51 | self._norm_type = norm_type
52 |
53 | def clip(self, model_parameters):
54 | clip_grad_norm_(model_parameters, self._max_norm, self._norm_type)
55 |
56 |
57 | class GradientValueClipping(GradientClippingStrategy):
58 |
59 | def __init__(self, clip_value: Union[int, float]):
60 | """"
61 | Gradient value clipping strategy.
62 |
63 | Args:
64 | clip_value (float or int): maximum allowed value of the gradients. The gradients are clipped in the range
65 | :math:`\left[\text{-clip\_value}, \text{clip\_value}\right]`
66 | """
67 | self._clip_value = clip_value
68 |
69 | def clip(self, model_parameters: Union[Iterable[torch.Tensor], torch.Tensor]):
70 | clip_grad_value_(model_parameters, self._clip_value)
71 |
72 |
73 | class GradientClippingStrategyFactory(object):
74 | def __init__(self):
75 | self._clipping_strategies = {
76 | "value": GradientValueClipping,
77 | "norm": GradientNormClipping
78 | }
79 |
80 | def create(self, clipping_strategy: Union[str, GradientClippingStrategyType], params):
81 | if clipping_strategy is not None:
82 | return self._clipping_strategies[str(clipping_strategy)](**params)
83 | else:
84 | return None
85 |
86 | def register(self, function: str, creator: Callable):
87 | self._clipping_strategies[function] = creator
88 |
--------------------------------------------------------------------------------
/kerosene/configs/parsers.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2019 Kerosene Authors. All Rights Reserved.
3 | #
4 | # Licensed under the MIT License;
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://opensource.org/licenses/MIT
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | import logging
17 |
18 | import torch
19 | import yaml
20 |
21 | from kerosene.configs.configs import ModelConfiguration, TrainerConfiguration, ConfigurationList
22 |
23 |
24 | class CustomYamlParser(object):
25 |
26 | def __init__(self):
27 | yaml.SafeLoader.add_constructor(u"!torch/tensor", CustomYamlParser.parse_tensor)
28 | yaml.SafeLoader.add_constructor(u"!python/tuple", CustomYamlParser.parse_tuple)
29 |
30 | @staticmethod
31 | def safe_load(file):
32 | return yaml.safe_load(file)
33 |
34 | @staticmethod
35 | def parse_tensor(loader, node):
36 | value = loader.construct_sequence(node, deep=True)
37 | tensor = torch.Tensor().new_tensor(value)
38 | return tensor
39 |
40 | @staticmethod
41 | def parse_tuple(loader, node):
42 | value = loader.construct_sequence(node, deep=True)
43 | tuple_ = tuple(value)
44 | return tuple_
45 |
46 |
47 | class YamlConfigurationParser(object):
48 | LOGGER = logging.getLogger("YamlConfigurationParser")
49 |
50 | @staticmethod
51 | def parse(config_file_path):
52 | with open(config_file_path, 'r') as config_file:
53 | try:
54 | config = CustomYamlParser().safe_load(config_file)
55 |
56 | model_trainer_configs = ConfigurationList(
57 | [ModelConfiguration.from_dict(model_name, config["models"][model_name]) for
58 | model_name in config["models"].keys()])
59 |
60 | model_trainer_configs = model_trainer_configs if len(model_trainer_configs) > 1 else \
61 | model_trainer_configs[0]
62 | training_config = TrainerConfiguration(config['training'])
63 | return model_trainer_configs, training_config
64 | except yaml.YAMLError as e:
65 | YamlConfigurationParser.LOGGER.warning(
66 | "Unable to read the training config file: {} with error {}".format(config_file_path, e))
67 |
68 | @staticmethod
69 | def parse_section(config_file_path, yml_tag):
70 | with open(config_file_path, 'r') as config_file:
71 | try:
72 | config = CustomYamlParser().safe_load(config_file)
73 |
74 | return config[yml_tag]
75 | except yaml.YAMLError as e:
76 | YamlConfigurationParser.LOGGER.warning(
77 | "Unable to read the training config file: {} with error {}".format(config_file_path, e))
78 |
--------------------------------------------------------------------------------
/tests/events/handlers/test_console.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import unittest
3 |
4 | import torch
5 | from mockito import verify, spy
6 |
7 | from kerosene.events import Phase, Monitor
8 | from kerosene.events.handlers.console import PrintTrainingStatus, StatusConsoleColorPalette, MonitorsTable
9 |
10 |
11 | class ConsoleHandlerTest(unittest.TestCase):
12 |
13 | def setUp(self) -> None:
14 | logging.basicConfig(level=logging.INFO)
15 | self._logger_mock = spy(logging.getLogger())
16 |
17 | def test_print_status_with_colors(self):
18 | expected_string = "\nCurrent states: [92mTraining[0m | Epoch: 0 | Training step: 0 | Validation step: 0 | Test step: 0\n"
19 | handler = PrintTrainingStatus(colors=StatusConsoleColorPalette.DEFAULT)
20 | handler.LOGGER = self._logger_mock
21 | handler.print_status(Phase.TRAINING, 0, 0, 0, 0)
22 | verify(self._logger_mock).info(expected_string)
23 |
24 | def test_print_status_without_colors(self):
25 | expected_string = "\nCurrent states: Training\x1b[0m | Epoch: 0 | Training step: 0 | Validation step: 0 | Test step: 0\n"
26 | handler = PrintTrainingStatus()
27 | handler.LOGGER = self._logger_mock
28 | handler.print_status(Phase.TRAINING, 0, 0, 0, 0)
29 | verify(self._logger_mock).info(expected_string)
30 |
31 | def test_print_monitors_table(self):
32 | monitors = {
33 | "Model 1": {Phase.TRAINING: {Monitor.METRICS: {"Accuracy": torch.tensor([0.6])},
34 | Monitor.LOSS: {"MSELoss": torch.tensor([0.6])}},
35 | Phase.VALIDATION: {Monitor.METRICS: {"Accuracy": torch.tensor([0.6])},
36 | Monitor.LOSS: {"MSELoss": torch.tensor([0.6])}},
37 | Phase.TEST: {Monitor.METRICS: {"Accuracy": torch.tensor([0.6])},
38 | Monitor.LOSS: {"MSELoss": torch.tensor([0.6])}}}}
39 |
40 | training_values = {**monitors["Model 1"][Phase.TRAINING][Monitor.METRICS],
41 | **monitors["Model 1"][Phase.TRAINING][Monitor.LOSS]}
42 | validation_values = {**monitors["Model 1"][Phase.VALIDATION][Monitor.METRICS],
43 | **monitors["Model 1"][Phase.VALIDATION][Monitor.LOSS]}
44 |
45 | table = MonitorsTable("Model 1", 20)
46 | table.update(training_values, validation_values)
47 | table.show()
48 |
49 | monitors = {
50 | "Model 1": {Phase.TRAINING: {Monitor.METRICS: {"Accuracy": torch.tensor([0.7])},
51 | Monitor.LOSS: {"MSELoss": torch.tensor([0.7])}},
52 | Phase.VALIDATION: {Monitor.METRICS: {"Accuracy": torch.tensor([0.4])},
53 | Monitor.LOSS: {"MSELoss": torch.tensor([0.6])}},
54 | Phase.TEST: {Monitor.METRICS: {"Accuracy": torch.tensor([0.4])},
55 | Monitor.LOSS: {"MSELoss": torch.tensor([0.8])}}}}
56 |
57 | training_values = {**monitors["Model 1"][Phase.TRAINING][Monitor.METRICS],
58 | **monitors["Model 1"][Phase.TRAINING][Monitor.LOSS]}
59 | validation_values = {**monitors["Model 1"][Phase.VALIDATION][Monitor.METRICS],
60 | **monitors["Model 1"][Phase.VALIDATION][Monitor.LOSS]}
61 |
62 | table.update(training_values, validation_values)
63 | table.show()
64 |
--------------------------------------------------------------------------------
/tests/functionals/distributed/main.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 | from argparse import ArgumentParser
5 |
6 | from kerosene.configs.configs import RunConfiguration
7 |
8 | sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)) + '/../../../')
9 | import torchvision
10 | from torchvision.transforms import Compose, ToTensor, Normalize
11 |
12 | from kerosene.configs.parsers import YamlConfigurationParser
13 | from kerosene.events import Event
14 | from kerosene.dataloaders.factories import DataloaderFactory
15 | from kerosene.events.handlers.console import ConsoleLogger
16 | from kerosene.events.handlers.visdom.config import VisdomConfiguration
17 | from kerosene.events.handlers.visdom.visdom import VisdomLogger
18 | from kerosene.events.preprocessors.visdom import PlotAllModelStateVariables
19 | from kerosene.training.trainers import ModelTrainerFactory
20 | from tests.functionals.models import SimpleNet
21 | from tests.functionals.distributed.mnist_trainer import MNISTTrainer
22 |
23 |
24 | class ArgsParserFactory(object):
25 |
26 | @staticmethod
27 | def create_parser():
28 | parser = ArgumentParser(description='DeepNormalize Training')
29 | parser.add_argument("--use_amp", dest="use_amp", action="store_true", default=True)
30 | parser.add_argument("--amp-opt-level", dest="amp_opt_level", type=str, default="O2",
31 | help="O0 - FP32 training, O1 - Mixed Precision (recommended), O2 - Almost FP16 Mixed Precision, O3 - FP16 Training.")
32 | parser.add_argument("--local_rank", dest="local_rank", default=0, type=int, help="The local_rank of the GPU.")
33 | return parser
34 |
35 |
36 | if __name__ == '__main__':
37 | logging.basicConfig(level=logging.INFO)
38 | CONFIG_FILE_PATH = "config.yml"
39 | args = ArgsParserFactory.create_parser().parse_args()
40 | run_config = RunConfiguration(args.use_amp, args.amp_opt_level, args.local_rank)
41 |
42 | model_trainer_config, training_config = YamlConfigurationParser.parse(CONFIG_FILE_PATH)
43 |
44 | # Initialize the dataset. This is the only part the user must define manually.
45 | train_dataset = torchvision.datasets.MNIST('./files/', train=True, download=True, transform=Compose(
46 | [ToTensor(), Normalize((0.1307,), (0.3081,))]))
47 | test_dataset = torchvision.datasets.MNIST('./files/', train=False, download=True, transform=Compose(
48 | [ToTensor(), Normalize((0.1307,), (0.3081,))]))
49 |
50 | # Initialize loaders.
51 | train_loader, valid_loader = DataloaderFactory(train_dataset, test_dataset).create(run_config, training_config)
52 |
53 | # Initialize the loggers.
54 | if run_config.local_rank == 0:
55 | visdom_logger = VisdomLogger(VisdomConfiguration.from_yml(CONFIG_FILE_PATH))
56 |
57 | # Initialize the model trainers.
58 | model_trainer = ModelTrainerFactory(model=SimpleNet()).create(model_trainer_config, run_config)
59 |
60 | # Train with the training strategy.
61 | if run_config.local_rank == 0:
62 | trainer = MNISTTrainer(training_config, model_trainer, train_loader, valid_loader, run_config) \
63 | .with_event_handler(ConsoleLogger(), Event.ON_EPOCH_END) \
64 | .with_event_handler(visdom_logger, Event.ON_EPOCH_END, PlotAllModelStateVariables()) \
65 | .train(training_config.nb_epochs)
66 | else:
67 | trainer = MNISTTrainer(training_config, model_trainer, train_loader, valid_loader, run_config) \
68 | .train(training_config.nb_epochs)
69 |
--------------------------------------------------------------------------------
/kerosene/events/__init__.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | from enum import Enum
3 |
4 |
5 | class Phase(Enum):
6 | TRAINING = "Training"
7 | VALIDATION = "Validation"
8 | TEST = "Test"
9 | ALL = [TRAINING, VALIDATION, TEST]
10 |
11 | def __str__(self):
12 | return self.value
13 |
14 | def __eq__(self, other):
15 | if isinstance(other, str):
16 | return self.value == other
17 | elif isinstance(other, Phase):
18 | return self.value == other.value
19 |
20 | def __hash__(self):
21 | return hash(self.value)
22 |
23 |
24 | class Frequency(Enum):
25 | STEP = "Step"
26 | EPOCH = "Epoch"
27 | PHASE = "Phase"
28 |
29 | def __str__(self):
30 | return self.value
31 |
32 | def __eq__(self, other):
33 | if isinstance(other, str):
34 | return self.value == other
35 | elif isinstance(other, Frequency):
36 | return self.value == other.value
37 |
38 | def __hash__(self):
39 | return hash(self.value)
40 |
41 |
42 | class Moment(object):
43 | def __init__(self, iteration, frequency: Frequency, phase: Phase, time: datetime = None):
44 | self._datetime = datetime.now() if time is None else time
45 | self._frequency = frequency
46 | self._iteration = iteration
47 | self._phase = phase
48 |
49 | @property
50 | def datetime(self):
51 | return self._datetime
52 |
53 | @datetime.setter
54 | def datetime(self, value):
55 | self._datetime = value
56 |
57 | @property
58 | def frequency(self):
59 | return self._frequency
60 |
61 | @frequency.setter
62 | def frequency(self, value):
63 | self._frequency = value
64 |
65 | @property
66 | def iteration(self):
67 | return self._iteration
68 |
69 | @iteration.setter
70 | def iteration(self, value):
71 | self._iteration = value
72 |
73 | @property
74 | def phase(self):
75 | return self._phase
76 |
77 | @phase.setter
78 | def phase(self, value):
79 | self._phase = value
80 |
81 |
82 | class BaseEvent(Enum):
83 | def __init__(self, value):
84 | self._value_ = value
85 |
86 | def __str__(self):
87 | return self.value
88 |
89 | def __eq__(self, other):
90 | if isinstance(other, str):
91 | return self.value == other
92 | elif isinstance(other, BaseEvent):
93 | return self.value == other.value
94 |
95 | def __hash__(self):
96 | return hash(self.value)
97 |
98 |
99 | class TemporalEvent(object):
100 | def __init__(self, event: BaseEvent, moment: Moment):
101 | self._event = event
102 | self._moment = moment
103 |
104 | @property
105 | def event(self):
106 | return self._event
107 |
108 | @property
109 | def frequency(self):
110 | return self._moment.frequency
111 |
112 | @property
113 | def phase(self):
114 | return self._moment.phase
115 |
116 | @property
117 | def datetime(self):
118 | return self._moment.datetime
119 |
120 | @property
121 | def iteration(self):
122 | return self._moment.iteration
123 |
124 | def __eq__(self, other):
125 | if isinstance(other, BaseEvent):
126 | return self.event == other
127 | elif isinstance(other, TemporalEvent):
128 | return self._event == other.event and self.frequency == other.frequency and \
129 | self.phase == other.phase and self.datetime == other.datetime and self.iteration == other.iteration
130 |
131 | def __str__(self):
132 | return str(self.event)
133 |
134 |
135 | class BaseVariable(Enum):
136 | def __str__(self):
137 | return self.value
138 |
139 | def __eq__(self, other):
140 | if isinstance(other, str):
141 | return self.value == other
142 | elif isinstance(other, BaseVariable):
143 | return self.value == other.value
144 |
145 | def __hash__(self):
146 | return hash(self.value)
147 |
148 |
149 | class Monitor(BaseVariable):
150 | LOSS = "loss"
151 | METRICS = "metrics"
152 |
153 | TRAIN_LOSS = "train_loss"
154 | VALID_LOSS = "valid_loss"
155 | TEST_LOSS = "test_loss"
156 | TRAIN_METRICS = "train_metrics"
157 | VALID_METRICS = "valid_metrics"
158 | TEST_METRICS = "test_metrics"
159 |
160 | def is_loss(self):
161 | return "loss" in self.value
162 |
163 | def is_metrics(self):
164 | return "metrics" in self.value
165 |
166 |
167 | class MonitorMode(Enum):
168 | MIN = -1
169 | MAX = 1
170 |
171 | def __str__(self):
172 | return self.value
173 |
--------------------------------------------------------------------------------
/tests/configs/test_model_trainers_config.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from hamcrest import *
4 |
5 | from kerosene.configs.configs import ModelConfiguration
6 | from kerosene.configs.exceptions import InvalidConfigurationError
7 | from kerosene.configs.parsers import YamlConfigurationParser
8 |
9 |
10 | class TestModelTrainerConfiguration(unittest.TestCase):
11 | VALID_CONFIG_FILE_PATH = "tests/configs/valid_config.yml"
12 | INVALID_CONFIG_FILE_PATH = "tests/configs/invalid_config.yml"
13 |
14 | MODELS_CONFIG_YML_TAG = "models"
15 |
16 | SIMPLE_NET_NAME = "SimpleNet"
17 | SIMPLE_NET_NAME_2 = "SimpleNet2"
18 |
19 | SIMPLE_NET_TYPE = "SimpleNet"
20 |
21 | SIMPLE_NET_OPTIMIZER_TYPE = "SGD"
22 | SIMPLE_NET_OPTIMIZER_PARAMS = {'lr': 0.01, 'momentum': 0.5, 'weight_decay': 0}
23 |
24 | SIMPLE_NET_SCHEDULER_TYPE = "ReduceLROnPlateau"
25 | SIMPLE_NET_SCHEDULER_PARAMS = {'mode': 'min', 'factor': 0.1, 'patience': 3}
26 |
27 | SIMPLE_NET_CRITERION_TYPE = "L1Loss"
28 | SIMPLE_NET_CRITERION_TYPE_2 = "MSELoss"
29 |
30 | SIMPLE_NET_METRIC_TYPE_1 = "Dice"
31 | SIMPLE_NET_METRIC_PARAMS_1 = {"num_classes": 4, "reduction": None, "ignore_index": 0, "average": None,
32 | "weight": None}
33 |
34 | SIMPLE_NET_METRIC_TYPE_2 = "Accuracy"
35 |
36 | SIMPLE_NET_GRADIENT_CLIPPING = {"type": "norm", "params": {"max_norm": 1.0}}
37 |
38 | def test_should_parse_valid_model_trainer_config(self):
39 | expected_config_dict = {self.SIMPLE_NET_NAME: {'type': self.SIMPLE_NET_TYPE,
40 | 'optimizer': {'type': self.SIMPLE_NET_OPTIMIZER_TYPE,
41 | 'params': self.SIMPLE_NET_OPTIMIZER_PARAMS},
42 | 'scheduler': {'type': self.SIMPLE_NET_SCHEDULER_TYPE,
43 | 'params': self.SIMPLE_NET_SCHEDULER_PARAMS},
44 | 'criterion': {"cycle": {'type': self.SIMPLE_NET_CRITERION_TYPE},
45 | "gan": {'type': self.SIMPLE_NET_CRITERION_TYPE_2}},
46 | 'metrics': {'Dice': {'type': self.SIMPLE_NET_METRIC_TYPE_1,
47 | 'params': self.SIMPLE_NET_METRIC_PARAMS_1},
48 | 'Accuracy': {'type': self.SIMPLE_NET_METRIC_TYPE_2}},
49 | 'gradients': self.SIMPLE_NET_GRADIENT_CLIPPING},
50 | self.SIMPLE_NET_NAME_2: {'type': self.SIMPLE_NET_TYPE,
51 | 'optimizer': {'type': self.SIMPLE_NET_OPTIMIZER_TYPE,
52 | 'params': self.SIMPLE_NET_OPTIMIZER_PARAMS},
53 | 'scheduler': {'type': self.SIMPLE_NET_SCHEDULER_TYPE,
54 | 'params': self.SIMPLE_NET_SCHEDULER_PARAMS},
55 | 'criterion': {
56 | "cycle": {'type': self.SIMPLE_NET_CRITERION_TYPE},
57 | "gan": {'type': self.SIMPLE_NET_CRITERION_TYPE_2}},
58 | 'metrics': {'Dice': {'type': self.SIMPLE_NET_METRIC_TYPE_1,
59 | 'params': self.SIMPLE_NET_METRIC_PARAMS_1},
60 | 'Accuracy': {
61 | 'type': self.SIMPLE_NET_METRIC_TYPE_2}},
62 | 'gradients': self.SIMPLE_NET_GRADIENT_CLIPPING}}
63 | config_dict = YamlConfigurationParser.parse_section(self.VALID_CONFIG_FILE_PATH, self.MODELS_CONFIG_YML_TAG)
64 | model_trainer_config = ModelConfiguration.from_dict(self.SIMPLE_NET_NAME,
65 | config_dict[self.SIMPLE_NET_NAME])
66 |
67 | assert_that(config_dict, equal_to(expected_config_dict))
68 |
69 | assert_that(model_trainer_config.optimizer_config.type, equal_to(self.SIMPLE_NET_OPTIMIZER_TYPE))
70 | assert_that(model_trainer_config.optimizer_config.params, equal_to(self.SIMPLE_NET_OPTIMIZER_PARAMS))
71 | assert_that(model_trainer_config.scheduler_config.type, equal_to(self.SIMPLE_NET_SCHEDULER_TYPE))
72 | assert_that(model_trainer_config.scheduler_config.params, equal_to(self.SIMPLE_NET_SCHEDULER_PARAMS))
73 | assert_that(model_trainer_config.criterions_configs[0].type, equal_to(self.SIMPLE_NET_CRITERION_TYPE))
74 | assert_that(model_trainer_config.criterions_configs[1].type, equal_to(self.SIMPLE_NET_CRITERION_TYPE_2))
75 | assert_that(model_trainer_config.metrics_configs[0].type, equal_to(self.SIMPLE_NET_METRIC_TYPE_1))
76 | assert_that(model_trainer_config.metrics_configs[0].params, equal_to(self.SIMPLE_NET_METRIC_PARAMS_1))
77 |
78 | def test_should_throw_on_invalid_model_trainer_config(self):
79 | config_dict = YamlConfigurationParser.parse_section(self.INVALID_CONFIG_FILE_PATH, self.MODELS_CONFIG_YML_TAG)
80 |
81 | assert_that(calling(ModelConfiguration.from_dict).with_args(self.SIMPLE_NET_NAME,
82 | config_dict[self.SIMPLE_NET_NAME]),
83 | raises(InvalidConfigurationError))
84 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | #
Kerosene
2 | > Kerosene is a high-level deep Learning framework for fast and clean research development with Pytorch - [see the doc for more details.](https://kerosene.readthedocs.io/en/latest/). Kerosene let you focus on your model and data by providing clean and readable code for training, visualizing and debugging your achitecture without forcing you to implement rigid interface for your model.
3 |
4 | ## Out of The Box Features
5 | - [X] Basic training logic and user defined trainers
6 | - [X] Fine grained event system with multiple handlers
7 | - [X] Multiple metrics and criterions support
8 | - [X] Automatic configuration parsing and model instantiation
9 | - [X] Automatic support of mixed precision with [Apex](https://github.com/NVIDIA/apex) and dataparallel training
10 | - [X] Automatic Visdom logging
11 | - [X] Integrated [Ignite](https://github.com/pytorch/ignite) metrics and [Pytorch](https://github.com/pytorch/pytorch) criterions
12 |
13 | ## MNIST Example
14 | > Here is a simple example that shows how easy and clean it is to train a simple network. In very few lines of code, the model is trained using mixed precision and you got Visdom + Console logging automatically. See full example there: [MNIST-Kerosene](https://github.com/banctilrobitaille/kerosene-mnist)
15 |
16 | ```python
17 | if __name__ == "__main__":
18 | logging.basicConfig(level=logging.INFO)
19 | CONFIG_FILE_PATH = "config.yml"
20 |
21 | model_trainer_config, training_config = YamlConfigurationParser.parse(CONFIG_FILE_PATH)
22 |
23 | train_loader = DataLoader(torchvision.datasets.MNIST('./files/', train=True, download=True, transform=Compose(
24 | [ToTensor(), Normalize((0.1307,), (0.3081,))])), batch_size=training_config.batch_size_train, shuffle=True)
25 |
26 | test_loader = DataLoader(torchvision.datasets.MNIST('./files/', train=False, download=True, transform=Compose(
27 | [ToTensor(), Normalize((0.1307,), (0.3081,))])), batch_size=training_config.batch_size_valid, shuffle=True)
28 |
29 | visdom_logger = VisdomLogger(VisdomConfiguration.from_yml(CONFIG_FILE_PATH))
30 |
31 | # Initialize the model trainers
32 | model_trainer = ModelTrainerFactory(model=SimpleNet()).create(model_trainer_config)
33 |
34 | # Train with the training strategy
35 | SimpleTrainer("MNIST Trainer", train_loader, test_loader, None, model_trainer, RunConfiguration(use_amp=False)) \
36 | .with_event_handler(PlotMonitors(every=500, visdom_logger=visdom_logger), Event.ON_BATCH_END) \
37 | .with_event_handler(PlotAvgGradientPerLayer(every=500, visdom_logger=visdom_logger), Event.ON_TRAIN_BATCH_END) \
38 | .with_event_handler(PrintTrainingStatus(every=100), Event.ON_BATCH_END) \
39 | .train(training_config.nb_epochs)
40 | ```
41 |
42 | ## Events
43 |
44 | | Event | Description |
45 | | ------------- | ------------- |
46 | | ON_TRAINING_BEGIN | At the beginning of the training phase |
47 | | ON_TRAINING_END | At the end of the training phase |
48 | | ON_VALID_BEGIN | At the beginning of the validation phase |
49 | | ON_VALID_END | At the end of the validation phase |
50 | | ON_TEST_BEGIN | At the beginning of the test phase |
51 | | ON_TEST_END | At the end of the test phase |
52 | | ON_EPOCH_BEGIN | At the beginning of each epoch (training, validation, test) |
53 | | ON_EPOCH_END | At the end of each epoch (training, validation, test) |
54 | | ON_TRAIN_EPOCH_BEGIN | At the beginning of each training epoch |
55 | | ON_TRAIN_EPOCH_END | At the end of each training epoch |
56 | | ON_VALID_EPOCH_BEGIN | At the beginning of each validation epoch |
57 | | ON_VALID_EPOCH_END | At the end of each validation epoch |
58 | | ON_TEST_EPOCH_BEGIN | At the beginning of each test epoch |
59 | | ON_TEST_EPOCH_END | At the end of each test epoch |
60 | | ON_BATCH_BEGIN | At the beginning of each batch (training, validation, test) |
61 | | ON_BATCH_END | At the end of each batch (training, validation, test) |
62 | | ON_TRAIN_BATCH_BEGIN | At the beginning of each train batch |
63 | | ON_TRAIN_BATCH_END | At the end of each train batch |
64 | | ON_VALID_BATCH_BEGIN | At the beginning of each validation batch |
65 | | ON_VALID_BATCH_END | At the end of each validation batch |
66 | | ON_TEST_BATCH_BEGIN | At the beginning of each test batch |
67 | | ON_TEST_BATCH_END | At the end of each test batch |
68 | | ON_FINALIZE | Before the end of the process |
69 |
70 | ## Handlers
71 | - [X] PrintTrainingStatus (Console)
72 | - [X] PrintMonitors (Console)
73 | - [X] PlotMonitors (Visdom)
74 | - [X] PlotLosses (Visdom)
75 | - [X] PlotMetrics (Visdom)
76 | - [X] PlotCustomVariables (Visdom)
77 | - [X] PlotLR (Visdom)
78 | - [X] PlotAvgGradientPerLayer (Visdom)
79 | - [X] Checkpoint
80 | - [X] EarlyStopping
81 |
82 | ## Contributing
83 |
84 | #### How to contribute ?
85 | - [X] Create a branch by feature and/or bug fix
86 | - [X] Get the code
87 | - [X] Commit and push
88 | - [X] Create a pull request
89 |
90 | #### Branch naming
91 |
92 | ##### Feature branch
93 | > feature/ [Short feature description] [Issue number]
94 |
95 | ##### Bug branch
96 | > fix/ [Short fix description] [Issue number]
97 |
98 | #### Commits syntax:
99 |
100 | ##### Adding code:
101 | > \+ Added [Short Description] [Issue Number]
102 |
103 | ##### Deleting code:
104 | > \- Deleted [Short Description] [Issue Number]
105 |
106 | ##### Modifying code:
107 | > \* Changed [Short Description] [Issue Number]
108 |
109 | ##### Merging code:
110 | > Y Merged [Short Description] [Issue Number]
111 |
112 |
113 | Icons made by Freepik from www.flaticon.com is licensed by CC 3.0 BY
114 |
--------------------------------------------------------------------------------
/kerosene/nn/apex.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2019 Kerosene Authors. All Rights Reserved.
3 | #
4 | # Licensed under the MIT License;
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://opensource.org/licenses/MIT
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | from __future__ import division
17 |
18 | from abc import ABC
19 | from typing import Optional
20 |
21 | import torch
22 | from torch import Tensor, nn
23 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP
24 | from torch.optim import Optimizer
25 |
26 | from kerosene.utils.devices import on_multiple_gpus, get_devices
27 |
28 | try:
29 | from apex import amp
30 | from apex.parallel import DistributedDataParallel as ApexDDP
31 |
32 | APEX_AVAILABLE = True
33 | except ModuleNotFoundError:
34 | APEX_AVAILABLE = False
35 |
36 |
37 | class ApexLoss(object):
38 | def __init__(self, loss_id: int, loss: Tensor, optimizer: Optimizer):
39 | super().__init__()
40 | self._loss_id = loss_id
41 | self._loss = loss
42 | self._optimizer = optimizer
43 |
44 | @property
45 | def loss_id(self):
46 | return self._loss_id
47 |
48 | @property
49 | def loss(self):
50 | return self._loss
51 |
52 | def backward(self, gradient: Optional[Tensor] = None, retain_graph=False,
53 | create_graph=False) -> None:
54 | if APEX_AVAILABLE:
55 | with amp.scale_loss(self._loss, self._optimizer, loss_id=self._loss_id) as scaled_loss:
56 | scaled_loss.backward(gradient, retain_graph, create_graph)
57 | else:
58 | self._loss.backward(gradient, retain_graph, create_graph)
59 |
60 | def cpu(self):
61 | return ApexLoss(self._loss_id, self._loss.cpu(), self._optimizer)
62 |
63 | def detach(self):
64 | return ApexLoss(self._loss_id, self._loss.detach(), self._optimizer)
65 |
66 | def float(self):
67 | return ApexLoss(self._loss_id, self._loss.float(), self._optimizer)
68 |
69 | def numpy(self):
70 | return self._loss.numpy()
71 |
72 | def mean(self):
73 | return ApexLoss(self._loss_id, torch.mean(self._loss), self._optimizer)
74 |
75 | def item(self):
76 | return self._loss.item()
77 |
78 | def __add__(self, other):
79 | if isinstance(other, (Tensor, int, float)):
80 | loss = ApexLoss(self._loss_id, torch.add(self._loss, other), self._optimizer)
81 | elif isinstance(other, ApexLoss):
82 | loss = ApexLoss(self._loss_id, torch.add(self._loss, other.loss), self._optimizer)
83 | else:
84 | raise NotImplementedError("Cannot add an element of type: {} to an ApexLoss.".format(str(type(other))))
85 | return loss
86 |
87 | def __mul__(self, other):
88 | if isinstance(other, (Tensor, int, float)):
89 | loss = ApexLoss(self._loss_id, torch.mul(self._loss, other), self._optimizer)
90 | elif isinstance(other, ApexLoss):
91 | loss = ApexLoss(self._loss_id, torch.mul(self._loss, other.loss), self._optimizer)
92 | else:
93 | raise NotImplementedError("Cannot mul an element of type: {} to an ApexLoss.".format(str(type(other))))
94 | return loss
95 |
96 | def __rmul__(self, other):
97 | if isinstance(other, (Tensor, int, float)):
98 | loss = ApexLoss(self._loss_id, torch.mul(self._loss, other), self._optimizer)
99 | else:
100 | raise NotImplementedError("Cannot rmul an element of type: {} to an ApexLoss.".format(str(type(other))))
101 | return loss
102 |
103 | def __truediv__(self, other):
104 | if isinstance(other, (Tensor, int, float)):
105 | loss = ApexLoss(self._loss_id, torch.div(self._loss, other), self._optimizer)
106 | elif isinstance(other, ApexLoss):
107 | loss = ApexLoss(self._loss_id, torch.div(self._loss, other.loss), self._optimizer)
108 | else:
109 | raise NotImplementedError("Cannot truediv an element of type: {} to an ApexLoss.".format(str(type(other))))
110 | return loss
111 |
112 | def __rtruediv__(self, other):
113 | if isinstance(other, (Tensor, int, float)):
114 | loss = ApexLoss(self._loss_id, torch.div(other, self._loss), self._optimizer)
115 | else:
116 | raise NotImplementedError("Cannot rtruediv an element of type: {} to an ApexLoss.".format(str(type(other))))
117 | return loss
118 |
119 | def __sub__(self, other):
120 | if isinstance(other, (Tensor, int, float)):
121 | loss = ApexLoss(self._loss_id, torch.sub(self._loss, other), self._optimizer)
122 | elif isinstance(other, ApexLoss):
123 | loss = ApexLoss(self._loss_id, torch.sub(self._loss, other.loss), self._optimizer)
124 | else:
125 | raise NotImplementedError(
126 | "Cannot substract an element of type: {} to an ApexLoss.".format(str(type(other))))
127 | return loss
128 |
129 | def __rsub__(self, other):
130 | if isinstance(other, (Tensor, int, float)):
131 | loss = ApexLoss(self._loss_id, torch.sub(other, self._loss), self._optimizer)
132 | else:
133 | raise NotImplementedError(
134 | "Cannot substract an element of type: {} to an ApexLoss.".format(str(type(other))))
135 | return loss
136 |
137 | def __eq__(self, other):
138 | if isinstance(other, Tensor):
139 | is_equal = torch.all(torch.eq(self._loss, other))
140 | elif isinstance(other, ApexLoss):
141 | is_equal = torch.all(torch.eq(self._loss, other.loss))
142 | else:
143 | raise NotImplementedError("Cannot compare an element of type: {} to an ApexLoss.".format(str(type(other))))
144 | return is_equal
145 |
146 |
147 | class ApexModule(ABC, nn.Module):
148 | def __init__(self, model, optimizer, amp_id=0, use_amp=True):
149 | super().__init__()
150 | self._amp_id = amp_id
151 | self._use_amp = use_amp
152 | self._model = model
153 | self._optimizer = optimizer
154 |
155 | @property
156 | def model(self):
157 | return self._model
158 |
159 | @model.setter
160 | def model(self, value):
161 | self._model = value
162 |
163 | @property
164 | def optimizer(self):
165 | return self._optimizer
166 |
167 | @property
168 | def amp_id(self):
169 | return self._amp_id
170 |
171 | @property
172 | def use_amp(self):
173 | return self._use_amp
174 |
175 | def initialize(self, amp_id: int, num_losses: int, use_amp: bool, amp_opt_level: str, device: torch.device):
176 | self._amp_id = amp_id
177 | self._use_amp = use_amp
178 |
179 | if APEX_AVAILABLE and self._use_amp:
180 | self._model, self._optimizer = amp.initialize(
181 | self._model, self._optimizer, opt_level=amp_opt_level, num_losses=num_losses)
182 | if on_multiple_gpus(get_devices()):
183 | self._model = ApexDDP(self._model, delay_allreduce=True)
184 | if not APEX_AVAILABLE and on_multiple_gpus(get_devices()):
185 | self._model = DDP(self._model, device_ids=[device])
186 |
--------------------------------------------------------------------------------
/tests/metrics/test_metrics.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2019 SAMITorch Authors. All Rights Reserved.
3 | #
4 | # Licensed under the MIT License;
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://opensource.org/licenses/MIT
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 |
18 | import torch
19 | import numpy as np
20 | import unittest
21 |
22 | from kerosene.metrics.metrics import Dice, GeneralizedDice
23 | from hamcrest import *
24 |
25 |
26 | def get_y_true_y_pred():
27 | # Generate an image with labels 0 (background), 1, 2
28 | # 3 classes:
29 | y_true = np.zeros((30, 30), dtype=np.int)
30 | y_true[1:11, 1:11] = 1
31 | y_true[15:25, 15:25] = 2
32 |
33 | y_pred = np.zeros((30, 30), dtype=np.int)
34 | y_pred[5:15, 1:11] = 1
35 | y_pred[20:30, 20:30] = 2
36 | return y_true, y_pred
37 |
38 |
39 | def compute_tensor_y_true_y_logits(y_true, y_pred):
40 | # Create torch.tensor from numpy
41 | y_true_tensor = torch.from_numpy(y_true).unsqueeze(0).type(torch.long)
42 | # Create logits torch.tensor:
43 | num_classes = max(np.max(y_true), np.max(y_pred)) + 1
44 | y_probs = np.ones((num_classes,) + y_true.shape) * -10
45 | for i in range(num_classes):
46 | y_probs[i, (y_pred == i)] = 720
47 | y_logits = torch.from_numpy(y_probs).unsqueeze(0)
48 | return y_true_tensor, y_logits
49 |
50 |
51 | def compute_dice_truth(y_true, y_pred):
52 | true_res = [0, 0, 0]
53 | for index in range(3):
54 | bin_y_true = y_true == index
55 | bin_y_pred = y_pred == index
56 | intersection = bin_y_true & bin_y_pred
57 | true_res[index] = 2 * intersection.sum() / (bin_y_pred.sum() + bin_y_true.sum())
58 | return true_res
59 |
60 |
61 | def compute_generalized_dice_truth(y_true, y_pred):
62 | true_res = [0, 0, 0]
63 | weights = [0, 0, 0]
64 | for index in range(3):
65 | bin_y_true = y_true == index
66 | bin_y_pred = y_pred == index
67 | weights[index] = (1.0 / (np.sum(bin_y_true) * np.sum(bin_y_true) + 1e-15))
68 | intersection = (bin_y_true & bin_y_pred)
69 | true_res[index] = 2 * intersection.sum() * weights[index] / (
70 | ((bin_y_pred.sum() + bin_y_true.sum()) * weights[index]) + 1e-15)
71 | return true_res, weights
72 |
73 |
74 | class TestDiceMetric(unittest.TestCase):
75 | INVALID_VALUE_1 = -1
76 | INVALID_REDUCTION = "sum"
77 | INVALID_VALUE_3 = 10
78 | INVALID_VALUE_4 = 11
79 |
80 | def setUp(self):
81 | self.y_true, self.y_pred = get_y_true_y_pred()
82 | self.y_true_tensor, self.y_logits = compute_tensor_y_true_y_logits(self.y_true, self.y_pred)
83 | self.dice_truth = compute_dice_truth(self.y_true, self.y_pred)
84 | self.mean_dice_truth = np.mean(self.dice_truth)
85 |
86 | def test_should_compute_dice_for_multiclass(self):
87 | dice_coefficient = Dice(num_classes=3, reduction=None)
88 | dice_coefficient.update((self.y_logits, self.y_true_tensor))
89 | res = dice_coefficient.compute().numpy()
90 | assert np.all(res == self.dice_truth)
91 |
92 | def test_should_compute_mean_dice_for_multiclass(self):
93 | dice_coefficient = Dice(num_classes=3, reduction="mean")
94 | dice_coefficient.update((self.y_logits, self.y_true_tensor))
95 | res = dice_coefficient.compute().numpy()
96 | truth = np.array(self.dice_truth).mean()
97 | assert res == truth
98 |
99 | def test_should_compute_dice_for_multiclass_with_ignored_index(self):
100 | for ignore_index in range(3):
101 | dice_coefficient = Dice(num_classes=3, ignore_index=ignore_index, reduction=None)
102 | dice_coefficient.update((self.y_logits, self.y_true_tensor))
103 | res = dice_coefficient.compute().numpy()
104 | true_res = self.dice_truth[:ignore_index] + self.dice_truth[ignore_index + 1:]
105 | assert np.all(res == true_res), "{}: {} vs {}".format(ignore_index, res, true_res)
106 |
107 | def test_should_compute_mean_dice(self):
108 | mean_dice_coefficient = Dice(num_classes=3, reduction="mean")
109 | mean_dice_coefficient.update((self.y_logits, self.y_true_tensor))
110 | res = mean_dice_coefficient.compute().numpy()
111 | assert_that(res, equal_to(self.mean_dice_truth))
112 |
113 | def test_should_compute_mean_dice_with_ignored_index(self):
114 | for ignore_index in range(3):
115 | mean_dice_coefficient = Dice(num_classes=3, reduction="mean", ignore_index=ignore_index)
116 | mean_dice_coefficient.update((self.y_logits, self.y_true_tensor))
117 | res = mean_dice_coefficient.compute().numpy()
118 | true_res = np.mean(self.dice_truth[:ignore_index] + self.dice_truth[ignore_index + 1:])
119 | assert_that(res, equal_to(true_res)), "{}: {} vs {}".format(ignore_index, res, true_res)
120 |
121 |
122 | class TestGeneralizedDiceMetric(unittest.TestCase):
123 | INVALID_VALUE_1 = -1
124 | INVALID_VALUE_2 = "STEVE JOBS"
125 | INVALID_VALUE_3 = 10
126 | INVALID_VALUE_4 = 11
127 |
128 | def setUp(self):
129 | self.y_true, self.y_pred = get_y_true_y_pred()
130 | self.y_true_tensor, self.y_logits = compute_tensor_y_true_y_logits(self.y_true, self.y_pred)
131 | self.generalized_dice_truth, weights = compute_generalized_dice_truth(self.y_true, self.y_pred)
132 | self.weights = torch.from_numpy(np.array(weights))
133 | self.generalized_mean_dice_truth = np.mean(self.generalized_dice_truth)
134 |
135 | def test_should_compute_dice_for_multiclass(self):
136 | generalized_dice_coefficient = GeneralizedDice(num_classes=3)
137 | generalized_dice_coefficient.update((self.y_logits, self.y_true_tensor))
138 | res = generalized_dice_coefficient.compute().numpy()
139 | assert np.all(res == self.generalized_dice_truth)
140 |
141 | def test_should_compute_dice_for_multiclass_with_ignored_index(self):
142 | for ignore_index in range(3):
143 | generalized_dice_coefficient = GeneralizedDice(num_classes=3, ignore_index=ignore_index)
144 | generalized_dice_coefficient.update((self.y_logits, self.y_true_tensor))
145 | res = generalized_dice_coefficient.compute().numpy()
146 | true_res = self.generalized_dice_truth[:ignore_index] + self.generalized_dice_truth[ignore_index + 1:]
147 | assert np.all(res == true_res), "{}: {} vs {}".format(ignore_index, res, true_res)
148 |
149 | def test_should_compute_mean_dice(self):
150 | mean_generalized_dice_coefficient = GeneralizedDice(num_classes=3, reduction="mean")
151 | mean_generalized_dice_coefficient.update((self.y_logits, self.y_true_tensor))
152 | res = mean_generalized_dice_coefficient.compute().numpy()
153 | assert_that(res, equal_to(self.generalized_mean_dice_truth))
154 |
155 | def test_should_compute_mean_dice_with_ignored_index(self):
156 | for ignore_index in range(3):
157 | mean_generalized_dice_coefficient = GeneralizedDice(num_classes=3, reduction="mean", ignore_index=ignore_index)
158 | mean_generalized_dice_coefficient.update((self.y_logits, self.y_true_tensor))
159 | res = mean_generalized_dice_coefficient.compute().numpy()
160 | true_res = np.mean(
161 | self.generalized_dice_truth[:ignore_index] + self.generalized_dice_truth[ignore_index + 1:])
162 | assert_that(res, equal_to(true_res)), "{}: {} vs {}".format(ignore_index, res, true_res)
--------------------------------------------------------------------------------
/kerosene/events/handlers/console.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2019 Kerosene Authors. All Rights Reserved.
3 | #
4 | # Licensed under the MIT License;
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://opensource.org/licenses/MIT
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | import logging
17 | import math
18 | from abc import ABC
19 | from enum import Enum
20 | from typing import Dict, Union, List, Optional
21 |
22 | from kerosene.events import Phase, TemporalEvent, BaseEvent, Monitor
23 | from kerosene.events.handlers.base_handler import EventHandler
24 | from kerosene.training import Status, BaseStatus
25 | from kerosene.training.events import Event
26 | from kerosene.training.trainers import Trainer
27 | import crayons
28 | from beautifultable import BeautifulTable
29 |
30 |
31 | class ConsoleColors(Enum):
32 | PURPLE = '\033[95m'
33 | BLUE = '\033[94m'
34 | GREEN = '\033[92m'
35 | YELLOW = '\033[93m'
36 | RED = '\033[91m'
37 | ENDC = '\033[0m'
38 | NONE = ""
39 |
40 | def __str__(self):
41 | return self.value
42 |
43 |
44 | class StatusConsoleColorPalette(Dict[Status, ConsoleColors]):
45 | DEFAULT = {
46 | Phase.TRAINING: ConsoleColors.GREEN,
47 | Phase.VALIDATION: ConsoleColors.BLUE,
48 | Phase.TEST: ConsoleColors.PURPLE
49 | }
50 |
51 |
52 | class BaseConsoleLogger(EventHandler, ABC):
53 | LOGGER = logging.getLogger("ConsoleLogger")
54 |
55 | def __init__(self, supported_events: List[Union[BaseEvent, TemporalEvent]] = None, every=1):
56 | super(BaseConsoleLogger, self).__init__(supported_events, every)
57 |
58 |
59 | class ColoredConsoleLogger(BaseConsoleLogger, ABC):
60 |
61 | def __init__(self, supported_events: List[Union[BaseEvent, TemporalEvent]] = None, every=1,
62 | colors: Dict[object, ConsoleColors] = None):
63 | super().__init__(supported_events, every)
64 | self._colors = colors if colors is not None else {}
65 |
66 | def color(self, text, color_key):
67 | return str(self._colors.get(color_key, ConsoleColors.NONE)) + str(text) + str(ConsoleColors.ENDC)
68 |
69 |
70 | class PrintTrainingStatus(ColoredConsoleLogger):
71 | SUPPORTED_EVENTS = [Event.ON_BATCH_END, Event.ON_EPOCH_END, Event.ON_TRAIN_BATCH_END, Event.ON_VALID_BATCH_END,
72 | Event.ON_TEST_BATCH_END]
73 |
74 | def __init__(self, every=1, colors: Union[Dict[BaseStatus, ConsoleColors], StatusConsoleColorPalette] = None):
75 | super().__init__(self.SUPPORTED_EVENTS, every, colors)
76 |
77 | def __call__(self, event: TemporalEvent, monitors: dict, trainer: Trainer):
78 | if self.should_handle(event):
79 | self.print_status(event.phase, trainer.epoch, trainer.current_train_step, trainer.current_valid_step,
80 | trainer.current_test_step)
81 |
82 | def print_status(self, status, epoch, train_step, valid_step, test_step):
83 | self.LOGGER.info(
84 | "\nCurrent states: {} | Epoch: {} | Training step: {} | Validation step: {} | Test step: {}\n".format(
85 | self.color(status, color_key=status), epoch, train_step, valid_step, test_step))
86 |
87 |
88 | class PrintMonitors(BaseConsoleLogger):
89 | SUPPORTED_EVENTS = [Event.ON_BATCH_END, Event.ON_EPOCH_END, Event.ON_TRAIN_BATCH_END, Event.ON_VALID_BATCH_END,
90 | Event.ON_TEST_BATCH_END]
91 |
92 | def __init__(self, every=1):
93 | super().__init__(self.SUPPORTED_EVENTS, every)
94 | self._monitors = {}
95 |
96 | def __call__(self, event: TemporalEvent, monitors: dict, trainer: Trainer):
97 | if self.should_handle(event):
98 | for model, monitor in monitors.items():
99 | if model in self._monitors:
100 | self._monitors[model].update(monitor)
101 | else:
102 | self._monitors[model] = monitor
103 |
104 | self.LOGGER.info("Model {}, {}".format(model, self._monitors[model]))
105 |
106 |
107 | class PrintMonitorsTable(BaseConsoleLogger):
108 | SUPPORTED_EVENTS = [Event.ON_BATCH_END, Event.ON_EPOCH_END, Event.ON_TRAIN_BATCH_END, Event.ON_VALID_BATCH_END,
109 | Event.ON_TEST_BATCH_END]
110 |
111 | def __init__(self, every=1, max_table_width=100):
112 | super().__init__(self.SUPPORTED_EVENTS, every)
113 | self._monitors = {}
114 | self._monitors_tables = {}
115 | self._max_table_width = max_table_width
116 |
117 | def __call__(self, event: TemporalEvent, monitors: dict, trainer: Trainer):
118 | if self.should_handle(event):
119 | for model, monitor in monitors.items():
120 | training_values = {**monitors[model][Phase.TRAINING][Monitor.METRICS],
121 | **monitors[model][Phase.TRAINING][Monitor.LOSS]}
122 | validation_values = {**monitors[model][Phase.VALIDATION][Monitor.METRICS],
123 | **monitors[model][Phase.VALIDATION][Monitor.LOSS]}
124 | test_values = {**monitors[model][Phase.TEST][Monitor.METRICS],
125 | **monitors[model][Phase.TEST][Monitor.LOSS]}
126 |
127 | if model in self._monitors_tables:
128 | self._monitors_tables[model].update(training_values, validation_values, test_values)
129 | else:
130 | self._monitors_tables[model] = MonitorsTable(model, self._max_table_width)
131 | self._monitors_tables[model].update(training_values, validation_values, test_values)
132 |
133 | self._monitors_tables[model].show()
134 |
135 |
136 | class MonitorsTable(object):
137 | def __init__(self, model_name: str, max_width):
138 | self.model_name = model_name
139 | self._training_monitors = {}
140 | self._validation_monitors = {}
141 | self._test_monitors = {}
142 | self._max_width = max_width
143 | self.table = BeautifulTable(max_width)
144 |
145 | def append(self, values: dict, old_values: dict):
146 | self.table.rows.append(list(map(lambda key: self.color(values[key], old_values.get(key, None)), values.keys())))
147 |
148 | def update(self, training_monitors: dict, validation_monitors: dict, test_monitors: Optional[dict] = None):
149 | self.table.clear()
150 | self.append(training_monitors, self._training_monitors)
151 | self.append(validation_monitors, self._validation_monitors)
152 |
153 | if test_monitors is not None:
154 | self.append(test_monitors, self._test_monitors)
155 | self.table.rows.header = ["Training", "Validation", "Test"]
156 | else:
157 | self.table.rows.header = ["Training", "Validation"]
158 |
159 | self.table.columns.header = training_monitors.keys()
160 |
161 | self._training_monitors = dict(training_monitors)
162 | self._validation_monitors = dict(validation_monitors)
163 | self._test_monitors = dict(test_monitors)
164 |
165 | def show(self):
166 | self.table._compute_width()
167 | width = self.table._width - 2 if self.table._width == self._max_width else self.table._width + 11
168 | topline = "".join(["+", "-" * width, "+"])
169 | print(topline)
170 | spc = (len(topline) - 2 - len(self.model_name)) / 2
171 | print("%s%s%s%s%s" % ("|", " " * math.ceil(spc), self.model_name, " " * math.floor(spc), "|"))
172 | print(self.table)
173 |
174 | def color(self, value, old_value):
175 | if old_value is not None:
176 | if value > old_value:
177 | return crayons.green("{} \u2197".format(value), bold=True)
178 | if value < old_value:
179 | return crayons.red("{} \u2198".format(value), bold=True)
180 |
181 | return crayons.white("{} \u2192".format(value), bold=True)
182 |
--------------------------------------------------------------------------------
/kerosene/training/events.py:
--------------------------------------------------------------------------------
1 | from kerosene.events import TemporalEvent, Moment, Frequency, Phase, BaseEvent
2 | from kerosene.training import Status
3 |
4 |
5 | class Event(BaseEvent):
6 | ON_TRAINING_BEGIN = "training_begin"
7 | ON_TRAINING_END = "training_end"
8 | ON_VALID_BEGIN = "valid_begin"
9 | ON_VALID_END = "valid_end"
10 | ON_TEST_BEGIN = "test_begin"
11 | ON_TEST_END = "test_end"
12 | ON_EPOCH_BEGIN = "epoch_begin"
13 | ON_EPOCH_END = "epoch_end"
14 | ON_TRAIN_EPOCH_BEGIN = "train_epoch_begin"
15 | ON_TRAIN_EPOCH_END = "train_epoch_end"
16 | ON_VALID_EPOCH_BEGIN = "valid_epoch_begin"
17 | ON_VALID_EPOCH_END = "valid_epoch_end"
18 | ON_TEST_EPOCH_BEGIN = "test_epoch_begin"
19 | ON_TEST_EPOCH_END = "test_epoch_end"
20 | ON_BATCH_BEGIN = "batch_begin"
21 | ON_TRAIN_BATCH_BEGIN = "train_batch_begin"
22 | ON_TRAIN_BATCH_END = "train_batch_end"
23 | ON_VALID_BATCH_BEGIN = "valid_batch_begin"
24 | ON_VALID_BATCH_END = "valid_batch_end"
25 | ON_TEST_BATCH_BEGIN = "test_batch_begin"
26 | ON_TEST_BATCH_END = "test_batch_end"
27 | ON_BATCH_END = "batch_end"
28 | ON_FINALIZE = "finalizing"
29 |
30 | def __call__(self, moment: Moment):
31 | return TemporalEvent(self, moment)
32 |
33 |
34 | class BatchEventPublisherMixin(object):
35 | @property
36 | def iteration_and_phase(self):
37 | if self.status is Status.TRAINING:
38 | iteration = self.current_train_step
39 | phase = Phase.TRAINING
40 | elif self.status is Status.VALIDATING:
41 | iteration = self.current_valid_step
42 | phase = Phase.VALIDATION
43 | else:
44 | iteration = self.current_test_step
45 | phase = Phase.TEST
46 |
47 | return iteration, phase
48 |
49 | def on_batch_begin(self):
50 | pass
51 |
52 | def on_batch_end(self):
53 | pass
54 |
55 | def on_train_batch_begin(self):
56 | pass
57 |
58 | def on_train_batch_end(self):
59 | pass
60 |
61 | def on_valid_batch_begin(self):
62 | pass
63 |
64 | def on_valid_batch_end(self):
65 | pass
66 |
67 | def on_test_batch_begin(self):
68 | pass
69 |
70 | def on_test_batch_end(self):
71 | pass
72 |
73 | def _on_batch_begin(self):
74 | self.on_batch_begin()
75 |
76 | iteration, phase = self.iteration_and_phase
77 |
78 | self.fire(Event.ON_BATCH_BEGIN(Moment(iteration, Frequency.STEP, phase)))
79 |
80 | def _on_batch_end(self):
81 | self.on_batch_end()
82 |
83 | iteration, phase = self.iteration_and_phase
84 |
85 | self.fire(Event.ON_BATCH_END(Moment(iteration, Frequency.STEP, phase)), self.step_monitors(phase))
86 |
87 | def _on_train_batch_begin(self):
88 | self.on_train_batch_begin()
89 | self.fire(Event.ON_TRAIN_BATCH_BEGIN(Moment(self.current_train_step, Frequency.STEP, Phase.TRAINING)))
90 |
91 | def _on_train_batch_end(self):
92 | self.on_train_batch_end()
93 | self.fire(Event.ON_TRAIN_BATCH_END(Moment(self.current_train_step, Frequency.STEP, Phase.TRAINING)),
94 | self.step_monitors(Phase.TRAINING))
95 |
96 | def _on_valid_batch_begin(self):
97 | self.on_valid_batch_begin()
98 | self.fire(Event.ON_VALID_BATCH_BEGIN(Moment(self.current_valid_step, Frequency.STEP, Phase.VALIDATION)))
99 |
100 | def _on_valid_batch_end(self):
101 | self.on_valid_batch_begin()
102 | self.fire(Event.ON_VALID_BATCH_END(Moment(self.current_valid_step, Frequency.STEP, Phase.VALIDATION)),
103 | self.step_monitors(Phase.VALIDATION))
104 |
105 | def _on_test_batch_begin(self):
106 | self.on_test_batch_begin()
107 | self.fire(Event.ON_TEST_BATCH_BEGIN(Moment(self.current_test_step, Frequency.STEP, Phase.TEST)))
108 |
109 | def _on_test_batch_end(self):
110 | self.on_test_batch_end()
111 | self.fire(Event.ON_TEST_BATCH_END(Moment(self.current_test_step, Frequency.STEP, Phase.TEST)),
112 | self.step_monitors(Phase.TEST))
113 |
114 |
115 | class EpochEventPublisherMixin(object):
116 | @property
117 | def epoch_and_phase(self):
118 | if self.status is Status.TRAINING:
119 | phase = Phase.TRAINING
120 | elif self.status is Status.VALIDATING:
121 | phase = Phase.VALIDATION
122 | else:
123 | phase = Phase.TEST
124 |
125 | return self.epoch, phase
126 |
127 | @property
128 | def phase(self):
129 | if self.status is Status.TRAINING:
130 | phase = Phase.TRAINING
131 | elif self.status is Status.VALIDATING:
132 | phase = Phase.VALIDATION
133 | else:
134 | phase = Phase.TEST
135 |
136 | return phase
137 |
138 | def on_epoch_begin(self):
139 | pass
140 |
141 | def on_epoch_end(self):
142 | pass
143 |
144 | def on_train_epoch_begin(self):
145 | pass
146 |
147 | def on_train_epoch_end(self):
148 | pass
149 |
150 | def on_valid_epoch_begin(self):
151 | pass
152 |
153 | def on_valid_epoch_end(self):
154 | pass
155 |
156 | def on_test_epoch_begin(self):
157 | pass
158 |
159 | def on_test_epoch_end(self):
160 | pass
161 |
162 | def _on_epoch_begin(self):
163 | self._reset_model_trainers()
164 | self.on_epoch_begin()
165 |
166 | epoch, phase = self.epoch_and_phase
167 |
168 | self.fire(Event.ON_EPOCH_BEGIN(Moment(epoch, Frequency.EPOCH, phase)))
169 |
170 | def _on_epoch_end(self):
171 | self.on_epoch_end()
172 |
173 | epoch, phase = self.epoch_and_phase
174 |
175 | self.fire(Event.ON_EPOCH_END(Moment(epoch, Frequency.EPOCH, phase)), self.epoch_monitors(phase))
176 |
177 | def _on_train_epoch_begin(self):
178 | self._status = Status.TRAINING
179 | self.on_train_epoch_begin()
180 | self.fire(Event.ON_TRAIN_EPOCH_BEGIN(Moment(self.epoch, Frequency.EPOCH, Phase.TRAINING)))
181 |
182 | def _on_train_epoch_end(self):
183 | self.on_train_epoch_end()
184 | self.fire(Event.ON_TRAIN_EPOCH_END(Moment(self.epoch, Frequency.EPOCH, Phase.TRAINING)),
185 | self.epoch_monitors(Phase.TRAINING))
186 |
187 | def _on_valid_epoch_begin(self):
188 | self.on_valid_epoch_begin()
189 | self.fire(Event.ON_VALID_EPOCH_BEGIN(Moment(self.epoch, Frequency.EPOCH, Phase.VALIDATION)))
190 |
191 | def _on_valid_epoch_end(self):
192 | self.on_valid_epoch_end()
193 | self.fire(Event.ON_VALID_EPOCH_END(Moment(self.epoch, Frequency.EPOCH, Phase.VALIDATION)),
194 | self.epoch_monitors(Phase.VALIDATION))
195 |
196 | def _on_test_epoch_begin(self):
197 | self.on_test_epoch_begin()
198 | self.fire(Event.ON_TEST_EPOCH_BEGIN(Moment(self.epoch, Frequency.EPOCH, Phase.TEST)))
199 |
200 | def _on_test_epoch_end(self):
201 | self._current_valid_batch = 0
202 | self._current_test_batch = 0
203 | self.on_test_epoch_end()
204 | self.fire(Event.ON_TEST_EPOCH_END(Moment(self.epoch, Frequency.STEP, Phase.TEST)),
205 | self.epoch_monitors(Phase.TEST))
206 |
207 |
208 | class TrainingPhaseEventPublisherMixin(object):
209 | def on_training_begin(self):
210 | pass
211 |
212 | def on_training_end(self):
213 | pass
214 |
215 | def on_valid_begin(self):
216 | pass
217 |
218 | def on_valid_end(self):
219 | pass
220 |
221 | def on_test_begin(self):
222 | pass
223 |
224 | def on_test_end(self):
225 | pass
226 |
227 | def finalize(self):
228 | pass
229 |
230 | def _on_training_begin(self):
231 | self.on_training_begin()
232 | self._status = Status.TRAINING
233 | self.fire(Event.ON_TRAINING_BEGIN(Moment(0, Frequency.PHASE, Phase.TRAINING)))
234 |
235 | def _on_training_end(self):
236 | self.on_training_end()
237 | self.scheduler_step()
238 | self.fire(Event.ON_TRAINING_END(Moment(0, Frequency.PHASE, Phase.TRAINING)))
239 |
240 | def _on_valid_begin(self):
241 | self.on_valid_begin()
242 | self._status = Status.VALIDATING
243 | self.fire(Event.ON_VALID_BEGIN(Moment(0, Frequency.PHASE, Phase.VALIDATION)))
244 |
245 | def _on_valid_end(self):
246 | self.on_valid_end()
247 | self.fire(Event.ON_VALID_END(Moment(0, Frequency.PHASE, Phase.VALIDATION)))
248 |
249 | def _on_test_begin(self):
250 | self.on_test_begin()
251 | self._status = Status.TESTING
252 | self.fire(Event.ON_TEST_BEGIN(Moment(0, Frequency.PHASE, Phase.TEST)))
253 |
254 | def _on_test_end(self):
255 | self.on_test_end()
256 | self.fire(Event.ON_TEST_END(Moment(0, Frequency.PHASE, Phase.TEST)))
257 |
258 | def _finalize(self):
259 | for model_trainer in self.model_trainers:
260 | model_trainer.finalize()
261 | self.finalize()
262 | self.status = Status.FINALIZED
263 |
--------------------------------------------------------------------------------
/kerosene/loggers/visdom/plots.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2019 Kerosene Authors. All Rights Reserved.
3 | #
4 | # Licensed under the MIT License;
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://opensource.org/licenses/MIT
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | from abc import ABC, abstractmethod
17 |
18 | from visdom import Visdom
19 |
20 | from kerosene.loggers.visdom import PlotType
21 | from kerosene.loggers.visdom.data import VisdomData
22 |
23 |
24 | class VisdomPlot(ABC):
25 |
26 | def __init__(self, visdom: Visdom):
27 | self._visdom = visdom
28 | self._window = None
29 |
30 | @property
31 | def visdom(self):
32 | return self._visdom
33 |
34 | @property
35 | def window(self):
36 | return self._window
37 |
38 | @abstractmethod
39 | def update(self, visdom_data: VisdomData):
40 | raise NotImplementedError()
41 |
42 |
43 | class ImagePlot(VisdomPlot):
44 | def __init__(self, visdom: Visdom):
45 | super().__init__(visdom)
46 |
47 | def update(self, visdom_data: VisdomData):
48 | if self._window is None:
49 | self._window = self.visdom.image(img=visdom_data.y, **visdom_data.params)
50 | else:
51 | self.visdom.image(img=visdom_data.y, win=self._window, **visdom_data.params)
52 |
53 |
54 | class ImagesPlot(VisdomPlot):
55 | def __init__(self, visdom: Visdom):
56 | super().__init__(visdom)
57 |
58 | def update(self, visdom_data: VisdomData):
59 | if self._window is None:
60 | self._window = self._visdom.images(tensor=visdom_data.y, **visdom_data.params)
61 | else:
62 | self._visdom.images(tensor=visdom_data.y, win=self._window, **visdom_data.params)
63 |
64 |
65 | class LinePlot(VisdomPlot):
66 | def __init__(self, visdom):
67 | super().__init__(visdom)
68 |
69 | def update(self, visdom_data: VisdomData):
70 | if self._window is None:
71 | self._window = self._visdom.line(X=visdom_data.x, Y=visdom_data.y, **visdom_data.params)
72 | else:
73 | self._visdom.line(X=visdom_data.x, Y=visdom_data.y, win=self._window, update='append',
74 | name=visdom_data.params['opts']['name'])
75 |
76 |
77 | class PiePlot(VisdomPlot):
78 | def __init__(self, visdom):
79 | super().__init__(visdom)
80 |
81 | def update(self, visdom_data: VisdomData):
82 | if self._window is None:
83 | self._window = self._visdom.pie(X=visdom_data.y, **visdom_data.params)
84 | else:
85 | self._visdom.pie(X=visdom_data.y, win=self._window, **visdom_data.params)
86 |
87 |
88 | class TextPlot(VisdomPlot):
89 | def __init__(self, visdom):
90 | super().__init__(visdom)
91 |
92 | def update(self, visdom_data: VisdomData):
93 | if self._window is None:
94 | self._window = self._visdom.text(visdom_data.y)
95 |
96 | else:
97 | self._visdom.text(visdom_data.y, win=self._window)
98 |
99 |
100 | class HistogramPlot(VisdomPlot):
101 |
102 | def __init__(self, visdom: Visdom):
103 | super().__init__(visdom)
104 |
105 | def update(self, visdom_data: VisdomData):
106 | if self._window is None:
107 | self._window = self._visdom.histogram(X=visdom_data.y, **visdom_data.params)
108 | else:
109 | self._visdom.histogram(X=visdom_data.y, win=self._window, **visdom_data.params)
110 |
111 |
112 | class ScatterPlot(VisdomPlot):
113 |
114 | def __init__(self, visdom: Visdom):
115 | super().__init__(visdom)
116 |
117 | def update(self, visdom_data: VisdomData):
118 | if self._window is None:
119 | self._window = self._visdom.scatter(X=visdom_data.x, Y=visdom_data.y, **visdom_data.params)
120 | else:
121 | self._visdom.scatter(X=visdom_data.x, Y=visdom_data.y, win=self._window, **visdom_data.params)
122 |
123 |
124 | class StemPlot(VisdomPlot):
125 |
126 | def __init__(self, visdom: Visdom):
127 | super().__init__(visdom)
128 |
129 | def update(self, visdom_data: VisdomData):
130 | if self._window is None:
131 | self._window = self._visdom.stem(X=visdom_data.x, Y=visdom_data.y, **visdom_data.params)
132 | else:
133 | self._visdom.stem(X=visdom_data.x, Y=visdom_data.y, win=self._window, **visdom_data.params)
134 |
135 |
136 | class HeatmapPlot(VisdomPlot):
137 |
138 | def __init__(self, visdom: Visdom):
139 | super().__init__(visdom)
140 |
141 | def update(self, visdom_data: VisdomData):
142 | if self._window is None:
143 | self._window = self._visdom.heatmap(X=visdom_data.y, **visdom_data.params)
144 | else:
145 | self._visdom.heatmap(X=visdom_data.y, win=self._window, **visdom_data.params)
146 |
147 |
148 | class BoxPlot(VisdomPlot):
149 |
150 | def __init__(self, visdom: Visdom):
151 | super().__init__(visdom)
152 |
153 | def update(self, visdom_data: VisdomData):
154 | if self._window is None:
155 | self._window = self._visdom.boxplot(X=visdom_data.y, **visdom_data.params)
156 | else:
157 | self._visdom.boxplot(X=visdom_data.x, win=self._window, **visdom_data.params)
158 |
159 |
160 | class SurfacePlot(VisdomPlot):
161 |
162 | def __init__(self, visdom: Visdom):
163 | super().__init__(visdom)
164 |
165 | def update(self, visdom_data: VisdomData):
166 | if self._window is None:
167 | self._window = self._visdom.surf(X=visdom_data.y, **visdom_data.params)
168 | else:
169 | self._visdom.surf(X=visdom_data.y, win=self._window, **visdom_data.params)
170 |
171 |
172 | class ContourPlot(VisdomPlot):
173 |
174 | def __init__(self, visdom: Visdom):
175 | super().__init__(visdom)
176 |
177 | def update(self, visdom_data: VisdomData):
178 | if self._window is None:
179 | self._window = self._visdom.contour(X=visdom_data.y, **visdom_data.params)
180 | else:
181 | self._visdom.contour(X=visdom_data.y, win=self._window, **visdom_data.params)
182 |
183 |
184 | class QuiverPlot(VisdomPlot):
185 |
186 | def __init__(self, visdom: Visdom):
187 | super().__init__(visdom)
188 |
189 | def update(self, visdom_data: VisdomData):
190 | if self._window is None:
191 | self._window = self._visdom.quiver(X=visdom_data.x, Y=visdom_data.y, **visdom_data.params)
192 | else:
193 | self._visdom.quiver(X=visdom_data.x, Y=visdom_data.y, **visdom_data.params)
194 |
195 |
196 | class MeshPlot(VisdomPlot):
197 |
198 | def __init__(self, visdom: Visdom):
199 | super().__init__(visdom)
200 |
201 | def update(self, visdom_data: VisdomData):
202 | if self._window is None:
203 | self._window = self._visdom.mesh(X=visdom_data.x, Y=visdom_data.y, **visdom_data.params)
204 | else:
205 | self._visdom.mesh(X=visdom_data.x, Y=visdom_data.y, **visdom_data.params)
206 |
207 |
208 | class BarPlot(VisdomPlot):
209 | def __init__(self, visdom):
210 | super().__init__(visdom)
211 |
212 | def update(self, visdom_data: VisdomData):
213 | if self._window is None:
214 | self._window = self._visdom.bar(X=visdom_data.x, Y=visdom_data.y, **visdom_data.params)
215 | else:
216 | self._visdom.bar(X=visdom_data.x, Y=visdom_data.y, win=self._window, **visdom_data.params)
217 |
218 |
219 | class MatplotlibPlot(VisdomPlot):
220 | def __init__(self, visdom):
221 | super().__init__(visdom)
222 |
223 | def update(self, visdom_data: VisdomData):
224 | if self._window is None:
225 | self._window = self._visdom.matplot(plot=visdom_data.y, **visdom_data.params)
226 | else:
227 | self._visdom.matplot(plot=visdom_data.y, win=self._window, **visdom_data.params)
228 |
229 |
230 | class VisdomPlotFactory(object):
231 |
232 | def __init__(self):
233 | self._plot = {
234 | PlotType.LINE_PLOT: LinePlot,
235 | PlotType.IMAGE_PLOT: ImagePlot,
236 | PlotType.IMAGES_PLOT: ImagesPlot,
237 | PlotType.PIE_PLOT: PiePlot,
238 | PlotType.TEXT_PLOT: TextPlot,
239 | PlotType.BAR_PLOT: BarPlot,
240 | PlotType.HISTOGRAM_PLOT: HistogramPlot,
241 | PlotType.SCATTER_PLOT: ScatterPlot,
242 | PlotType.STEM_PLOT: StemPlot,
243 | PlotType.HEATMAP_PLOT: HeatmapPlot,
244 | PlotType.BOX_PLOT: BoxPlot,
245 | PlotType.SURFACE_PLOT: SurfacePlot,
246 | PlotType.CONTOUR_PLOT: ContourPlot,
247 | PlotType.QUIVER_PLOT: QuiverPlot,
248 | PlotType.MESH_PLOT: MeshPlot,
249 | PlotType.MATPLOTLIB_PLOT: MatplotlibPlot
250 | }
251 |
252 | def create_plot(self, visdom, plot_type: PlotType):
253 | return self._plot[plot_type](visdom)
254 |
--------------------------------------------------------------------------------
/kerosene/configs/configs.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2019 Kerosene Authors. All Rights Reserved.
3 | #
4 | # Licensed under the MIT License;
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://opensource.org/licenses/MIT
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | import abc
17 | import os
18 | from socket import socket
19 | from typing import List
20 |
21 | import torch
22 |
23 | try:
24 | from torch.distributed import is_nccl_available
25 |
26 | NCCL_AVAILABLE = is_nccl_available()
27 | except ImportError:
28 | NCCL_AVAILABLE = False
29 |
30 | from kerosene.configs.exceptions import InvalidConfigurationError
31 | from kerosene.utils.devices import get_devices, on_multiple_gpus, num_gpus, on_cpu
32 |
33 |
34 | class HtmlConfiguration(object):
35 |
36 | @abc.abstractmethod
37 | def to_html(self):
38 | raise NotImplementedError()
39 |
40 |
41 | class DatasetConfiguration(HtmlConfiguration):
42 | def __init__(self, config_dict):
43 | for key in config_dict:
44 | setattr(self, key, config_dict[key])
45 |
46 | def to_html(self):
47 | configuration_values = '\n'.join("%s: %s
" % item for item in vars(self).items())
48 | return "Dataset Configuration
\n {}".format(configuration_values)
49 |
50 |
51 | class Configuration(object):
52 | def __init__(self, name, type, params):
53 | self._name = name
54 | self._type = type
55 | self._params = params
56 |
57 | @property
58 | def name(self):
59 | return self._name
60 |
61 | @name.setter
62 | def name(self, value):
63 | self._name = value
64 |
65 | @property
66 | def type(self):
67 | return self._type
68 |
69 | @type.setter
70 | def type(self, value):
71 | self._type = value
72 |
73 | @property
74 | def params(self):
75 | return self._params
76 |
77 | @params.setter
78 | def params(self, value):
79 | self._params = value
80 |
81 | def update(self, config_dict):
82 | if "type" in config_dict:
83 | self._type = config_dict["type"]
84 |
85 | if "params" in config_dict:
86 | self._params.update(config_dict["params"])
87 |
88 | @classmethod
89 | def from_dict(cls, name, config_dict):
90 | return cls(name, config_dict["type"], config_dict.get("params", {}))
91 |
92 |
93 | class ConfigurationList(Configuration):
94 | def __init__(self, configurations: List[Configuration]):
95 | super().__init__(list(map(lambda config: config.name, configurations)),
96 | list(map(lambda config: config.type, configurations)),
97 | list(map(lambda config: config.params, configurations)))
98 | self._configurations = configurations
99 |
100 | def update(self, config_dict):
101 | for configuration in self._configurations:
102 | if configuration.name in config_dict:
103 | configuration.update(config_dict[configuration.name])
104 |
105 | def __len__(self):
106 | return len(self._configurations)
107 |
108 | def __getitem__(self, index):
109 | return self._configurations[index]
110 |
111 | def __iter__(self):
112 | return iter(self._configurations)
113 |
114 |
115 | class ModelConfiguration(Configuration, HtmlConfiguration):
116 | def __init__(self, model_name, model_type, model_params, path, optimizer_config: Configuration,
117 | scheduler_config: Configuration, criterions_configs: ConfigurationList,
118 | metrics_configs: ConfigurationList, gradient_clipping_config: Configuration):
119 | super().__init__(model_name, model_type, model_params)
120 | self._path = path
121 | self._optimizer_config = optimizer_config
122 | self._scheduler_config = scheduler_config
123 | self._criterions_configs = criterions_configs
124 | self._metrics_configs = metrics_configs
125 | self._gradient_clipping_config = gradient_clipping_config
126 |
127 | @property
128 | def path(self):
129 | return self._path
130 |
131 | @property
132 | def optimizer_config(self):
133 | return self._optimizer_config
134 |
135 | @property
136 | def scheduler_config(self):
137 | return self._scheduler_config
138 |
139 | @property
140 | def criterions_configs(self):
141 | return self._criterions_configs
142 |
143 | @property
144 | def metrics_configs(self):
145 | return self._metrics_configs
146 |
147 | @property
148 | def gradient_clipping_config(self):
149 | return self._gradient_clipping_config
150 |
151 | def update(self, config_dict):
152 | if "type" in config_dict:
153 | self._type = config_dict["type"]
154 |
155 | if "params" in config_dict:
156 | self._params.update(config_dict["params"])
157 |
158 | if "optimizer" in config_dict:
159 | self._optimizer_config.update(config_dict["optimizer"])
160 |
161 | if "scheduler" in config_dict:
162 | self._scheduler_config.update(config_dict["scheduler"])
163 |
164 | if "criterions" in config_dict:
165 | self._criterions_configs.update(config_dict["criterions"])
166 |
167 | if "metrics" in config_dict:
168 | self._metrics_configs.update(config_dict["metrics"])
169 |
170 | @classmethod
171 | def from_dict(cls, model_name, config_dict):
172 | try:
173 | optimizer_config = Configuration.from_dict("", config_dict["optimizer"])
174 | scheduler_config = Configuration.from_dict("", config_dict["scheduler"])
175 | criterion_configs = ConfigurationList(
176 | [Configuration.from_dict(criterion, config_dict["criterion"][criterion]) for criterion
177 | in config_dict["criterion"].keys()])
178 |
179 | if "metrics" in config_dict.keys():
180 | metric_configs = ConfigurationList(
181 | [Configuration.from_dict(metric, config_dict["metrics"][metric]) for metric in
182 | config_dict["metrics"].keys()])
183 | else:
184 | metric_configs = None
185 |
186 | if "gradients" in config_dict.keys():
187 | gradient_clipping_config = Configuration.from_dict("", config_dict["gradients"])
188 | else:
189 | gradient_clipping_config = None
190 |
191 | return cls(model_name, config_dict["type"], config_dict.get("params"), config_dict.get("path"),
192 | optimizer_config, scheduler_config, criterion_configs, metric_configs, gradient_clipping_config)
193 | except KeyError as e:
194 | raise InvalidConfigurationError(
195 | "The provided model configuration is invalid. The section {} is missing.".format(e))
196 |
197 | def to_html(self):
198 | configuration_values = '\n'.join("%s: %s
" % item for item in vars(self).items())
199 | return "Model Configuration \n
{}
\n {}".format(self.name, configuration_values)
200 |
201 |
202 | class RunConfiguration(HtmlConfiguration):
203 |
204 | def __init__(self, use_amp: bool = True, amp_opt_level: str = 'O1', local_rank: int = 0,
205 | world_size: int = num_gpus()):
206 | self._use_amp = use_amp
207 | self._amp_opt_level = amp_opt_level
208 | self._local_rank = local_rank
209 | self._world_size = world_size
210 |
211 | self._devices = get_devices()
212 | self._device = self._devices[self._local_rank]
213 |
214 | if not on_cpu(self._device):
215 | torch.cuda.set_device(self._device)
216 | self._initialize_ddp_process_group()
217 |
218 | @property
219 | def use_amp(self):
220 | return self._use_amp
221 |
222 | @property
223 | def amp_opt_level(self):
224 | return self._amp_opt_level
225 |
226 | @property
227 | def local_rank(self):
228 | return self._local_rank
229 |
230 | @property
231 | def world_size(self):
232 | return self._world_size
233 |
234 | @property
235 | def devices(self):
236 | return self._devices
237 |
238 | @property
239 | def device(self):
240 | return self._device
241 |
242 | @device.setter
243 | def device(self, device):
244 | self._device = device
245 |
246 | @staticmethod
247 | def _get_random_free_port():
248 | with socket() as s:
249 | s.bind(("", 0))
250 | return s.getsockname()[1]
251 |
252 | def _initialize_ddp_process_group(self):
253 | if on_multiple_gpus(self._devices):
254 | if NCCL_AVAILABLE:
255 | if os.environ.get("MASTER_ADDR") is None:
256 | os.environ["MASTER_ADDR"] = "127.0.0.1"
257 | if os.environ.get("MASTER_PORT") is None:
258 | os.environ["MASTER_PORT"] = str(self._get_random_free_port())
259 | if os.environ.get("WORLD_SIZE") is None:
260 | os.environ["WORLD_SIZE"] = str(self._world_size)
261 | torch.distributed.init_process_group(backend='nccl', init_method='env://',
262 | world_size=int(os.environ["WORLD_SIZE"]), rank=self._local_rank)
263 | else:
264 | raise Exception("NCCL not available and required for multi-GPU training.")
265 |
266 | def to_html(self):
267 | configuration_values = '\n'.join("%s: %s
" % item for item in vars(self).items())
268 | return "Run Configuration
\n {}".format(configuration_values)
269 |
270 |
271 | class TrainerConfiguration(HtmlConfiguration):
272 | def __init__(self, config_dict):
273 | for key in config_dict:
274 | setattr(self, key, config_dict[key])
275 |
276 | def to_html(self):
277 | configuration_values = '\n'.join("%s: %s
" % item for item in vars(self).items())
278 | return "Training Configuration
\n {}".format(configuration_values)
279 |
--------------------------------------------------------------------------------
/tests/events/handlers/test_checkpoints.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import unittest
4 | from unittest.mock import PropertyMock
5 |
6 | import mock
7 | import mockito
8 | import torch
9 | from hamcrest import *
10 | from ignite.metrics import Accuracy
11 | from torch import nn
12 | from torch.optim import Optimizer, lr_scheduler
13 |
14 | from kerosene.events import MonitorMode, Moment, Frequency, Phase, TemporalEvent, Monitor
15 | from kerosene.events.handlers.checkpoints import Checkpoint
16 | from kerosene.nn.utils.gradients import GradientClippingStrategy
17 | from kerosene.training.events import Event
18 | from kerosene.training.trainers import SimpleTrainer, ModelTrainer, ModelTrainerList
19 |
20 |
21 | class ModelCheckpointIfBetterTest(unittest.TestCase):
22 | MODEL_NAME = "test_model"
23 | TRAINER_NAME = "test_trainer"
24 | SAVE_PATH = "tests/output/"
25 |
26 | def setUp(self) -> None:
27 | self._model_mock = mockito.mock(nn.Module)
28 | self._criterion_mock = {"CrossEntropyLoss": mockito.mock(nn.CrossEntropyLoss)}
29 | self._optimizer_mock = mockito.mock(Optimizer)
30 | self._scheduler_mock = mockito.mock(lr_scheduler)
31 | self._metric_computer_mock = {"Accuracy": mockito.mock(Accuracy)}
32 | self._gradient_clipping_strategy = mockito.mock(GradientClippingStrategy)
33 |
34 | self._model_trainer = ModelTrainer(self.MODEL_NAME, self._model_mock, self._criterion_mock,
35 | self._optimizer_mock, self._scheduler_mock, self._metric_computer_mock,
36 | self._gradient_clipping_strategy)
37 |
38 | self._trainer_mock = mockito.mock(SimpleTrainer)
39 | self._trainer_mock.epoch = 0
40 | self._trainer_mock.model_trainers = ModelTrainerList([self._model_trainer])
41 |
42 | def tearDown(self) -> None:
43 | if os.path.exists(self.SAVE_PATH):
44 | shutil.rmtree(self.SAVE_PATH)
45 |
46 | @mock.patch("kerosene.training.trainers.ModelTrainer.optimizer_state", new_callable=PropertyMock)
47 | @mock.patch("kerosene.training.trainers.ModelTrainer.model_state", new_callable=PropertyMock)
48 | def test_should_not_save_model_with_higher_valid_loss(self, model_state_mock, optimizer_states_mock):
49 | model_state_mock.return_value = dict()
50 | optimizer_states_mock.return_value = list(dict())
51 | moment = Moment(200, Frequency.EPOCH, Phase.VALIDATION)
52 |
53 | handler_mock = mockito.spy(Checkpoint(self.SAVE_PATH, self.MODEL_NAME, "MSELoss", 0.01, MonitorMode.MIN))
54 |
55 | monitors = {self.MODEL_NAME: {Phase.TRAINING: {Monitor.METRICS: {}, Monitor.LOSS: {}},
56 | Phase.VALIDATION: {Monitor.METRICS: {},
57 | Monitor.LOSS: {"MSELoss": torch.tensor([0.5])}},
58 | Phase.TEST: {Monitor.METRICS: {}, Monitor.LOSS: {}}}}
59 |
60 | handler_mock(TemporalEvent(Event.ON_EPOCH_END, moment), monitors, self._trainer_mock)
61 |
62 | monitors = {self.MODEL_NAME: {Phase.TRAINING: {Monitor.METRICS: {}, Monitor.LOSS: {}},
63 | Phase.VALIDATION: {Monitor.METRICS: {},
64 | Monitor.LOSS: {"MSELoss": torch.tensor([0.6])}},
65 | Phase.TEST: {Monitor.METRICS: {}, Monitor.LOSS: {}}}}
66 |
67 | handler_mock(TemporalEvent(Event.ON_EPOCH_END, moment), monitors, self._trainer_mock)
68 |
69 | assert_that(not os.path.exists(os.path.join(self.SAVE_PATH, self.MODEL_NAME, self.MODEL_NAME + ".tar")))
70 |
71 | @mock.patch("kerosene.training.trainers.ModelTrainer.optimizer_state", new_callable=PropertyMock)
72 | @mock.patch("kerosene.training.trainers.ModelTrainer.model_state", new_callable=PropertyMock)
73 | def test_should_not_save_model_with_higher_valid_losses(self, model_state_mock, optimizer_states_mock):
74 | model_state_mock.return_value = dict()
75 | optimizer_states_mock.return_value = list(dict())
76 | moment = Moment(200, Frequency.EPOCH, Phase.VALIDATION)
77 |
78 | handler_mock = mockito.spy(
79 | Checkpoint(self.SAVE_PATH, self.MODEL_NAME, ["MSELoss", "L1Loss"], 0.01, MonitorMode.MIN))
80 |
81 | monitors = {self.MODEL_NAME: {Phase.TRAINING: {Monitor.METRICS: {}, Monitor.LOSS: {}},
82 | Phase.VALIDATION: {Monitor.METRICS: {},
83 | Monitor.LOSS: {"MSELoss": torch.tensor([0.5]),
84 | "L1Loss": torch.tensor([0.5])}},
85 | Phase.TEST: {Monitor.METRICS: {}, Monitor.LOSS: {}}}}
86 |
87 | handler_mock(TemporalEvent(Event.ON_EPOCH_END, moment), monitors, self._trainer_mock)
88 |
89 | monitors = {self.MODEL_NAME: {Phase.TRAINING: {Monitor.METRICS: {}, Monitor.LOSS: {}},
90 | Phase.VALIDATION: {Monitor.METRICS: {},
91 | Monitor.LOSS: {"MSELoss": torch.tensor([0.5]),
92 | "L1Loss": torch.tensor([0.6])}},
93 | Phase.TEST: {Monitor.METRICS: {}, Monitor.LOSS: {}}}}
94 |
95 | handler_mock(TemporalEvent(Event.ON_EPOCH_END, moment), monitors, self._trainer_mock)
96 |
97 | assert_that(not os.path.exists(os.path.join(self.SAVE_PATH, self.MODEL_NAME, self.MODEL_NAME + ".tar")))
98 |
99 | @mock.patch("kerosene.training.trainers.ModelTrainer.optimizer_state", new_callable=PropertyMock)
100 | @mock.patch("kerosene.training.trainers.ModelTrainer.model_state", new_callable=PropertyMock)
101 | def test_should_save_model_with_lower_valid_loss(self, model_state_mock, optimizer_states_mock):
102 | model_state_mock.return_value = dict()
103 | optimizer_states_mock.return_value = list(dict())
104 | moment = Moment(200, Frequency.EPOCH, Phase.VALIDATION)
105 |
106 | handler_mock = mockito.spy(Checkpoint(self.SAVE_PATH, self.MODEL_NAME, "MSELoss", 0.01, MonitorMode.MIN))
107 |
108 | monitors = {self.MODEL_NAME: {Phase.TRAINING: {Monitor.METRICS: {}, Monitor.LOSS: {}},
109 | Phase.VALIDATION: {Monitor.METRICS: {},
110 | Monitor.LOSS: {"MSELoss": torch.tensor([0.5])}},
111 | Phase.TEST: {Monitor.METRICS: {}, Monitor.LOSS: {}}}}
112 |
113 | handler_mock(TemporalEvent(Event.ON_EPOCH_END, moment), monitors, self._trainer_mock)
114 |
115 | monitors = {self.MODEL_NAME: {Phase.TRAINING: {Monitor.METRICS: {}, Monitor.LOSS: {}},
116 | Phase.VALIDATION: {Monitor.METRICS: {},
117 | Monitor.LOSS: {"MSELoss": torch.tensor([0.3])}},
118 | Phase.TEST: {Monitor.METRICS: {}, Monitor.LOSS: {}}}}
119 |
120 | handler_mock(TemporalEvent(Event.ON_EPOCH_END, moment), monitors, self._trainer_mock)
121 |
122 | assert_that(os.path.exists(os.path.join(self.SAVE_PATH, self.MODEL_NAME, self.MODEL_NAME + ".tar")))
123 |
124 | @mock.patch("kerosene.training.trainers.ModelTrainer.optimizer_state", new_callable=PropertyMock)
125 | @mock.patch("kerosene.training.trainers.ModelTrainer.model_state", new_callable=PropertyMock)
126 | def test_should_save_model_with_higher_valid_metric(self, model_state_mock, optimizer_states_mock):
127 | model_state_mock.return_value = dict()
128 | optimizer_states_mock.return_value = list(dict())
129 | moment = Moment(200, Frequency.EPOCH, Phase.VALIDATION)
130 |
131 | handler_mock = mockito.spy(Checkpoint(self.SAVE_PATH, self.MODEL_NAME, "Accuracy", 0.01, MonitorMode.MAX))
132 |
133 | monitors = {self.MODEL_NAME: {Phase.TRAINING: {Monitor.METRICS: {}, Monitor.LOSS: {}},
134 | Phase.VALIDATION: {Monitor.METRICS: {"Accuracy": torch.tensor([0.5])},
135 | Monitor.LOSS: {}},
136 | Phase.TEST: {Monitor.METRICS: {}, Monitor.LOSS: {}}}}
137 |
138 | handler_mock(TemporalEvent(Event.ON_EPOCH_END, moment), monitors, self._trainer_mock)
139 |
140 | monitors = {self.MODEL_NAME: {Phase.TRAINING: {Monitor.METRICS: {}, Monitor.LOSS: {}},
141 | Phase.VALIDATION: {Monitor.METRICS: {"Accuracy": torch.tensor([0.6])},
142 | Monitor.LOSS: {}},
143 | Phase.TEST: {Monitor.METRICS: {}, Monitor.LOSS: {}}}}
144 |
145 | handler_mock(TemporalEvent(Event.ON_EPOCH_END, moment), monitors, self._trainer_mock)
146 |
147 | assert_that(os.path.exists(os.path.join(self.SAVE_PATH, self.MODEL_NAME, self.MODEL_NAME + ".tar")))
148 |
149 | @mock.patch("kerosene.training.trainers.ModelTrainer.optimizer_state", new_callable=PropertyMock)
150 | @mock.patch("kerosene.training.trainers.ModelTrainer.model_state", new_callable=PropertyMock)
151 | def test_should_not_save_model_with_lower_valid_metric(self, model_state_mock, optimizer_states_mock):
152 | model_state_mock.return_value = dict()
153 | optimizer_states_mock.return_value = list(dict())
154 | moment = Moment(200, Frequency.EPOCH, Phase.VALIDATION)
155 |
156 | handler_mock = mockito.spy(Checkpoint(self.SAVE_PATH, self.MODEL_NAME, "Accuracy", 0.01, MonitorMode.MAX))
157 |
158 | monitors = {self.MODEL_NAME: {Phase.TRAINING: {Monitor.METRICS: {}, Monitor.LOSS: {}},
159 | Phase.VALIDATION: {Monitor.METRICS: {"Accuracy": torch.tensor([0.5])},
160 | Monitor.LOSS: {}},
161 | Phase.TEST: {Monitor.METRICS: {}, Monitor.LOSS: {}}}}
162 |
163 | handler_mock(TemporalEvent(Event.ON_EPOCH_END, moment), monitors, self._trainer_mock)
164 |
165 | monitors = {self.MODEL_NAME: {Phase.TRAINING: {Monitor.METRICS: {}, Monitor.LOSS: {}},
166 | Phase.VALIDATION: {Monitor.METRICS: {"Accuracy": torch.tensor([0.4])},
167 | Monitor.LOSS: {}},
168 | Phase.TEST: {Monitor.METRICS: {}, Monitor.LOSS: {}}}}
169 |
170 | handler_mock(TemporalEvent(Event.ON_EPOCH_END, moment), monitors, self._trainer_mock)
171 |
172 | assert_that(not os.path.exists(os.path.join(self.SAVE_PATH, self.MODEL_NAME, self.MODEL_NAME + ".tar")))
173 |
--------------------------------------------------------------------------------
/kerosene/events/handlers/visdom.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2019 Kerosene Authors. All Rights Reserved.
3 | #
4 | # Licensed under the MIT License;
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://opensource.org/licenses/MIT
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | import logging
17 | from abc import ABC
18 | from typing import Union, List
19 |
20 | from kerosene.events import TemporalEvent, Monitor, BaseEvent
21 | from kerosene.events.handlers.base_handler import EventHandler
22 | from kerosene.loggers.visdom import PlotType
23 | from kerosene.loggers.visdom.visdom import VisdomLogger, VisdomData
24 | from kerosene.training.events import Event
25 | from kerosene.training.trainers import Trainer, ModelTrainer
26 |
27 |
28 | class BaseVisdomHandler(EventHandler, ABC):
29 | def __init__(self, supported_events: List[Union[BaseEvent, TemporalEvent]], visdom_logger: VisdomLogger, every=1):
30 | super().__init__(supported_events, every)
31 | self._visdom_logger = visdom_logger
32 |
33 | @property
34 | def visdom_logger(self):
35 | return self._visdom_logger
36 |
37 | def create_visdom_data(self, *args, **kwargs):
38 | raise NotImplementedError()
39 |
40 | def flatten(self, list_of_visdom_data):
41 | return [item for sublist in list_of_visdom_data for item in sublist]
42 |
43 |
44 | class PlotMonitors(BaseVisdomHandler):
45 | SUPPORTED_EVENTS = [Event.ON_EPOCH_END, Event.ON_TRAIN_BATCH_END, Event.ON_VALID_BATCH_END, Event.ON_TEST_BATCH_END,
46 | Event.ON_BATCH_END]
47 |
48 | def __init__(self, visdom_logger: VisdomLogger, every=1):
49 | super().__init__(self.SUPPORTED_EVENTS, visdom_logger, every)
50 | self._plot_losses = PlotLosses(visdom_logger, every)
51 | self._plot_metrics = PlotMetrics(visdom_logger, every)
52 |
53 | def __call__(self, event: TemporalEvent, monitors: dict, trainer: Trainer):
54 | if self.should_handle(event):
55 | self._plot_losses(event, monitors, trainer)
56 | self._plot_metrics(event, monitors, trainer)
57 |
58 | def create_visdom_data(self, event, model_name, monitors):
59 | data = self._plot_losses.create_visdom_data(event, model_name, monitors)
60 | data.extend(self._plot_metrics.create_visdom_data(event, model_name, monitors))
61 |
62 |
63 | class PlotLosses(BaseVisdomHandler):
64 | SUPPORTED_EVENTS = [Event.ON_EPOCH_END, Event.ON_TRAIN_BATCH_END, Event.ON_VALID_BATCH_END, Event.ON_TEST_BATCH_END,
65 | Event.ON_BATCH_END]
66 |
67 | def __init__(self, visdom_logger: VisdomLogger, every=1):
68 | super().__init__(self.SUPPORTED_EVENTS, visdom_logger, every)
69 |
70 | def __call__(self, event: TemporalEvent, monitors: dict, trainer: Trainer):
71 | data = list()
72 |
73 | if self.should_handle(event):
74 | for model_name, monitor in monitors.items():
75 | data.extend(self.create_visdom_data(event, model_name, monitor[event.phase][Monitor.LOSS]))
76 |
77 | if data is not None:
78 | self.visdom_logger(data)
79 |
80 | def create_visdom_data(self, event, model_name, monitors):
81 | return [VisdomData(model_name, loss_name, PlotType.LINE_PLOT, event.frequency, [[event.iteration]],
82 | [[loss_value]], params={'opts': {'xlabel': str(event.frequency), 'ylabel': loss_name,
83 | 'title': "{} {} per {}".format(model_name, loss_name,
84 | str(event.frequency)),
85 | 'name': str(event.phase),
86 | 'legend': [str(event.phase)]}})
87 | for loss_name, loss_value in monitors.items()]
88 |
89 |
90 | class PlotMetrics(BaseVisdomHandler):
91 | SUPPORTED_EVENTS = [Event.ON_EPOCH_END, Event.ON_TRAIN_BATCH_END, Event.ON_VALID_BATCH_END, Event.ON_TEST_BATCH_END,
92 | Event.ON_BATCH_END]
93 |
94 | def __init__(self, visdom_logger: VisdomLogger, every=1):
95 | super().__init__(self.SUPPORTED_EVENTS, visdom_logger, every)
96 |
97 | def __call__(self, event: TemporalEvent, monitors: dict, trainer: Trainer):
98 | data = list()
99 |
100 | if self.should_handle(event):
101 | for model_name, monitor in monitors.items():
102 | data.extend(self.create_visdom_data(event, model_name, monitor[event.phase][Monitor.METRICS]))
103 |
104 | if data is not None:
105 | self.visdom_logger(data)
106 |
107 | def create_visdom_data(self, event, model_name, monitors):
108 | return [VisdomData(model_name, metric_name, PlotType.LINE_PLOT, event.frequency, [[event.iteration]],
109 | [[metric_value]], params={'opts': {'xlabel': str(event.frequency), 'ylabel': metric_name,
110 | 'title': "{} {} per {}".format(model_name, metric_name,
111 | str(event.frequency)),
112 | 'name': str(event.phase), 'legend': [str(event.phase)]}})
113 | for metric_name, metric_value in monitors.items()]
114 |
115 |
116 | class PlotCustomVariables(BaseVisdomHandler):
117 | LOGGER = logging.getLogger("PlotCustomVariables")
118 | SUPPORTED_EVENTS = [Event.ON_EPOCH_END, Event.ON_TRAIN_EPOCH_END, Event.ON_VALID_EPOCH_END, Event.ON_TEST_EPOCH_END,
119 | Event.ON_TRAIN_BATCH_END, Event.ON_VALID_BATCH_END, Event.ON_TEST_BATCH_END, Event.ON_BATCH_END]
120 |
121 | def __init__(self, visdom_logger: VisdomLogger, variable_name, plot_type: PlotType, params, every=1):
122 | super().__init__(self.SUPPORTED_EVENTS, visdom_logger, every)
123 | self._variable_name = variable_name
124 | self._plot_type = plot_type
125 | self._params = params
126 |
127 | def __call__(self, event: TemporalEvent, monitors: dict, trainer: Trainer):
128 | data = None
129 |
130 | if self.should_handle(event):
131 | try:
132 | data = self.create_visdom_data(event, trainer)
133 | except KeyError:
134 | self.LOGGER.warning(
135 | "Unable to plot custom variable: {}. The variable is not in the trainer's custom variable dict.".format(
136 | self._variable_name))
137 |
138 | if data is not None:
139 | self.visdom_logger(data)
140 |
141 | def create_visdom_data(self, event: TemporalEvent, trainer):
142 | if self._plot_type == PlotType.LINE_PLOT and "name" not in self._params['opts'].keys():
143 | self._params['opts']['name'] = str(event.phase)
144 |
145 | return [VisdomData(trainer.name, self._variable_name, self._plot_type, event.frequency,
146 | [event.iteration], trainer.custom_variables[self._variable_name], self._params)]
147 |
148 |
149 | class PlotLR(BaseVisdomHandler):
150 | SUPPORTED_EVENTS = [Event.ON_EPOCH_END, Event.ON_TRAIN_BATCH_END, Event.ON_VALID_BATCH_END, Event.ON_TEST_BATCH_END,
151 | Event.ON_BATCH_END]
152 |
153 | def __init__(self, visdom_logger: VisdomLogger, every=1):
154 | super().__init__(self.SUPPORTED_EVENTS, visdom_logger, every)
155 |
156 | def __call__(self, event: TemporalEvent, monitors: dict, trainer: Trainer):
157 | data = None
158 |
159 | if self.should_handle(event):
160 | data = list(map(
161 | lambda model_trainer: self.create_visdom_data(event, model_trainer), trainer.model_trainers))
162 |
163 | if data is not None:
164 | self.visdom_logger(data)
165 |
166 | def create_visdom_data(self, event, model_trainer: ModelTrainer):
167 | return VisdomData(model_trainer.name, "Learning Rate", PlotType.LINE_PLOT, event.frequency, [event.iteration],
168 | model_trainer.optimizer_lr,
169 | params={'opts': {'xlabel': str(event.frequency),
170 | 'ylabel': "Learning Rate",
171 | 'title': "{} {} per {}".format(model_trainer.name, "Learning Rate",
172 | str(event.frequency)),
173 | 'name': model_trainer.name,
174 | 'legend': [model_trainer.name]}})
175 |
176 |
177 | class PlotAvgGradientPerLayer(BaseVisdomHandler):
178 | SUPPORTED_EVENTS = [Event.ON_EPOCH_END, Event.ON_TRAIN_BATCH_END, Event.ON_VALID_BATCH_END, Event.ON_TEST_BATCH_END,
179 | Event.ON_BATCH_END]
180 |
181 | def __init__(self, visdom_logger: VisdomLogger, every=1):
182 | super().__init__(self.SUPPORTED_EVENTS, visdom_logger, every)
183 |
184 | def __call__(self, event: TemporalEvent, monitors: dict, trainer: Trainer):
185 | data = None
186 |
187 | if self.should_handle(event):
188 | data = list(
189 | map(lambda model_trainer: self.create_visdom_data(event, model_trainer), trainer.model_trainers))
190 |
191 | if data is not None:
192 | self.visdom_logger(data)
193 |
194 | def create_visdom_data(self, event: TemporalEvent, model_trainer: ModelTrainer):
195 | avg_grads, layers = [], []
196 |
197 | for n, p in model_trainer.named_parameters():
198 | if p.requires_grad and ("bias" not in n):
199 | layers.append(n)
200 | if p.grad is not None:
201 | avg_grads.append(p.grad.abs().mean().item())
202 | else:
203 | avg_grads.append(0)
204 |
205 | return VisdomData(model_trainer.name, "Gradient Flow", PlotType.BAR_PLOT, event.frequency, y=layers,
206 | x=avg_grads, params={'opts': {'xlabel': "Layers", 'ylabel': "Avg. Gradients",
207 | 'title': "{} {} per {}".format(model_trainer.name,
208 | "Avg. Gradient", "Layer"),
209 | 'marginbottom': 200}})
210 |
211 |
212 | class SaveVisdomEnvToFile(BaseVisdomHandler):
213 | SUPPORTED_EVENTS = [Event.ON_EPOCH_END, Event.ON_TRAIN_EPOCH_END, Event.ON_VALID_EPOCH_END, Event.ON_TEST_EPOCH_END,
214 | Event.ON_FINALIZE]
215 |
216 | def __init__(self, visdom_logger: VisdomLogger, params, every=1):
217 | super().__init__(self.SUPPORTED_EVENTS, visdom_logger, every)
218 | self._params = params
219 |
220 | def create_visdom_data(self, *args, **kwargs):
221 | pass
222 |
223 | def __call__(self, event: TemporalEvent, monitors: dict, trainer: Trainer):
224 | if self.should_handle(event):
225 | self._visdom_logger.save()
226 |
--------------------------------------------------------------------------------
/tests/nn/test_criterions.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import numpy as np
4 | import torch
5 | from hamcrest import *
6 |
7 | from kerosene.nn.criterions import DiceLoss, GeneralizedDiceLoss, WeightedCrossEntropyLoss, TverskyLoss
8 | from kerosene.utils.tensors import to_onehot
9 |
10 |
11 | def get_y_true_y_pred():
12 | # Generate an image with labels 0 (background), 1, 2
13 | # 3 classes:
14 | y_true = np.zeros((30, 30), dtype=np.int)
15 | y_true[1:11, 1:11] = 1
16 | y_true[15:25, 15:25] = 2
17 |
18 | y_pred = np.zeros((30, 30), dtype=np.int)
19 | y_pred[5:15, 1:11] = 1
20 | y_pred[20:30, 20:30] = 2
21 | return y_true, y_pred
22 |
23 |
24 | def compute_tensor_y_true_y_logits(y_true, y_pred):
25 | # Create torch.tensor from numpy
26 | y_true_tensor = torch.from_numpy(y_true).unsqueeze(0).type(torch.long)
27 | # Create logits torch.tensor:
28 | num_classes = max(np.max(y_true), np.max(y_pred)) + 1
29 | y_probas = np.ones((num_classes,) + y_true.shape) * 0.0
30 | for i in range(num_classes):
31 | y_probas[i, (y_pred == i)] = 1.0
32 | y_logits = torch.from_numpy(y_probas).unsqueeze(0).type(torch.float32)
33 | return y_true_tensor, y_logits
34 |
35 |
36 | def compute_dice_truth(y_true, y_pred):
37 | true_res = [0, 0, 0]
38 | for index in range(3):
39 | bin_y_true = y_true == index
40 | bin_y_pred = y_pred == index
41 | intersection = bin_y_true & bin_y_pred
42 | true_res[index] = 2.0 * intersection.sum() / (bin_y_pred.sum() + bin_y_true.sum())
43 | return true_res
44 |
45 |
46 | def compute_generalized_dice_loss_truth(y_true, y_pred):
47 | true_res = [0, 0, 0]
48 | for index in range(3):
49 | bin_y_true = y_true == index
50 | bin_y_pred = y_pred == index
51 | weights = (1.0 / (np.sum(bin_y_true) * np.sum(bin_y_true) + 1e-15))
52 | intersection = (bin_y_true & bin_y_pred)
53 | true_res[index] = 2 * intersection.sum() * weights / (((bin_y_pred.sum() + bin_y_true.sum()) * weights) + 1e-15)
54 | return true_res
55 |
56 |
57 | class TestDiceLoss(unittest.TestCase):
58 | INVALID_VALUE_1 = -1
59 | INVALID_VALUE_2 = "STEVE JOBS"
60 | INVALID_VALUE_3 = 10
61 | INVALID_VALUE_4 = 11
62 |
63 | def setUp(self):
64 | self.y_true, self.y_pred = get_y_true_y_pred()
65 | self.y_true_tensor, self.y_logits = compute_tensor_y_true_y_logits(self.y_true, self.y_pred)
66 | self.dice = compute_dice_truth(self.y_true, self.y_pred)
67 | self.mean_dice_loss = np.subtract(1.0, np.mean(self.dice))
68 |
69 | def test_should_raise_exception_with_bad_values(self):
70 | dice_loss = DiceLoss()
71 | assert_that(calling(dice_loss.forward).with_args(inputs=None, targets=None),
72 | raises(AttributeError))
73 | assert_that(calling(dice_loss.forward).with_args(inputs=self.y_logits, targets=None),
74 | raises(AttributeError))
75 | assert_that(calling(dice_loss.forward).with_args(inputs=None, targets=self.y_true_tensor),
76 | raises(AttributeError))
77 |
78 | def test_should_compute_dice(self):
79 | dice_loss = DiceLoss(reduction=None)
80 | loss = dice_loss.forward(self.y_logits, to_onehot(self.y_true_tensor, num_classes=3))
81 |
82 | np.testing.assert_almost_equal(loss.numpy(), np.subtract(1.0, self.dice))
83 |
84 | def test_should_compute_dice_for_multiclass_with_ignored_index(self):
85 | for ignore_index in range(3):
86 | dice_loss = DiceLoss(reduction=None, ignore_index=ignore_index)
87 | res = dice_loss.forward(self.y_logits, to_onehot(self.y_true_tensor, num_classes=3))
88 | true_res = np.subtract(1.0, self.dice[:ignore_index] + self.dice[ignore_index + 1:])
89 | np.testing.assert_almost_equal(res.numpy(), true_res), "{}: {} vs {}".format(ignore_index, res, true_res)
90 |
91 | def test_should_compute_mean_dice(self):
92 | dice_loss = DiceLoss(reduction="mean")
93 | loss = dice_loss.forward(self.y_logits, to_onehot(self.y_true_tensor, num_classes=3))
94 |
95 | np.testing.assert_almost_equal(loss.numpy(), self.mean_dice_loss)
96 |
97 | def test_should_compute_mean_dice_for_multiclass_with_ignored_index(self):
98 | for ignore_index in range(3):
99 | dice_loss = DiceLoss(ignore_index=ignore_index)
100 | res = dice_loss.forward(self.y_logits, to_onehot(self.y_true_tensor, num_classes=3))
101 | true_res = np.subtract(1.0, self.dice[:ignore_index] + self.dice[ignore_index + 1:]).mean()
102 | np.testing.assert_almost_equal(res.numpy(), true_res), "{}: {} vs {}".format(ignore_index, res, true_res)
103 |
104 |
105 | class TestGeneralizedDiceLoss(unittest.TestCase):
106 | INVALID_REDUCTION = "sum"
107 | INVALID_INDEX = -1
108 |
109 | def setUp(self):
110 | self.y_true, self.y_pred = get_y_true_y_pred()
111 | self.y_true_tensor, self.y_logits = compute_tensor_y_true_y_logits(self.y_true, self.y_pred)
112 | self.generalized_dice_loss = compute_generalized_dice_loss_truth(self.y_true, self.y_pred)
113 | self.mean_generalized_dice_loss = np.subtract(1.0, np.mean(self.generalized_dice_loss))
114 |
115 | def test_should_raise_exception_with_bad_values(self):
116 | generalized_dice_loss = GeneralizedDiceLoss()
117 | assert_that(calling(GeneralizedDiceLoss).with_args(reduction=self.INVALID_REDUCTION),
118 | raises(NotImplementedError))
119 | assert_that(calling(generalized_dice_loss.forward).with_args(inputs=None, targets=None),
120 | raises(AttributeError))
121 | assert_that(calling(generalized_dice_loss.forward).with_args(inputs=self.y_logits, targets=None),
122 | raises(AttributeError))
123 | assert_that(calling(generalized_dice_loss.forward).with_args(inputs=None, targets=self.y_true_tensor),
124 | raises(AttributeError))
125 |
126 | def test_should_raise_exception_with_bad_ignore_index_values(self):
127 | generalized_dice_loss = GeneralizedDiceLoss(ignore_index=self.INVALID_INDEX)
128 |
129 | assert_that(calling(generalized_dice_loss.forward).with_args(inputs=self.y_logits,
130 | targets=to_onehot(self.y_true_tensor,
131 | num_classes=3)),
132 | raises(IndexError))
133 |
134 | def test_should_compute_generalized_dice(self):
135 | generalized_dice_loss = GeneralizedDiceLoss()
136 | loss = generalized_dice_loss.forward(self.y_logits, to_onehot(self.y_true_tensor, num_classes=3))
137 | np.testing.assert_almost_equal(loss.numpy(), self.mean_generalized_dice_loss)
138 |
139 | def test_should_compute_generalized_dice_for_multiclass_with_ignored_index(self):
140 | for ignore_index in range(3):
141 | generalized_dice_loss = GeneralizedDiceLoss(reduction=None, ignore_index=ignore_index)
142 | res = generalized_dice_loss.forward(self.y_logits, to_onehot(self.y_true_tensor, num_classes=3))
143 | true_res = np.subtract(1.0, self.generalized_dice_loss[:ignore_index] + self.generalized_dice_loss[
144 | ignore_index + 1:])
145 | np.testing.assert_almost_equal(res.numpy(), true_res), "{}: {} vs {}".format(ignore_index, res, true_res)
146 |
147 | def test_should_compute_mean_generalized_dice(self):
148 | dice_loss = GeneralizedDiceLoss()
149 | loss = dice_loss.forward(self.y_logits, to_onehot(self.y_true_tensor, num_classes=3))
150 |
151 | np.testing.assert_almost_equal(loss.numpy(), self.mean_generalized_dice_loss)
152 |
153 | def test_should_compute_mean_generalized_dice_for_multiclass_with_ignored_index(self):
154 | for ignore_index in range(3):
155 | dice_loss = GeneralizedDiceLoss(ignore_index=ignore_index)
156 | res = dice_loss.forward(self.y_logits, to_onehot(self.y_true_tensor, num_classes=3))
157 | true_res = np.subtract(1.0, self.generalized_dice_loss[:ignore_index] + self.generalized_dice_loss[
158 | ignore_index + 1:]).mean()
159 | np.testing.assert_almost_equal(res.numpy(), true_res), "{}: {} vs {}".format(ignore_index, res, true_res)
160 |
161 |
162 | class TestWeightedCrossEntropy(unittest.TestCase):
163 | WEIGHTED_CROSS_ENTROPY_LOSS_TRUTH = 1.0808
164 |
165 | def setUp(self):
166 | self.y_true, self.y_pred = get_y_true_y_pred()
167 | self.y_true_tensor, self.y_logits = compute_tensor_y_true_y_logits(self.y_true, self.y_pred)
168 |
169 | def test_should_raise_exception_with_bad_values(self):
170 | weighted_cross_entropy_loss = WeightedCrossEntropyLoss()
171 | assert_that(calling(weighted_cross_entropy_loss.forward).with_args(inputs=None, targets=None),
172 | raises(AttributeError))
173 | assert_that(calling(weighted_cross_entropy_loss.forward).with_args(inputs=self.y_logits, targets=None),
174 | raises(AttributeError))
175 | assert_that(calling(weighted_cross_entropy_loss.forward).with_args(inputs=None, targets=self.y_true_tensor),
176 | raises(AttributeError))
177 |
178 | def test_should_compute_weights(self):
179 | weights = WeightedCrossEntropyLoss.compute_class_weights(self.y_logits)
180 | np.testing.assert_almost_equal(weights.numpy(), np.array([0.2857143, 8.0, 8.0]), decimal=7)
181 |
182 | def test_should_return_loss(self):
183 | weighted_cross_entropy_loss = WeightedCrossEntropyLoss()
184 | loss = weighted_cross_entropy_loss.forward(self.y_logits, self.y_true_tensor)
185 | np.testing.assert_almost_equal(loss.numpy(), self.WEIGHTED_CROSS_ENTROPY_LOSS_TRUTH)
186 |
187 |
188 | class TestTverskyLoss(unittest.TestCase):
189 | INVALID_VALUE_1 = -1
190 | INVALID_VALUE_2 = "STEVE JOBS"
191 | INVALID_VALUE_3 = 10
192 | INVALID_VALUE_4 = 11
193 |
194 | def setUp(self):
195 | self.y_true, self.y_pred = get_y_true_y_pred()
196 | self.y_true_tensor, self.y_logits = compute_tensor_y_true_y_logits(self.y_true, self.y_pred)
197 | self.dice = compute_dice_truth(self.y_true, self.y_pred)
198 | self.mean_dice_loss = np.subtract(1.0, np.mean(self.dice))
199 |
200 | def test_should_raise_exception_with_bad_values(self):
201 | tversky_loss = TverskyLoss()
202 | assert_that(calling(tversky_loss.forward).with_args(inputs=None, targets=None),
203 | raises(AttributeError))
204 | assert_that(calling(tversky_loss.forward).with_args(inputs=self.y_logits, targets=None),
205 | raises(AttributeError))
206 | assert_that(calling(tversky_loss.forward).with_args(inputs=None, targets=self.y_true_tensor),
207 | raises(AttributeError))
208 |
209 | def test_should_compute_tversky_index(self):
210 | tversky_loss = TverskyLoss(reduction=None)
211 | loss = tversky_loss.forward(self.y_logits, to_onehot(self.y_true_tensor, num_classes=3))
212 |
213 | np.testing.assert_almost_equal(loss.numpy(), np.subtract(1.0, self.dice))
214 |
215 | def test_should_compute_dice_for_multiclass_with_ignored_index(self):
216 | for ignore_index in range(3):
217 | tversky_loss = TverskyLoss(reduction=None, ignore_index=ignore_index)
218 | res = tversky_loss.forward(self.y_logits, to_onehot(self.y_true_tensor, num_classes=3))
219 | true_res = np.subtract(1.0, self.dice[:ignore_index] + self.dice[ignore_index + 1:])
220 | np.testing.assert_almost_equal(res.numpy(), true_res), "{}: {} vs {}".format(ignore_index, res, true_res)
221 |
222 | def test_should_compute_mean_dice(self):
223 | tversky_loss = TverskyLoss(reduction="mean")
224 | loss = tversky_loss.forward(self.y_logits, to_onehot(self.y_true_tensor, num_classes=3))
225 |
226 | np.testing.assert_almost_equal(loss.numpy(), self.mean_dice_loss)
227 |
228 | def test_should_compute_mean_dice_for_multiclass_with_ignored_index(self):
229 | for ignore_index in range(3):
230 | tversky_loss = TverskyLoss(ignore_index=ignore_index)
231 | res = tversky_loss.forward(self.y_logits, to_onehot(self.y_true_tensor, num_classes=3))
232 | true_res = np.subtract(1.0, self.dice[:ignore_index] + self.dice[ignore_index + 1:]).mean()
233 | np.testing.assert_almost_equal(res.numpy(), true_res), "{}: {} vs {}".format(ignore_index, res, true_res)
234 |
--------------------------------------------------------------------------------
/kerosene/metrics/metrics.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2019 Kerosene Authors. All Rights Reserved.
3 | #
4 | # Licensed under the MIT License;
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://opensource.org/licenses/MIT
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | from enum import Enum
17 | from typing import Union, Tuple
18 |
19 | import torch
20 | from ignite.metrics import Accuracy, Precision, MeanAbsoluteError, MeanPairwiseDistance, MeanSquaredError, \
21 | Recall, RootMeanSquaredError, TopKCategoricalAccuracy, Metric, IoU, ConfusionMatrix, MetricsLambda
22 |
23 | from kerosene.utils.constants import EPSILON
24 | from kerosene.utils.tensors import flatten, to_onehot
25 |
26 |
27 | class MetricType(Enum):
28 | Accuracy = "Accuracy"
29 | ConfusionMatrix = "ConfusionMatrix"
30 | DiceCoefficient = "Dice"
31 | GeneralizedDice = "GeneralizedDice"
32 | IoU = "IoU"
33 | MeanAbsoluteError = "MeanAbsoluteError"
34 | MeanPairwiseDistance = "MeanPairwiseDistance"
35 | MeanSquaredError = "MeanSquaredError"
36 | Precision = "Precision"
37 | Recall = "Recall"
38 | RootMeanSquaredError = "RootMeanSquaredError"
39 | TopKCategoricalAccuracy = "TopKCategoricalAccuracy"
40 | VariableAccumulation = "VariableAccumulation"
41 |
42 | def __str__(self):
43 | return self.value
44 |
45 |
46 | class MetricFactory(object):
47 | def __init__(self):
48 | self._metrics = {
49 | "Accuracy": Accuracy,
50 | "ConfusionMatrix": ConfusionMatrix,
51 | "Dice": Dice,
52 | "GeneralizedDice": GeneralizedDice,
53 | "IoU": IoU_,
54 | "MeanAbsoluteError": MeanAbsoluteError,
55 | "MeanPairwiseDistance": MeanPairwiseDistance,
56 | "MeanSquaredError": MeanSquaredError,
57 | "Precision": Precision,
58 | "Recall": Recall,
59 | "RootMeanSquaredError": RootMeanSquaredError,
60 | "TopKCategoricalAccuracy": TopKCategoricalAccuracy
61 | }
62 |
63 | def create(self, metric_type: Union[str, MetricType], params):
64 | return self._metrics[str(metric_type)](**params) if params is not None else self._metrics[str(metric_type)]()
65 |
66 | def register(self, metric: str, creator: Metric):
67 | """
68 | Add a new metric.
69 | Args:
70 | metric (str): Metric's name.
71 | creator: A torch or ignite module object wrapping the new custom metric function.
72 | """
73 | self._metrics[metric] = creator
74 |
75 |
76 | class Dice(Metric):
77 | """
78 | The Dice Metric.
79 | """
80 | SUPPORTED_REDUCTIONS = [None, "mean"]
81 |
82 | def __init__(self, num_classes: int, reduction: Union[None, str] = "mean", average: str = None,
83 | weight: torch.Tensor = None, ignore_index: int = -100,
84 | output_transform: callable = lambda x: x) -> None:
85 | """
86 | Metric initializer.
87 | Args:
88 | num_classes (int): The number of classes in the problem. In case of images, num_classes should also count the background index 0.
89 | average (str, optional): Confusion matrix values averaging schema: None, "samples", "recall", "precision".
90 | Default is None. If `average="samples"` then confusion matrix values are normalized by the number of seen
91 | samples. If `average="recall"` then confusion matrix values are normalized such that diagonal values
92 | represent class recalls. If `average="precision"` then confusion matrix values are normalized such that
93 | diagonal values represent class precisions.
94 | reduction (str): The type of reduction to apply (e.g. 'mean').
95 | ignore_index (int, optional): To ignore an index in Dice computation.
96 | output_transform (callable, optional): a callable that is used to transform the
97 | output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and
98 | you want to compute the metric with respect to one of the outputs.
99 | """
100 | if reduction not in self.SUPPORTED_REDUCTIONS:
101 | raise NotImplementedError("Reduction type not supported.")
102 | self._num_classes = num_classes
103 | self._ignore_index = ignore_index
104 | self._reduction = reduction
105 | self._weight = weight
106 | self._cm = ConfusionMatrix(num_classes=num_classes, average=average, output_transform=output_transform)
107 | self._metric = self.create_dice_metric(self._cm)
108 | super(Dice, self).__init__(output_transform=output_transform)
109 |
110 | def reset(self) -> None:
111 | """
112 | Reset the confusion matrix object.
113 | """
114 | self._cm.confusion_matrix = torch.zeros(self._num_classes, self._num_classes, dtype=torch.float64)
115 |
116 | def compute(self) -> torch.Tensor:
117 | """
118 | Compute the metric.
119 | Returns:
120 | :obj:`torch.Tensor`: The dice coefficient for each class.
121 | """
122 | return self._metric.compute()
123 |
124 | def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None:
125 | """
126 | Update the confusion matrix with output values.
127 | Args:
128 | output (tuple of :obj:`torch.Tensor`): A tuple containing predictions and ground truth of the form `(y_pred, y)`.
129 | """
130 | self._cm.update(output)
131 |
132 | def create_dice_metric(self, cm: ConfusionMatrix):
133 | """
134 | Computes the Sørensen–Dice Coefficient (https://en.wikipedia.org/wiki/Sørensen–Dice_coefficient)
135 | Args:
136 | cm (:obj:`ignite.metrics.ConfusionMatrix`): A confusion matrix representing the classification of data.
137 | Returns:
138 | array or float: The Sørensen–Dice Coefficient for each class or the mean Sørensen–Dice Coefficient.
139 | """
140 | # Increase floating point precision
141 | cm = cm.type(torch.float64)
142 | dice = 2 * cm.diag() / (cm.sum(dim=1) + cm.sum(dim=0) + EPSILON)
143 |
144 | if self._ignore_index != -100:
145 | def remove_index(dice_vector):
146 | try:
147 | indices = list(range(len(dice_vector)))
148 | indices.remove(self._ignore_index)
149 | return dice_vector[indices]
150 | except ValueError as e:
151 | raise IndexError(
152 | "'ignore_index' must be non-negative, and lower than the number of classes in confusion matrix, but {} was given. ".format(
153 | self._ignore_index))
154 |
155 | dice = MetricsLambda(remove_index, dice)
156 |
157 | if self._weight is not None:
158 | def multiply_weights(dice_vector):
159 | return self._weight * dice_vector
160 |
161 | dice = MetricsLambda(multiply_weights, dice)
162 |
163 | if self._reduction == "mean":
164 | dice = dice.mean()
165 |
166 | return dice
167 |
168 |
169 | class GeneralizedDice(Metric):
170 | """
171 | The Generalized Dice Metric.
172 | """
173 | SUPPORTED_REDUCTIONS = [None, "mean"]
174 |
175 | def __init__(self, num_classes: int, reduction: str = None, average: str = None,
176 | ignore_index: int = -100, output_transform: callable = lambda x: x) -> None:
177 | """
178 | Metric initializer.
179 | Args:
180 | num_classes (int): The number of classes in the problem. In case of images, num_classes should also count the background index 0.
181 | average (str, optional): Confusion matrix values averaging schema: None, "samples", "recall", "precision".
182 | Default is None. If `average="samples"` then confusion matrix values are normalized by the number of seen
183 | samples. If `average="recall"` then confusion matrix values are normalized such that diagonal values
184 | represent class recalls. If `average="precision"` then confusion matrix values are normalized such that
185 | diagonal values represent class precisions.
186 | reduction (str): The type of reduction to apply (e.g. 'mean').
187 | ignore_index (int, optional): To ignore an index in Dice computation.
188 | output_transform (callable, optional): a callable that is used to transform the
189 | output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and
190 | you want to compute the metric with respect to one of the outputs.
191 | """
192 | if reduction not in self.SUPPORTED_REDUCTIONS:
193 | raise NotImplementedError("Reduction type not supported.")
194 | self._num_classes = num_classes
195 | self._average = average
196 | self._ignore_index = ignore_index
197 | self._metric = None
198 | self._reduction = reduction
199 | self._cm = ConfusionMatrix(num_classes=num_classes, average=average,
200 | output_transform=output_transform)
201 | super(GeneralizedDice, self).__init__(output_transform=output_transform)
202 |
203 | def reset(self) -> None:
204 | """
205 | Reset the confusion matrix object.
206 | """
207 | self._cm.confusion_matrix = torch.zeros(self._num_classes, self._num_classes, dtype=torch.float64)
208 |
209 | def compute(self) -> torch.Tensor:
210 | """
211 | Compute the metric.
212 | Returns:
213 | :obj:`torch.Tensor`: The dice coefficient for each class.
214 | """
215 | return self._metric.compute()
216 |
217 | def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None:
218 | """
219 | Update the confusion matrix with output values.
220 | Args:
221 | output (tuple of :obj:`torch.Tensor`): A tuple containing predictions and ground truth of the form `(y_pred, y)`.
222 | """
223 | flattened_targets = flatten(to_onehot(output[1], output[0].size(1))).double()
224 | ones = torch.Tensor().new_ones((output[0].size(1),), dtype=torch.double, device=flattened_targets.device)
225 | weights = ones / torch.pow(flattened_targets.sum(-1), 2).clamp(min=EPSILON)
226 |
227 | self._metric = self.create_generalized_dice_metric(self._cm, weights)
228 |
229 | if self._reduction == "mean":
230 | self._metric = self._metric.mean()
231 | elif self._reduction is None:
232 | pass
233 | else:
234 | raise NotImplementedError("Reduction method not implemented.")
235 |
236 | self._cm.update(output)
237 |
238 | def create_generalized_dice_metric(self, cm: ConfusionMatrix, weight: torch.Tensor):
239 | """
240 | Computes the Sørensen–Dice Coefficient (https://en.wikipedia.org/wiki/Sørensen–Dice_coefficient)
241 | Args:
242 | cm (:obj:`ignite.metrics.ConfusionMatrix`): A confusion matrix representing the classification of data.
243 | weight (:obj:`torch.Tensor`): A weight vector which length equals to the number of classes.
244 | Returns:
245 | ignite.Metric: The Generalized Dice Coefficient Metric object.
246 | """
247 |
248 | # Increase floating point precision
249 | cm = cm.type(torch.float64)
250 | dice = 2 * (cm.diag() * weight) / (((cm.sum(dim=1) + cm.sum(dim=0)) * weight) + EPSILON)
251 |
252 | if self._ignore_index != -100:
253 | def remove_index(dice_vector):
254 | try:
255 | indices = list(range(len(dice_vector)))
256 | indices.remove(self._ignore_index)
257 | return dice_vector[indices]
258 | except ValueError as e:
259 | raise IndexError(
260 | "'ignore_index' must be non-negative, and lower than the number of classes in confusion matrix, but {} was given. ".format(
261 | self._ignore_index))
262 |
263 | return MetricsLambda(remove_index, dice)
264 | else:
265 | return dice
266 |
267 |
268 | class IoU_(Metric):
269 | """
270 | The IoU Metric.
271 | """
272 | SUPPORTED_REDUCTIONS = [None, "mean"]
273 |
274 | def __init__(self, num_classes: int, reduction: Union[None, str] = "mean", average: str = None,
275 | weight: torch.Tensor = None, ignore_index: int = -100,
276 | output_transform: callable = lambda x: x) -> None:
277 | """
278 | Metric initializer.
279 | Args:
280 | num_classes (int): The number of classes in the problem. In case of images, num_classes should also count the background index 0.
281 | average (str, optional): Confusion matrix values averaging schema: None, "samples", "recall", "precision".
282 | Default is None. If `average="samples"` then confusion matrix values are normalized by the number of seen
283 | samples. If `average="recall"` then confusion matrix values are normalized such that diagonal values
284 | represent class recalls. If `average="precision"` then confusion matrix values are normalized such that
285 | diagonal values represent class precisions.
286 | reduction (str): The type of reduction to apply (e.g. 'mean').
287 | ignore_index (int, optional): To ignore an index in Dice computation.
288 | output_transform (callable, optional): a callable that is used to transform the
289 | output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and
290 | you want to compute the metric with respect to one of the outputs.
291 | """
292 | if reduction not in self.SUPPORTED_REDUCTIONS:
293 | raise NotImplementedError("Reduction type not supported.")
294 | self._num_classes = num_classes
295 | self._ignore_index = ignore_index
296 | self._reduction = reduction
297 | self._weight = weight
298 | self._cm = ConfusionMatrix(num_classes=num_classes, average=average, output_transform=output_transform)
299 | self._metric = self.create_iou_metric(self._cm)
300 | super(IoU_, self).__init__(output_transform=output_transform)
301 |
302 | def reset(self) -> None:
303 | """
304 | Reset the confusion matrix object.
305 | """
306 | self._cm.confusion_matrix = torch.zeros(self._num_classes, self._num_classes, dtype=torch.float64)
307 |
308 | def compute(self) -> torch.Tensor:
309 | """
310 | Compute the metric.
311 | Returns:
312 | :obj:`torch.Tensor`: The dice coefficient for each class.
313 | """
314 | return self._metric.compute()
315 |
316 | def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None:
317 | """
318 | Update the confusion matrix with output values.
319 | Args:
320 | output (tuple of :obj:`torch.Tensor`): A tuple containing predictions and ground truth of the form `(y_pred, y)`.
321 | """
322 | self._cm.update(output)
323 |
324 | def create_iou_metric(self, cm: ConfusionMatrix):
325 | """
326 | Computes the Sørensen–Dice Coefficient (https://en.wikipedia.org/wiki/Sørensen–Dice_coefficient)
327 | Args:
328 | cm (:obj:`ignite.metrics.ConfusionMatrix`): A confusion matrix representing the classification of data.
329 | Returns:
330 | array or float: The Sørensen–Dice Coefficient for each class or the mean Sørensen–Dice Coefficient.
331 | """
332 |
333 | metric = IoU(cm, ignore_index=self._ignore_index)
334 |
335 | if self._reduction == "mean":
336 | metric = metric.mean()
337 | return metric
338 |
339 |
340 | def to_tensor(value):
341 | if not isinstance(value, torch.Tensor):
342 | return torch.tensor([value])
343 | else:
344 | return value
345 |
--------------------------------------------------------------------------------
/kerosene/nn/criterions.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2019 Kerosene Authors. All Rights Reserved.
3 | #
4 | # Licensed under the MIT License;
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://opensource.org/licenses/MIT
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | from enum import Enum
17 | from typing import Union
18 |
19 | import torch
20 | from ignite.metrics import MetricsLambda
21 | from torch import nn
22 | from torch.nn.modules.loss import _Loss, _WeightedLoss
23 |
24 | from kerosene.utils.constants import EPSILON
25 | from kerosene.utils.tensors import flatten
26 |
27 |
28 | class CriterionType(Enum):
29 | DiceLoss = "DiceLoss"
30 | GeneralizedDiceLoss = "GeneralizedDiceLoss"
31 | TverskyLoss = "TverskyLoss"
32 | FocalTverskyLoss = "FocalTverskyLoss"
33 | BCELoss = "BCELoss"
34 | BCEWithLogitsLoss = "BCEWithLogtisLoss"
35 | PoissonNLLLoss = "PoissonNLLLoss"
36 | CosineEmbeddingLoss = "CosineEmbeddingLoss"
37 | CrossEntropyLoss = "CrossEntropyLoss"
38 | CTCLoss = "CTCLoss"
39 | HingeEmbeddingLoss = "HingeEmbeddingLoss"
40 | KLDivLoss = "KLDivLoss"
41 | L1Loss = "L1Loss"
42 | MSELoss = "MSELoss"
43 | MarginRankingLoss = "MarginRankingLoss"
44 | MultiLabelMarginLoss = "MultiLabelMarginLoss"
45 | MultiLabelSoftMarginLoss = "MultiLabelSoftMarginLoss"
46 | MultiMarginLoss = "MultiMarginLoss"
47 | NLLLoss = "NLLLoss"
48 | SmoothL1Loss = "SmoothL1Loss"
49 | SoftMarginLoss = "SoftMarginLoss"
50 | TripletMarginLoss = "TripletMarginLoss"
51 |
52 |
53 | class CriterionFactory(object):
54 | def __init__(self):
55 | super(CriterionFactory, self).__init__()
56 | self._criterion = {
57 | "DiceLoss": DiceLoss,
58 | "GeneralizedDiceLoss": GeneralizedDiceLoss,
59 | "TverskyLoss": TverskyLoss,
60 | "FocalTverskyLoss": FocalTverskyLoss,
61 | "BCELoss": nn.BCELoss,
62 | "BCEWithLogitsLoss": nn.BCEWithLogitsLoss,
63 | "PoissonNLLLoss": nn.PoissonNLLLoss,
64 | "CosineEmbeddingLoss": nn.CosineEmbeddingLoss,
65 | "CrossEntropyLoss": nn.CrossEntropyLoss,
66 | "CTCLoss": nn.CTCLoss,
67 | "HingeEmbeddingLoss": nn.HingeEmbeddingLoss,
68 | "KLDivLoss": nn.KLDivLoss,
69 | "L1Loss": nn.L1Loss,
70 | "MSELoss": nn.MSELoss,
71 | "MarginRankingLoss": nn.MarginRankingLoss,
72 | "MultiLabelMarginLoss": nn.MultiLabelMarginLoss,
73 | "MultiLabelSoftMarginLoss": nn.MultiLabelSoftMarginLoss,
74 | "MultiMarginLoss": nn.MultiMarginLoss,
75 | "NLLLoss": nn.NLLLoss,
76 | "SmoothL1Loss": nn.SmoothL1Loss,
77 | "SoftMarginLoss": nn.SoftMarginLoss,
78 | "TripletMarginLoss": nn.TripletMarginLoss,
79 | }
80 |
81 | def create(self, criterion_type: Union[str, CriterionType], params):
82 | return self._criterion[str(criterion_type)](**params) if params is not None else self._criterion[
83 | str(criterion_type)]()
84 |
85 | def register(self, function: str, creator: _Loss):
86 | """
87 | Add a new criterion.
88 | Args:
89 | function (str): Criterion's name.
90 | creator (:obj:`torch.nn.Module`): A torch module object wrapping the new custom criterion function.
91 | """
92 | self._criterion[function] = creator
93 |
94 |
95 | class DiceLoss(_WeightedLoss):
96 | """
97 | The Sørensen-Dice Loss.
98 | """
99 | SUPPORTED_REDUCTIONS = [None, "mean"]
100 |
101 | def __init__(self, reduction: Union[None, str] = "mean", ignore_index: int = -100, weight: torch.Tensor = None):
102 | if reduction not in self.SUPPORTED_REDUCTIONS:
103 | raise NotImplementedError("Reduction type not supported.")
104 | super(DiceLoss, self).__init__(weight=weight, reduction=reduction)
105 | self._ignore_index = ignore_index
106 |
107 | if self.weight is not None:
108 | if self.weight.requires_grad is not False:
109 | self.weight.requires_grad = False
110 |
111 | def forward(self, inputs: torch.Tensor, targets: torch.Tensor):
112 | """
113 | Computes the Sørensen–Dice loss.
114 | Note that PyTorch optimizers minimize a loss. In this case, we would like to maximize the dice loss so we
115 | return the negated dice loss.
116 | Args:
117 | inputs (:obj:`torch.Tensor`) : A tensor of shape (B, C, ..). The model prediction on which the loss has to
118 | be computed.
119 | targets (:obj:`torch.Tensor`) : A tensor of shape (B, C, ..). The ground truth.
120 | Returns:
121 | :obj:`torch.Tensor`: The Sørensen–Dice loss for each class or reduced according to reduction method.
122 | """
123 | if not inputs.size() == targets.size():
124 | raise ValueError("'Inputs' and 'Targets' must have the same shape.")
125 |
126 | inputs = flatten(inputs)
127 | targets = flatten(targets).float()
128 |
129 | # Compute per channel Dice Coefficient
130 | intersection = (inputs * targets).sum(-1)
131 |
132 | if self.weight is not None:
133 | intersection = self.weight * intersection
134 |
135 | cardinality = (inputs + targets).sum(-1)
136 |
137 | ones = torch.Tensor().new_ones((inputs.size(0),), dtype=torch.float, device=inputs.device)
138 |
139 | dice = ones - (2.0 * intersection / cardinality.clamp(min=EPSILON))
140 |
141 | if self._ignore_index != -100:
142 | def ignore_index_fn(dice_vector):
143 | try:
144 | indices = list(range(len(dice_vector)))
145 | indices.remove(self._ignore_index)
146 | return dice_vector[indices]
147 | except ValueError as e:
148 | raise IndexError(
149 | "'ignore_index' must be non-negative, and lower than the number of classes in confusion matrix, but {} was given. ".format(
150 | self._ignore_index))
151 |
152 | dice = MetricsLambda(ignore_index_fn, dice).compute()
153 |
154 | if self.reduction == "mean":
155 | dice = dice.mean()
156 |
157 | return dice
158 |
159 |
160 | class GeneralizedDiceLoss(_Loss):
161 | """
162 | The Generalized Sørensen-Dice Loss.
163 | """
164 | SUPPORTED_REDUCTIONS = [None, "mean"]
165 |
166 | def __init__(self, reduction: Union[None, str] = "mean", ignore_index: int = -100):
167 | if reduction not in self.SUPPORTED_REDUCTIONS:
168 | raise NotImplementedError("Reduction type not supported.")
169 | super(GeneralizedDiceLoss, self).__init__(reduction=reduction)
170 | self._ignore_index = ignore_index
171 |
172 | def forward(self, inputs: torch.Tensor, targets: torch.Tensor):
173 | """
174 | Computes the Sørensen–Dice loss.
175 | Note that PyTorch optimizers minimize a loss. In this case, we would like to maximize the dice loss so we
176 | return the negated dice loss.
177 | Args:
178 | inputs (:obj:`torch.Tensor`) : A tensor of shape (B, C, ..). The model prediction on which the loss has to
179 | be computed.
180 | targets (:obj:`torch.Tensor`) : A tensor of shape (B, C, ..). The ground truth.
181 | Returns:
182 | :obj:`torch.Tensor`: The Sørensen–Dice loss for each class or reduced according to reduction method.
183 | """
184 | if not inputs.size() == targets.size():
185 | raise ValueError("'Inputs' and 'Targets' must have the same shape.")
186 |
187 | inputs = flatten(inputs)
188 | targets = flatten(targets).float()
189 | ones = torch.Tensor().new_ones((inputs.size(0),), dtype=torch.float, device=inputs.device)
190 | class_weights = ones / torch.pow(targets.sum(-1), 2).clamp(min=EPSILON)
191 |
192 | # Compute per channel Dice Coefficient
193 | intersect = (inputs * targets).sum(-1) * class_weights
194 |
195 | denominator = (inputs + targets).sum(-1) * class_weights
196 |
197 | dice = ones - (2.0 * intersect / denominator.clamp(min=EPSILON))
198 |
199 | if self._ignore_index != -100:
200 | def ignore_index_fn(dice_vector):
201 | try:
202 | indices = list(range(len(dice_vector)))
203 | indices.remove(self._ignore_index)
204 | return dice_vector[indices]
205 | except ValueError as e:
206 | raise IndexError(
207 | "'ignore_index' must be non-negative, and lower than the number of classes in confusion matrix, but {} was given. ".format(
208 | self._ignore_index))
209 |
210 | dice = MetricsLambda(ignore_index_fn, dice).compute()
211 |
212 | if self.reduction == "mean":
213 | dice = dice.mean()
214 |
215 | return dice
216 |
217 |
218 | class TverskyLoss(_WeightedLoss):
219 | SUPPORTED_REDUCTIONS = [None, "mean"]
220 |
221 | def __init__(self, reduction: Union[None, str] = "mean", ignore_index: int = -100, weight: torch.Tensor = None,
222 | alpha: float = 0.5, beta: float = 0.5):
223 | if reduction not in self.SUPPORTED_REDUCTIONS:
224 | raise NotImplementedError("Reduction type not supported.")
225 | super(TverskyLoss, self).__init__(weight=weight, reduction=reduction)
226 | self._ignore_index = ignore_index
227 | self._alpha = alpha
228 | self._beta = beta
229 |
230 | def forward(self, inputs: torch.Tensor, targets: torch.Tensor):
231 | """
232 | Computes the Tversky loss based on https://arxiv.org/pdf/1706.05721.pdf
233 | Note that PyTorch optimizers minimize a loss. In this case, we would like to maximize the dice loss so we
234 | return the negated dice loss.
235 | Args:
236 | inputs (:obj:`torch.Tensor`) : A tensor of shape (B, C, ..). The model prediction on which the loss has to
237 | be computed.
238 | targets (:obj:`torch.Tensor`) : A tensor of shape (B, C, ..). The ground truth.
239 | Returns:
240 | :obj:`torch.Tensor`: The Tversky loss for each class or reduced according to reduction method.
241 | """
242 | if not inputs.size() == targets.size():
243 | raise ValueError("'Inputs' and 'Targets' must have the same shape.")
244 |
245 | inputs = flatten(inputs)
246 | targets = flatten(targets).float()
247 | ones = torch.Tensor().new_ones((inputs.size()), dtype=torch.float, device=inputs.device)
248 |
249 | P_G = (inputs * targets).sum(-1)
250 | if self.weight is not None:
251 | P_G = self.weight * P_G
252 |
253 | P_NG = (inputs * (ones - targets)).sum(-1)
254 | NP_G = ((ones - inputs) * targets).sum(-1)
255 |
256 | ones = torch.Tensor().new_ones((inputs.size(0),), dtype=torch.float, device=inputs.device)
257 | tversky = P_G / (P_G + self._alpha * P_NG + self._beta * NP_G + EPSILON)
258 |
259 | tversky_loss = ones - tversky
260 |
261 | if self._ignore_index != -100:
262 | def ignore_index_fn(tversky_vector):
263 | try:
264 | indices = list(range(len(tversky_vector)))
265 | indices.remove(self._ignore_index)
266 | return tversky_vector[indices]
267 | except ValueError as e:
268 | raise IndexError(
269 | "'ignore_index' must be non-negative, and lower than the number of classes in confusion matrix, but {} was given. ".format(
270 | self._ignore_index))
271 |
272 | tversky_loss = MetricsLambda(ignore_index_fn, tversky_loss).compute()
273 |
274 | if self.reduction == "mean":
275 | tversky_loss = tversky_loss.mean()
276 |
277 | return tversky_loss
278 |
279 |
280 | class FocalTverskyLoss(_WeightedLoss):
281 | SUPPORTED_REDUCTIONS = [None, "mean"]
282 |
283 | def __init__(self, reduction: Union[None, str] = "mean", ignore_index: int = -100, weight: torch.Tensor = None,
284 | alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.0):
285 | if reduction not in self.SUPPORTED_REDUCTIONS:
286 | raise NotImplementedError("Reduction type not supported.")
287 | super(FocalTverskyLoss, self).__init__(weight=weight, reduction=reduction)
288 | self._ignore_index = ignore_index
289 | self._alpha = alpha
290 | self._beta = beta
291 | self._gamma = gamma
292 |
293 | def forward(self, inputs: torch.Tensor, targets: torch.Tensor):
294 | """
295 | Computes the Focal Tversky Loss based on https://arxiv.org/pdf/1708.02002.pdf.
296 | Note that PyTorch optimizers minimize a loss. In this case, we would like to maximize the dice loss so we
297 | return the negated dice loss.
298 | Args:
299 | inputs (:obj:`torch.Tensor`) : A tensor of shape (B, C, ..). The model prediction on which the loss has to
300 | be computed.
301 | targets (:obj:`torch.Tensor`) : A tensor of shape (B, C, ..). The ground truth.
302 | Returns:
303 | :obj:`torch.Tensor`: The Sørensen–Dice loss for each class or reduced according to reduction method.
304 | """
305 | if not inputs.size() == targets.size():
306 | raise ValueError("'Inputs' and 'Targets' must have the same shape.")
307 |
308 | inputs = flatten(inputs)
309 | targets = flatten(targets).float()
310 | ones = torch.Tensor().new_ones((inputs.size()), dtype=torch.float, device=inputs.device)
311 |
312 | P_G = (inputs * targets).sum(-1)
313 | if self.weight is not None:
314 | P_G = self.weight * P_G
315 |
316 | P_NG = (inputs * (ones - targets)).sum(-1)
317 | NP_G = ((ones - inputs) * targets).sum(-1)
318 |
319 | ones = torch.Tensor().new_ones((inputs.size(0),), dtype=torch.float, device=inputs.device)
320 | tversky = P_G / (P_G + self._alpha * P_NG + self._beta * NP_G + EPSILON)
321 |
322 | tversky = ones - (torch.pow((ones - tversky), 1 / self._gamma))
323 |
324 | if self._ignore_index != -100:
325 | def ignore_index_fn(tversky_vector):
326 | try:
327 | indices = list(range(len(tversky_vector)))
328 | indices.remove(self._ignore_index)
329 | return tversky_vector[indices]
330 | except ValueError as e:
331 | raise IndexError(
332 | "'ignore_index' must be non-negative, and lower than the number of classes in confusion matrix, but {} was given. ".format(
333 | self._ignore_index))
334 |
335 | tversky = MetricsLambda(ignore_index_fn, tversky).compute()
336 |
337 | if self.reduction == "mean":
338 | tversky = tversky.mean()
339 |
340 | return tversky
341 |
342 |
343 | class WeightedCrossEntropyLoss(_Loss):
344 |
345 | def __init__(self, ignore_index: int = -100):
346 | super(WeightedCrossEntropyLoss, self).__init__()
347 | self._ignore_index = ignore_index
348 |
349 | def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> float:
350 | """
351 | Computes the Weighted Cross Entropy Loss (WCE) as described in https://arxiv.org/pdf/1707.03237.pdf
352 | Args:
353 | inputs (:obj:`torch.Tensor`): A tensor of shape (B, C, ..). The model's prediction on which the loss has to be computed.
354 | targets (:obj:`torch.Tensor`): A tensor of shape (B, ..). The ground truth.
355 | Returns:
356 | float: the weighted Cross-Entropy loss value.
357 | """
358 | class_weights = self.compute_class_weights(inputs)
359 |
360 | return torch.nn.functional.cross_entropy(inputs, targets, weight=class_weights,
361 | ignore_index=self._ignore_index)
362 |
363 | @staticmethod
364 | def compute_class_weights(inputs: torch.Tensor):
365 | """
366 | Compute weights for each class as described in https://arxiv.org/pdf/1707.03237.pdf
367 | Args:
368 | inputs: (:obj:`torch.Tensor`): A tensor of shape (B, C, ..). The model's prediction on which the loss has to be computed.
369 | Returns:
370 | :obj:`torch.Tensor`: A tensor containing class weights.
371 | """
372 | flattened_inputs = flatten(inputs)
373 | class_weights = (flattened_inputs.shape[1] - flattened_inputs.sum(-1)) / flattened_inputs.sum(-1)
374 | return class_weights
375 |
--------------------------------------------------------------------------------