├── tests ├── __init__.py ├── optim │ ├── __init__.py │ ├── test_optimizers.py │ └── test_schedulers.py ├── utils │ ├── __init__.py │ ├── test_configs.py │ └── test_devices.py ├── configs │ ├── __init__.py │ ├── invalid_config.yml │ ├── valid_config.yml │ └── test_model_trainers_config.py ├── events │ ├── __init__.py │ ├── handlers │ │ ├── __init__.py │ │ ├── test_early_stopping.py │ │ ├── test_visdom.py │ │ ├── test_console.py │ │ └── test_checkpoints.py │ └── publishers │ │ └── __init__.py ├── loggers │ ├── __init__.py │ └── visdom │ │ ├── __init__.py │ │ └── test_visdom.py ├── metrics │ ├── __init__.py │ └── test_metrics.py ├── models │ └── __init__.py ├── nn │ ├── utils │ │ ├── __init__.py │ │ └── test_gradients.py │ ├── __init__.py │ ├── test_functional.py │ ├── test_apex.py │ └── test_criterions.py ├── parsers │ ├── __init__.py │ ├── test_yaml.yml │ └── yaml_parser_test.py ├── scheduler │ └── __init__.py ├── training │ ├── __init__.py │ └── test_trainers.py ├── functionals │ ├── __init__.py │ ├── distributed │ │ ├── __init__.py │ │ ├── run.sh │ │ ├── config.yml │ │ ├── mnist_trainer.py │ │ └── main.py │ └── models.py ├── constants.py └── base_test.py ├── kerosene ├── __init__.py ├── nn │ ├── __init__.py │ ├── utils │ │ ├── __init__.py │ │ └── gradients.py │ ├── functional.py │ ├── apex.py │ └── criterions.py ├── configs │ ├── __init__.py │ ├── exceptions.py │ ├── parsers.py │ └── configs.py ├── loggers │ ├── __init__.py │ └── visdom │ │ ├── __init__.py │ │ ├── data.py │ │ ├── config.py │ │ ├── visdom.py │ │ └── plots.py ├── metrics │ ├── __init__.py │ ├── gauges.py │ └── metrics.py ├── models │ ├── __init__.py │ └── models.py ├── optim │ ├── __init__.py │ ├── schedulers.py │ └── optimizers.py ├── utils │ ├── __init__.py │ ├── constants.py │ ├── configs.py │ ├── devices.py │ ├── files.py │ └── tensors.py ├── events │ ├── handlers │ │ ├── __init__.py │ │ ├── early_stopping.py │ │ ├── base_handler.py │ │ ├── checkpoints.py │ │ ├── base_monitor_watcher.py │ │ ├── console.py │ │ └── visdom.py │ ├── publishers │ │ ├── __init__.py │ │ └── base_publisher.py │ ├── exceptions.py │ └── __init__.py └── training │ ├── __init__.py │ └── events.py ├── examples ├── mnist │ ├── __init__.py │ ├── requirements.txt │ ├── config.yml │ ├── models.py │ ├── main.py │ └── main_hyperopt.py └── config_template.yml ├── icons └── oil.png ├── deploy.sh ├── deploy.bat ├── .travis.yml ├── requirements.txt ├── .github └── ISSUE_TEMPLATE │ ├── other-issues.md │ ├── documentation-issue.md │ ├── feature_request.md │ └── bug---performance-report.md ├── LICENSE ├── setup.py ├── .gitignore └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /kerosene/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /kerosene/nn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/optim/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/mnist/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /kerosene/configs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /kerosene/loggers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /kerosene/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /kerosene/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /kerosene/optim/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /kerosene/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/configs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/events/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/loggers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/nn/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/parsers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/training/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /kerosene/nn/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/events/handlers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/functionals/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/loggers/visdom/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/optim/test_optimizers.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/optim/test_schedulers.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/training/test_trainers.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /kerosene/events/handlers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /kerosene/events/publishers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/events/publishers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/loggers/visdom/test_visdom.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/functionals/distributed/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/events/handlers/test_early_stopping.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/mnist/requirements.txt: -------------------------------------------------------------------------------- 1 | torch-kerosene 2 | hyperopt 3 | -------------------------------------------------------------------------------- /icons/oil.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/banctilrobitaille/kerosene/HEAD/icons/oil.png -------------------------------------------------------------------------------- /deploy.sh: -------------------------------------------------------------------------------- 1 | python3 setup.py sdist 2 | python3 -m twine upload dist/torch-kerosene-0.2.4.tar.gz -------------------------------------------------------------------------------- /kerosene/configs/exceptions.py: -------------------------------------------------------------------------------- 1 | class InvalidConfigurationError(Exception): 2 | pass 3 | -------------------------------------------------------------------------------- /deploy.bat: -------------------------------------------------------------------------------- 1 | python setup.py sdist 2 | python -m twine upload dist/torch-kerosene-0.2.4.tar.gz 3 | 4 | -------------------------------------------------------------------------------- /tests/parsers/test_yaml.yml: -------------------------------------------------------------------------------- 1 | weights: !torch/tensor [1, 1, 1, 1, 1, 1] 2 | tuple: !python/tuple [0.1,0.2] -------------------------------------------------------------------------------- /tests/functionals/distributed/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | CUDA_VISIBLE_DEVICES=0,1 python3 -m torch.distributed.launch --nproc_per_node=2 main.py --amp-opt-level=O1 -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "3.6" 4 | before_install: 5 | - pip install --upgrade pip 6 | install: 7 | - pip install --progress-bar off -r requirements.txt 8 | script: 9 | - pytest -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | PyYAML>=5.1 2 | PyHamcrest>=1.9.0 3 | numpy>=1.16.1 4 | pyparsing>=2.3.1 5 | pytest>=4.3.0 6 | torch>=1.1 7 | torchfile>=0.1.0 8 | torchvision>=0.2.1 9 | visdom>=0.1.8.8 10 | pytorch-ignite>= 0.2.0 11 | mockito >= 1.1.1 12 | mock >= 3.0.5 13 | beautifultable 14 | crayons 15 | -------------------------------------------------------------------------------- /kerosene/nn/functional.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def js_div(inputs): 6 | js_divs = [] 7 | for dist in range(inputs.size(1)): 8 | js_divs.append( 9 | F.kl_div(torch.mean(inputs, dim=1, keepdim=True).log(), inputs[:, dist].unsqueeze(0), reduction='sum')) 10 | 11 | return torch.tensor(js_divs).mean() 12 | -------------------------------------------------------------------------------- /tests/configs/invalid_config.yml: -------------------------------------------------------------------------------- 1 | training: 2 | nb_epochs: 10 3 | batch_size_train: 10 4 | batch_size_valid: 10 5 | 6 | models: 7 | SimpleNet: 8 | type: SimpleNet 9 | scheduler: 10 | type: ReduceLROnPlateau 11 | params: 12 | mode: 'min' 13 | factor: 0.1 14 | patience: 3 15 | criterion: 16 | type: CrossEntropyLoss 17 | metric: 18 | type: Accuracy 19 | -------------------------------------------------------------------------------- /kerosene/events/exceptions.py: -------------------------------------------------------------------------------- 1 | class UnsupportedEventException(Exception): 2 | 3 | def __init__(self, supported_events, unsupported_event): 4 | super(UnsupportedEventException, self).__init__( 5 | "Unsupported event provided ({}). Only {} are permitted".format(str(unsupported_event), 6 | [str(event) for event in 7 | supported_events])) 8 | -------------------------------------------------------------------------------- /examples/config_template.yml: -------------------------------------------------------------------------------- 1 | training: 2 | nb_epochs: 250 3 | batch_size: 1 4 | 5 | models: 6 | MyModel: 7 | type: resnet 8 | params: 9 | group_size: 32 10 | optimizer: 11 | type: "SGD" 12 | params: 13 | learning_rate: 0.001 14 | momentum: 0.9 15 | weigth_decay: 0.1 16 | scheduler: 17 | type: "ReduceOnPlateu" 18 | params: 19 | mode: 'min' 20 | factor: 0.1 21 | patience: 3 22 | criterion: 23 | type: "BCELoss" 24 | params: -------------------------------------------------------------------------------- /tests/events/handlers/test_visdom.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from mockito import mock, spy 4 | 5 | from kerosene.events.handlers.visdom import PlotLosses 6 | from kerosene.loggers.visdom.visdom import VisdomLogger 7 | from tests.base_test import BaseTestEnvironment 8 | 9 | 10 | class VisdomTest(BaseTestEnvironment): 11 | 12 | def setUp(self): 13 | self._visdom_logger_mock = mock(VisdomLogger) 14 | self._plot_losses_handler = spy(PlotLosses) 15 | 16 | def test_should_plot_losses(self): 17 | pass 18 | -------------------------------------------------------------------------------- /tests/constants.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | DELTA = 0.0001 4 | ZERO = torch.tensor([0.0]) 5 | ONE = torch.tensor([1.0]) 6 | TARGET_CLASS_0 = torch.tensor([0]) 7 | TARGET_CLASS_1 = torch.tensor([1]) 8 | MODEL_PREDICTION_CLASS_0 = torch.tensor([[1.0, 0.0]]) 9 | MODEL_PREDICTION_CLASS_1 = torch.tensor([[0.0, 1.0]]) 10 | MINIMUM_BINARY_CROSS_ENTROPY_LOSS = torch.tensor(0.3133) 11 | MAXIMUM_BINARY_CROSS_ENTROPY_LOSS = torch.tensor(1.3133) 12 | AVERAGED_BINARY_CROSS_ENTROPY_LOSS = (MINIMUM_BINARY_CROSS_ENTROPY_LOSS + MAXIMUM_BINARY_CROSS_ENTROPY_LOSS) / 2 13 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/other-issues.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Other issues 3 | about: Use this template for any other non-support related issues 4 | title: "[ISSUE]" 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | This template is for miscellaneous issues not covered by the other issue categories. 11 | 12 | **Describe the issue** 13 | A clear and concise description of what the problem is. 14 | 15 | **Expectations** 16 | A clear and concise description of what you expect by submitting this issue. 17 | 18 | **Additional context** 19 | Add any other context or screenshots about the issue here. 20 | -------------------------------------------------------------------------------- /examples/mnist/config.yml: -------------------------------------------------------------------------------- 1 | training: 2 | nb_epochs: 1 3 | batch_size_train: 10 4 | batch_size_valid: 10 5 | 6 | models: 7 | SimpleNet: 8 | type: SimpleNet 9 | optimizer: 10 | type: SGD 11 | params: 12 | lr: 0.001 13 | momentum: 0.9 14 | weight_decay: 0.001 15 | scheduler: 16 | type: ReduceLROnPlateau 17 | params: 18 | mode: 'min' 19 | factor: 0.1 20 | patience: 3 21 | criterion: 22 | CrossEntropy: 23 | type: CrossEntropyLoss 24 | 25 | 26 | 27 | visdom: 28 | server: localhost 29 | port: 8097 30 | env: "Kerosene-MNIST" -------------------------------------------------------------------------------- /tests/parsers/yaml_parser_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | from hamcrest import * 5 | 6 | from kerosene.configs.parsers import CustomYamlParser 7 | 8 | 9 | class YamlParserTest(unittest.TestCase): 10 | 11 | def setUp(self) -> None: 12 | self._path = "tests/parsers/test_yaml.yml" 13 | self._parser = CustomYamlParser() 14 | 15 | def test_should_parse_file_and_create_tensor(self): 16 | with open(self._path) as file: 17 | config = self._parser.safe_load(file) 18 | 19 | assert_that(config["weights"], instance_of(torch.Tensor)) 20 | assert_that(config["tuple"], instance_of(tuple)) 21 | -------------------------------------------------------------------------------- /tests/functionals/distributed/config.yml: -------------------------------------------------------------------------------- 1 | training: 2 | nb_epochs: 100 3 | batch_size: 32 4 | 5 | models: 6 | SimpleNet: 7 | type: SimpleNet 8 | params: 9 | optimizer: 10 | type: FusedSGD 11 | params: 12 | lr: 0.01 13 | momentum: 0.5 14 | weight_decay: 0 15 | scheduler: 16 | type: ReduceLROnPlateau 17 | params: 18 | mode: 'min' 19 | factor: 0.1 20 | patience: 3 21 | criterion: 22 | type: CrossEntropyLoss 23 | params: 24 | metric: 25 | type: Accuracy 26 | params: 27 | 28 | visdom: 29 | server: "http://10.0.3.9" 30 | port: 8097 31 | env: "Kerosene-MNIST" -------------------------------------------------------------------------------- /kerosene/training/__init__.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class BaseStatus(Enum): 5 | 6 | def __str__(self): 7 | return self.value 8 | 9 | def __eq__(self, other): 10 | if isinstance(other, str): 11 | return self.value == other 12 | elif isinstance(other, BaseStatus): 13 | return self.value == other.value 14 | 15 | def __hash__(self): 16 | return hash(self.value) 17 | 18 | 19 | class Status(BaseStatus): 20 | INITIALIZING = "Initializing" 21 | INITIALIZED = "Initialized" 22 | READY = "Ready" 23 | TRAINING = "Training" 24 | VALIDATING = "Validating" 25 | TESTING = "Testing" 26 | FINALIZED = "Finalized" 27 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation-issue.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Documentation issue 3 | about: Create a report to help us improve documentation 4 | title: "[DOCUMENTATION]" 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the documentation issue** 11 | A clear and concise description of what the documentation issue is. 12 | 13 | **System information** 14 | - Framework version: [e.g. 0.1.0] 15 | - Python version: [e.g. Python 3.7] 16 | - Doc Link: 17 | 18 | **Expected** 19 | A clear and concise description of what you expected to happen. 20 | 21 | **Screenshots** 22 | If applicable, add screenshots to help explain your problem. 23 | 24 | **Describe the documentation issue** 25 | Add any other context about the problem here. 26 | -------------------------------------------------------------------------------- /examples/mnist/models.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from torch.nn import Module, Linear, Dropout2d, Conv2d 3 | 4 | 5 | class SimpleConvNet(Module): 6 | def __init__(self): 7 | super().__init__() 8 | self.conv1 = Conv2d(1, 10, kernel_size=5) 9 | self.conv2 = Conv2d(10, 20, kernel_size=5) 10 | self.conv2_drop = Dropout2d() 11 | self.fc1 = Linear(320, 50) 12 | self.fc2 = Linear(50, 10) 13 | 14 | def forward(self, x): 15 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 16 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 17 | x = x.view(-1, 320) 18 | x = F.relu(self.fc1(x)) 19 | x = F.dropout(x) 20 | x = self.fc2(x) 21 | return x 22 | -------------------------------------------------------------------------------- /tests/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2019 SAMITorch Authors. All Rights Reserved. 3 | # # 4 | # Licensed under the MIT License; 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # # 8 | # https://opensource.org/licenses/MIT 9 | # # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | -------------------------------------------------------------------------------- /tests/functionals/models.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from torch.nn import Module, Linear, Dropout2d, Conv2d 3 | 4 | 5 | class SimpleNet(Module): 6 | def __init__(self): 7 | super().__init__() 8 | self.conv1 = Conv2d(1, 10, kernel_size=5) 9 | self.conv2 = Conv2d(10, 20, kernel_size=5) 10 | self.conv2_drop = Dropout2d() 11 | self.fc1 = Linear(320, 50) 12 | self.fc2 = Linear(50, 10) 13 | 14 | def forward(self, x): 15 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 16 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 17 | x = x.view(-1, 320) 18 | x = F.relu(self.fc1(x)) 19 | x = F.dropout(x, training=self.training) 20 | x = self.fc2(x) 21 | return x -------------------------------------------------------------------------------- /kerosene/utils/constants.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2019 Kerosene Authors. All Rights Reserved. 3 | # 4 | # Licensed under the MIT License; 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://opensource.org/licenses/MIT 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | EPSILON = 1e-15 17 | CHECKPOINT_EXT = ".tar" 18 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: "[FEATURE]" 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Will this change the current api? How?** 20 | A clear and concise description of how this feature request could change the current api. 21 | 22 | **Additional context** 23 | Add any other context or screenshots about the feature request here. 24 | -------------------------------------------------------------------------------- /kerosene/loggers/visdom/__init__.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class PlotType(Enum): 5 | LINE_PLOT = "Line Plot" 6 | IMAGES_PLOT = "Images Plot" 7 | IMAGE_PLOT = "Image Plot" 8 | PIE_PLOT = "Pie Plot" 9 | TEXT_PLOT = "Text Plot" 10 | HISTOGRAM_PLOT = "Histogram Plot" 11 | SCATTER_PLOT = "Scatter Plot" 12 | STEM_PLOT = "Stem Plot" 13 | HEATMAP_PLOT = "Heatmap Plot" 14 | BAR_PLOT = "Bar Plot" 15 | BOX_PLOT = "Box Plot" 16 | SURFACE_PLOT = "Surface Plot" 17 | CONTOUR_PLOT = "Contour Plot" 18 | QUIVER_PLOT = "Quiver Plot" 19 | MESH_PLOT = "Mesh Plot" 20 | MATPLOTLIB_PLOT = "Matplotlib Plot" 21 | 22 | def __str__(self): 23 | return self.value 24 | 25 | 26 | class PlotFrequency(Enum): 27 | EVERY_STEP = "Step" 28 | EVERY_EPOCH = "Epoch" 29 | 30 | def __str__(self): 31 | return self.value 32 | -------------------------------------------------------------------------------- /kerosene/models/models.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2019 Kerosene Authors. All Rights Reserved. 3 | # 4 | # Licensed under the MIT License; 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://opensource.org/licenses/MIT 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | from abc import ABC, abstractmethod 17 | 18 | 19 | class ModelFactory(ABC): 20 | @abstractmethod 21 | def create(self, model_type, params): 22 | raise NotImplementedError 23 | -------------------------------------------------------------------------------- /kerosene/events/handlers/early_stopping.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from kerosene.events import MonitorMode, TemporalEvent 4 | from kerosene.events.handlers.base_monitor_watcher import MonitorWatcher, MonitorPatienceExceeded 5 | from kerosene.training.trainers import Trainer 6 | 7 | 8 | class EarlyStopping(MonitorWatcher): 9 | LOGGER = logging.getLogger("EarlyStopping") 10 | 11 | def __init__(self, monitor_fn, mode: MonitorMode, min_delta=0.01, patience=3): 12 | super(EarlyStopping, self).__init__(monitor_fn, mode, min_delta, patience) 13 | 14 | def __call__(self, temporal_event: TemporalEvent, monitors, trainer: Trainer): 15 | for model_trainer in trainer.model_trainers: 16 | try: 17 | value = self._monitor_fn(model_trainer) 18 | self.watch(model_trainer.name, value) 19 | except MonitorPatienceExceeded as e: 20 | model_trainer.finalize() 21 | -------------------------------------------------------------------------------- /tests/utils/test_configs.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from hamcrest import * 4 | 5 | from kerosene.configs.configs import TrainerConfiguration, DatasetConfiguration 6 | from kerosene.utils.configs import configs_to_html 7 | 8 | 9 | class ConfigTests(unittest.TestCase): 10 | 11 | def setUp(self) -> None: 12 | self._config_1 = TrainerConfiguration({"num_epochs": 30, "batch_size": 32}) 13 | self._config_2 = DatasetConfiguration({"path": "/home/data", "validation_split": 0.2}) 14 | 15 | def test_should_produce_html_from_configs(self): 16 | config_html = configs_to_html([self._config_1, self._config_2]) 17 | assert_that(config_html[0], 18 | is_('

Training Configuration

\n

num_epochs: 30

\n

batch_size: 32

')) 19 | assert_that(config_html[1], 20 | is_('

Dataset Configuration

\n

path: /home/data

\n

validation_split: 0.2

')) 21 | -------------------------------------------------------------------------------- /tests/nn/test_functional.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | from hamcrest import assert_that, equal_to 5 | 6 | import kerosene.nn.functional as F 7 | 8 | 9 | class TestJensenShannonDivergence(unittest.TestCase): 10 | PROB_DIST1, PROB_DIST2, PROB_DIST3 = [1 / 2, 1 / 2, 0], [0, 1 / 10, 9 / 10], [1 / 3, 1 / 3, 1 / 3] 11 | 12 | def test_should_compute_jensen_shannon_divergence(self): 13 | prob_distributions = torch.tensor([[self.PROB_DIST1, self.PROB_DIST2, self.PROB_DIST3]]) 14 | expected_results = torch.tensor([0.378889]) 15 | 16 | assert_that(F.js_div(prob_distributions), equal_to(expected_results)) 17 | 18 | def test_should_compute_jensen_shannon_divergence_of_same_distribution(self): 19 | prob_distributions = torch.tensor([[self.PROB_DIST1, self.PROB_DIST1, self.PROB_DIST1]]) 20 | expected_results = torch.tensor([0.0]) 21 | 22 | assert_that(F.js_div(prob_distributions), equal_to(expected_results)) 23 | -------------------------------------------------------------------------------- /kerosene/utils/configs.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2019 Kerosene Authors. All Rights Reserved. 3 | # 4 | # Licensed under the MIT License; 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://opensource.org/licenses/MIT 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | from typing import List 17 | 18 | from kerosene.configs.configs import HtmlConfiguration 19 | 20 | 21 | def configs_to_html(configs: List[HtmlConfiguration]): 22 | return list(map(lambda config: config.to_html(), configs)) 23 | -------------------------------------------------------------------------------- /tests/nn/utils/test_gradients.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from hamcrest import * 4 | 5 | from kerosene.nn.utils.gradients import GradientClippingStrategyFactory, GradientClippingStrategyType, \ 6 | GradientClippingStrategy 7 | 8 | 9 | class TestGradientClippingStrategyFactory(unittest.TestCase): 10 | def setUp(self) -> None: 11 | pass 12 | 13 | def test_should_return_a_gradient_norm_clipping_strategy(self): 14 | strategy = GradientClippingStrategyFactory().create(GradientClippingStrategyType.Norm, 15 | {"max_norm": 1.0, "norm_type": 2}) 16 | 17 | assert_that(strategy, instance_of(GradientClippingStrategy)) 18 | 19 | def test_should_return_a_gradient_value_clipping_strategy(self): 20 | strategy = GradientClippingStrategyFactory().create(GradientClippingStrategyType.Value, {"clip_value": 1.0}) 21 | 22 | assert_that(strategy, instance_of(GradientClippingStrategy)) 23 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug---performance-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug / Performance report 3 | about: Create a report to help us improve 4 | title: "[BUG]" 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Execute this script '...' 17 | 3. ... 18 | 19 | **Expected behavior** 20 | A clear and concise description of what you expected to happen. 21 | 22 | **Screenshots** 23 | If applicable, add screenshots to help explain your problem. 24 | 25 | **Desktop (please complete the following information):** 26 | - OS platform and distribution: [e.g. Linux Ubuntu 18.04.2 LTS] 27 | - Python version: [e.g. Python 3.7] 28 | - CUDA/cuDNN version: [e.g. CUDA 10.0 / cuDNN 7.3.1] 29 | - GPU model and memory: [e.g. NVIDIA GeForce RTX 2080 8 GB] 30 | 31 | **Additional context** 32 | Add any other context about the problem here. 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Benoit Anctil-Robitaille 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | from setuptools import find_packages, setup 4 | 5 | # The text of the README file 6 | README_CONTENT = (pathlib.Path(__file__).parent / "README.md").read_text() 7 | 8 | setup( 9 | name='torch-kerosene', 10 | version='0.2.4', 11 | description='Deep Learning framework for fast and clean research development with Pytorch', 12 | long_description=README_CONTENT, 13 | long_description_content_type='text/markdown', 14 | author='Benoit Anctil-Robitaille', 15 | author_email='benoit.anctil-robitaille.1@ens.etsmtl.ca', 16 | license='MIT', 17 | classifiers=[ 18 | "License :: OSI Approved :: MIT License", 19 | "Programming Language :: Python :: 3", 20 | "Programming Language :: Python :: 3.7"], 21 | packages=find_packages(exclude=("tests",)), 22 | install_requires=['numpy>=1.16.1', 23 | 'visdom>=0.1.8.8', 24 | 'pytorch-ignite>= 0.2.0', 25 | 'torch>=1.1', 26 | 'torchvision>=0.2.1', 27 | 'PyYAML', 28 | 'crayons', 29 | 'beautifultable'] 30 | ) 31 | -------------------------------------------------------------------------------- /kerosene/events/publishers/base_publisher.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2019 Kerosene Authors. All Rights Reserved. 3 | # 4 | # Licensed under the MIT License; 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://opensource.org/licenses/MIT 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | from abc import ABC, abstractmethod 17 | 18 | from kerosene.events import BaseEvent, TemporalEvent 19 | 20 | 21 | class EventPublisher(ABC): 22 | def __init__(self): 23 | self._event_handlers = {} 24 | 25 | @property 26 | @abstractmethod 27 | def sender(self): 28 | raise NotImplementedError() 29 | 30 | @abstractmethod 31 | def with_event_handler(self, handler, event: BaseEvent): 32 | raise NotImplementedError() 33 | 34 | def fire(self, temporal_event: TemporalEvent, monitors: dict = None): 35 | if temporal_event.event in self._event_handlers.keys(): 36 | for handler in self._event_handlers[temporal_event.event]: 37 | handler(temporal_event, monitors, self.sender) 38 | -------------------------------------------------------------------------------- /tests/base_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from ignite.metrics import Accuracy, Recall 4 | from mockito import mock, spy 5 | from torch import nn 6 | from torch.optim import Optimizer, lr_scheduler 7 | from torch.utils.data import DataLoader 8 | 9 | from kerosene.nn.utils.gradients import GradientClippingStrategy 10 | from kerosene.training.trainers import ModelTrainer 11 | 12 | 13 | class BaseTestEnvironment(unittest.TestCase): 14 | MODEL_NAME = "Harry Potter" 15 | TRAINER_NAME = "Drago Malfoy" 16 | 17 | def setUp(self) -> None: 18 | self._model_mock = mock(nn.Module) 19 | self._criterion_mock = spy(nn.CrossEntropyLoss()) 20 | self._optimizer_mock = mock(Optimizer) 21 | self._scheduler_mock = mock(lr_scheduler) 22 | self._accuracy_computer_mock = spy(Accuracy()) 23 | self._recall_computer_mock = spy(Recall()) 24 | self._gradient_clipping_strategy = mock(GradientClippingStrategy) 25 | 26 | self._training_data_loader_mock = mock(DataLoader) 27 | self._valid_data_loader_mock = mock(DataLoader) 28 | self._test_data_loader_mock = mock(DataLoader) 29 | 30 | self._model_trainer = ModelTrainer(self.MODEL_NAME, self._model_mock, self._criterion_mock, 31 | self._optimizer_mock, self._scheduler_mock, 32 | {"Accuracy": self._accuracy_computer_mock, 33 | "Recall": self._recall_computer_mock}, self._gradient_clipping_strategy) 34 | -------------------------------------------------------------------------------- /tests/utils/test_devices.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | from hamcrest import * 5 | 6 | from kerosene.utils.devices import on_multiple_gpus, on_single_gpu, get_devices 7 | 8 | 9 | class DeviceTests(unittest.TestCase): 10 | 11 | def setUp(self) -> None: 12 | self._multiple_gpus_devices = [torch.device("cuda:0"), torch.device("cuda:1")] 13 | self._single_gpu_device = [torch.device("cuda:0")] 14 | self._single_cpu_device = [torch.device("cpu")] 15 | 16 | def test_on_single_gpu_should_return_true_with_single_GPU_device(self): 17 | assert_that(on_single_gpu(self._single_gpu_device), is_(True)) 18 | 19 | def test_on_single_gpu_should_return_false_with_multiple_GPU_devices(self): 20 | assert_that(on_single_gpu(self._multiple_gpus_devices), is_(False)) 21 | 22 | def test_on_multiple_gpu_should_return_false_with_single_GPU_device(self): 23 | assert_that(on_multiple_gpus(self._single_gpu_device), is_(False)) 24 | 25 | def test_on_multiple_gpu_should_return_true_with_multiple_GPU_devices(self): 26 | assert_that(on_multiple_gpus(self._multiple_gpus_devices), is_(True)) 27 | 28 | def test_get_devices_should_return_at_least_one_cuda_enabled_device(self): 29 | if torch.cuda.is_available(): 30 | assert_that(get_devices(), has_item(torch.device("cuda:0"))) 31 | assert_that(get_devices(), has_length(torch.cuda.device_count())) 32 | assert_that(get_devices(), is_( 33 | [torch.device("cuda:{}".format(id)) for id in range(int(torch.cuda.device_count()))])) 34 | else: 35 | assert_that(get_devices(), is_([torch.device("cpu")])) 36 | -------------------------------------------------------------------------------- /tests/configs/valid_config.yml: -------------------------------------------------------------------------------- 1 | training: 2 | nb_epochs: 10 3 | batch_size_train: 10 4 | batch_size_valid: 10 5 | 6 | models: 7 | SimpleNet: 8 | type: SimpleNet 9 | optimizer: 10 | type: SGD 11 | params: 12 | lr: 0.01 13 | momentum: 0.5 14 | weight_decay: 0 15 | scheduler: 16 | type: ReduceLROnPlateau 17 | params: 18 | mode: 'min' 19 | factor: 0.1 20 | patience: 3 21 | criterion: 22 | cycle: 23 | type: "L1Loss" 24 | gan: 25 | type: "MSELoss" 26 | metrics: 27 | Dice: 28 | type: Dice 29 | params: 30 | num_classes: 4 31 | reduction: !!null 32 | ignore_index: 0 33 | average: !!null 34 | weight: !!null 35 | Accuracy: 36 | type: Accuracy 37 | gradients: 38 | type: 'norm' 39 | params: 40 | max_norm: 1.0 41 | SimpleNet2: 42 | type: SimpleNet 43 | optimizer: 44 | type: SGD 45 | params: 46 | lr: 0.01 47 | momentum: 0.5 48 | weight_decay: 0 49 | scheduler: 50 | type: ReduceLROnPlateau 51 | params: 52 | mode: 'min' 53 | factor: 0.1 54 | patience: 3 55 | criterion: 56 | cycle: 57 | type: "L1Loss" 58 | gan: 59 | type: "MSELoss" 60 | metrics: 61 | Dice: 62 | type: Dice 63 | params: 64 | num_classes: 4 65 | reduction: !!null 66 | ignore_index: 0 67 | average: !!null 68 | weight: !!null 69 | Accuracy: 70 | type: Accuracy 71 | gradients: 72 | type: 'norm' 73 | params: 74 | max_norm: 1.0 75 | -------------------------------------------------------------------------------- /kerosene/metrics/gauges.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2019 Kerosene Authors. All Rights Reserved. 3 | # 4 | # Licensed under the MIT License; 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://opensource.org/licenses/MIT 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | from abc import ABC, abstractmethod 17 | 18 | 19 | class Gauge(ABC): 20 | 21 | @abstractmethod 22 | def update(self, **kwargs): 23 | raise NotImplementedError() 24 | 25 | @abstractmethod 26 | def compute(self): 27 | raise NotImplementedError() 28 | 29 | @abstractmethod 30 | def reset(self): 31 | raise NotImplementedError() 32 | 33 | 34 | class AverageGauge(Gauge): 35 | 36 | def __init__(self): 37 | self._value = 0.0 38 | self._sum = 0.0 39 | self._count = 0 40 | 41 | @property 42 | def count(self): 43 | return self._count 44 | 45 | def has_been_updated(self): 46 | return True if self._count > 0 else False 47 | 48 | def update(self, value, n=1): 49 | self._value = value 50 | self._sum += value * n 51 | self._count += n 52 | 53 | def compute(self): 54 | return self._sum / self._count if self._count != 0 else 0.0 55 | 56 | def reset(self): 57 | self._value = 0.0 58 | self._sum = 0.0 59 | self._count = 0 60 | -------------------------------------------------------------------------------- /kerosene/utils/devices.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2019 Kerosene Authors. All Rights Reserved. 3 | # 4 | # Licensed under the MIT License; 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://opensource.org/licenses/MIT 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | from typing import List 17 | 18 | import torch 19 | 20 | 21 | def on_cpu(device: torch.device): 22 | return str(device) == "cpu" 23 | 24 | 25 | def on_gpu(device: torch.device): 26 | return device.type == "cuda" 27 | 28 | 29 | def on_gpus(devices: List[torch.device]): 30 | return all([device.type == "cuda" for device in devices]) 31 | 32 | 33 | def on_single_device(devices: List[torch.device]): 34 | return len(devices) == 1 35 | 36 | 37 | def on_multiple_devices(devices: List[torch.device]): 38 | return len(devices) > 1 39 | 40 | 41 | def on_single_gpu(devices: List[torch.device]): 42 | return on_single_device(devices) and on_gpus(devices) 43 | 44 | 45 | def on_multiple_gpus(devices: List[torch.device]): 46 | return on_multiple_devices(devices) and on_gpus(devices) 47 | 48 | 49 | def get_devices(): 50 | return [torch.device("cuda:{}".format(device_id)) for device_id in 51 | range(torch.cuda.device_count())] if torch.cuda.is_available() else [torch.device("cpu")] 52 | 53 | 54 | def num_gpus(): 55 | return len(list(([device.type == "cuda" for device in get_devices()]))) 56 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | idea 107 | .idea 108 | /data/* 109 | /tests/functionals/distributed/files/ 110 | files -------------------------------------------------------------------------------- /kerosene/events/handlers/base_handler.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2019 Kerosene Authors. All Rights Reserved. 3 | # 4 | # Licensed under the MIT License; 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://opensource.org/licenses/MIT 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | from abc import ABC, abstractmethod 17 | from typing import List, Union 18 | 19 | from kerosene.events import TemporalEvent, BaseEvent 20 | from kerosene.events.exceptions import UnsupportedEventException 21 | 22 | 23 | class EventHandler(ABC): 24 | 25 | def __init__(self, supported_events: List[Union[BaseEvent, TemporalEvent]] = None, every=1): 26 | self._every = every 27 | self._supported_events = supported_events 28 | 29 | @property 30 | def every(self): 31 | return self._every 32 | 33 | @property 34 | def supported_events(self): 35 | return self._supported_events 36 | 37 | def should_handle(self, event: TemporalEvent): 38 | if (self.supported_events is not None) and (event not in self.supported_events): 39 | raise UnsupportedEventException(self.supported_events, event) 40 | 41 | if iter == 0 and self._every != 1: 42 | return False 43 | else: 44 | return event.iteration % self._every == 0 45 | 46 | @abstractmethod 47 | def __call__(self, temporal_event: TemporalEvent, monitors, sender): 48 | raise NotImplementedError() 49 | -------------------------------------------------------------------------------- /kerosene/loggers/visdom/data.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List 2 | 3 | 4 | class VisdomData(object): 5 | 6 | def __init__(self, source_name, variable_name: Union[List[str], str], plot_type, plot_frequency, x, y, 7 | params: dict = None): 8 | self._source_name = source_name 9 | self._variable_name = variable_name 10 | self._plot_type = plot_type 11 | self._plot_frequency = plot_frequency 12 | self._x = x 13 | self._y = y 14 | self._params = params 15 | 16 | @property 17 | def id(self): 18 | return hash(self) 19 | 20 | @property 21 | def source_name(self): 22 | return self._source_name 23 | 24 | @property 25 | def variable_name(self): 26 | return self._variable_name 27 | 28 | @property 29 | def plot_type(self): 30 | return self._plot_type 31 | 32 | @property 33 | def plot_frequency(self): 34 | return self._plot_frequency 35 | 36 | @property 37 | def x(self): 38 | return self._x 39 | 40 | @property 41 | def y(self): 42 | return self._y 43 | 44 | @property 45 | def params(self): 46 | return self._params 47 | 48 | def __hash__(self): 49 | return hash(self._source_name + self._variable_name + str(self._plot_frequency) + str(self._plot_type)) 50 | 51 | def __str__(self): 52 | return "Source: {} Variable: {} Plot Type: {} Frequency: {} x: {} y: {}".format(self.source_name, 53 | self.variable_name, 54 | str(self.plot_type), 55 | str(self.plot_frequency), 56 | self.x, self.y) 57 | -------------------------------------------------------------------------------- /kerosene/utils/files.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2019 Kerosene Authors. All Rights Reserved. 3 | # 4 | # Licensed under the MIT License; 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://opensource.org/licenses/MIT 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | import os 17 | from glob import glob 18 | from typing import List, Tuple 19 | 20 | 21 | def split_filename(filepath: str) -> Tuple[str, str, str]: 22 | """ 23 | Split a filepath into the directory, base, and extension 24 | 25 | Args: 26 | filepath (str): The base file path. 27 | 28 | Returns: 29 | Tuple: The complete file path, base path and file extension. 30 | """ 31 | path = os.path.dirname(filepath) 32 | filename = os.path.basename(filepath) 33 | base, ext = os.path.splitext(filename) 34 | if ext == '.gz': 35 | base, ext2 = os.path.splitext(base) 36 | ext = ext2 + ext 37 | return path, base, ext 38 | 39 | 40 | def extract_file_paths(path: str, ext='*.nii*') -> List[str]: 41 | """ 42 | Grab all `ext` files in a directory and sort them for consistency. 43 | 44 | Args: 45 | path (str): File path. 46 | ext (str): File's extension to grab. 47 | 48 | Returns: 49 | list: A list of string containing every file paths. 50 | """ 51 | file_paths = sorted(glob(os.path.join(path, ext))) 52 | return file_paths 53 | 54 | 55 | def should_create_dir(path, dir_name): 56 | return not os.path.exists(os.path.join(path, dir_name)) 57 | -------------------------------------------------------------------------------- /kerosene/loggers/visdom/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2019 Kerosene Authors. All Rights Reserved. 3 | # 4 | # Licensed under the MIT License; 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://opensource.org/licenses/MIT 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | from kerosene.configs.parsers import YamlConfigurationParser 17 | 18 | 19 | class VisdomConfiguration(object): 20 | def __init__(self, port, server, env, filename, offline=False): 21 | self._port = port 22 | self._server = server 23 | self._env = env 24 | self._filename = filename 25 | self._offline = offline 26 | 27 | @property 28 | def port(self): 29 | return self._port 30 | 31 | @property 32 | def server(self): 33 | return self._server 34 | 35 | @property 36 | def env(self): 37 | return self._env 38 | 39 | @property 40 | def filename(self): 41 | return self._filename 42 | 43 | @property 44 | def offline(self): 45 | return self._offline 46 | 47 | @classmethod 48 | def from_dict(cls, config_dict): 49 | return cls(config_dict.get("port", 8097), config_dict.get("server", "http://localhost"), 50 | config_dict.get("env", "main"), config_dict.get("filename", None), 51 | config_dict.get("offline", False)) 52 | 53 | @classmethod 54 | def from_yml(cls, yml_file, yml_tag="visdom"): 55 | config = YamlConfigurationParser.parse_section(yml_file, yml_tag) 56 | return VisdomConfiguration.from_dict(config) 57 | -------------------------------------------------------------------------------- /kerosene/utils/tensors.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2019 Kerosene Authors. All Rights Reserved. 3 | # 4 | # Licensed under the MIT License; 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://opensource.org/licenses/MIT 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | import torch 17 | 18 | 19 | def to_onehot(indices: torch.Tensor, num_classes: int) -> torch.Tensor: 20 | """ 21 | Convert a tensor of indices of any shape `(N, ...)` to a tensor of one-hot indicators of shape 22 | `(N, num_classes, ...)`. 23 | 24 | Args: 25 | indices (:obj:`torch.Tensor`): Tensor of indices. 26 | num_classes (int): The number of classes of the problem. 27 | 28 | Returns: 29 | :obj:`torch.Tensor`: The one-hot representation of the input tensor. 30 | """ 31 | onehot = torch.zeros(indices.shape[0], num_classes, *indices.shape[1:], device=indices.device) 32 | return onehot.scatter_(1, indices.unsqueeze(1), 1) 33 | 34 | 35 | def flatten(tensor: torch.Tensor) -> torch.Tensor: 36 | """ 37 | Flattens a given tensor such that the channel axis is first. 38 | The shapes are transformed as follows: (N, C, D, H, W) -> (C, N * D * H * W) 39 | 40 | Args: 41 | tensor (:obj:`torch.Tensor`): Tensor to flatten. 42 | 43 | Returns: 44 | :obj:`torch.Tensor`: The flattened, 1-D tensor. 45 | """ 46 | C = tensor.size(1) 47 | # new axis order 48 | axis_order = (1, 0) + tuple(range(2, tensor.dim())) 49 | # Transpose: (N, C, D, H, W) -> (C, N, D, H, W) 50 | transposed = tensor.permute(axis_order).contiguous() 51 | # Flatten: (C, N, D, H, W) -> (C, N * D * H * W) 52 | return transposed.view(C, -1) 53 | -------------------------------------------------------------------------------- /examples/mnist/main.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torchvision 4 | from torch.utils.data import DataLoader 5 | from torchvision.transforms import Compose, ToTensor, Normalize 6 | 7 | from examples.mnist.models import SimpleConvNet 8 | from kerosene.configs.configs import RunConfiguration 9 | from kerosene.configs.parsers import YamlConfigurationParser 10 | from kerosene.events.handlers.console import PrintTrainingStatus 11 | from kerosene.events.handlers.visdom import PlotMonitors, PlotAvgGradientPerLayer 12 | from kerosene.loggers.visdom.config import VisdomConfiguration 13 | from kerosene.loggers.visdom.visdom import VisdomLogger 14 | from kerosene.training.events import Event 15 | from kerosene.training.trainers import ModelTrainerFactory, SimpleTrainer 16 | 17 | if __name__ == "__main__": 18 | logging.basicConfig(level=logging.INFO) 19 | CONFIG_FILE_PATH = "config.yml" 20 | 21 | model_trainer_config, training_config = YamlConfigurationParser.parse(CONFIG_FILE_PATH) 22 | 23 | train_loader = DataLoader(torchvision.datasets.MNIST('./files/', train=True, download=True, transform=Compose( 24 | [ToTensor(), Normalize((0.1307,), (0.3081,))])), batch_size=training_config.batch_size_train, shuffle=True) 25 | 26 | test_loader = DataLoader(torchvision.datasets.MNIST('./files/', train=False, download=True, transform=Compose( 27 | [ToTensor(), Normalize((0.1307,), (0.3081,))])), batch_size=training_config.batch_size_valid, shuffle=True) 28 | 29 | visdom_logger = VisdomLogger(VisdomConfiguration.from_yml(CONFIG_FILE_PATH)) 30 | 31 | # Initialize the model trainers 32 | model_trainer = ModelTrainerFactory(model=SimpleConvNet()).create(model_trainer_config) 33 | 34 | # Train with the training strategy 35 | SimpleTrainer("MNIST Trainer", train_loader, test_loader, None, model_trainer, RunConfiguration(use_amp=False)) \ 36 | .with_event_handler(PlotMonitors(every=500, visdom_logger=visdom_logger), Event.ON_BATCH_END) \ 37 | .with_event_handler(PlotMonitors(visdom_logger=visdom_logger), Event.ON_EPOCH_END) \ 38 | .with_event_handler(PlotAvgGradientPerLayer(every=500, visdom_logger=visdom_logger), Event.ON_TRAIN_BATCH_END) \ 39 | .with_event_handler(PrintTrainingStatus(every=100), Event.ON_BATCH_END) \ 40 | .train(training_config.nb_epochs) 41 | -------------------------------------------------------------------------------- /kerosene/optim/schedulers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2019 Kerosene Authors. All Rights Reserved. 3 | # 4 | # Licensed under the MIT License; 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://opensource.org/licenses/MIT 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | from enum import Enum 17 | from typing import Union 18 | 19 | from torch.optim import Optimizer 20 | from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR, MultiStepLR, ExponentialLR, CosineAnnealingLR, \ 21 | _LRScheduler 22 | 23 | 24 | class SchedulerType(Enum): 25 | ReduceLROnPlateau = "ReduceLROnPlateau" 26 | StepLR = "StepLR" 27 | MultiStepLR = "MultiStepLR" 28 | ExponentialLR = "ExponentialLR" 29 | CosineAnnealingLR = "CosineAnnealingLR" 30 | CyclicLR = "CyclicLR" 31 | CosineAnnealingWarmRestarts = "CosineAnnealingWarmRestarts" 32 | 33 | 34 | class SchedulerFactory(object): 35 | 36 | def __init__(self): 37 | self._schedulers = { 38 | "ReduceLROnPlateau": ReduceLROnPlateau, 39 | "StepLR": StepLR, 40 | "MultiStepLR": MultiStepLR, 41 | "ExponentialLR": ExponentialLR, 42 | "CosineAnnealingLR": CosineAnnealingLR, 43 | } 44 | 45 | def create(self, scheduler_type: Union[str, SchedulerType], optimizer: Optimizer, params): 46 | return self._schedulers[str(scheduler_type)](optimizer=optimizer, **params) if params is not None else \ 47 | self._schedulers[str(scheduler_type)](optimizer=optimizer) 48 | 49 | def register(self, function: str, creator: _LRScheduler): 50 | """ 51 | Add a new activation layer. 52 | Args: 53 | function (str): Activation layer name. 54 | creator (:obj:`torch.nn.Module`): A torch module object wrapping the new custom optimizer function. 55 | """ 56 | self._schedulers[function] = creator 57 | -------------------------------------------------------------------------------- /examples/mnist/main_hyperopt.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | 4 | import torchvision 5 | from hyperopt import fmin, tpe, hp, STATUS_OK 6 | from torch.utils.data import DataLoader 7 | from torchvision.transforms import Compose, ToTensor, Normalize 8 | 9 | from examples.mnist.models import SimpleConvNet 10 | from kerosene.configs.configs import RunConfiguration 11 | from kerosene.configs.parsers import YamlConfigurationParser 12 | from kerosene.events import Monitor, Phase 13 | from kerosene.events.handlers.console import PrintTrainingStatus 14 | from kerosene.training.events import Event 15 | from kerosene.training.trainers import ModelTrainerFactory, SimpleTrainer 16 | 17 | 18 | def objective(hyper_params): 19 | # Update the trainer with the new hyper-parameters 20 | model_config.update(hyper_params) 21 | 22 | # Create the model trainer 23 | model_trainer = ModelTrainerFactory(model=SimpleConvNet()).create(model_config) 24 | 25 | # Train with the training strategy 26 | monitor = SimpleTrainer("MNIST Trainer", train_loader, valid_loader, None, model_trainer, 27 | RunConfiguration(use_amp=False)) \ 28 | .with_event_handler(PrintTrainingStatus(every=100), Event.ON_BATCH_END) \ 29 | .train(training_config.nb_epochs) 30 | 31 | return {'loss': monitor["SimpleNet"][Phase.VALIDATION][Monitor.LOSS]["CrossEntropy"], 'status': STATUS_OK} 32 | 33 | 34 | if __name__ == "__main__": 35 | logging.basicConfig(level=logging.INFO) 36 | CONFIG_FILE_PATH = "config.yml" 37 | 38 | model_config, training_config = YamlConfigurationParser.parse(CONFIG_FILE_PATH) 39 | 40 | train_loader = DataLoader(torchvision.datasets.MNIST('./files/', train=True, download=True, transform=Compose( 41 | [ToTensor(), Normalize((0.1307,), (0.3081,))])), batch_size=training_config.batch_size_train, shuffle=True) 42 | 43 | valid_loader = DataLoader(torchvision.datasets.MNIST('./files/', train=False, download=True, transform=Compose( 44 | [ToTensor(), Normalize((0.1307,), (0.3081,))])), batch_size=training_config.batch_size_valid, shuffle=True) 45 | 46 | search_space = { 47 | 'SimpleNet': {'optimizer': {'params': {'lr': hp.loguniform('lr', math.log(0.0005), math.log(0.01))}}} 48 | } 49 | 50 | best = fmin(objective, space=search_space, algo=tpe.suggest, max_evals=2) 51 | 52 | print("The best hyper-parameters are: {}".format(best)) 53 | -------------------------------------------------------------------------------- /kerosene/events/handlers/checkpoints.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from typing import Callable, List, Union 4 | 5 | import torch 6 | 7 | from kerosene.events import MonitorMode, TemporalEvent, Phase, Monitor 8 | from kerosene.events.handlers.base_monitor_watcher import MonitorWatcher, MonitorPatienceExceeded 9 | from kerosene.training.events import Event 10 | from kerosene.training.trainers import Trainer 11 | from kerosene.utils.constants import CHECKPOINT_EXT 12 | from kerosene.utils.files import should_create_dir 13 | 14 | 15 | class Checkpoint(MonitorWatcher): 16 | LOGGER = logging.getLogger("Checkpoint") 17 | SUPPORTED_EVENTS = [Event.ON_EPOCH_END] 18 | 19 | def __init__(self, path: str, model_name: Union[str, None], monitor_names: Union[List[str], str], delta: float, 20 | mode: MonitorMode, reduction: Callable = torch.mean): 21 | super(Checkpoint, self).__init__(mode, delta, patience=0, supported_events=self.SUPPORTED_EVENTS) 22 | self._path = path 23 | self._model_name = model_name 24 | self._reduction = reduction 25 | self._monitors_names = monitor_names if isinstance(monitor_names, list) else [monitor_names] 26 | 27 | def __call__(self, event: TemporalEvent, monitors: dict, trainer: Trainer): 28 | if self.should_handle(event): 29 | try: 30 | model_monitors = {**monitors[self._model_name][Phase.VALIDATION][Monitor.METRICS], 31 | **monitors[self._model_name][Phase.VALIDATION][Monitor.LOSS]} 32 | values = torch.cat([model_monitors[monitor_name] for monitor_name in self._monitors_names]) 33 | 34 | self.watch(self._model_name, self._reduction(values)) 35 | except MonitorPatienceExceeded as e: 36 | self._save(trainer.epoch, self._model_name, trainer.model_trainers[self._model_name].model_state, 37 | trainer.model_trainers[self._model_name].optimizer_state) 38 | except KeyError as e: 39 | self.LOGGER.warning("Invalid model or monitor name: {}".format(e)) 40 | 41 | def _save(self, epoch_num, model_name, model_state, optimizer_states): 42 | if should_create_dir(self._path, model_name): 43 | os.makedirs(os.path.join(self._path, model_name)) 44 | torch.save({"epoch_num": epoch_num, 45 | "model_state_dict": model_state, 46 | "optimizer_state_dict": optimizer_states}, 47 | os.path.join(self._path, model_name, model_name + CHECKPOINT_EXT)) 48 | -------------------------------------------------------------------------------- /kerosene/loggers/visdom/visdom.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from abc import ABC, abstractmethod 3 | from typing import Union, List, Dict 4 | 5 | from visdom import Visdom 6 | 7 | from kerosene.loggers.visdom.config import VisdomConfiguration 8 | from kerosene.loggers.visdom.data import VisdomData 9 | from kerosene.loggers.visdom.plots import VisdomPlotFactory, VisdomPlot 10 | 11 | 12 | class BaseVisdomLogger(ABC): 13 | 14 | def __init__(self, visdom_config: VisdomConfiguration): 15 | self._visdom_config = visdom_config 16 | 17 | @abstractmethod 18 | def __call__(self, visdom_data: Union[List[VisdomData], VisdomData]): 19 | raise NotImplementedError() 20 | 21 | @property 22 | def visdom_config(self): 23 | return self._visdom_config 24 | 25 | 26 | class VisdomLogger(BaseVisdomLogger): 27 | LOGGER = logging.getLogger("VisdomLogger") 28 | 29 | def __init__(self, visdom_config: VisdomConfiguration, plot_factory: VisdomPlotFactory = VisdomPlotFactory()): 30 | super().__init__(visdom_config) 31 | self._plots: Dict[int, VisdomPlot] = {} 32 | self._plot_factory = plot_factory 33 | self._visdom_config = visdom_config 34 | 35 | self._visdom = Visdom(server=visdom_config.server, port=visdom_config.port, env=visdom_config.env, 36 | log_to_filename=visdom_config.filename, offline=visdom_config.offline) 37 | 38 | @property 39 | def plots(self): 40 | return self._plots 41 | 42 | def save(self): 43 | self._visdom.save(envs=[self._visdom_config.env]) 44 | 45 | def __call__(self, visdom_data: Union[List[VisdomData], VisdomData]): 46 | visdom_data = [visdom_data] if not hasattr(visdom_data, '__iter__') else visdom_data 47 | 48 | for visdom_datum in visdom_data: 49 | if visdom_datum.id not in self._plots.keys(): 50 | self._plots[visdom_datum.id] = self._plot_factory.create_plot(self._visdom, visdom_datum.plot_type) 51 | 52 | self._plots[visdom_datum.id].update(visdom_datum) 53 | 54 | 55 | class DummyVisdomLogger(BaseVisdomLogger): 56 | LOGGER = logging.getLogger("DummyVisdomLogger") 57 | 58 | def __init__(self, visdom_config: VisdomConfiguration = None): 59 | super().__init__(visdom_config) 60 | 61 | def __call__(self, visdom_data: Union[List[VisdomData], VisdomData]): 62 | if not isinstance(visdom_data, list): 63 | visdom_data = [visdom_data] 64 | 65 | self.LOGGER.debug("\n".join(list(map(lambda visdom_datum: str(visdom_datum), visdom_data)))) 66 | -------------------------------------------------------------------------------- /tests/nn/test_apex.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | from hamcrest import assert_that, equal_to 5 | 6 | from kerosene.nn.apex import ApexLoss 7 | 8 | 9 | class TestApexLoss(unittest.TestCase): 10 | VALID_LOSS_VALUE = torch.tensor([1.0, 2.0, 3.0]) 11 | INVALID_LOSS_VALUE = "Ford Fiesta" 12 | 13 | VALID_SCALAR = 2.0 14 | INVALID_SCALAR = "Ford Focus" 15 | 16 | VALID_LOSS_ID = 0 17 | INVALID_LOSS_ID = "Ford F150" 18 | 19 | def test_should_add_losses(self): 20 | expected_result = ApexLoss(self.VALID_LOSS_ID, torch.tensor([2.0, 4.0, 6.0]), None) 21 | 22 | apex_loss1 = ApexLoss(self.VALID_LOSS_ID, self.VALID_LOSS_VALUE, None) 23 | apex_loss2 = ApexLoss(self.VALID_LOSS_ID, self.VALID_LOSS_VALUE, None) 24 | 25 | assert_that(apex_loss1 + apex_loss2, equal_to(expected_result)) 26 | 27 | def test_should_multiply_by_a_scalar(self): 28 | expected_result = ApexLoss(self.VALID_LOSS_ID, self.VALID_SCALAR * self.VALID_LOSS_VALUE, None) 29 | 30 | loss = ApexLoss(self.VALID_LOSS_ID, self.VALID_LOSS_VALUE, None) 31 | 32 | assert_that(self.VALID_SCALAR * loss, equal_to(expected_result)) 33 | assert_that(loss * self.VALID_SCALAR, equal_to(expected_result)) 34 | 35 | # noinspection PyTypeChecker 36 | def test_should_divide_by_a_scalar(self): 37 | left_div_expected_result = ApexLoss(self.VALID_LOSS_ID, self.VALID_LOSS_VALUE / self.VALID_SCALAR, 38 | None) 39 | right__div_expected_result = ApexLoss(self.VALID_LOSS_ID, self.VALID_SCALAR / self.VALID_LOSS_VALUE, 40 | None) 41 | 42 | loss = ApexLoss(self.VALID_LOSS_ID, self.VALID_LOSS_VALUE, None) 43 | 44 | assert_that(loss / self.VALID_SCALAR, equal_to(left_div_expected_result)) 45 | assert_that(self.VALID_SCALAR / loss, equal_to(right__div_expected_result)) 46 | 47 | def test_should_compute_the_mean(self): 48 | expected_result = ApexLoss(self.VALID_LOSS_ID, torch.tensor([2.0]), None) 49 | 50 | loss = ApexLoss(self.VALID_LOSS_ID, self.VALID_LOSS_VALUE, None) 51 | 52 | assert_that(loss.mean(), equal_to(expected_result)) 53 | 54 | # noinspection PyTypeChecker 55 | def test_should_substract_a_scalar(self): 56 | left_sub_expected_result = ApexLoss(self.VALID_LOSS_ID, self.VALID_LOSS_VALUE - self.VALID_SCALAR, None) 57 | right_sub_expected_result = ApexLoss(self.VALID_LOSS_ID, self.VALID_SCALAR - self.VALID_LOSS_VALUE, None) 58 | 59 | loss = ApexLoss(self.VALID_LOSS_ID, self.VALID_LOSS_VALUE, None) 60 | 61 | assert_that(loss - self.VALID_SCALAR, equal_to(left_sub_expected_result)) 62 | assert_that(self.VALID_SCALAR - loss, equal_to(right_sub_expected_result)) 63 | -------------------------------------------------------------------------------- /tests/functionals/distributed/mnist_trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2019 Pierre-Luc Delisle. All Rights Reserved. 3 | # 4 | # Licensed under the MIT License; 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://opensource.org/licenses/MIT 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | from typing import List 18 | 19 | import torch 20 | from torch.utils.data import DataLoader 21 | 22 | from kerosene.configs.configs import RunConfiguration 23 | from kerosene.training.trainers import ModelTrainer 24 | from kerosene.training.trainers import Trainer 25 | from kerosene.utils.devices import on_single_device 26 | 27 | 28 | class MNISTTrainer(Trainer): 29 | 30 | def __init__(self, training_config, model_trainers: List[ModelTrainer], 31 | train_data_loader: DataLoader, valid_data_loader: DataLoader, run_config: RunConfiguration): 32 | super(MNISTTrainer, self).__init__("MNISTTrainer", train_data_loader, valid_data_loader, model_trainers, 33 | run_config) 34 | 35 | self._training_config = training_config 36 | 37 | def train_step(self, inputs, target): 38 | model = self._model_trainers[0] 39 | 40 | pred = model.forward(inputs) 41 | model.compute_train_metric(pred, target) 42 | loss = model.compute_and_update_train_loss(pred, target) 43 | 44 | model.zero_grad() 45 | loss.backward() 46 | 47 | if not on_single_device(self._run_config.devices): 48 | self.average_gradients(model) 49 | 50 | model.step() 51 | 52 | def validate_step(self, inputs, target): 53 | model = self._model_trainers[0] 54 | 55 | pred = model.forward(inputs) 56 | model.compute_valid_metric(pred, target) 57 | model.compute_and_update_valid_loss(pred, target) 58 | 59 | def scheduler_step(self): 60 | self._model_trainers[0].scheduler_step() 61 | 62 | @staticmethod 63 | def average_gradients(model): 64 | size = float(torch.distributed.get_world_size()) 65 | for param in model.parameters(): 66 | torch.distributed.all_reduce(param.grad.data, op=torch.distributed.ReduceOp.SUM) 67 | param.grad.data /= size 68 | 69 | def on_epoch_begin(self): 70 | pass 71 | 72 | def on_epoch_end(self): 73 | pass 74 | -------------------------------------------------------------------------------- /kerosene/events/handlers/base_monitor_watcher.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from typing import Dict, List, Union 3 | 4 | from kerosene.events import MonitorMode, BaseEvent, TemporalEvent 5 | from kerosene.events.handlers.base_handler import EventHandler 6 | 7 | 8 | class MonitorPatienceExceeded(Exception): 9 | pass 10 | 11 | 12 | class MonitorInspection(object): 13 | def __init__(self, value=0, inspection_num=0): 14 | self._value = value 15 | self._inspection_num = inspection_num 16 | 17 | @property 18 | def value(self): 19 | return self._value 20 | 21 | @value.setter 22 | def value(self, new_value): 23 | self._value = new_value 24 | 25 | @property 26 | def inspection_num(self): 27 | return self._inspection_num 28 | 29 | def add_inspection(self): 30 | self._inspection_num = self._inspection_num + 1 31 | 32 | def reset_inspection_num(self): 33 | self._inspection_num = 0 34 | return self 35 | 36 | def with_value(self, value): 37 | self._value = value 38 | return self 39 | 40 | 41 | class MonitorWatcher(EventHandler, ABC): 42 | def __init__(self, mode: MonitorMode, min_delta, patience, every=1, 43 | supported_events: List[Union[BaseEvent, TemporalEvent]] = None): 44 | super().__init__(supported_events, every) 45 | 46 | self._mode = mode 47 | self._min_delta = min_delta 48 | self._patience = patience 49 | self._monitor_values: Dict[str, MonitorInspection] = {} 50 | 51 | @property 52 | def mode(self): 53 | return self._mode 54 | 55 | @property 56 | def min_delta(self): 57 | return self._min_delta 58 | 59 | @property 60 | def patience(self): 61 | return self._patience 62 | 63 | @property 64 | def monitor_values(self): 65 | return self._monitor_values 66 | 67 | def watch(self, source_name, current_monitor_value): 68 | if source_name not in self._monitor_values.keys(): 69 | self._monitor_values[source_name] = MonitorInspection(value=current_monitor_value) 70 | else: 71 | if self._mode is MonitorMode.MIN: 72 | delta = self._monitor_values[source_name].value - current_monitor_value 73 | else: 74 | delta = current_monitor_value - self._monitor_values[source_name].value 75 | 76 | if delta <= self._min_delta: 77 | self._monitor_values[source_name].with_value(current_monitor_value).reset_inspection_num() 78 | else: 79 | self._monitor_values[source_name].add_inspection() 80 | if self._monitor_values[source_name].inspection_num >= self._patience: 81 | raise MonitorPatienceExceeded() 82 | -------------------------------------------------------------------------------- /kerosene/optim/optimizers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2019 Kerosene Authors. All Rights Reserved. 3 | # 4 | # Licensed under the MIT License; 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://opensource.org/licenses/MIT 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | from enum import Enum 17 | from typing import Union 18 | 19 | from torch import optim 20 | from torch.optim import Optimizer 21 | 22 | try: 23 | from apex.optimizers import FusedSGD, FusedAdam 24 | 25 | APEX_AVAILABLE = True 26 | except ModuleNotFoundError: 27 | APEX_AVAILABLE = False 28 | except ImportError as e: 29 | APEX_AVAILABLE = False 30 | print("Unable to import apex optimizers please upgrade your apex version".format(e)) 31 | 32 | 33 | class OptimizerType(Enum): 34 | FusedSGD = "FusedSGD" 35 | FusedAdam = "FusedAdam" 36 | SGD = "SGD" 37 | Adam = "Adam" 38 | Adagrad = "Adagrad" 39 | Adadelta = "Adadelta" 40 | SparseAdam = "SparseAdam" 41 | Adamax = "Adamax" 42 | Rprop = "Rprop" 43 | RMSprop = "RMSprop" 44 | ASGD = "ASGD" 45 | 46 | def __str__(self): 47 | return self.value 48 | 49 | 50 | class OptimizerFactory(object): 51 | 52 | def __init__(self): 53 | self._optimizers = { 54 | "Adam": optim.Adam, 55 | "SGD": optim.SGD, 56 | "RMSprop": optim.RMSprop, 57 | "Adagrad": optim.Adagrad, 58 | "Adadelta": optim.Adadelta, 59 | "SparseAdam": optim.SparseAdam, 60 | "Adamax": optim.Adamax, 61 | "Rprop": optim.Rprop, 62 | "ASGD": optim.ASGD 63 | } 64 | 65 | if APEX_AVAILABLE: 66 | self.register("FusedSGD", FusedSGD) 67 | self.register("FusedAdam", FusedAdam) 68 | 69 | def create(self, optimizer_type: Union[str, OptimizerType], model_params, params): 70 | return self._optimizers[str(optimizer_type)](model_params, **params) if params is not None else \ 71 | self._optimizers[str(optimizer_type)](model_params) 72 | 73 | def register(self, function: str, creator: Optimizer): 74 | """ 75 | Add a new activation layer. 76 | Args: 77 | function (str): Activation layer name. 78 | creator (:obj:`torch.nn.Module`): A torch module object wrapping the new custom optimizer function. 79 | """ 80 | self._optimizers[function] = creator 81 | -------------------------------------------------------------------------------- /kerosene/nn/utils/gradients.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2019 Kerosene Authors. All Rights Reserved. 3 | # 4 | # Licensed under the MIT License; 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://opensource.org/licenses/MIT 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | import abc 18 | from enum import Enum 19 | from typing import Union, Callable, Iterable 20 | 21 | import torch 22 | from torch.nn.utils import clip_grad_norm_, clip_grad_value_ 23 | 24 | 25 | class GradientClippingStrategyType(Enum): 26 | Value = "value" 27 | Norm = "norm" 28 | 29 | def __str__(self): 30 | return self.value 31 | 32 | 33 | class GradientClippingStrategy(metaclass=abc.ABCMeta): 34 | 35 | @abc.abstractmethod 36 | def clip(self, model_parameters: Union[Iterable[torch.Tensor], torch.Tensor]): 37 | raise NotImplementedError() 38 | 39 | 40 | class GradientNormClipping(GradientClippingStrategy): 41 | 42 | def __init__(self, max_norm: float, norm_type: Union[int, float] = 2): 43 | """" 44 | Gradient norm clipping strategy. 45 | 46 | Args: 47 | max_norm (float or int): max norm of the gradients 48 | norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. 49 | """ 50 | self._max_norm = max_norm 51 | self._norm_type = norm_type 52 | 53 | def clip(self, model_parameters): 54 | clip_grad_norm_(model_parameters, self._max_norm, self._norm_type) 55 | 56 | 57 | class GradientValueClipping(GradientClippingStrategy): 58 | 59 | def __init__(self, clip_value: Union[int, float]): 60 | """" 61 | Gradient value clipping strategy. 62 | 63 | Args: 64 | clip_value (float or int): maximum allowed value of the gradients. The gradients are clipped in the range 65 | :math:`\left[\text{-clip\_value}, \text{clip\_value}\right]` 66 | """ 67 | self._clip_value = clip_value 68 | 69 | def clip(self, model_parameters: Union[Iterable[torch.Tensor], torch.Tensor]): 70 | clip_grad_value_(model_parameters, self._clip_value) 71 | 72 | 73 | class GradientClippingStrategyFactory(object): 74 | def __init__(self): 75 | self._clipping_strategies = { 76 | "value": GradientValueClipping, 77 | "norm": GradientNormClipping 78 | } 79 | 80 | def create(self, clipping_strategy: Union[str, GradientClippingStrategyType], params): 81 | if clipping_strategy is not None: 82 | return self._clipping_strategies[str(clipping_strategy)](**params) 83 | else: 84 | return None 85 | 86 | def register(self, function: str, creator: Callable): 87 | self._clipping_strategies[function] = creator 88 | -------------------------------------------------------------------------------- /kerosene/configs/parsers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2019 Kerosene Authors. All Rights Reserved. 3 | # 4 | # Licensed under the MIT License; 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://opensource.org/licenses/MIT 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | import logging 17 | 18 | import torch 19 | import yaml 20 | 21 | from kerosene.configs.configs import ModelConfiguration, TrainerConfiguration, ConfigurationList 22 | 23 | 24 | class CustomYamlParser(object): 25 | 26 | def __init__(self): 27 | yaml.SafeLoader.add_constructor(u"!torch/tensor", CustomYamlParser.parse_tensor) 28 | yaml.SafeLoader.add_constructor(u"!python/tuple", CustomYamlParser.parse_tuple) 29 | 30 | @staticmethod 31 | def safe_load(file): 32 | return yaml.safe_load(file) 33 | 34 | @staticmethod 35 | def parse_tensor(loader, node): 36 | value = loader.construct_sequence(node, deep=True) 37 | tensor = torch.Tensor().new_tensor(value) 38 | return tensor 39 | 40 | @staticmethod 41 | def parse_tuple(loader, node): 42 | value = loader.construct_sequence(node, deep=True) 43 | tuple_ = tuple(value) 44 | return tuple_ 45 | 46 | 47 | class YamlConfigurationParser(object): 48 | LOGGER = logging.getLogger("YamlConfigurationParser") 49 | 50 | @staticmethod 51 | def parse(config_file_path): 52 | with open(config_file_path, 'r') as config_file: 53 | try: 54 | config = CustomYamlParser().safe_load(config_file) 55 | 56 | model_trainer_configs = ConfigurationList( 57 | [ModelConfiguration.from_dict(model_name, config["models"][model_name]) for 58 | model_name in config["models"].keys()]) 59 | 60 | model_trainer_configs = model_trainer_configs if len(model_trainer_configs) > 1 else \ 61 | model_trainer_configs[0] 62 | training_config = TrainerConfiguration(config['training']) 63 | return model_trainer_configs, training_config 64 | except yaml.YAMLError as e: 65 | YamlConfigurationParser.LOGGER.warning( 66 | "Unable to read the training config file: {} with error {}".format(config_file_path, e)) 67 | 68 | @staticmethod 69 | def parse_section(config_file_path, yml_tag): 70 | with open(config_file_path, 'r') as config_file: 71 | try: 72 | config = CustomYamlParser().safe_load(config_file) 73 | 74 | return config[yml_tag] 75 | except yaml.YAMLError as e: 76 | YamlConfigurationParser.LOGGER.warning( 77 | "Unable to read the training config file: {} with error {}".format(config_file_path, e)) 78 | -------------------------------------------------------------------------------- /tests/events/handlers/test_console.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import unittest 3 | 4 | import torch 5 | from mockito import verify, spy 6 | 7 | from kerosene.events import Phase, Monitor 8 | from kerosene.events.handlers.console import PrintTrainingStatus, StatusConsoleColorPalette, MonitorsTable 9 | 10 | 11 | class ConsoleHandlerTest(unittest.TestCase): 12 | 13 | def setUp(self) -> None: 14 | logging.basicConfig(level=logging.INFO) 15 | self._logger_mock = spy(logging.getLogger()) 16 | 17 | def test_print_status_with_colors(self): 18 | expected_string = "\nCurrent states: Training | Epoch: 0 | Training step: 0 | Validation step: 0 | Test step: 0\n" 19 | handler = PrintTrainingStatus(colors=StatusConsoleColorPalette.DEFAULT) 20 | handler.LOGGER = self._logger_mock 21 | handler.print_status(Phase.TRAINING, 0, 0, 0, 0) 22 | verify(self._logger_mock).info(expected_string) 23 | 24 | def test_print_status_without_colors(self): 25 | expected_string = "\nCurrent states: Training\x1b[0m | Epoch: 0 | Training step: 0 | Validation step: 0 | Test step: 0\n" 26 | handler = PrintTrainingStatus() 27 | handler.LOGGER = self._logger_mock 28 | handler.print_status(Phase.TRAINING, 0, 0, 0, 0) 29 | verify(self._logger_mock).info(expected_string) 30 | 31 | def test_print_monitors_table(self): 32 | monitors = { 33 | "Model 1": {Phase.TRAINING: {Monitor.METRICS: {"Accuracy": torch.tensor([0.6])}, 34 | Monitor.LOSS: {"MSELoss": torch.tensor([0.6])}}, 35 | Phase.VALIDATION: {Monitor.METRICS: {"Accuracy": torch.tensor([0.6])}, 36 | Monitor.LOSS: {"MSELoss": torch.tensor([0.6])}}, 37 | Phase.TEST: {Monitor.METRICS: {"Accuracy": torch.tensor([0.6])}, 38 | Monitor.LOSS: {"MSELoss": torch.tensor([0.6])}}}} 39 | 40 | training_values = {**monitors["Model 1"][Phase.TRAINING][Monitor.METRICS], 41 | **monitors["Model 1"][Phase.TRAINING][Monitor.LOSS]} 42 | validation_values = {**monitors["Model 1"][Phase.VALIDATION][Monitor.METRICS], 43 | **monitors["Model 1"][Phase.VALIDATION][Monitor.LOSS]} 44 | 45 | table = MonitorsTable("Model 1", 20) 46 | table.update(training_values, validation_values) 47 | table.show() 48 | 49 | monitors = { 50 | "Model 1": {Phase.TRAINING: {Monitor.METRICS: {"Accuracy": torch.tensor([0.7])}, 51 | Monitor.LOSS: {"MSELoss": torch.tensor([0.7])}}, 52 | Phase.VALIDATION: {Monitor.METRICS: {"Accuracy": torch.tensor([0.4])}, 53 | Monitor.LOSS: {"MSELoss": torch.tensor([0.6])}}, 54 | Phase.TEST: {Monitor.METRICS: {"Accuracy": torch.tensor([0.4])}, 55 | Monitor.LOSS: {"MSELoss": torch.tensor([0.8])}}}} 56 | 57 | training_values = {**monitors["Model 1"][Phase.TRAINING][Monitor.METRICS], 58 | **monitors["Model 1"][Phase.TRAINING][Monitor.LOSS]} 59 | validation_values = {**monitors["Model 1"][Phase.VALIDATION][Monitor.METRICS], 60 | **monitors["Model 1"][Phase.VALIDATION][Monitor.LOSS]} 61 | 62 | table.update(training_values, validation_values) 63 | table.show() 64 | -------------------------------------------------------------------------------- /tests/functionals/distributed/main.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | from argparse import ArgumentParser 5 | 6 | from kerosene.configs.configs import RunConfiguration 7 | 8 | sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)) + '/../../../') 9 | import torchvision 10 | from torchvision.transforms import Compose, ToTensor, Normalize 11 | 12 | from kerosene.configs.parsers import YamlConfigurationParser 13 | from kerosene.events import Event 14 | from kerosene.dataloaders.factories import DataloaderFactory 15 | from kerosene.events.handlers.console import ConsoleLogger 16 | from kerosene.events.handlers.visdom.config import VisdomConfiguration 17 | from kerosene.events.handlers.visdom.visdom import VisdomLogger 18 | from kerosene.events.preprocessors.visdom import PlotAllModelStateVariables 19 | from kerosene.training.trainers import ModelTrainerFactory 20 | from tests.functionals.models import SimpleNet 21 | from tests.functionals.distributed.mnist_trainer import MNISTTrainer 22 | 23 | 24 | class ArgsParserFactory(object): 25 | 26 | @staticmethod 27 | def create_parser(): 28 | parser = ArgumentParser(description='DeepNormalize Training') 29 | parser.add_argument("--use_amp", dest="use_amp", action="store_true", default=True) 30 | parser.add_argument("--amp-opt-level", dest="amp_opt_level", type=str, default="O2", 31 | help="O0 - FP32 training, O1 - Mixed Precision (recommended), O2 - Almost FP16 Mixed Precision, O3 - FP16 Training.") 32 | parser.add_argument("--local_rank", dest="local_rank", default=0, type=int, help="The local_rank of the GPU.") 33 | return parser 34 | 35 | 36 | if __name__ == '__main__': 37 | logging.basicConfig(level=logging.INFO) 38 | CONFIG_FILE_PATH = "config.yml" 39 | args = ArgsParserFactory.create_parser().parse_args() 40 | run_config = RunConfiguration(args.use_amp, args.amp_opt_level, args.local_rank) 41 | 42 | model_trainer_config, training_config = YamlConfigurationParser.parse(CONFIG_FILE_PATH) 43 | 44 | # Initialize the dataset. This is the only part the user must define manually. 45 | train_dataset = torchvision.datasets.MNIST('./files/', train=True, download=True, transform=Compose( 46 | [ToTensor(), Normalize((0.1307,), (0.3081,))])) 47 | test_dataset = torchvision.datasets.MNIST('./files/', train=False, download=True, transform=Compose( 48 | [ToTensor(), Normalize((0.1307,), (0.3081,))])) 49 | 50 | # Initialize loaders. 51 | train_loader, valid_loader = DataloaderFactory(train_dataset, test_dataset).create(run_config, training_config) 52 | 53 | # Initialize the loggers. 54 | if run_config.local_rank == 0: 55 | visdom_logger = VisdomLogger(VisdomConfiguration.from_yml(CONFIG_FILE_PATH)) 56 | 57 | # Initialize the model trainers. 58 | model_trainer = ModelTrainerFactory(model=SimpleNet()).create(model_trainer_config, run_config) 59 | 60 | # Train with the training strategy. 61 | if run_config.local_rank == 0: 62 | trainer = MNISTTrainer(training_config, model_trainer, train_loader, valid_loader, run_config) \ 63 | .with_event_handler(ConsoleLogger(), Event.ON_EPOCH_END) \ 64 | .with_event_handler(visdom_logger, Event.ON_EPOCH_END, PlotAllModelStateVariables()) \ 65 | .train(training_config.nb_epochs) 66 | else: 67 | trainer = MNISTTrainer(training_config, model_trainer, train_loader, valid_loader, run_config) \ 68 | .train(training_config.nb_epochs) 69 | -------------------------------------------------------------------------------- /kerosene/events/__init__.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from enum import Enum 3 | 4 | 5 | class Phase(Enum): 6 | TRAINING = "Training" 7 | VALIDATION = "Validation" 8 | TEST = "Test" 9 | ALL = [TRAINING, VALIDATION, TEST] 10 | 11 | def __str__(self): 12 | return self.value 13 | 14 | def __eq__(self, other): 15 | if isinstance(other, str): 16 | return self.value == other 17 | elif isinstance(other, Phase): 18 | return self.value == other.value 19 | 20 | def __hash__(self): 21 | return hash(self.value) 22 | 23 | 24 | class Frequency(Enum): 25 | STEP = "Step" 26 | EPOCH = "Epoch" 27 | PHASE = "Phase" 28 | 29 | def __str__(self): 30 | return self.value 31 | 32 | def __eq__(self, other): 33 | if isinstance(other, str): 34 | return self.value == other 35 | elif isinstance(other, Frequency): 36 | return self.value == other.value 37 | 38 | def __hash__(self): 39 | return hash(self.value) 40 | 41 | 42 | class Moment(object): 43 | def __init__(self, iteration, frequency: Frequency, phase: Phase, time: datetime = None): 44 | self._datetime = datetime.now() if time is None else time 45 | self._frequency = frequency 46 | self._iteration = iteration 47 | self._phase = phase 48 | 49 | @property 50 | def datetime(self): 51 | return self._datetime 52 | 53 | @datetime.setter 54 | def datetime(self, value): 55 | self._datetime = value 56 | 57 | @property 58 | def frequency(self): 59 | return self._frequency 60 | 61 | @frequency.setter 62 | def frequency(self, value): 63 | self._frequency = value 64 | 65 | @property 66 | def iteration(self): 67 | return self._iteration 68 | 69 | @iteration.setter 70 | def iteration(self, value): 71 | self._iteration = value 72 | 73 | @property 74 | def phase(self): 75 | return self._phase 76 | 77 | @phase.setter 78 | def phase(self, value): 79 | self._phase = value 80 | 81 | 82 | class BaseEvent(Enum): 83 | def __init__(self, value): 84 | self._value_ = value 85 | 86 | def __str__(self): 87 | return self.value 88 | 89 | def __eq__(self, other): 90 | if isinstance(other, str): 91 | return self.value == other 92 | elif isinstance(other, BaseEvent): 93 | return self.value == other.value 94 | 95 | def __hash__(self): 96 | return hash(self.value) 97 | 98 | 99 | class TemporalEvent(object): 100 | def __init__(self, event: BaseEvent, moment: Moment): 101 | self._event = event 102 | self._moment = moment 103 | 104 | @property 105 | def event(self): 106 | return self._event 107 | 108 | @property 109 | def frequency(self): 110 | return self._moment.frequency 111 | 112 | @property 113 | def phase(self): 114 | return self._moment.phase 115 | 116 | @property 117 | def datetime(self): 118 | return self._moment.datetime 119 | 120 | @property 121 | def iteration(self): 122 | return self._moment.iteration 123 | 124 | def __eq__(self, other): 125 | if isinstance(other, BaseEvent): 126 | return self.event == other 127 | elif isinstance(other, TemporalEvent): 128 | return self._event == other.event and self.frequency == other.frequency and \ 129 | self.phase == other.phase and self.datetime == other.datetime and self.iteration == other.iteration 130 | 131 | def __str__(self): 132 | return str(self.event) 133 | 134 | 135 | class BaseVariable(Enum): 136 | def __str__(self): 137 | return self.value 138 | 139 | def __eq__(self, other): 140 | if isinstance(other, str): 141 | return self.value == other 142 | elif isinstance(other, BaseVariable): 143 | return self.value == other.value 144 | 145 | def __hash__(self): 146 | return hash(self.value) 147 | 148 | 149 | class Monitor(BaseVariable): 150 | LOSS = "loss" 151 | METRICS = "metrics" 152 | 153 | TRAIN_LOSS = "train_loss" 154 | VALID_LOSS = "valid_loss" 155 | TEST_LOSS = "test_loss" 156 | TRAIN_METRICS = "train_metrics" 157 | VALID_METRICS = "valid_metrics" 158 | TEST_METRICS = "test_metrics" 159 | 160 | def is_loss(self): 161 | return "loss" in self.value 162 | 163 | def is_metrics(self): 164 | return "metrics" in self.value 165 | 166 | 167 | class MonitorMode(Enum): 168 | MIN = -1 169 | MAX = 1 170 | 171 | def __str__(self): 172 | return self.value 173 | -------------------------------------------------------------------------------- /tests/configs/test_model_trainers_config.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from hamcrest import * 4 | 5 | from kerosene.configs.configs import ModelConfiguration 6 | from kerosene.configs.exceptions import InvalidConfigurationError 7 | from kerosene.configs.parsers import YamlConfigurationParser 8 | 9 | 10 | class TestModelTrainerConfiguration(unittest.TestCase): 11 | VALID_CONFIG_FILE_PATH = "tests/configs/valid_config.yml" 12 | INVALID_CONFIG_FILE_PATH = "tests/configs/invalid_config.yml" 13 | 14 | MODELS_CONFIG_YML_TAG = "models" 15 | 16 | SIMPLE_NET_NAME = "SimpleNet" 17 | SIMPLE_NET_NAME_2 = "SimpleNet2" 18 | 19 | SIMPLE_NET_TYPE = "SimpleNet" 20 | 21 | SIMPLE_NET_OPTIMIZER_TYPE = "SGD" 22 | SIMPLE_NET_OPTIMIZER_PARAMS = {'lr': 0.01, 'momentum': 0.5, 'weight_decay': 0} 23 | 24 | SIMPLE_NET_SCHEDULER_TYPE = "ReduceLROnPlateau" 25 | SIMPLE_NET_SCHEDULER_PARAMS = {'mode': 'min', 'factor': 0.1, 'patience': 3} 26 | 27 | SIMPLE_NET_CRITERION_TYPE = "L1Loss" 28 | SIMPLE_NET_CRITERION_TYPE_2 = "MSELoss" 29 | 30 | SIMPLE_NET_METRIC_TYPE_1 = "Dice" 31 | SIMPLE_NET_METRIC_PARAMS_1 = {"num_classes": 4, "reduction": None, "ignore_index": 0, "average": None, 32 | "weight": None} 33 | 34 | SIMPLE_NET_METRIC_TYPE_2 = "Accuracy" 35 | 36 | SIMPLE_NET_GRADIENT_CLIPPING = {"type": "norm", "params": {"max_norm": 1.0}} 37 | 38 | def test_should_parse_valid_model_trainer_config(self): 39 | expected_config_dict = {self.SIMPLE_NET_NAME: {'type': self.SIMPLE_NET_TYPE, 40 | 'optimizer': {'type': self.SIMPLE_NET_OPTIMIZER_TYPE, 41 | 'params': self.SIMPLE_NET_OPTIMIZER_PARAMS}, 42 | 'scheduler': {'type': self.SIMPLE_NET_SCHEDULER_TYPE, 43 | 'params': self.SIMPLE_NET_SCHEDULER_PARAMS}, 44 | 'criterion': {"cycle": {'type': self.SIMPLE_NET_CRITERION_TYPE}, 45 | "gan": {'type': self.SIMPLE_NET_CRITERION_TYPE_2}}, 46 | 'metrics': {'Dice': {'type': self.SIMPLE_NET_METRIC_TYPE_1, 47 | 'params': self.SIMPLE_NET_METRIC_PARAMS_1}, 48 | 'Accuracy': {'type': self.SIMPLE_NET_METRIC_TYPE_2}}, 49 | 'gradients': self.SIMPLE_NET_GRADIENT_CLIPPING}, 50 | self.SIMPLE_NET_NAME_2: {'type': self.SIMPLE_NET_TYPE, 51 | 'optimizer': {'type': self.SIMPLE_NET_OPTIMIZER_TYPE, 52 | 'params': self.SIMPLE_NET_OPTIMIZER_PARAMS}, 53 | 'scheduler': {'type': self.SIMPLE_NET_SCHEDULER_TYPE, 54 | 'params': self.SIMPLE_NET_SCHEDULER_PARAMS}, 55 | 'criterion': { 56 | "cycle": {'type': self.SIMPLE_NET_CRITERION_TYPE}, 57 | "gan": {'type': self.SIMPLE_NET_CRITERION_TYPE_2}}, 58 | 'metrics': {'Dice': {'type': self.SIMPLE_NET_METRIC_TYPE_1, 59 | 'params': self.SIMPLE_NET_METRIC_PARAMS_1}, 60 | 'Accuracy': { 61 | 'type': self.SIMPLE_NET_METRIC_TYPE_2}}, 62 | 'gradients': self.SIMPLE_NET_GRADIENT_CLIPPING}} 63 | config_dict = YamlConfigurationParser.parse_section(self.VALID_CONFIG_FILE_PATH, self.MODELS_CONFIG_YML_TAG) 64 | model_trainer_config = ModelConfiguration.from_dict(self.SIMPLE_NET_NAME, 65 | config_dict[self.SIMPLE_NET_NAME]) 66 | 67 | assert_that(config_dict, equal_to(expected_config_dict)) 68 | 69 | assert_that(model_trainer_config.optimizer_config.type, equal_to(self.SIMPLE_NET_OPTIMIZER_TYPE)) 70 | assert_that(model_trainer_config.optimizer_config.params, equal_to(self.SIMPLE_NET_OPTIMIZER_PARAMS)) 71 | assert_that(model_trainer_config.scheduler_config.type, equal_to(self.SIMPLE_NET_SCHEDULER_TYPE)) 72 | assert_that(model_trainer_config.scheduler_config.params, equal_to(self.SIMPLE_NET_SCHEDULER_PARAMS)) 73 | assert_that(model_trainer_config.criterions_configs[0].type, equal_to(self.SIMPLE_NET_CRITERION_TYPE)) 74 | assert_that(model_trainer_config.criterions_configs[1].type, equal_to(self.SIMPLE_NET_CRITERION_TYPE_2)) 75 | assert_that(model_trainer_config.metrics_configs[0].type, equal_to(self.SIMPLE_NET_METRIC_TYPE_1)) 76 | assert_that(model_trainer_config.metrics_configs[0].params, equal_to(self.SIMPLE_NET_METRIC_PARAMS_1)) 77 | 78 | def test_should_throw_on_invalid_model_trainer_config(self): 79 | config_dict = YamlConfigurationParser.parse_section(self.INVALID_CONFIG_FILE_PATH, self.MODELS_CONFIG_YML_TAG) 80 | 81 | assert_that(calling(ModelConfiguration.from_dict).with_args(self.SIMPLE_NET_NAME, 82 | config_dict[self.SIMPLE_NET_NAME]), 83 | raises(InvalidConfigurationError)) 84 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Kerosene 2 | > Kerosene is a high-level deep Learning framework for fast and clean research development with Pytorch - [see the doc for more details.](https://kerosene.readthedocs.io/en/latest/). Kerosene let you focus on your model and data by providing clean and readable code for training, visualizing and debugging your achitecture without forcing you to implement rigid interface for your model. 3 | 4 | ## Out of The Box Features 5 | - [X] Basic training logic and user defined trainers 6 | - [X] Fine grained event system with multiple handlers 7 | - [X] Multiple metrics and criterions support 8 | - [X] Automatic configuration parsing and model instantiation 9 | - [X] Automatic support of mixed precision with [Apex](https://github.com/NVIDIA/apex) and dataparallel training 10 | - [X] Automatic Visdom logging 11 | - [X] Integrated [Ignite](https://github.com/pytorch/ignite) metrics and [Pytorch](https://github.com/pytorch/pytorch) criterions 12 | 13 | ## MNIST Example 14 | > Here is a simple example that shows how easy and clean it is to train a simple network. In very few lines of code, the model is trained using mixed precision and you got Visdom + Console logging automatically. See full example there: [MNIST-Kerosene](https://github.com/banctilrobitaille/kerosene-mnist) 15 | 16 | ```python 17 | if __name__ == "__main__": 18 | logging.basicConfig(level=logging.INFO) 19 | CONFIG_FILE_PATH = "config.yml" 20 | 21 | model_trainer_config, training_config = YamlConfigurationParser.parse(CONFIG_FILE_PATH) 22 | 23 | train_loader = DataLoader(torchvision.datasets.MNIST('./files/', train=True, download=True, transform=Compose( 24 | [ToTensor(), Normalize((0.1307,), (0.3081,))])), batch_size=training_config.batch_size_train, shuffle=True) 25 | 26 | test_loader = DataLoader(torchvision.datasets.MNIST('./files/', train=False, download=True, transform=Compose( 27 | [ToTensor(), Normalize((0.1307,), (0.3081,))])), batch_size=training_config.batch_size_valid, shuffle=True) 28 | 29 | visdom_logger = VisdomLogger(VisdomConfiguration.from_yml(CONFIG_FILE_PATH)) 30 | 31 | # Initialize the model trainers 32 | model_trainer = ModelTrainerFactory(model=SimpleNet()).create(model_trainer_config) 33 | 34 | # Train with the training strategy 35 | SimpleTrainer("MNIST Trainer", train_loader, test_loader, None, model_trainer, RunConfiguration(use_amp=False)) \ 36 | .with_event_handler(PlotMonitors(every=500, visdom_logger=visdom_logger), Event.ON_BATCH_END) \ 37 | .with_event_handler(PlotAvgGradientPerLayer(every=500, visdom_logger=visdom_logger), Event.ON_TRAIN_BATCH_END) \ 38 | .with_event_handler(PrintTrainingStatus(every=100), Event.ON_BATCH_END) \ 39 | .train(training_config.nb_epochs) 40 | ``` 41 | 42 | ## Events 43 | 44 | | Event | Description | 45 | | ------------- | ------------- | 46 | | ON_TRAINING_BEGIN | At the beginning of the training phase | 47 | | ON_TRAINING_END | At the end of the training phase | 48 | | ON_VALID_BEGIN | At the beginning of the validation phase | 49 | | ON_VALID_END | At the end of the validation phase | 50 | | ON_TEST_BEGIN | At the beginning of the test phase | 51 | | ON_TEST_END | At the end of the test phase | 52 | | ON_EPOCH_BEGIN | At the beginning of each epoch (training, validation, test) | 53 | | ON_EPOCH_END | At the end of each epoch (training, validation, test) | 54 | | ON_TRAIN_EPOCH_BEGIN | At the beginning of each training epoch | 55 | | ON_TRAIN_EPOCH_END | At the end of each training epoch | 56 | | ON_VALID_EPOCH_BEGIN | At the beginning of each validation epoch | 57 | | ON_VALID_EPOCH_END | At the end of each validation epoch | 58 | | ON_TEST_EPOCH_BEGIN | At the beginning of each test epoch | 59 | | ON_TEST_EPOCH_END | At the end of each test epoch | 60 | | ON_BATCH_BEGIN | At the beginning of each batch (training, validation, test) | 61 | | ON_BATCH_END | At the end of each batch (training, validation, test) | 62 | | ON_TRAIN_BATCH_BEGIN | At the beginning of each train batch | 63 | | ON_TRAIN_BATCH_END | At the end of each train batch | 64 | | ON_VALID_BATCH_BEGIN | At the beginning of each validation batch | 65 | | ON_VALID_BATCH_END | At the end of each validation batch | 66 | | ON_TEST_BATCH_BEGIN | At the beginning of each test batch | 67 | | ON_TEST_BATCH_END | At the end of each test batch | 68 | | ON_FINALIZE | Before the end of the process | 69 | 70 | ## Handlers 71 | - [X] PrintTrainingStatus (Console) 72 | - [X] PrintMonitors (Console) 73 | - [X] PlotMonitors (Visdom) 74 | - [X] PlotLosses (Visdom) 75 | - [X] PlotMetrics (Visdom) 76 | - [X] PlotCustomVariables (Visdom) 77 | - [X] PlotLR (Visdom) 78 | - [X] PlotAvgGradientPerLayer (Visdom) 79 | - [X] Checkpoint 80 | - [X] EarlyStopping 81 | 82 | ## Contributing 83 | 84 | #### How to contribute ? 85 | - [X] Create a branch by feature and/or bug fix 86 | - [X] Get the code 87 | - [X] Commit and push 88 | - [X] Create a pull request 89 | 90 | #### Branch naming 91 | 92 | ##### Feature branch 93 | > feature/ [Short feature description] [Issue number] 94 | 95 | ##### Bug branch 96 | > fix/ [Short fix description] [Issue number] 97 | 98 | #### Commits syntax: 99 | 100 | ##### Adding code: 101 | > \+ Added [Short Description] [Issue Number] 102 | 103 | ##### Deleting code: 104 | > \- Deleted [Short Description] [Issue Number] 105 | 106 | ##### Modifying code: 107 | > \* Changed [Short Description] [Issue Number] 108 | 109 | ##### Merging code: 110 | > Y Merged [Short Description] [Issue Number] 111 | 112 | 113 | Icons made by Freepik from www.flaticon.com is licensed by CC 3.0 BY 114 | -------------------------------------------------------------------------------- /kerosene/nn/apex.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2019 Kerosene Authors. All Rights Reserved. 3 | # 4 | # Licensed under the MIT License; 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://opensource.org/licenses/MIT 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | from __future__ import division 17 | 18 | from abc import ABC 19 | from typing import Optional 20 | 21 | import torch 22 | from torch import Tensor, nn 23 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP 24 | from torch.optim import Optimizer 25 | 26 | from kerosene.utils.devices import on_multiple_gpus, get_devices 27 | 28 | try: 29 | from apex import amp 30 | from apex.parallel import DistributedDataParallel as ApexDDP 31 | 32 | APEX_AVAILABLE = True 33 | except ModuleNotFoundError: 34 | APEX_AVAILABLE = False 35 | 36 | 37 | class ApexLoss(object): 38 | def __init__(self, loss_id: int, loss: Tensor, optimizer: Optimizer): 39 | super().__init__() 40 | self._loss_id = loss_id 41 | self._loss = loss 42 | self._optimizer = optimizer 43 | 44 | @property 45 | def loss_id(self): 46 | return self._loss_id 47 | 48 | @property 49 | def loss(self): 50 | return self._loss 51 | 52 | def backward(self, gradient: Optional[Tensor] = None, retain_graph=False, 53 | create_graph=False) -> None: 54 | if APEX_AVAILABLE: 55 | with amp.scale_loss(self._loss, self._optimizer, loss_id=self._loss_id) as scaled_loss: 56 | scaled_loss.backward(gradient, retain_graph, create_graph) 57 | else: 58 | self._loss.backward(gradient, retain_graph, create_graph) 59 | 60 | def cpu(self): 61 | return ApexLoss(self._loss_id, self._loss.cpu(), self._optimizer) 62 | 63 | def detach(self): 64 | return ApexLoss(self._loss_id, self._loss.detach(), self._optimizer) 65 | 66 | def float(self): 67 | return ApexLoss(self._loss_id, self._loss.float(), self._optimizer) 68 | 69 | def numpy(self): 70 | return self._loss.numpy() 71 | 72 | def mean(self): 73 | return ApexLoss(self._loss_id, torch.mean(self._loss), self._optimizer) 74 | 75 | def item(self): 76 | return self._loss.item() 77 | 78 | def __add__(self, other): 79 | if isinstance(other, (Tensor, int, float)): 80 | loss = ApexLoss(self._loss_id, torch.add(self._loss, other), self._optimizer) 81 | elif isinstance(other, ApexLoss): 82 | loss = ApexLoss(self._loss_id, torch.add(self._loss, other.loss), self._optimizer) 83 | else: 84 | raise NotImplementedError("Cannot add an element of type: {} to an ApexLoss.".format(str(type(other)))) 85 | return loss 86 | 87 | def __mul__(self, other): 88 | if isinstance(other, (Tensor, int, float)): 89 | loss = ApexLoss(self._loss_id, torch.mul(self._loss, other), self._optimizer) 90 | elif isinstance(other, ApexLoss): 91 | loss = ApexLoss(self._loss_id, torch.mul(self._loss, other.loss), self._optimizer) 92 | else: 93 | raise NotImplementedError("Cannot mul an element of type: {} to an ApexLoss.".format(str(type(other)))) 94 | return loss 95 | 96 | def __rmul__(self, other): 97 | if isinstance(other, (Tensor, int, float)): 98 | loss = ApexLoss(self._loss_id, torch.mul(self._loss, other), self._optimizer) 99 | else: 100 | raise NotImplementedError("Cannot rmul an element of type: {} to an ApexLoss.".format(str(type(other)))) 101 | return loss 102 | 103 | def __truediv__(self, other): 104 | if isinstance(other, (Tensor, int, float)): 105 | loss = ApexLoss(self._loss_id, torch.div(self._loss, other), self._optimizer) 106 | elif isinstance(other, ApexLoss): 107 | loss = ApexLoss(self._loss_id, torch.div(self._loss, other.loss), self._optimizer) 108 | else: 109 | raise NotImplementedError("Cannot truediv an element of type: {} to an ApexLoss.".format(str(type(other)))) 110 | return loss 111 | 112 | def __rtruediv__(self, other): 113 | if isinstance(other, (Tensor, int, float)): 114 | loss = ApexLoss(self._loss_id, torch.div(other, self._loss), self._optimizer) 115 | else: 116 | raise NotImplementedError("Cannot rtruediv an element of type: {} to an ApexLoss.".format(str(type(other)))) 117 | return loss 118 | 119 | def __sub__(self, other): 120 | if isinstance(other, (Tensor, int, float)): 121 | loss = ApexLoss(self._loss_id, torch.sub(self._loss, other), self._optimizer) 122 | elif isinstance(other, ApexLoss): 123 | loss = ApexLoss(self._loss_id, torch.sub(self._loss, other.loss), self._optimizer) 124 | else: 125 | raise NotImplementedError( 126 | "Cannot substract an element of type: {} to an ApexLoss.".format(str(type(other)))) 127 | return loss 128 | 129 | def __rsub__(self, other): 130 | if isinstance(other, (Tensor, int, float)): 131 | loss = ApexLoss(self._loss_id, torch.sub(other, self._loss), self._optimizer) 132 | else: 133 | raise NotImplementedError( 134 | "Cannot substract an element of type: {} to an ApexLoss.".format(str(type(other)))) 135 | return loss 136 | 137 | def __eq__(self, other): 138 | if isinstance(other, Tensor): 139 | is_equal = torch.all(torch.eq(self._loss, other)) 140 | elif isinstance(other, ApexLoss): 141 | is_equal = torch.all(torch.eq(self._loss, other.loss)) 142 | else: 143 | raise NotImplementedError("Cannot compare an element of type: {} to an ApexLoss.".format(str(type(other)))) 144 | return is_equal 145 | 146 | 147 | class ApexModule(ABC, nn.Module): 148 | def __init__(self, model, optimizer, amp_id=0, use_amp=True): 149 | super().__init__() 150 | self._amp_id = amp_id 151 | self._use_amp = use_amp 152 | self._model = model 153 | self._optimizer = optimizer 154 | 155 | @property 156 | def model(self): 157 | return self._model 158 | 159 | @model.setter 160 | def model(self, value): 161 | self._model = value 162 | 163 | @property 164 | def optimizer(self): 165 | return self._optimizer 166 | 167 | @property 168 | def amp_id(self): 169 | return self._amp_id 170 | 171 | @property 172 | def use_amp(self): 173 | return self._use_amp 174 | 175 | def initialize(self, amp_id: int, num_losses: int, use_amp: bool, amp_opt_level: str, device: torch.device): 176 | self._amp_id = amp_id 177 | self._use_amp = use_amp 178 | 179 | if APEX_AVAILABLE and self._use_amp: 180 | self._model, self._optimizer = amp.initialize( 181 | self._model, self._optimizer, opt_level=amp_opt_level, num_losses=num_losses) 182 | if on_multiple_gpus(get_devices()): 183 | self._model = ApexDDP(self._model, delay_allreduce=True) 184 | if not APEX_AVAILABLE and on_multiple_gpus(get_devices()): 185 | self._model = DDP(self._model, device_ids=[device]) 186 | -------------------------------------------------------------------------------- /tests/metrics/test_metrics.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2019 SAMITorch Authors. All Rights Reserved. 3 | # 4 | # Licensed under the MIT License; 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://opensource.org/licenses/MIT 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | 18 | import torch 19 | import numpy as np 20 | import unittest 21 | 22 | from kerosene.metrics.metrics import Dice, GeneralizedDice 23 | from hamcrest import * 24 | 25 | 26 | def get_y_true_y_pred(): 27 | # Generate an image with labels 0 (background), 1, 2 28 | # 3 classes: 29 | y_true = np.zeros((30, 30), dtype=np.int) 30 | y_true[1:11, 1:11] = 1 31 | y_true[15:25, 15:25] = 2 32 | 33 | y_pred = np.zeros((30, 30), dtype=np.int) 34 | y_pred[5:15, 1:11] = 1 35 | y_pred[20:30, 20:30] = 2 36 | return y_true, y_pred 37 | 38 | 39 | def compute_tensor_y_true_y_logits(y_true, y_pred): 40 | # Create torch.tensor from numpy 41 | y_true_tensor = torch.from_numpy(y_true).unsqueeze(0).type(torch.long) 42 | # Create logits torch.tensor: 43 | num_classes = max(np.max(y_true), np.max(y_pred)) + 1 44 | y_probs = np.ones((num_classes,) + y_true.shape) * -10 45 | for i in range(num_classes): 46 | y_probs[i, (y_pred == i)] = 720 47 | y_logits = torch.from_numpy(y_probs).unsqueeze(0) 48 | return y_true_tensor, y_logits 49 | 50 | 51 | def compute_dice_truth(y_true, y_pred): 52 | true_res = [0, 0, 0] 53 | for index in range(3): 54 | bin_y_true = y_true == index 55 | bin_y_pred = y_pred == index 56 | intersection = bin_y_true & bin_y_pred 57 | true_res[index] = 2 * intersection.sum() / (bin_y_pred.sum() + bin_y_true.sum()) 58 | return true_res 59 | 60 | 61 | def compute_generalized_dice_truth(y_true, y_pred): 62 | true_res = [0, 0, 0] 63 | weights = [0, 0, 0] 64 | for index in range(3): 65 | bin_y_true = y_true == index 66 | bin_y_pred = y_pred == index 67 | weights[index] = (1.0 / (np.sum(bin_y_true) * np.sum(bin_y_true) + 1e-15)) 68 | intersection = (bin_y_true & bin_y_pred) 69 | true_res[index] = 2 * intersection.sum() * weights[index] / ( 70 | ((bin_y_pred.sum() + bin_y_true.sum()) * weights[index]) + 1e-15) 71 | return true_res, weights 72 | 73 | 74 | class TestDiceMetric(unittest.TestCase): 75 | INVALID_VALUE_1 = -1 76 | INVALID_REDUCTION = "sum" 77 | INVALID_VALUE_3 = 10 78 | INVALID_VALUE_4 = 11 79 | 80 | def setUp(self): 81 | self.y_true, self.y_pred = get_y_true_y_pred() 82 | self.y_true_tensor, self.y_logits = compute_tensor_y_true_y_logits(self.y_true, self.y_pred) 83 | self.dice_truth = compute_dice_truth(self.y_true, self.y_pred) 84 | self.mean_dice_truth = np.mean(self.dice_truth) 85 | 86 | def test_should_compute_dice_for_multiclass(self): 87 | dice_coefficient = Dice(num_classes=3, reduction=None) 88 | dice_coefficient.update((self.y_logits, self.y_true_tensor)) 89 | res = dice_coefficient.compute().numpy() 90 | assert np.all(res == self.dice_truth) 91 | 92 | def test_should_compute_mean_dice_for_multiclass(self): 93 | dice_coefficient = Dice(num_classes=3, reduction="mean") 94 | dice_coefficient.update((self.y_logits, self.y_true_tensor)) 95 | res = dice_coefficient.compute().numpy() 96 | truth = np.array(self.dice_truth).mean() 97 | assert res == truth 98 | 99 | def test_should_compute_dice_for_multiclass_with_ignored_index(self): 100 | for ignore_index in range(3): 101 | dice_coefficient = Dice(num_classes=3, ignore_index=ignore_index, reduction=None) 102 | dice_coefficient.update((self.y_logits, self.y_true_tensor)) 103 | res = dice_coefficient.compute().numpy() 104 | true_res = self.dice_truth[:ignore_index] + self.dice_truth[ignore_index + 1:] 105 | assert np.all(res == true_res), "{}: {} vs {}".format(ignore_index, res, true_res) 106 | 107 | def test_should_compute_mean_dice(self): 108 | mean_dice_coefficient = Dice(num_classes=3, reduction="mean") 109 | mean_dice_coefficient.update((self.y_logits, self.y_true_tensor)) 110 | res = mean_dice_coefficient.compute().numpy() 111 | assert_that(res, equal_to(self.mean_dice_truth)) 112 | 113 | def test_should_compute_mean_dice_with_ignored_index(self): 114 | for ignore_index in range(3): 115 | mean_dice_coefficient = Dice(num_classes=3, reduction="mean", ignore_index=ignore_index) 116 | mean_dice_coefficient.update((self.y_logits, self.y_true_tensor)) 117 | res = mean_dice_coefficient.compute().numpy() 118 | true_res = np.mean(self.dice_truth[:ignore_index] + self.dice_truth[ignore_index + 1:]) 119 | assert_that(res, equal_to(true_res)), "{}: {} vs {}".format(ignore_index, res, true_res) 120 | 121 | 122 | class TestGeneralizedDiceMetric(unittest.TestCase): 123 | INVALID_VALUE_1 = -1 124 | INVALID_VALUE_2 = "STEVE JOBS" 125 | INVALID_VALUE_3 = 10 126 | INVALID_VALUE_4 = 11 127 | 128 | def setUp(self): 129 | self.y_true, self.y_pred = get_y_true_y_pred() 130 | self.y_true_tensor, self.y_logits = compute_tensor_y_true_y_logits(self.y_true, self.y_pred) 131 | self.generalized_dice_truth, weights = compute_generalized_dice_truth(self.y_true, self.y_pred) 132 | self.weights = torch.from_numpy(np.array(weights)) 133 | self.generalized_mean_dice_truth = np.mean(self.generalized_dice_truth) 134 | 135 | def test_should_compute_dice_for_multiclass(self): 136 | generalized_dice_coefficient = GeneralizedDice(num_classes=3) 137 | generalized_dice_coefficient.update((self.y_logits, self.y_true_tensor)) 138 | res = generalized_dice_coefficient.compute().numpy() 139 | assert np.all(res == self.generalized_dice_truth) 140 | 141 | def test_should_compute_dice_for_multiclass_with_ignored_index(self): 142 | for ignore_index in range(3): 143 | generalized_dice_coefficient = GeneralizedDice(num_classes=3, ignore_index=ignore_index) 144 | generalized_dice_coefficient.update((self.y_logits, self.y_true_tensor)) 145 | res = generalized_dice_coefficient.compute().numpy() 146 | true_res = self.generalized_dice_truth[:ignore_index] + self.generalized_dice_truth[ignore_index + 1:] 147 | assert np.all(res == true_res), "{}: {} vs {}".format(ignore_index, res, true_res) 148 | 149 | def test_should_compute_mean_dice(self): 150 | mean_generalized_dice_coefficient = GeneralizedDice(num_classes=3, reduction="mean") 151 | mean_generalized_dice_coefficient.update((self.y_logits, self.y_true_tensor)) 152 | res = mean_generalized_dice_coefficient.compute().numpy() 153 | assert_that(res, equal_to(self.generalized_mean_dice_truth)) 154 | 155 | def test_should_compute_mean_dice_with_ignored_index(self): 156 | for ignore_index in range(3): 157 | mean_generalized_dice_coefficient = GeneralizedDice(num_classes=3, reduction="mean", ignore_index=ignore_index) 158 | mean_generalized_dice_coefficient.update((self.y_logits, self.y_true_tensor)) 159 | res = mean_generalized_dice_coefficient.compute().numpy() 160 | true_res = np.mean( 161 | self.generalized_dice_truth[:ignore_index] + self.generalized_dice_truth[ignore_index + 1:]) 162 | assert_that(res, equal_to(true_res)), "{}: {} vs {}".format(ignore_index, res, true_res) -------------------------------------------------------------------------------- /kerosene/events/handlers/console.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2019 Kerosene Authors. All Rights Reserved. 3 | # 4 | # Licensed under the MIT License; 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://opensource.org/licenses/MIT 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | import logging 17 | import math 18 | from abc import ABC 19 | from enum import Enum 20 | from typing import Dict, Union, List, Optional 21 | 22 | from kerosene.events import Phase, TemporalEvent, BaseEvent, Monitor 23 | from kerosene.events.handlers.base_handler import EventHandler 24 | from kerosene.training import Status, BaseStatus 25 | from kerosene.training.events import Event 26 | from kerosene.training.trainers import Trainer 27 | import crayons 28 | from beautifultable import BeautifulTable 29 | 30 | 31 | class ConsoleColors(Enum): 32 | PURPLE = '\033[95m' 33 | BLUE = '\033[94m' 34 | GREEN = '\033[92m' 35 | YELLOW = '\033[93m' 36 | RED = '\033[91m' 37 | ENDC = '\033[0m' 38 | NONE = "" 39 | 40 | def __str__(self): 41 | return self.value 42 | 43 | 44 | class StatusConsoleColorPalette(Dict[Status, ConsoleColors]): 45 | DEFAULT = { 46 | Phase.TRAINING: ConsoleColors.GREEN, 47 | Phase.VALIDATION: ConsoleColors.BLUE, 48 | Phase.TEST: ConsoleColors.PURPLE 49 | } 50 | 51 | 52 | class BaseConsoleLogger(EventHandler, ABC): 53 | LOGGER = logging.getLogger("ConsoleLogger") 54 | 55 | def __init__(self, supported_events: List[Union[BaseEvent, TemporalEvent]] = None, every=1): 56 | super(BaseConsoleLogger, self).__init__(supported_events, every) 57 | 58 | 59 | class ColoredConsoleLogger(BaseConsoleLogger, ABC): 60 | 61 | def __init__(self, supported_events: List[Union[BaseEvent, TemporalEvent]] = None, every=1, 62 | colors: Dict[object, ConsoleColors] = None): 63 | super().__init__(supported_events, every) 64 | self._colors = colors if colors is not None else {} 65 | 66 | def color(self, text, color_key): 67 | return str(self._colors.get(color_key, ConsoleColors.NONE)) + str(text) + str(ConsoleColors.ENDC) 68 | 69 | 70 | class PrintTrainingStatus(ColoredConsoleLogger): 71 | SUPPORTED_EVENTS = [Event.ON_BATCH_END, Event.ON_EPOCH_END, Event.ON_TRAIN_BATCH_END, Event.ON_VALID_BATCH_END, 72 | Event.ON_TEST_BATCH_END] 73 | 74 | def __init__(self, every=1, colors: Union[Dict[BaseStatus, ConsoleColors], StatusConsoleColorPalette] = None): 75 | super().__init__(self.SUPPORTED_EVENTS, every, colors) 76 | 77 | def __call__(self, event: TemporalEvent, monitors: dict, trainer: Trainer): 78 | if self.should_handle(event): 79 | self.print_status(event.phase, trainer.epoch, trainer.current_train_step, trainer.current_valid_step, 80 | trainer.current_test_step) 81 | 82 | def print_status(self, status, epoch, train_step, valid_step, test_step): 83 | self.LOGGER.info( 84 | "\nCurrent states: {} | Epoch: {} | Training step: {} | Validation step: {} | Test step: {}\n".format( 85 | self.color(status, color_key=status), epoch, train_step, valid_step, test_step)) 86 | 87 | 88 | class PrintMonitors(BaseConsoleLogger): 89 | SUPPORTED_EVENTS = [Event.ON_BATCH_END, Event.ON_EPOCH_END, Event.ON_TRAIN_BATCH_END, Event.ON_VALID_BATCH_END, 90 | Event.ON_TEST_BATCH_END] 91 | 92 | def __init__(self, every=1): 93 | super().__init__(self.SUPPORTED_EVENTS, every) 94 | self._monitors = {} 95 | 96 | def __call__(self, event: TemporalEvent, monitors: dict, trainer: Trainer): 97 | if self.should_handle(event): 98 | for model, monitor in monitors.items(): 99 | if model in self._monitors: 100 | self._monitors[model].update(monitor) 101 | else: 102 | self._monitors[model] = monitor 103 | 104 | self.LOGGER.info("Model {}, {}".format(model, self._monitors[model])) 105 | 106 | 107 | class PrintMonitorsTable(BaseConsoleLogger): 108 | SUPPORTED_EVENTS = [Event.ON_BATCH_END, Event.ON_EPOCH_END, Event.ON_TRAIN_BATCH_END, Event.ON_VALID_BATCH_END, 109 | Event.ON_TEST_BATCH_END] 110 | 111 | def __init__(self, every=1, max_table_width=100): 112 | super().__init__(self.SUPPORTED_EVENTS, every) 113 | self._monitors = {} 114 | self._monitors_tables = {} 115 | self._max_table_width = max_table_width 116 | 117 | def __call__(self, event: TemporalEvent, monitors: dict, trainer: Trainer): 118 | if self.should_handle(event): 119 | for model, monitor in monitors.items(): 120 | training_values = {**monitors[model][Phase.TRAINING][Monitor.METRICS], 121 | **monitors[model][Phase.TRAINING][Monitor.LOSS]} 122 | validation_values = {**monitors[model][Phase.VALIDATION][Monitor.METRICS], 123 | **monitors[model][Phase.VALIDATION][Monitor.LOSS]} 124 | test_values = {**monitors[model][Phase.TEST][Monitor.METRICS], 125 | **monitors[model][Phase.TEST][Monitor.LOSS]} 126 | 127 | if model in self._monitors_tables: 128 | self._monitors_tables[model].update(training_values, validation_values, test_values) 129 | else: 130 | self._monitors_tables[model] = MonitorsTable(model, self._max_table_width) 131 | self._monitors_tables[model].update(training_values, validation_values, test_values) 132 | 133 | self._monitors_tables[model].show() 134 | 135 | 136 | class MonitorsTable(object): 137 | def __init__(self, model_name: str, max_width): 138 | self.model_name = model_name 139 | self._training_monitors = {} 140 | self._validation_monitors = {} 141 | self._test_monitors = {} 142 | self._max_width = max_width 143 | self.table = BeautifulTable(max_width) 144 | 145 | def append(self, values: dict, old_values: dict): 146 | self.table.rows.append(list(map(lambda key: self.color(values[key], old_values.get(key, None)), values.keys()))) 147 | 148 | def update(self, training_monitors: dict, validation_monitors: dict, test_monitors: Optional[dict] = None): 149 | self.table.clear() 150 | self.append(training_monitors, self._training_monitors) 151 | self.append(validation_monitors, self._validation_monitors) 152 | 153 | if test_monitors is not None: 154 | self.append(test_monitors, self._test_monitors) 155 | self.table.rows.header = ["Training", "Validation", "Test"] 156 | else: 157 | self.table.rows.header = ["Training", "Validation"] 158 | 159 | self.table.columns.header = training_monitors.keys() 160 | 161 | self._training_monitors = dict(training_monitors) 162 | self._validation_monitors = dict(validation_monitors) 163 | self._test_monitors = dict(test_monitors) 164 | 165 | def show(self): 166 | self.table._compute_width() 167 | width = self.table._width - 2 if self.table._width == self._max_width else self.table._width + 11 168 | topline = "".join(["+", "-" * width, "+"]) 169 | print(topline) 170 | spc = (len(topline) - 2 - len(self.model_name)) / 2 171 | print("%s%s%s%s%s" % ("|", " " * math.ceil(spc), self.model_name, " " * math.floor(spc), "|")) 172 | print(self.table) 173 | 174 | def color(self, value, old_value): 175 | if old_value is not None: 176 | if value > old_value: 177 | return crayons.green("{} \u2197".format(value), bold=True) 178 | if value < old_value: 179 | return crayons.red("{} \u2198".format(value), bold=True) 180 | 181 | return crayons.white("{} \u2192".format(value), bold=True) 182 | -------------------------------------------------------------------------------- /kerosene/training/events.py: -------------------------------------------------------------------------------- 1 | from kerosene.events import TemporalEvent, Moment, Frequency, Phase, BaseEvent 2 | from kerosene.training import Status 3 | 4 | 5 | class Event(BaseEvent): 6 | ON_TRAINING_BEGIN = "training_begin" 7 | ON_TRAINING_END = "training_end" 8 | ON_VALID_BEGIN = "valid_begin" 9 | ON_VALID_END = "valid_end" 10 | ON_TEST_BEGIN = "test_begin" 11 | ON_TEST_END = "test_end" 12 | ON_EPOCH_BEGIN = "epoch_begin" 13 | ON_EPOCH_END = "epoch_end" 14 | ON_TRAIN_EPOCH_BEGIN = "train_epoch_begin" 15 | ON_TRAIN_EPOCH_END = "train_epoch_end" 16 | ON_VALID_EPOCH_BEGIN = "valid_epoch_begin" 17 | ON_VALID_EPOCH_END = "valid_epoch_end" 18 | ON_TEST_EPOCH_BEGIN = "test_epoch_begin" 19 | ON_TEST_EPOCH_END = "test_epoch_end" 20 | ON_BATCH_BEGIN = "batch_begin" 21 | ON_TRAIN_BATCH_BEGIN = "train_batch_begin" 22 | ON_TRAIN_BATCH_END = "train_batch_end" 23 | ON_VALID_BATCH_BEGIN = "valid_batch_begin" 24 | ON_VALID_BATCH_END = "valid_batch_end" 25 | ON_TEST_BATCH_BEGIN = "test_batch_begin" 26 | ON_TEST_BATCH_END = "test_batch_end" 27 | ON_BATCH_END = "batch_end" 28 | ON_FINALIZE = "finalizing" 29 | 30 | def __call__(self, moment: Moment): 31 | return TemporalEvent(self, moment) 32 | 33 | 34 | class BatchEventPublisherMixin(object): 35 | @property 36 | def iteration_and_phase(self): 37 | if self.status is Status.TRAINING: 38 | iteration = self.current_train_step 39 | phase = Phase.TRAINING 40 | elif self.status is Status.VALIDATING: 41 | iteration = self.current_valid_step 42 | phase = Phase.VALIDATION 43 | else: 44 | iteration = self.current_test_step 45 | phase = Phase.TEST 46 | 47 | return iteration, phase 48 | 49 | def on_batch_begin(self): 50 | pass 51 | 52 | def on_batch_end(self): 53 | pass 54 | 55 | def on_train_batch_begin(self): 56 | pass 57 | 58 | def on_train_batch_end(self): 59 | pass 60 | 61 | def on_valid_batch_begin(self): 62 | pass 63 | 64 | def on_valid_batch_end(self): 65 | pass 66 | 67 | def on_test_batch_begin(self): 68 | pass 69 | 70 | def on_test_batch_end(self): 71 | pass 72 | 73 | def _on_batch_begin(self): 74 | self.on_batch_begin() 75 | 76 | iteration, phase = self.iteration_and_phase 77 | 78 | self.fire(Event.ON_BATCH_BEGIN(Moment(iteration, Frequency.STEP, phase))) 79 | 80 | def _on_batch_end(self): 81 | self.on_batch_end() 82 | 83 | iteration, phase = self.iteration_and_phase 84 | 85 | self.fire(Event.ON_BATCH_END(Moment(iteration, Frequency.STEP, phase)), self.step_monitors(phase)) 86 | 87 | def _on_train_batch_begin(self): 88 | self.on_train_batch_begin() 89 | self.fire(Event.ON_TRAIN_BATCH_BEGIN(Moment(self.current_train_step, Frequency.STEP, Phase.TRAINING))) 90 | 91 | def _on_train_batch_end(self): 92 | self.on_train_batch_end() 93 | self.fire(Event.ON_TRAIN_BATCH_END(Moment(self.current_train_step, Frequency.STEP, Phase.TRAINING)), 94 | self.step_monitors(Phase.TRAINING)) 95 | 96 | def _on_valid_batch_begin(self): 97 | self.on_valid_batch_begin() 98 | self.fire(Event.ON_VALID_BATCH_BEGIN(Moment(self.current_valid_step, Frequency.STEP, Phase.VALIDATION))) 99 | 100 | def _on_valid_batch_end(self): 101 | self.on_valid_batch_begin() 102 | self.fire(Event.ON_VALID_BATCH_END(Moment(self.current_valid_step, Frequency.STEP, Phase.VALIDATION)), 103 | self.step_monitors(Phase.VALIDATION)) 104 | 105 | def _on_test_batch_begin(self): 106 | self.on_test_batch_begin() 107 | self.fire(Event.ON_TEST_BATCH_BEGIN(Moment(self.current_test_step, Frequency.STEP, Phase.TEST))) 108 | 109 | def _on_test_batch_end(self): 110 | self.on_test_batch_end() 111 | self.fire(Event.ON_TEST_BATCH_END(Moment(self.current_test_step, Frequency.STEP, Phase.TEST)), 112 | self.step_monitors(Phase.TEST)) 113 | 114 | 115 | class EpochEventPublisherMixin(object): 116 | @property 117 | def epoch_and_phase(self): 118 | if self.status is Status.TRAINING: 119 | phase = Phase.TRAINING 120 | elif self.status is Status.VALIDATING: 121 | phase = Phase.VALIDATION 122 | else: 123 | phase = Phase.TEST 124 | 125 | return self.epoch, phase 126 | 127 | @property 128 | def phase(self): 129 | if self.status is Status.TRAINING: 130 | phase = Phase.TRAINING 131 | elif self.status is Status.VALIDATING: 132 | phase = Phase.VALIDATION 133 | else: 134 | phase = Phase.TEST 135 | 136 | return phase 137 | 138 | def on_epoch_begin(self): 139 | pass 140 | 141 | def on_epoch_end(self): 142 | pass 143 | 144 | def on_train_epoch_begin(self): 145 | pass 146 | 147 | def on_train_epoch_end(self): 148 | pass 149 | 150 | def on_valid_epoch_begin(self): 151 | pass 152 | 153 | def on_valid_epoch_end(self): 154 | pass 155 | 156 | def on_test_epoch_begin(self): 157 | pass 158 | 159 | def on_test_epoch_end(self): 160 | pass 161 | 162 | def _on_epoch_begin(self): 163 | self._reset_model_trainers() 164 | self.on_epoch_begin() 165 | 166 | epoch, phase = self.epoch_and_phase 167 | 168 | self.fire(Event.ON_EPOCH_BEGIN(Moment(epoch, Frequency.EPOCH, phase))) 169 | 170 | def _on_epoch_end(self): 171 | self.on_epoch_end() 172 | 173 | epoch, phase = self.epoch_and_phase 174 | 175 | self.fire(Event.ON_EPOCH_END(Moment(epoch, Frequency.EPOCH, phase)), self.epoch_monitors(phase)) 176 | 177 | def _on_train_epoch_begin(self): 178 | self._status = Status.TRAINING 179 | self.on_train_epoch_begin() 180 | self.fire(Event.ON_TRAIN_EPOCH_BEGIN(Moment(self.epoch, Frequency.EPOCH, Phase.TRAINING))) 181 | 182 | def _on_train_epoch_end(self): 183 | self.on_train_epoch_end() 184 | self.fire(Event.ON_TRAIN_EPOCH_END(Moment(self.epoch, Frequency.EPOCH, Phase.TRAINING)), 185 | self.epoch_monitors(Phase.TRAINING)) 186 | 187 | def _on_valid_epoch_begin(self): 188 | self.on_valid_epoch_begin() 189 | self.fire(Event.ON_VALID_EPOCH_BEGIN(Moment(self.epoch, Frequency.EPOCH, Phase.VALIDATION))) 190 | 191 | def _on_valid_epoch_end(self): 192 | self.on_valid_epoch_end() 193 | self.fire(Event.ON_VALID_EPOCH_END(Moment(self.epoch, Frequency.EPOCH, Phase.VALIDATION)), 194 | self.epoch_monitors(Phase.VALIDATION)) 195 | 196 | def _on_test_epoch_begin(self): 197 | self.on_test_epoch_begin() 198 | self.fire(Event.ON_TEST_EPOCH_BEGIN(Moment(self.epoch, Frequency.EPOCH, Phase.TEST))) 199 | 200 | def _on_test_epoch_end(self): 201 | self._current_valid_batch = 0 202 | self._current_test_batch = 0 203 | self.on_test_epoch_end() 204 | self.fire(Event.ON_TEST_EPOCH_END(Moment(self.epoch, Frequency.STEP, Phase.TEST)), 205 | self.epoch_monitors(Phase.TEST)) 206 | 207 | 208 | class TrainingPhaseEventPublisherMixin(object): 209 | def on_training_begin(self): 210 | pass 211 | 212 | def on_training_end(self): 213 | pass 214 | 215 | def on_valid_begin(self): 216 | pass 217 | 218 | def on_valid_end(self): 219 | pass 220 | 221 | def on_test_begin(self): 222 | pass 223 | 224 | def on_test_end(self): 225 | pass 226 | 227 | def finalize(self): 228 | pass 229 | 230 | def _on_training_begin(self): 231 | self.on_training_begin() 232 | self._status = Status.TRAINING 233 | self.fire(Event.ON_TRAINING_BEGIN(Moment(0, Frequency.PHASE, Phase.TRAINING))) 234 | 235 | def _on_training_end(self): 236 | self.on_training_end() 237 | self.scheduler_step() 238 | self.fire(Event.ON_TRAINING_END(Moment(0, Frequency.PHASE, Phase.TRAINING))) 239 | 240 | def _on_valid_begin(self): 241 | self.on_valid_begin() 242 | self._status = Status.VALIDATING 243 | self.fire(Event.ON_VALID_BEGIN(Moment(0, Frequency.PHASE, Phase.VALIDATION))) 244 | 245 | def _on_valid_end(self): 246 | self.on_valid_end() 247 | self.fire(Event.ON_VALID_END(Moment(0, Frequency.PHASE, Phase.VALIDATION))) 248 | 249 | def _on_test_begin(self): 250 | self.on_test_begin() 251 | self._status = Status.TESTING 252 | self.fire(Event.ON_TEST_BEGIN(Moment(0, Frequency.PHASE, Phase.TEST))) 253 | 254 | def _on_test_end(self): 255 | self.on_test_end() 256 | self.fire(Event.ON_TEST_END(Moment(0, Frequency.PHASE, Phase.TEST))) 257 | 258 | def _finalize(self): 259 | for model_trainer in self.model_trainers: 260 | model_trainer.finalize() 261 | self.finalize() 262 | self.status = Status.FINALIZED 263 | -------------------------------------------------------------------------------- /kerosene/loggers/visdom/plots.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2019 Kerosene Authors. All Rights Reserved. 3 | # 4 | # Licensed under the MIT License; 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://opensource.org/licenses/MIT 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | from abc import ABC, abstractmethod 17 | 18 | from visdom import Visdom 19 | 20 | from kerosene.loggers.visdom import PlotType 21 | from kerosene.loggers.visdom.data import VisdomData 22 | 23 | 24 | class VisdomPlot(ABC): 25 | 26 | def __init__(self, visdom: Visdom): 27 | self._visdom = visdom 28 | self._window = None 29 | 30 | @property 31 | def visdom(self): 32 | return self._visdom 33 | 34 | @property 35 | def window(self): 36 | return self._window 37 | 38 | @abstractmethod 39 | def update(self, visdom_data: VisdomData): 40 | raise NotImplementedError() 41 | 42 | 43 | class ImagePlot(VisdomPlot): 44 | def __init__(self, visdom: Visdom): 45 | super().__init__(visdom) 46 | 47 | def update(self, visdom_data: VisdomData): 48 | if self._window is None: 49 | self._window = self.visdom.image(img=visdom_data.y, **visdom_data.params) 50 | else: 51 | self.visdom.image(img=visdom_data.y, win=self._window, **visdom_data.params) 52 | 53 | 54 | class ImagesPlot(VisdomPlot): 55 | def __init__(self, visdom: Visdom): 56 | super().__init__(visdom) 57 | 58 | def update(self, visdom_data: VisdomData): 59 | if self._window is None: 60 | self._window = self._visdom.images(tensor=visdom_data.y, **visdom_data.params) 61 | else: 62 | self._visdom.images(tensor=visdom_data.y, win=self._window, **visdom_data.params) 63 | 64 | 65 | class LinePlot(VisdomPlot): 66 | def __init__(self, visdom): 67 | super().__init__(visdom) 68 | 69 | def update(self, visdom_data: VisdomData): 70 | if self._window is None: 71 | self._window = self._visdom.line(X=visdom_data.x, Y=visdom_data.y, **visdom_data.params) 72 | else: 73 | self._visdom.line(X=visdom_data.x, Y=visdom_data.y, win=self._window, update='append', 74 | name=visdom_data.params['opts']['name']) 75 | 76 | 77 | class PiePlot(VisdomPlot): 78 | def __init__(self, visdom): 79 | super().__init__(visdom) 80 | 81 | def update(self, visdom_data: VisdomData): 82 | if self._window is None: 83 | self._window = self._visdom.pie(X=visdom_data.y, **visdom_data.params) 84 | else: 85 | self._visdom.pie(X=visdom_data.y, win=self._window, **visdom_data.params) 86 | 87 | 88 | class TextPlot(VisdomPlot): 89 | def __init__(self, visdom): 90 | super().__init__(visdom) 91 | 92 | def update(self, visdom_data: VisdomData): 93 | if self._window is None: 94 | self._window = self._visdom.text(visdom_data.y) 95 | 96 | else: 97 | self._visdom.text(visdom_data.y, win=self._window) 98 | 99 | 100 | class HistogramPlot(VisdomPlot): 101 | 102 | def __init__(self, visdom: Visdom): 103 | super().__init__(visdom) 104 | 105 | def update(self, visdom_data: VisdomData): 106 | if self._window is None: 107 | self._window = self._visdom.histogram(X=visdom_data.y, **visdom_data.params) 108 | else: 109 | self._visdom.histogram(X=visdom_data.y, win=self._window, **visdom_data.params) 110 | 111 | 112 | class ScatterPlot(VisdomPlot): 113 | 114 | def __init__(self, visdom: Visdom): 115 | super().__init__(visdom) 116 | 117 | def update(self, visdom_data: VisdomData): 118 | if self._window is None: 119 | self._window = self._visdom.scatter(X=visdom_data.x, Y=visdom_data.y, **visdom_data.params) 120 | else: 121 | self._visdom.scatter(X=visdom_data.x, Y=visdom_data.y, win=self._window, **visdom_data.params) 122 | 123 | 124 | class StemPlot(VisdomPlot): 125 | 126 | def __init__(self, visdom: Visdom): 127 | super().__init__(visdom) 128 | 129 | def update(self, visdom_data: VisdomData): 130 | if self._window is None: 131 | self._window = self._visdom.stem(X=visdom_data.x, Y=visdom_data.y, **visdom_data.params) 132 | else: 133 | self._visdom.stem(X=visdom_data.x, Y=visdom_data.y, win=self._window, **visdom_data.params) 134 | 135 | 136 | class HeatmapPlot(VisdomPlot): 137 | 138 | def __init__(self, visdom: Visdom): 139 | super().__init__(visdom) 140 | 141 | def update(self, visdom_data: VisdomData): 142 | if self._window is None: 143 | self._window = self._visdom.heatmap(X=visdom_data.y, **visdom_data.params) 144 | else: 145 | self._visdom.heatmap(X=visdom_data.y, win=self._window, **visdom_data.params) 146 | 147 | 148 | class BoxPlot(VisdomPlot): 149 | 150 | def __init__(self, visdom: Visdom): 151 | super().__init__(visdom) 152 | 153 | def update(self, visdom_data: VisdomData): 154 | if self._window is None: 155 | self._window = self._visdom.boxplot(X=visdom_data.y, **visdom_data.params) 156 | else: 157 | self._visdom.boxplot(X=visdom_data.x, win=self._window, **visdom_data.params) 158 | 159 | 160 | class SurfacePlot(VisdomPlot): 161 | 162 | def __init__(self, visdom: Visdom): 163 | super().__init__(visdom) 164 | 165 | def update(self, visdom_data: VisdomData): 166 | if self._window is None: 167 | self._window = self._visdom.surf(X=visdom_data.y, **visdom_data.params) 168 | else: 169 | self._visdom.surf(X=visdom_data.y, win=self._window, **visdom_data.params) 170 | 171 | 172 | class ContourPlot(VisdomPlot): 173 | 174 | def __init__(self, visdom: Visdom): 175 | super().__init__(visdom) 176 | 177 | def update(self, visdom_data: VisdomData): 178 | if self._window is None: 179 | self._window = self._visdom.contour(X=visdom_data.y, **visdom_data.params) 180 | else: 181 | self._visdom.contour(X=visdom_data.y, win=self._window, **visdom_data.params) 182 | 183 | 184 | class QuiverPlot(VisdomPlot): 185 | 186 | def __init__(self, visdom: Visdom): 187 | super().__init__(visdom) 188 | 189 | def update(self, visdom_data: VisdomData): 190 | if self._window is None: 191 | self._window = self._visdom.quiver(X=visdom_data.x, Y=visdom_data.y, **visdom_data.params) 192 | else: 193 | self._visdom.quiver(X=visdom_data.x, Y=visdom_data.y, **visdom_data.params) 194 | 195 | 196 | class MeshPlot(VisdomPlot): 197 | 198 | def __init__(self, visdom: Visdom): 199 | super().__init__(visdom) 200 | 201 | def update(self, visdom_data: VisdomData): 202 | if self._window is None: 203 | self._window = self._visdom.mesh(X=visdom_data.x, Y=visdom_data.y, **visdom_data.params) 204 | else: 205 | self._visdom.mesh(X=visdom_data.x, Y=visdom_data.y, **visdom_data.params) 206 | 207 | 208 | class BarPlot(VisdomPlot): 209 | def __init__(self, visdom): 210 | super().__init__(visdom) 211 | 212 | def update(self, visdom_data: VisdomData): 213 | if self._window is None: 214 | self._window = self._visdom.bar(X=visdom_data.x, Y=visdom_data.y, **visdom_data.params) 215 | else: 216 | self._visdom.bar(X=visdom_data.x, Y=visdom_data.y, win=self._window, **visdom_data.params) 217 | 218 | 219 | class MatplotlibPlot(VisdomPlot): 220 | def __init__(self, visdom): 221 | super().__init__(visdom) 222 | 223 | def update(self, visdom_data: VisdomData): 224 | if self._window is None: 225 | self._window = self._visdom.matplot(plot=visdom_data.y, **visdom_data.params) 226 | else: 227 | self._visdom.matplot(plot=visdom_data.y, win=self._window, **visdom_data.params) 228 | 229 | 230 | class VisdomPlotFactory(object): 231 | 232 | def __init__(self): 233 | self._plot = { 234 | PlotType.LINE_PLOT: LinePlot, 235 | PlotType.IMAGE_PLOT: ImagePlot, 236 | PlotType.IMAGES_PLOT: ImagesPlot, 237 | PlotType.PIE_PLOT: PiePlot, 238 | PlotType.TEXT_PLOT: TextPlot, 239 | PlotType.BAR_PLOT: BarPlot, 240 | PlotType.HISTOGRAM_PLOT: HistogramPlot, 241 | PlotType.SCATTER_PLOT: ScatterPlot, 242 | PlotType.STEM_PLOT: StemPlot, 243 | PlotType.HEATMAP_PLOT: HeatmapPlot, 244 | PlotType.BOX_PLOT: BoxPlot, 245 | PlotType.SURFACE_PLOT: SurfacePlot, 246 | PlotType.CONTOUR_PLOT: ContourPlot, 247 | PlotType.QUIVER_PLOT: QuiverPlot, 248 | PlotType.MESH_PLOT: MeshPlot, 249 | PlotType.MATPLOTLIB_PLOT: MatplotlibPlot 250 | } 251 | 252 | def create_plot(self, visdom, plot_type: PlotType): 253 | return self._plot[plot_type](visdom) 254 | -------------------------------------------------------------------------------- /kerosene/configs/configs.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2019 Kerosene Authors. All Rights Reserved. 3 | # 4 | # Licensed under the MIT License; 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://opensource.org/licenses/MIT 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | import abc 17 | import os 18 | from socket import socket 19 | from typing import List 20 | 21 | import torch 22 | 23 | try: 24 | from torch.distributed import is_nccl_available 25 | 26 | NCCL_AVAILABLE = is_nccl_available() 27 | except ImportError: 28 | NCCL_AVAILABLE = False 29 | 30 | from kerosene.configs.exceptions import InvalidConfigurationError 31 | from kerosene.utils.devices import get_devices, on_multiple_gpus, num_gpus, on_cpu 32 | 33 | 34 | class HtmlConfiguration(object): 35 | 36 | @abc.abstractmethod 37 | def to_html(self): 38 | raise NotImplementedError() 39 | 40 | 41 | class DatasetConfiguration(HtmlConfiguration): 42 | def __init__(self, config_dict): 43 | for key in config_dict: 44 | setattr(self, key, config_dict[key]) 45 | 46 | def to_html(self): 47 | configuration_values = '\n'.join("

%s: %s

" % item for item in vars(self).items()) 48 | return "

Dataset Configuration

\n {}".format(configuration_values) 49 | 50 | 51 | class Configuration(object): 52 | def __init__(self, name, type, params): 53 | self._name = name 54 | self._type = type 55 | self._params = params 56 | 57 | @property 58 | def name(self): 59 | return self._name 60 | 61 | @name.setter 62 | def name(self, value): 63 | self._name = value 64 | 65 | @property 66 | def type(self): 67 | return self._type 68 | 69 | @type.setter 70 | def type(self, value): 71 | self._type = value 72 | 73 | @property 74 | def params(self): 75 | return self._params 76 | 77 | @params.setter 78 | def params(self, value): 79 | self._params = value 80 | 81 | def update(self, config_dict): 82 | if "type" in config_dict: 83 | self._type = config_dict["type"] 84 | 85 | if "params" in config_dict: 86 | self._params.update(config_dict["params"]) 87 | 88 | @classmethod 89 | def from_dict(cls, name, config_dict): 90 | return cls(name, config_dict["type"], config_dict.get("params", {})) 91 | 92 | 93 | class ConfigurationList(Configuration): 94 | def __init__(self, configurations: List[Configuration]): 95 | super().__init__(list(map(lambda config: config.name, configurations)), 96 | list(map(lambda config: config.type, configurations)), 97 | list(map(lambda config: config.params, configurations))) 98 | self._configurations = configurations 99 | 100 | def update(self, config_dict): 101 | for configuration in self._configurations: 102 | if configuration.name in config_dict: 103 | configuration.update(config_dict[configuration.name]) 104 | 105 | def __len__(self): 106 | return len(self._configurations) 107 | 108 | def __getitem__(self, index): 109 | return self._configurations[index] 110 | 111 | def __iter__(self): 112 | return iter(self._configurations) 113 | 114 | 115 | class ModelConfiguration(Configuration, HtmlConfiguration): 116 | def __init__(self, model_name, model_type, model_params, path, optimizer_config: Configuration, 117 | scheduler_config: Configuration, criterions_configs: ConfigurationList, 118 | metrics_configs: ConfigurationList, gradient_clipping_config: Configuration): 119 | super().__init__(model_name, model_type, model_params) 120 | self._path = path 121 | self._optimizer_config = optimizer_config 122 | self._scheduler_config = scheduler_config 123 | self._criterions_configs = criterions_configs 124 | self._metrics_configs = metrics_configs 125 | self._gradient_clipping_config = gradient_clipping_config 126 | 127 | @property 128 | def path(self): 129 | return self._path 130 | 131 | @property 132 | def optimizer_config(self): 133 | return self._optimizer_config 134 | 135 | @property 136 | def scheduler_config(self): 137 | return self._scheduler_config 138 | 139 | @property 140 | def criterions_configs(self): 141 | return self._criterions_configs 142 | 143 | @property 144 | def metrics_configs(self): 145 | return self._metrics_configs 146 | 147 | @property 148 | def gradient_clipping_config(self): 149 | return self._gradient_clipping_config 150 | 151 | def update(self, config_dict): 152 | if "type" in config_dict: 153 | self._type = config_dict["type"] 154 | 155 | if "params" in config_dict: 156 | self._params.update(config_dict["params"]) 157 | 158 | if "optimizer" in config_dict: 159 | self._optimizer_config.update(config_dict["optimizer"]) 160 | 161 | if "scheduler" in config_dict: 162 | self._scheduler_config.update(config_dict["scheduler"]) 163 | 164 | if "criterions" in config_dict: 165 | self._criterions_configs.update(config_dict["criterions"]) 166 | 167 | if "metrics" in config_dict: 168 | self._metrics_configs.update(config_dict["metrics"]) 169 | 170 | @classmethod 171 | def from_dict(cls, model_name, config_dict): 172 | try: 173 | optimizer_config = Configuration.from_dict("", config_dict["optimizer"]) 174 | scheduler_config = Configuration.from_dict("", config_dict["scheduler"]) 175 | criterion_configs = ConfigurationList( 176 | [Configuration.from_dict(criterion, config_dict["criterion"][criterion]) for criterion 177 | in config_dict["criterion"].keys()]) 178 | 179 | if "metrics" in config_dict.keys(): 180 | metric_configs = ConfigurationList( 181 | [Configuration.from_dict(metric, config_dict["metrics"][metric]) for metric in 182 | config_dict["metrics"].keys()]) 183 | else: 184 | metric_configs = None 185 | 186 | if "gradients" in config_dict.keys(): 187 | gradient_clipping_config = Configuration.from_dict("", config_dict["gradients"]) 188 | else: 189 | gradient_clipping_config = None 190 | 191 | return cls(model_name, config_dict["type"], config_dict.get("params"), config_dict.get("path"), 192 | optimizer_config, scheduler_config, criterion_configs, metric_configs, gradient_clipping_config) 193 | except KeyError as e: 194 | raise InvalidConfigurationError( 195 | "The provided model configuration is invalid. The section {} is missing.".format(e)) 196 | 197 | def to_html(self): 198 | configuration_values = '\n'.join("

%s: %s

" % item for item in vars(self).items()) 199 | return "

Model Configuration \n

{}

\n {}".format(self.name, configuration_values) 200 | 201 | 202 | class RunConfiguration(HtmlConfiguration): 203 | 204 | def __init__(self, use_amp: bool = True, amp_opt_level: str = 'O1', local_rank: int = 0, 205 | world_size: int = num_gpus()): 206 | self._use_amp = use_amp 207 | self._amp_opt_level = amp_opt_level 208 | self._local_rank = local_rank 209 | self._world_size = world_size 210 | 211 | self._devices = get_devices() 212 | self._device = self._devices[self._local_rank] 213 | 214 | if not on_cpu(self._device): 215 | torch.cuda.set_device(self._device) 216 | self._initialize_ddp_process_group() 217 | 218 | @property 219 | def use_amp(self): 220 | return self._use_amp 221 | 222 | @property 223 | def amp_opt_level(self): 224 | return self._amp_opt_level 225 | 226 | @property 227 | def local_rank(self): 228 | return self._local_rank 229 | 230 | @property 231 | def world_size(self): 232 | return self._world_size 233 | 234 | @property 235 | def devices(self): 236 | return self._devices 237 | 238 | @property 239 | def device(self): 240 | return self._device 241 | 242 | @device.setter 243 | def device(self, device): 244 | self._device = device 245 | 246 | @staticmethod 247 | def _get_random_free_port(): 248 | with socket() as s: 249 | s.bind(("", 0)) 250 | return s.getsockname()[1] 251 | 252 | def _initialize_ddp_process_group(self): 253 | if on_multiple_gpus(self._devices): 254 | if NCCL_AVAILABLE: 255 | if os.environ.get("MASTER_ADDR") is None: 256 | os.environ["MASTER_ADDR"] = "127.0.0.1" 257 | if os.environ.get("MASTER_PORT") is None: 258 | os.environ["MASTER_PORT"] = str(self._get_random_free_port()) 259 | if os.environ.get("WORLD_SIZE") is None: 260 | os.environ["WORLD_SIZE"] = str(self._world_size) 261 | torch.distributed.init_process_group(backend='nccl', init_method='env://', 262 | world_size=int(os.environ["WORLD_SIZE"]), rank=self._local_rank) 263 | else: 264 | raise Exception("NCCL not available and required for multi-GPU training.") 265 | 266 | def to_html(self): 267 | configuration_values = '\n'.join("

%s: %s

" % item for item in vars(self).items()) 268 | return "

Run Configuration

\n {}".format(configuration_values) 269 | 270 | 271 | class TrainerConfiguration(HtmlConfiguration): 272 | def __init__(self, config_dict): 273 | for key in config_dict: 274 | setattr(self, key, config_dict[key]) 275 | 276 | def to_html(self): 277 | configuration_values = '\n'.join("

%s: %s

" % item for item in vars(self).items()) 278 | return "

Training Configuration

\n {}".format(configuration_values) 279 | -------------------------------------------------------------------------------- /tests/events/handlers/test_checkpoints.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import unittest 4 | from unittest.mock import PropertyMock 5 | 6 | import mock 7 | import mockito 8 | import torch 9 | from hamcrest import * 10 | from ignite.metrics import Accuracy 11 | from torch import nn 12 | from torch.optim import Optimizer, lr_scheduler 13 | 14 | from kerosene.events import MonitorMode, Moment, Frequency, Phase, TemporalEvent, Monitor 15 | from kerosene.events.handlers.checkpoints import Checkpoint 16 | from kerosene.nn.utils.gradients import GradientClippingStrategy 17 | from kerosene.training.events import Event 18 | from kerosene.training.trainers import SimpleTrainer, ModelTrainer, ModelTrainerList 19 | 20 | 21 | class ModelCheckpointIfBetterTest(unittest.TestCase): 22 | MODEL_NAME = "test_model" 23 | TRAINER_NAME = "test_trainer" 24 | SAVE_PATH = "tests/output/" 25 | 26 | def setUp(self) -> None: 27 | self._model_mock = mockito.mock(nn.Module) 28 | self._criterion_mock = {"CrossEntropyLoss": mockito.mock(nn.CrossEntropyLoss)} 29 | self._optimizer_mock = mockito.mock(Optimizer) 30 | self._scheduler_mock = mockito.mock(lr_scheduler) 31 | self._metric_computer_mock = {"Accuracy": mockito.mock(Accuracy)} 32 | self._gradient_clipping_strategy = mockito.mock(GradientClippingStrategy) 33 | 34 | self._model_trainer = ModelTrainer(self.MODEL_NAME, self._model_mock, self._criterion_mock, 35 | self._optimizer_mock, self._scheduler_mock, self._metric_computer_mock, 36 | self._gradient_clipping_strategy) 37 | 38 | self._trainer_mock = mockito.mock(SimpleTrainer) 39 | self._trainer_mock.epoch = 0 40 | self._trainer_mock.model_trainers = ModelTrainerList([self._model_trainer]) 41 | 42 | def tearDown(self) -> None: 43 | if os.path.exists(self.SAVE_PATH): 44 | shutil.rmtree(self.SAVE_PATH) 45 | 46 | @mock.patch("kerosene.training.trainers.ModelTrainer.optimizer_state", new_callable=PropertyMock) 47 | @mock.patch("kerosene.training.trainers.ModelTrainer.model_state", new_callable=PropertyMock) 48 | def test_should_not_save_model_with_higher_valid_loss(self, model_state_mock, optimizer_states_mock): 49 | model_state_mock.return_value = dict() 50 | optimizer_states_mock.return_value = list(dict()) 51 | moment = Moment(200, Frequency.EPOCH, Phase.VALIDATION) 52 | 53 | handler_mock = mockito.spy(Checkpoint(self.SAVE_PATH, self.MODEL_NAME, "MSELoss", 0.01, MonitorMode.MIN)) 54 | 55 | monitors = {self.MODEL_NAME: {Phase.TRAINING: {Monitor.METRICS: {}, Monitor.LOSS: {}}, 56 | Phase.VALIDATION: {Monitor.METRICS: {}, 57 | Monitor.LOSS: {"MSELoss": torch.tensor([0.5])}}, 58 | Phase.TEST: {Monitor.METRICS: {}, Monitor.LOSS: {}}}} 59 | 60 | handler_mock(TemporalEvent(Event.ON_EPOCH_END, moment), monitors, self._trainer_mock) 61 | 62 | monitors = {self.MODEL_NAME: {Phase.TRAINING: {Monitor.METRICS: {}, Monitor.LOSS: {}}, 63 | Phase.VALIDATION: {Monitor.METRICS: {}, 64 | Monitor.LOSS: {"MSELoss": torch.tensor([0.6])}}, 65 | Phase.TEST: {Monitor.METRICS: {}, Monitor.LOSS: {}}}} 66 | 67 | handler_mock(TemporalEvent(Event.ON_EPOCH_END, moment), monitors, self._trainer_mock) 68 | 69 | assert_that(not os.path.exists(os.path.join(self.SAVE_PATH, self.MODEL_NAME, self.MODEL_NAME + ".tar"))) 70 | 71 | @mock.patch("kerosene.training.trainers.ModelTrainer.optimizer_state", new_callable=PropertyMock) 72 | @mock.patch("kerosene.training.trainers.ModelTrainer.model_state", new_callable=PropertyMock) 73 | def test_should_not_save_model_with_higher_valid_losses(self, model_state_mock, optimizer_states_mock): 74 | model_state_mock.return_value = dict() 75 | optimizer_states_mock.return_value = list(dict()) 76 | moment = Moment(200, Frequency.EPOCH, Phase.VALIDATION) 77 | 78 | handler_mock = mockito.spy( 79 | Checkpoint(self.SAVE_PATH, self.MODEL_NAME, ["MSELoss", "L1Loss"], 0.01, MonitorMode.MIN)) 80 | 81 | monitors = {self.MODEL_NAME: {Phase.TRAINING: {Monitor.METRICS: {}, Monitor.LOSS: {}}, 82 | Phase.VALIDATION: {Monitor.METRICS: {}, 83 | Monitor.LOSS: {"MSELoss": torch.tensor([0.5]), 84 | "L1Loss": torch.tensor([0.5])}}, 85 | Phase.TEST: {Monitor.METRICS: {}, Monitor.LOSS: {}}}} 86 | 87 | handler_mock(TemporalEvent(Event.ON_EPOCH_END, moment), monitors, self._trainer_mock) 88 | 89 | monitors = {self.MODEL_NAME: {Phase.TRAINING: {Monitor.METRICS: {}, Monitor.LOSS: {}}, 90 | Phase.VALIDATION: {Monitor.METRICS: {}, 91 | Monitor.LOSS: {"MSELoss": torch.tensor([0.5]), 92 | "L1Loss": torch.tensor([0.6])}}, 93 | Phase.TEST: {Monitor.METRICS: {}, Monitor.LOSS: {}}}} 94 | 95 | handler_mock(TemporalEvent(Event.ON_EPOCH_END, moment), monitors, self._trainer_mock) 96 | 97 | assert_that(not os.path.exists(os.path.join(self.SAVE_PATH, self.MODEL_NAME, self.MODEL_NAME + ".tar"))) 98 | 99 | @mock.patch("kerosene.training.trainers.ModelTrainer.optimizer_state", new_callable=PropertyMock) 100 | @mock.patch("kerosene.training.trainers.ModelTrainer.model_state", new_callable=PropertyMock) 101 | def test_should_save_model_with_lower_valid_loss(self, model_state_mock, optimizer_states_mock): 102 | model_state_mock.return_value = dict() 103 | optimizer_states_mock.return_value = list(dict()) 104 | moment = Moment(200, Frequency.EPOCH, Phase.VALIDATION) 105 | 106 | handler_mock = mockito.spy(Checkpoint(self.SAVE_PATH, self.MODEL_NAME, "MSELoss", 0.01, MonitorMode.MIN)) 107 | 108 | monitors = {self.MODEL_NAME: {Phase.TRAINING: {Monitor.METRICS: {}, Monitor.LOSS: {}}, 109 | Phase.VALIDATION: {Monitor.METRICS: {}, 110 | Monitor.LOSS: {"MSELoss": torch.tensor([0.5])}}, 111 | Phase.TEST: {Monitor.METRICS: {}, Monitor.LOSS: {}}}} 112 | 113 | handler_mock(TemporalEvent(Event.ON_EPOCH_END, moment), monitors, self._trainer_mock) 114 | 115 | monitors = {self.MODEL_NAME: {Phase.TRAINING: {Monitor.METRICS: {}, Monitor.LOSS: {}}, 116 | Phase.VALIDATION: {Monitor.METRICS: {}, 117 | Monitor.LOSS: {"MSELoss": torch.tensor([0.3])}}, 118 | Phase.TEST: {Monitor.METRICS: {}, Monitor.LOSS: {}}}} 119 | 120 | handler_mock(TemporalEvent(Event.ON_EPOCH_END, moment), monitors, self._trainer_mock) 121 | 122 | assert_that(os.path.exists(os.path.join(self.SAVE_PATH, self.MODEL_NAME, self.MODEL_NAME + ".tar"))) 123 | 124 | @mock.patch("kerosene.training.trainers.ModelTrainer.optimizer_state", new_callable=PropertyMock) 125 | @mock.patch("kerosene.training.trainers.ModelTrainer.model_state", new_callable=PropertyMock) 126 | def test_should_save_model_with_higher_valid_metric(self, model_state_mock, optimizer_states_mock): 127 | model_state_mock.return_value = dict() 128 | optimizer_states_mock.return_value = list(dict()) 129 | moment = Moment(200, Frequency.EPOCH, Phase.VALIDATION) 130 | 131 | handler_mock = mockito.spy(Checkpoint(self.SAVE_PATH, self.MODEL_NAME, "Accuracy", 0.01, MonitorMode.MAX)) 132 | 133 | monitors = {self.MODEL_NAME: {Phase.TRAINING: {Monitor.METRICS: {}, Monitor.LOSS: {}}, 134 | Phase.VALIDATION: {Monitor.METRICS: {"Accuracy": torch.tensor([0.5])}, 135 | Monitor.LOSS: {}}, 136 | Phase.TEST: {Monitor.METRICS: {}, Monitor.LOSS: {}}}} 137 | 138 | handler_mock(TemporalEvent(Event.ON_EPOCH_END, moment), monitors, self._trainer_mock) 139 | 140 | monitors = {self.MODEL_NAME: {Phase.TRAINING: {Monitor.METRICS: {}, Monitor.LOSS: {}}, 141 | Phase.VALIDATION: {Monitor.METRICS: {"Accuracy": torch.tensor([0.6])}, 142 | Monitor.LOSS: {}}, 143 | Phase.TEST: {Monitor.METRICS: {}, Monitor.LOSS: {}}}} 144 | 145 | handler_mock(TemporalEvent(Event.ON_EPOCH_END, moment), monitors, self._trainer_mock) 146 | 147 | assert_that(os.path.exists(os.path.join(self.SAVE_PATH, self.MODEL_NAME, self.MODEL_NAME + ".tar"))) 148 | 149 | @mock.patch("kerosene.training.trainers.ModelTrainer.optimizer_state", new_callable=PropertyMock) 150 | @mock.patch("kerosene.training.trainers.ModelTrainer.model_state", new_callable=PropertyMock) 151 | def test_should_not_save_model_with_lower_valid_metric(self, model_state_mock, optimizer_states_mock): 152 | model_state_mock.return_value = dict() 153 | optimizer_states_mock.return_value = list(dict()) 154 | moment = Moment(200, Frequency.EPOCH, Phase.VALIDATION) 155 | 156 | handler_mock = mockito.spy(Checkpoint(self.SAVE_PATH, self.MODEL_NAME, "Accuracy", 0.01, MonitorMode.MAX)) 157 | 158 | monitors = {self.MODEL_NAME: {Phase.TRAINING: {Monitor.METRICS: {}, Monitor.LOSS: {}}, 159 | Phase.VALIDATION: {Monitor.METRICS: {"Accuracy": torch.tensor([0.5])}, 160 | Monitor.LOSS: {}}, 161 | Phase.TEST: {Monitor.METRICS: {}, Monitor.LOSS: {}}}} 162 | 163 | handler_mock(TemporalEvent(Event.ON_EPOCH_END, moment), monitors, self._trainer_mock) 164 | 165 | monitors = {self.MODEL_NAME: {Phase.TRAINING: {Monitor.METRICS: {}, Monitor.LOSS: {}}, 166 | Phase.VALIDATION: {Monitor.METRICS: {"Accuracy": torch.tensor([0.4])}, 167 | Monitor.LOSS: {}}, 168 | Phase.TEST: {Monitor.METRICS: {}, Monitor.LOSS: {}}}} 169 | 170 | handler_mock(TemporalEvent(Event.ON_EPOCH_END, moment), monitors, self._trainer_mock) 171 | 172 | assert_that(not os.path.exists(os.path.join(self.SAVE_PATH, self.MODEL_NAME, self.MODEL_NAME + ".tar"))) 173 | -------------------------------------------------------------------------------- /kerosene/events/handlers/visdom.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2019 Kerosene Authors. All Rights Reserved. 3 | # 4 | # Licensed under the MIT License; 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://opensource.org/licenses/MIT 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | import logging 17 | from abc import ABC 18 | from typing import Union, List 19 | 20 | from kerosene.events import TemporalEvent, Monitor, BaseEvent 21 | from kerosene.events.handlers.base_handler import EventHandler 22 | from kerosene.loggers.visdom import PlotType 23 | from kerosene.loggers.visdom.visdom import VisdomLogger, VisdomData 24 | from kerosene.training.events import Event 25 | from kerosene.training.trainers import Trainer, ModelTrainer 26 | 27 | 28 | class BaseVisdomHandler(EventHandler, ABC): 29 | def __init__(self, supported_events: List[Union[BaseEvent, TemporalEvent]], visdom_logger: VisdomLogger, every=1): 30 | super().__init__(supported_events, every) 31 | self._visdom_logger = visdom_logger 32 | 33 | @property 34 | def visdom_logger(self): 35 | return self._visdom_logger 36 | 37 | def create_visdom_data(self, *args, **kwargs): 38 | raise NotImplementedError() 39 | 40 | def flatten(self, list_of_visdom_data): 41 | return [item for sublist in list_of_visdom_data for item in sublist] 42 | 43 | 44 | class PlotMonitors(BaseVisdomHandler): 45 | SUPPORTED_EVENTS = [Event.ON_EPOCH_END, Event.ON_TRAIN_BATCH_END, Event.ON_VALID_BATCH_END, Event.ON_TEST_BATCH_END, 46 | Event.ON_BATCH_END] 47 | 48 | def __init__(self, visdom_logger: VisdomLogger, every=1): 49 | super().__init__(self.SUPPORTED_EVENTS, visdom_logger, every) 50 | self._plot_losses = PlotLosses(visdom_logger, every) 51 | self._plot_metrics = PlotMetrics(visdom_logger, every) 52 | 53 | def __call__(self, event: TemporalEvent, monitors: dict, trainer: Trainer): 54 | if self.should_handle(event): 55 | self._plot_losses(event, monitors, trainer) 56 | self._plot_metrics(event, monitors, trainer) 57 | 58 | def create_visdom_data(self, event, model_name, monitors): 59 | data = self._plot_losses.create_visdom_data(event, model_name, monitors) 60 | data.extend(self._plot_metrics.create_visdom_data(event, model_name, monitors)) 61 | 62 | 63 | class PlotLosses(BaseVisdomHandler): 64 | SUPPORTED_EVENTS = [Event.ON_EPOCH_END, Event.ON_TRAIN_BATCH_END, Event.ON_VALID_BATCH_END, Event.ON_TEST_BATCH_END, 65 | Event.ON_BATCH_END] 66 | 67 | def __init__(self, visdom_logger: VisdomLogger, every=1): 68 | super().__init__(self.SUPPORTED_EVENTS, visdom_logger, every) 69 | 70 | def __call__(self, event: TemporalEvent, monitors: dict, trainer: Trainer): 71 | data = list() 72 | 73 | if self.should_handle(event): 74 | for model_name, monitor in monitors.items(): 75 | data.extend(self.create_visdom_data(event, model_name, monitor[event.phase][Monitor.LOSS])) 76 | 77 | if data is not None: 78 | self.visdom_logger(data) 79 | 80 | def create_visdom_data(self, event, model_name, monitors): 81 | return [VisdomData(model_name, loss_name, PlotType.LINE_PLOT, event.frequency, [[event.iteration]], 82 | [[loss_value]], params={'opts': {'xlabel': str(event.frequency), 'ylabel': loss_name, 83 | 'title': "{} {} per {}".format(model_name, loss_name, 84 | str(event.frequency)), 85 | 'name': str(event.phase), 86 | 'legend': [str(event.phase)]}}) 87 | for loss_name, loss_value in monitors.items()] 88 | 89 | 90 | class PlotMetrics(BaseVisdomHandler): 91 | SUPPORTED_EVENTS = [Event.ON_EPOCH_END, Event.ON_TRAIN_BATCH_END, Event.ON_VALID_BATCH_END, Event.ON_TEST_BATCH_END, 92 | Event.ON_BATCH_END] 93 | 94 | def __init__(self, visdom_logger: VisdomLogger, every=1): 95 | super().__init__(self.SUPPORTED_EVENTS, visdom_logger, every) 96 | 97 | def __call__(self, event: TemporalEvent, monitors: dict, trainer: Trainer): 98 | data = list() 99 | 100 | if self.should_handle(event): 101 | for model_name, monitor in monitors.items(): 102 | data.extend(self.create_visdom_data(event, model_name, monitor[event.phase][Monitor.METRICS])) 103 | 104 | if data is not None: 105 | self.visdom_logger(data) 106 | 107 | def create_visdom_data(self, event, model_name, monitors): 108 | return [VisdomData(model_name, metric_name, PlotType.LINE_PLOT, event.frequency, [[event.iteration]], 109 | [[metric_value]], params={'opts': {'xlabel': str(event.frequency), 'ylabel': metric_name, 110 | 'title': "{} {} per {}".format(model_name, metric_name, 111 | str(event.frequency)), 112 | 'name': str(event.phase), 'legend': [str(event.phase)]}}) 113 | for metric_name, metric_value in monitors.items()] 114 | 115 | 116 | class PlotCustomVariables(BaseVisdomHandler): 117 | LOGGER = logging.getLogger("PlotCustomVariables") 118 | SUPPORTED_EVENTS = [Event.ON_EPOCH_END, Event.ON_TRAIN_EPOCH_END, Event.ON_VALID_EPOCH_END, Event.ON_TEST_EPOCH_END, 119 | Event.ON_TRAIN_BATCH_END, Event.ON_VALID_BATCH_END, Event.ON_TEST_BATCH_END, Event.ON_BATCH_END] 120 | 121 | def __init__(self, visdom_logger: VisdomLogger, variable_name, plot_type: PlotType, params, every=1): 122 | super().__init__(self.SUPPORTED_EVENTS, visdom_logger, every) 123 | self._variable_name = variable_name 124 | self._plot_type = plot_type 125 | self._params = params 126 | 127 | def __call__(self, event: TemporalEvent, monitors: dict, trainer: Trainer): 128 | data = None 129 | 130 | if self.should_handle(event): 131 | try: 132 | data = self.create_visdom_data(event, trainer) 133 | except KeyError: 134 | self.LOGGER.warning( 135 | "Unable to plot custom variable: {}. The variable is not in the trainer's custom variable dict.".format( 136 | self._variable_name)) 137 | 138 | if data is not None: 139 | self.visdom_logger(data) 140 | 141 | def create_visdom_data(self, event: TemporalEvent, trainer): 142 | if self._plot_type == PlotType.LINE_PLOT and "name" not in self._params['opts'].keys(): 143 | self._params['opts']['name'] = str(event.phase) 144 | 145 | return [VisdomData(trainer.name, self._variable_name, self._plot_type, event.frequency, 146 | [event.iteration], trainer.custom_variables[self._variable_name], self._params)] 147 | 148 | 149 | class PlotLR(BaseVisdomHandler): 150 | SUPPORTED_EVENTS = [Event.ON_EPOCH_END, Event.ON_TRAIN_BATCH_END, Event.ON_VALID_BATCH_END, Event.ON_TEST_BATCH_END, 151 | Event.ON_BATCH_END] 152 | 153 | def __init__(self, visdom_logger: VisdomLogger, every=1): 154 | super().__init__(self.SUPPORTED_EVENTS, visdom_logger, every) 155 | 156 | def __call__(self, event: TemporalEvent, monitors: dict, trainer: Trainer): 157 | data = None 158 | 159 | if self.should_handle(event): 160 | data = list(map( 161 | lambda model_trainer: self.create_visdom_data(event, model_trainer), trainer.model_trainers)) 162 | 163 | if data is not None: 164 | self.visdom_logger(data) 165 | 166 | def create_visdom_data(self, event, model_trainer: ModelTrainer): 167 | return VisdomData(model_trainer.name, "Learning Rate", PlotType.LINE_PLOT, event.frequency, [event.iteration], 168 | model_trainer.optimizer_lr, 169 | params={'opts': {'xlabel': str(event.frequency), 170 | 'ylabel': "Learning Rate", 171 | 'title': "{} {} per {}".format(model_trainer.name, "Learning Rate", 172 | str(event.frequency)), 173 | 'name': model_trainer.name, 174 | 'legend': [model_trainer.name]}}) 175 | 176 | 177 | class PlotAvgGradientPerLayer(BaseVisdomHandler): 178 | SUPPORTED_EVENTS = [Event.ON_EPOCH_END, Event.ON_TRAIN_BATCH_END, Event.ON_VALID_BATCH_END, Event.ON_TEST_BATCH_END, 179 | Event.ON_BATCH_END] 180 | 181 | def __init__(self, visdom_logger: VisdomLogger, every=1): 182 | super().__init__(self.SUPPORTED_EVENTS, visdom_logger, every) 183 | 184 | def __call__(self, event: TemporalEvent, monitors: dict, trainer: Trainer): 185 | data = None 186 | 187 | if self.should_handle(event): 188 | data = list( 189 | map(lambda model_trainer: self.create_visdom_data(event, model_trainer), trainer.model_trainers)) 190 | 191 | if data is not None: 192 | self.visdom_logger(data) 193 | 194 | def create_visdom_data(self, event: TemporalEvent, model_trainer: ModelTrainer): 195 | avg_grads, layers = [], [] 196 | 197 | for n, p in model_trainer.named_parameters(): 198 | if p.requires_grad and ("bias" not in n): 199 | layers.append(n) 200 | if p.grad is not None: 201 | avg_grads.append(p.grad.abs().mean().item()) 202 | else: 203 | avg_grads.append(0) 204 | 205 | return VisdomData(model_trainer.name, "Gradient Flow", PlotType.BAR_PLOT, event.frequency, y=layers, 206 | x=avg_grads, params={'opts': {'xlabel': "Layers", 'ylabel': "Avg. Gradients", 207 | 'title': "{} {} per {}".format(model_trainer.name, 208 | "Avg. Gradient", "Layer"), 209 | 'marginbottom': 200}}) 210 | 211 | 212 | class SaveVisdomEnvToFile(BaseVisdomHandler): 213 | SUPPORTED_EVENTS = [Event.ON_EPOCH_END, Event.ON_TRAIN_EPOCH_END, Event.ON_VALID_EPOCH_END, Event.ON_TEST_EPOCH_END, 214 | Event.ON_FINALIZE] 215 | 216 | def __init__(self, visdom_logger: VisdomLogger, params, every=1): 217 | super().__init__(self.SUPPORTED_EVENTS, visdom_logger, every) 218 | self._params = params 219 | 220 | def create_visdom_data(self, *args, **kwargs): 221 | pass 222 | 223 | def __call__(self, event: TemporalEvent, monitors: dict, trainer: Trainer): 224 | if self.should_handle(event): 225 | self._visdom_logger.save() 226 | -------------------------------------------------------------------------------- /tests/nn/test_criterions.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import torch 5 | from hamcrest import * 6 | 7 | from kerosene.nn.criterions import DiceLoss, GeneralizedDiceLoss, WeightedCrossEntropyLoss, TverskyLoss 8 | from kerosene.utils.tensors import to_onehot 9 | 10 | 11 | def get_y_true_y_pred(): 12 | # Generate an image with labels 0 (background), 1, 2 13 | # 3 classes: 14 | y_true = np.zeros((30, 30), dtype=np.int) 15 | y_true[1:11, 1:11] = 1 16 | y_true[15:25, 15:25] = 2 17 | 18 | y_pred = np.zeros((30, 30), dtype=np.int) 19 | y_pred[5:15, 1:11] = 1 20 | y_pred[20:30, 20:30] = 2 21 | return y_true, y_pred 22 | 23 | 24 | def compute_tensor_y_true_y_logits(y_true, y_pred): 25 | # Create torch.tensor from numpy 26 | y_true_tensor = torch.from_numpy(y_true).unsqueeze(0).type(torch.long) 27 | # Create logits torch.tensor: 28 | num_classes = max(np.max(y_true), np.max(y_pred)) + 1 29 | y_probas = np.ones((num_classes,) + y_true.shape) * 0.0 30 | for i in range(num_classes): 31 | y_probas[i, (y_pred == i)] = 1.0 32 | y_logits = torch.from_numpy(y_probas).unsqueeze(0).type(torch.float32) 33 | return y_true_tensor, y_logits 34 | 35 | 36 | def compute_dice_truth(y_true, y_pred): 37 | true_res = [0, 0, 0] 38 | for index in range(3): 39 | bin_y_true = y_true == index 40 | bin_y_pred = y_pred == index 41 | intersection = bin_y_true & bin_y_pred 42 | true_res[index] = 2.0 * intersection.sum() / (bin_y_pred.sum() + bin_y_true.sum()) 43 | return true_res 44 | 45 | 46 | def compute_generalized_dice_loss_truth(y_true, y_pred): 47 | true_res = [0, 0, 0] 48 | for index in range(3): 49 | bin_y_true = y_true == index 50 | bin_y_pred = y_pred == index 51 | weights = (1.0 / (np.sum(bin_y_true) * np.sum(bin_y_true) + 1e-15)) 52 | intersection = (bin_y_true & bin_y_pred) 53 | true_res[index] = 2 * intersection.sum() * weights / (((bin_y_pred.sum() + bin_y_true.sum()) * weights) + 1e-15) 54 | return true_res 55 | 56 | 57 | class TestDiceLoss(unittest.TestCase): 58 | INVALID_VALUE_1 = -1 59 | INVALID_VALUE_2 = "STEVE JOBS" 60 | INVALID_VALUE_3 = 10 61 | INVALID_VALUE_4 = 11 62 | 63 | def setUp(self): 64 | self.y_true, self.y_pred = get_y_true_y_pred() 65 | self.y_true_tensor, self.y_logits = compute_tensor_y_true_y_logits(self.y_true, self.y_pred) 66 | self.dice = compute_dice_truth(self.y_true, self.y_pred) 67 | self.mean_dice_loss = np.subtract(1.0, np.mean(self.dice)) 68 | 69 | def test_should_raise_exception_with_bad_values(self): 70 | dice_loss = DiceLoss() 71 | assert_that(calling(dice_loss.forward).with_args(inputs=None, targets=None), 72 | raises(AttributeError)) 73 | assert_that(calling(dice_loss.forward).with_args(inputs=self.y_logits, targets=None), 74 | raises(AttributeError)) 75 | assert_that(calling(dice_loss.forward).with_args(inputs=None, targets=self.y_true_tensor), 76 | raises(AttributeError)) 77 | 78 | def test_should_compute_dice(self): 79 | dice_loss = DiceLoss(reduction=None) 80 | loss = dice_loss.forward(self.y_logits, to_onehot(self.y_true_tensor, num_classes=3)) 81 | 82 | np.testing.assert_almost_equal(loss.numpy(), np.subtract(1.0, self.dice)) 83 | 84 | def test_should_compute_dice_for_multiclass_with_ignored_index(self): 85 | for ignore_index in range(3): 86 | dice_loss = DiceLoss(reduction=None, ignore_index=ignore_index) 87 | res = dice_loss.forward(self.y_logits, to_onehot(self.y_true_tensor, num_classes=3)) 88 | true_res = np.subtract(1.0, self.dice[:ignore_index] + self.dice[ignore_index + 1:]) 89 | np.testing.assert_almost_equal(res.numpy(), true_res), "{}: {} vs {}".format(ignore_index, res, true_res) 90 | 91 | def test_should_compute_mean_dice(self): 92 | dice_loss = DiceLoss(reduction="mean") 93 | loss = dice_loss.forward(self.y_logits, to_onehot(self.y_true_tensor, num_classes=3)) 94 | 95 | np.testing.assert_almost_equal(loss.numpy(), self.mean_dice_loss) 96 | 97 | def test_should_compute_mean_dice_for_multiclass_with_ignored_index(self): 98 | for ignore_index in range(3): 99 | dice_loss = DiceLoss(ignore_index=ignore_index) 100 | res = dice_loss.forward(self.y_logits, to_onehot(self.y_true_tensor, num_classes=3)) 101 | true_res = np.subtract(1.0, self.dice[:ignore_index] + self.dice[ignore_index + 1:]).mean() 102 | np.testing.assert_almost_equal(res.numpy(), true_res), "{}: {} vs {}".format(ignore_index, res, true_res) 103 | 104 | 105 | class TestGeneralizedDiceLoss(unittest.TestCase): 106 | INVALID_REDUCTION = "sum" 107 | INVALID_INDEX = -1 108 | 109 | def setUp(self): 110 | self.y_true, self.y_pred = get_y_true_y_pred() 111 | self.y_true_tensor, self.y_logits = compute_tensor_y_true_y_logits(self.y_true, self.y_pred) 112 | self.generalized_dice_loss = compute_generalized_dice_loss_truth(self.y_true, self.y_pred) 113 | self.mean_generalized_dice_loss = np.subtract(1.0, np.mean(self.generalized_dice_loss)) 114 | 115 | def test_should_raise_exception_with_bad_values(self): 116 | generalized_dice_loss = GeneralizedDiceLoss() 117 | assert_that(calling(GeneralizedDiceLoss).with_args(reduction=self.INVALID_REDUCTION), 118 | raises(NotImplementedError)) 119 | assert_that(calling(generalized_dice_loss.forward).with_args(inputs=None, targets=None), 120 | raises(AttributeError)) 121 | assert_that(calling(generalized_dice_loss.forward).with_args(inputs=self.y_logits, targets=None), 122 | raises(AttributeError)) 123 | assert_that(calling(generalized_dice_loss.forward).with_args(inputs=None, targets=self.y_true_tensor), 124 | raises(AttributeError)) 125 | 126 | def test_should_raise_exception_with_bad_ignore_index_values(self): 127 | generalized_dice_loss = GeneralizedDiceLoss(ignore_index=self.INVALID_INDEX) 128 | 129 | assert_that(calling(generalized_dice_loss.forward).with_args(inputs=self.y_logits, 130 | targets=to_onehot(self.y_true_tensor, 131 | num_classes=3)), 132 | raises(IndexError)) 133 | 134 | def test_should_compute_generalized_dice(self): 135 | generalized_dice_loss = GeneralizedDiceLoss() 136 | loss = generalized_dice_loss.forward(self.y_logits, to_onehot(self.y_true_tensor, num_classes=3)) 137 | np.testing.assert_almost_equal(loss.numpy(), self.mean_generalized_dice_loss) 138 | 139 | def test_should_compute_generalized_dice_for_multiclass_with_ignored_index(self): 140 | for ignore_index in range(3): 141 | generalized_dice_loss = GeneralizedDiceLoss(reduction=None, ignore_index=ignore_index) 142 | res = generalized_dice_loss.forward(self.y_logits, to_onehot(self.y_true_tensor, num_classes=3)) 143 | true_res = np.subtract(1.0, self.generalized_dice_loss[:ignore_index] + self.generalized_dice_loss[ 144 | ignore_index + 1:]) 145 | np.testing.assert_almost_equal(res.numpy(), true_res), "{}: {} vs {}".format(ignore_index, res, true_res) 146 | 147 | def test_should_compute_mean_generalized_dice(self): 148 | dice_loss = GeneralizedDiceLoss() 149 | loss = dice_loss.forward(self.y_logits, to_onehot(self.y_true_tensor, num_classes=3)) 150 | 151 | np.testing.assert_almost_equal(loss.numpy(), self.mean_generalized_dice_loss) 152 | 153 | def test_should_compute_mean_generalized_dice_for_multiclass_with_ignored_index(self): 154 | for ignore_index in range(3): 155 | dice_loss = GeneralizedDiceLoss(ignore_index=ignore_index) 156 | res = dice_loss.forward(self.y_logits, to_onehot(self.y_true_tensor, num_classes=3)) 157 | true_res = np.subtract(1.0, self.generalized_dice_loss[:ignore_index] + self.generalized_dice_loss[ 158 | ignore_index + 1:]).mean() 159 | np.testing.assert_almost_equal(res.numpy(), true_res), "{}: {} vs {}".format(ignore_index, res, true_res) 160 | 161 | 162 | class TestWeightedCrossEntropy(unittest.TestCase): 163 | WEIGHTED_CROSS_ENTROPY_LOSS_TRUTH = 1.0808 164 | 165 | def setUp(self): 166 | self.y_true, self.y_pred = get_y_true_y_pred() 167 | self.y_true_tensor, self.y_logits = compute_tensor_y_true_y_logits(self.y_true, self.y_pred) 168 | 169 | def test_should_raise_exception_with_bad_values(self): 170 | weighted_cross_entropy_loss = WeightedCrossEntropyLoss() 171 | assert_that(calling(weighted_cross_entropy_loss.forward).with_args(inputs=None, targets=None), 172 | raises(AttributeError)) 173 | assert_that(calling(weighted_cross_entropy_loss.forward).with_args(inputs=self.y_logits, targets=None), 174 | raises(AttributeError)) 175 | assert_that(calling(weighted_cross_entropy_loss.forward).with_args(inputs=None, targets=self.y_true_tensor), 176 | raises(AttributeError)) 177 | 178 | def test_should_compute_weights(self): 179 | weights = WeightedCrossEntropyLoss.compute_class_weights(self.y_logits) 180 | np.testing.assert_almost_equal(weights.numpy(), np.array([0.2857143, 8.0, 8.0]), decimal=7) 181 | 182 | def test_should_return_loss(self): 183 | weighted_cross_entropy_loss = WeightedCrossEntropyLoss() 184 | loss = weighted_cross_entropy_loss.forward(self.y_logits, self.y_true_tensor) 185 | np.testing.assert_almost_equal(loss.numpy(), self.WEIGHTED_CROSS_ENTROPY_LOSS_TRUTH) 186 | 187 | 188 | class TestTverskyLoss(unittest.TestCase): 189 | INVALID_VALUE_1 = -1 190 | INVALID_VALUE_2 = "STEVE JOBS" 191 | INVALID_VALUE_3 = 10 192 | INVALID_VALUE_4 = 11 193 | 194 | def setUp(self): 195 | self.y_true, self.y_pred = get_y_true_y_pred() 196 | self.y_true_tensor, self.y_logits = compute_tensor_y_true_y_logits(self.y_true, self.y_pred) 197 | self.dice = compute_dice_truth(self.y_true, self.y_pred) 198 | self.mean_dice_loss = np.subtract(1.0, np.mean(self.dice)) 199 | 200 | def test_should_raise_exception_with_bad_values(self): 201 | tversky_loss = TverskyLoss() 202 | assert_that(calling(tversky_loss.forward).with_args(inputs=None, targets=None), 203 | raises(AttributeError)) 204 | assert_that(calling(tversky_loss.forward).with_args(inputs=self.y_logits, targets=None), 205 | raises(AttributeError)) 206 | assert_that(calling(tversky_loss.forward).with_args(inputs=None, targets=self.y_true_tensor), 207 | raises(AttributeError)) 208 | 209 | def test_should_compute_tversky_index(self): 210 | tversky_loss = TverskyLoss(reduction=None) 211 | loss = tversky_loss.forward(self.y_logits, to_onehot(self.y_true_tensor, num_classes=3)) 212 | 213 | np.testing.assert_almost_equal(loss.numpy(), np.subtract(1.0, self.dice)) 214 | 215 | def test_should_compute_dice_for_multiclass_with_ignored_index(self): 216 | for ignore_index in range(3): 217 | tversky_loss = TverskyLoss(reduction=None, ignore_index=ignore_index) 218 | res = tversky_loss.forward(self.y_logits, to_onehot(self.y_true_tensor, num_classes=3)) 219 | true_res = np.subtract(1.0, self.dice[:ignore_index] + self.dice[ignore_index + 1:]) 220 | np.testing.assert_almost_equal(res.numpy(), true_res), "{}: {} vs {}".format(ignore_index, res, true_res) 221 | 222 | def test_should_compute_mean_dice(self): 223 | tversky_loss = TverskyLoss(reduction="mean") 224 | loss = tversky_loss.forward(self.y_logits, to_onehot(self.y_true_tensor, num_classes=3)) 225 | 226 | np.testing.assert_almost_equal(loss.numpy(), self.mean_dice_loss) 227 | 228 | def test_should_compute_mean_dice_for_multiclass_with_ignored_index(self): 229 | for ignore_index in range(3): 230 | tversky_loss = TverskyLoss(ignore_index=ignore_index) 231 | res = tversky_loss.forward(self.y_logits, to_onehot(self.y_true_tensor, num_classes=3)) 232 | true_res = np.subtract(1.0, self.dice[:ignore_index] + self.dice[ignore_index + 1:]).mean() 233 | np.testing.assert_almost_equal(res.numpy(), true_res), "{}: {} vs {}".format(ignore_index, res, true_res) 234 | -------------------------------------------------------------------------------- /kerosene/metrics/metrics.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2019 Kerosene Authors. All Rights Reserved. 3 | # 4 | # Licensed under the MIT License; 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://opensource.org/licenses/MIT 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | from enum import Enum 17 | from typing import Union, Tuple 18 | 19 | import torch 20 | from ignite.metrics import Accuracy, Precision, MeanAbsoluteError, MeanPairwiseDistance, MeanSquaredError, \ 21 | Recall, RootMeanSquaredError, TopKCategoricalAccuracy, Metric, IoU, ConfusionMatrix, MetricsLambda 22 | 23 | from kerosene.utils.constants import EPSILON 24 | from kerosene.utils.tensors import flatten, to_onehot 25 | 26 | 27 | class MetricType(Enum): 28 | Accuracy = "Accuracy" 29 | ConfusionMatrix = "ConfusionMatrix" 30 | DiceCoefficient = "Dice" 31 | GeneralizedDice = "GeneralizedDice" 32 | IoU = "IoU" 33 | MeanAbsoluteError = "MeanAbsoluteError" 34 | MeanPairwiseDistance = "MeanPairwiseDistance" 35 | MeanSquaredError = "MeanSquaredError" 36 | Precision = "Precision" 37 | Recall = "Recall" 38 | RootMeanSquaredError = "RootMeanSquaredError" 39 | TopKCategoricalAccuracy = "TopKCategoricalAccuracy" 40 | VariableAccumulation = "VariableAccumulation" 41 | 42 | def __str__(self): 43 | return self.value 44 | 45 | 46 | class MetricFactory(object): 47 | def __init__(self): 48 | self._metrics = { 49 | "Accuracy": Accuracy, 50 | "ConfusionMatrix": ConfusionMatrix, 51 | "Dice": Dice, 52 | "GeneralizedDice": GeneralizedDice, 53 | "IoU": IoU_, 54 | "MeanAbsoluteError": MeanAbsoluteError, 55 | "MeanPairwiseDistance": MeanPairwiseDistance, 56 | "MeanSquaredError": MeanSquaredError, 57 | "Precision": Precision, 58 | "Recall": Recall, 59 | "RootMeanSquaredError": RootMeanSquaredError, 60 | "TopKCategoricalAccuracy": TopKCategoricalAccuracy 61 | } 62 | 63 | def create(self, metric_type: Union[str, MetricType], params): 64 | return self._metrics[str(metric_type)](**params) if params is not None else self._metrics[str(metric_type)]() 65 | 66 | def register(self, metric: str, creator: Metric): 67 | """ 68 | Add a new metric. 69 | Args: 70 | metric (str): Metric's name. 71 | creator: A torch or ignite module object wrapping the new custom metric function. 72 | """ 73 | self._metrics[metric] = creator 74 | 75 | 76 | class Dice(Metric): 77 | """ 78 | The Dice Metric. 79 | """ 80 | SUPPORTED_REDUCTIONS = [None, "mean"] 81 | 82 | def __init__(self, num_classes: int, reduction: Union[None, str] = "mean", average: str = None, 83 | weight: torch.Tensor = None, ignore_index: int = -100, 84 | output_transform: callable = lambda x: x) -> None: 85 | """ 86 | Metric initializer. 87 | Args: 88 | num_classes (int): The number of classes in the problem. In case of images, num_classes should also count the background index 0. 89 | average (str, optional): Confusion matrix values averaging schema: None, "samples", "recall", "precision". 90 | Default is None. If `average="samples"` then confusion matrix values are normalized by the number of seen 91 | samples. If `average="recall"` then confusion matrix values are normalized such that diagonal values 92 | represent class recalls. If `average="precision"` then confusion matrix values are normalized such that 93 | diagonal values represent class precisions. 94 | reduction (str): The type of reduction to apply (e.g. 'mean'). 95 | ignore_index (int, optional): To ignore an index in Dice computation. 96 | output_transform (callable, optional): a callable that is used to transform the 97 | output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and 98 | you want to compute the metric with respect to one of the outputs. 99 | """ 100 | if reduction not in self.SUPPORTED_REDUCTIONS: 101 | raise NotImplementedError("Reduction type not supported.") 102 | self._num_classes = num_classes 103 | self._ignore_index = ignore_index 104 | self._reduction = reduction 105 | self._weight = weight 106 | self._cm = ConfusionMatrix(num_classes=num_classes, average=average, output_transform=output_transform) 107 | self._metric = self.create_dice_metric(self._cm) 108 | super(Dice, self).__init__(output_transform=output_transform) 109 | 110 | def reset(self) -> None: 111 | """ 112 | Reset the confusion matrix object. 113 | """ 114 | self._cm.confusion_matrix = torch.zeros(self._num_classes, self._num_classes, dtype=torch.float64) 115 | 116 | def compute(self) -> torch.Tensor: 117 | """ 118 | Compute the metric. 119 | Returns: 120 | :obj:`torch.Tensor`: The dice coefficient for each class. 121 | """ 122 | return self._metric.compute() 123 | 124 | def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None: 125 | """ 126 | Update the confusion matrix with output values. 127 | Args: 128 | output (tuple of :obj:`torch.Tensor`): A tuple containing predictions and ground truth of the form `(y_pred, y)`. 129 | """ 130 | self._cm.update(output) 131 | 132 | def create_dice_metric(self, cm: ConfusionMatrix): 133 | """ 134 | Computes the Sørensen–Dice Coefficient (https://en.wikipedia.org/wiki/Sørensen–Dice_coefficient) 135 | Args: 136 | cm (:obj:`ignite.metrics.ConfusionMatrix`): A confusion matrix representing the classification of data. 137 | Returns: 138 | array or float: The Sørensen–Dice Coefficient for each class or the mean Sørensen–Dice Coefficient. 139 | """ 140 | # Increase floating point precision 141 | cm = cm.type(torch.float64) 142 | dice = 2 * cm.diag() / (cm.sum(dim=1) + cm.sum(dim=0) + EPSILON) 143 | 144 | if self._ignore_index != -100: 145 | def remove_index(dice_vector): 146 | try: 147 | indices = list(range(len(dice_vector))) 148 | indices.remove(self._ignore_index) 149 | return dice_vector[indices] 150 | except ValueError as e: 151 | raise IndexError( 152 | "'ignore_index' must be non-negative, and lower than the number of classes in confusion matrix, but {} was given. ".format( 153 | self._ignore_index)) 154 | 155 | dice = MetricsLambda(remove_index, dice) 156 | 157 | if self._weight is not None: 158 | def multiply_weights(dice_vector): 159 | return self._weight * dice_vector 160 | 161 | dice = MetricsLambda(multiply_weights, dice) 162 | 163 | if self._reduction == "mean": 164 | dice = dice.mean() 165 | 166 | return dice 167 | 168 | 169 | class GeneralizedDice(Metric): 170 | """ 171 | The Generalized Dice Metric. 172 | """ 173 | SUPPORTED_REDUCTIONS = [None, "mean"] 174 | 175 | def __init__(self, num_classes: int, reduction: str = None, average: str = None, 176 | ignore_index: int = -100, output_transform: callable = lambda x: x) -> None: 177 | """ 178 | Metric initializer. 179 | Args: 180 | num_classes (int): The number of classes in the problem. In case of images, num_classes should also count the background index 0. 181 | average (str, optional): Confusion matrix values averaging schema: None, "samples", "recall", "precision". 182 | Default is None. If `average="samples"` then confusion matrix values are normalized by the number of seen 183 | samples. If `average="recall"` then confusion matrix values are normalized such that diagonal values 184 | represent class recalls. If `average="precision"` then confusion matrix values are normalized such that 185 | diagonal values represent class precisions. 186 | reduction (str): The type of reduction to apply (e.g. 'mean'). 187 | ignore_index (int, optional): To ignore an index in Dice computation. 188 | output_transform (callable, optional): a callable that is used to transform the 189 | output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and 190 | you want to compute the metric with respect to one of the outputs. 191 | """ 192 | if reduction not in self.SUPPORTED_REDUCTIONS: 193 | raise NotImplementedError("Reduction type not supported.") 194 | self._num_classes = num_classes 195 | self._average = average 196 | self._ignore_index = ignore_index 197 | self._metric = None 198 | self._reduction = reduction 199 | self._cm = ConfusionMatrix(num_classes=num_classes, average=average, 200 | output_transform=output_transform) 201 | super(GeneralizedDice, self).__init__(output_transform=output_transform) 202 | 203 | def reset(self) -> None: 204 | """ 205 | Reset the confusion matrix object. 206 | """ 207 | self._cm.confusion_matrix = torch.zeros(self._num_classes, self._num_classes, dtype=torch.float64) 208 | 209 | def compute(self) -> torch.Tensor: 210 | """ 211 | Compute the metric. 212 | Returns: 213 | :obj:`torch.Tensor`: The dice coefficient for each class. 214 | """ 215 | return self._metric.compute() 216 | 217 | def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None: 218 | """ 219 | Update the confusion matrix with output values. 220 | Args: 221 | output (tuple of :obj:`torch.Tensor`): A tuple containing predictions and ground truth of the form `(y_pred, y)`. 222 | """ 223 | flattened_targets = flatten(to_onehot(output[1], output[0].size(1))).double() 224 | ones = torch.Tensor().new_ones((output[0].size(1),), dtype=torch.double, device=flattened_targets.device) 225 | weights = ones / torch.pow(flattened_targets.sum(-1), 2).clamp(min=EPSILON) 226 | 227 | self._metric = self.create_generalized_dice_metric(self._cm, weights) 228 | 229 | if self._reduction == "mean": 230 | self._metric = self._metric.mean() 231 | elif self._reduction is None: 232 | pass 233 | else: 234 | raise NotImplementedError("Reduction method not implemented.") 235 | 236 | self._cm.update(output) 237 | 238 | def create_generalized_dice_metric(self, cm: ConfusionMatrix, weight: torch.Tensor): 239 | """ 240 | Computes the Sørensen–Dice Coefficient (https://en.wikipedia.org/wiki/Sørensen–Dice_coefficient) 241 | Args: 242 | cm (:obj:`ignite.metrics.ConfusionMatrix`): A confusion matrix representing the classification of data. 243 | weight (:obj:`torch.Tensor`): A weight vector which length equals to the number of classes. 244 | Returns: 245 | ignite.Metric: The Generalized Dice Coefficient Metric object. 246 | """ 247 | 248 | # Increase floating point precision 249 | cm = cm.type(torch.float64) 250 | dice = 2 * (cm.diag() * weight) / (((cm.sum(dim=1) + cm.sum(dim=0)) * weight) + EPSILON) 251 | 252 | if self._ignore_index != -100: 253 | def remove_index(dice_vector): 254 | try: 255 | indices = list(range(len(dice_vector))) 256 | indices.remove(self._ignore_index) 257 | return dice_vector[indices] 258 | except ValueError as e: 259 | raise IndexError( 260 | "'ignore_index' must be non-negative, and lower than the number of classes in confusion matrix, but {} was given. ".format( 261 | self._ignore_index)) 262 | 263 | return MetricsLambda(remove_index, dice) 264 | else: 265 | return dice 266 | 267 | 268 | class IoU_(Metric): 269 | """ 270 | The IoU Metric. 271 | """ 272 | SUPPORTED_REDUCTIONS = [None, "mean"] 273 | 274 | def __init__(self, num_classes: int, reduction: Union[None, str] = "mean", average: str = None, 275 | weight: torch.Tensor = None, ignore_index: int = -100, 276 | output_transform: callable = lambda x: x) -> None: 277 | """ 278 | Metric initializer. 279 | Args: 280 | num_classes (int): The number of classes in the problem. In case of images, num_classes should also count the background index 0. 281 | average (str, optional): Confusion matrix values averaging schema: None, "samples", "recall", "precision". 282 | Default is None. If `average="samples"` then confusion matrix values are normalized by the number of seen 283 | samples. If `average="recall"` then confusion matrix values are normalized such that diagonal values 284 | represent class recalls. If `average="precision"` then confusion matrix values are normalized such that 285 | diagonal values represent class precisions. 286 | reduction (str): The type of reduction to apply (e.g. 'mean'). 287 | ignore_index (int, optional): To ignore an index in Dice computation. 288 | output_transform (callable, optional): a callable that is used to transform the 289 | output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and 290 | you want to compute the metric with respect to one of the outputs. 291 | """ 292 | if reduction not in self.SUPPORTED_REDUCTIONS: 293 | raise NotImplementedError("Reduction type not supported.") 294 | self._num_classes = num_classes 295 | self._ignore_index = ignore_index 296 | self._reduction = reduction 297 | self._weight = weight 298 | self._cm = ConfusionMatrix(num_classes=num_classes, average=average, output_transform=output_transform) 299 | self._metric = self.create_iou_metric(self._cm) 300 | super(IoU_, self).__init__(output_transform=output_transform) 301 | 302 | def reset(self) -> None: 303 | """ 304 | Reset the confusion matrix object. 305 | """ 306 | self._cm.confusion_matrix = torch.zeros(self._num_classes, self._num_classes, dtype=torch.float64) 307 | 308 | def compute(self) -> torch.Tensor: 309 | """ 310 | Compute the metric. 311 | Returns: 312 | :obj:`torch.Tensor`: The dice coefficient for each class. 313 | """ 314 | return self._metric.compute() 315 | 316 | def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None: 317 | """ 318 | Update the confusion matrix with output values. 319 | Args: 320 | output (tuple of :obj:`torch.Tensor`): A tuple containing predictions and ground truth of the form `(y_pred, y)`. 321 | """ 322 | self._cm.update(output) 323 | 324 | def create_iou_metric(self, cm: ConfusionMatrix): 325 | """ 326 | Computes the Sørensen–Dice Coefficient (https://en.wikipedia.org/wiki/Sørensen–Dice_coefficient) 327 | Args: 328 | cm (:obj:`ignite.metrics.ConfusionMatrix`): A confusion matrix representing the classification of data. 329 | Returns: 330 | array or float: The Sørensen–Dice Coefficient for each class or the mean Sørensen–Dice Coefficient. 331 | """ 332 | 333 | metric = IoU(cm, ignore_index=self._ignore_index) 334 | 335 | if self._reduction == "mean": 336 | metric = metric.mean() 337 | return metric 338 | 339 | 340 | def to_tensor(value): 341 | if not isinstance(value, torch.Tensor): 342 | return torch.tensor([value]) 343 | else: 344 | return value 345 | -------------------------------------------------------------------------------- /kerosene/nn/criterions.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2019 Kerosene Authors. All Rights Reserved. 3 | # 4 | # Licensed under the MIT License; 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://opensource.org/licenses/MIT 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | from enum import Enum 17 | from typing import Union 18 | 19 | import torch 20 | from ignite.metrics import MetricsLambda 21 | from torch import nn 22 | from torch.nn.modules.loss import _Loss, _WeightedLoss 23 | 24 | from kerosene.utils.constants import EPSILON 25 | from kerosene.utils.tensors import flatten 26 | 27 | 28 | class CriterionType(Enum): 29 | DiceLoss = "DiceLoss" 30 | GeneralizedDiceLoss = "GeneralizedDiceLoss" 31 | TverskyLoss = "TverskyLoss" 32 | FocalTverskyLoss = "FocalTverskyLoss" 33 | BCELoss = "BCELoss" 34 | BCEWithLogitsLoss = "BCEWithLogtisLoss" 35 | PoissonNLLLoss = "PoissonNLLLoss" 36 | CosineEmbeddingLoss = "CosineEmbeddingLoss" 37 | CrossEntropyLoss = "CrossEntropyLoss" 38 | CTCLoss = "CTCLoss" 39 | HingeEmbeddingLoss = "HingeEmbeddingLoss" 40 | KLDivLoss = "KLDivLoss" 41 | L1Loss = "L1Loss" 42 | MSELoss = "MSELoss" 43 | MarginRankingLoss = "MarginRankingLoss" 44 | MultiLabelMarginLoss = "MultiLabelMarginLoss" 45 | MultiLabelSoftMarginLoss = "MultiLabelSoftMarginLoss" 46 | MultiMarginLoss = "MultiMarginLoss" 47 | NLLLoss = "NLLLoss" 48 | SmoothL1Loss = "SmoothL1Loss" 49 | SoftMarginLoss = "SoftMarginLoss" 50 | TripletMarginLoss = "TripletMarginLoss" 51 | 52 | 53 | class CriterionFactory(object): 54 | def __init__(self): 55 | super(CriterionFactory, self).__init__() 56 | self._criterion = { 57 | "DiceLoss": DiceLoss, 58 | "GeneralizedDiceLoss": GeneralizedDiceLoss, 59 | "TverskyLoss": TverskyLoss, 60 | "FocalTverskyLoss": FocalTverskyLoss, 61 | "BCELoss": nn.BCELoss, 62 | "BCEWithLogitsLoss": nn.BCEWithLogitsLoss, 63 | "PoissonNLLLoss": nn.PoissonNLLLoss, 64 | "CosineEmbeddingLoss": nn.CosineEmbeddingLoss, 65 | "CrossEntropyLoss": nn.CrossEntropyLoss, 66 | "CTCLoss": nn.CTCLoss, 67 | "HingeEmbeddingLoss": nn.HingeEmbeddingLoss, 68 | "KLDivLoss": nn.KLDivLoss, 69 | "L1Loss": nn.L1Loss, 70 | "MSELoss": nn.MSELoss, 71 | "MarginRankingLoss": nn.MarginRankingLoss, 72 | "MultiLabelMarginLoss": nn.MultiLabelMarginLoss, 73 | "MultiLabelSoftMarginLoss": nn.MultiLabelSoftMarginLoss, 74 | "MultiMarginLoss": nn.MultiMarginLoss, 75 | "NLLLoss": nn.NLLLoss, 76 | "SmoothL1Loss": nn.SmoothL1Loss, 77 | "SoftMarginLoss": nn.SoftMarginLoss, 78 | "TripletMarginLoss": nn.TripletMarginLoss, 79 | } 80 | 81 | def create(self, criterion_type: Union[str, CriterionType], params): 82 | return self._criterion[str(criterion_type)](**params) if params is not None else self._criterion[ 83 | str(criterion_type)]() 84 | 85 | def register(self, function: str, creator: _Loss): 86 | """ 87 | Add a new criterion. 88 | Args: 89 | function (str): Criterion's name. 90 | creator (:obj:`torch.nn.Module`): A torch module object wrapping the new custom criterion function. 91 | """ 92 | self._criterion[function] = creator 93 | 94 | 95 | class DiceLoss(_WeightedLoss): 96 | """ 97 | The Sørensen-Dice Loss. 98 | """ 99 | SUPPORTED_REDUCTIONS = [None, "mean"] 100 | 101 | def __init__(self, reduction: Union[None, str] = "mean", ignore_index: int = -100, weight: torch.Tensor = None): 102 | if reduction not in self.SUPPORTED_REDUCTIONS: 103 | raise NotImplementedError("Reduction type not supported.") 104 | super(DiceLoss, self).__init__(weight=weight, reduction=reduction) 105 | self._ignore_index = ignore_index 106 | 107 | if self.weight is not None: 108 | if self.weight.requires_grad is not False: 109 | self.weight.requires_grad = False 110 | 111 | def forward(self, inputs: torch.Tensor, targets: torch.Tensor): 112 | """ 113 | Computes the Sørensen–Dice loss. 114 | Note that PyTorch optimizers minimize a loss. In this case, we would like to maximize the dice loss so we 115 | return the negated dice loss. 116 | Args: 117 | inputs (:obj:`torch.Tensor`) : A tensor of shape (B, C, ..). The model prediction on which the loss has to 118 | be computed. 119 | targets (:obj:`torch.Tensor`) : A tensor of shape (B, C, ..). The ground truth. 120 | Returns: 121 | :obj:`torch.Tensor`: The Sørensen–Dice loss for each class or reduced according to reduction method. 122 | """ 123 | if not inputs.size() == targets.size(): 124 | raise ValueError("'Inputs' and 'Targets' must have the same shape.") 125 | 126 | inputs = flatten(inputs) 127 | targets = flatten(targets).float() 128 | 129 | # Compute per channel Dice Coefficient 130 | intersection = (inputs * targets).sum(-1) 131 | 132 | if self.weight is not None: 133 | intersection = self.weight * intersection 134 | 135 | cardinality = (inputs + targets).sum(-1) 136 | 137 | ones = torch.Tensor().new_ones((inputs.size(0),), dtype=torch.float, device=inputs.device) 138 | 139 | dice = ones - (2.0 * intersection / cardinality.clamp(min=EPSILON)) 140 | 141 | if self._ignore_index != -100: 142 | def ignore_index_fn(dice_vector): 143 | try: 144 | indices = list(range(len(dice_vector))) 145 | indices.remove(self._ignore_index) 146 | return dice_vector[indices] 147 | except ValueError as e: 148 | raise IndexError( 149 | "'ignore_index' must be non-negative, and lower than the number of classes in confusion matrix, but {} was given. ".format( 150 | self._ignore_index)) 151 | 152 | dice = MetricsLambda(ignore_index_fn, dice).compute() 153 | 154 | if self.reduction == "mean": 155 | dice = dice.mean() 156 | 157 | return dice 158 | 159 | 160 | class GeneralizedDiceLoss(_Loss): 161 | """ 162 | The Generalized Sørensen-Dice Loss. 163 | """ 164 | SUPPORTED_REDUCTIONS = [None, "mean"] 165 | 166 | def __init__(self, reduction: Union[None, str] = "mean", ignore_index: int = -100): 167 | if reduction not in self.SUPPORTED_REDUCTIONS: 168 | raise NotImplementedError("Reduction type not supported.") 169 | super(GeneralizedDiceLoss, self).__init__(reduction=reduction) 170 | self._ignore_index = ignore_index 171 | 172 | def forward(self, inputs: torch.Tensor, targets: torch.Tensor): 173 | """ 174 | Computes the Sørensen–Dice loss. 175 | Note that PyTorch optimizers minimize a loss. In this case, we would like to maximize the dice loss so we 176 | return the negated dice loss. 177 | Args: 178 | inputs (:obj:`torch.Tensor`) : A tensor of shape (B, C, ..). The model prediction on which the loss has to 179 | be computed. 180 | targets (:obj:`torch.Tensor`) : A tensor of shape (B, C, ..). The ground truth. 181 | Returns: 182 | :obj:`torch.Tensor`: The Sørensen–Dice loss for each class or reduced according to reduction method. 183 | """ 184 | if not inputs.size() == targets.size(): 185 | raise ValueError("'Inputs' and 'Targets' must have the same shape.") 186 | 187 | inputs = flatten(inputs) 188 | targets = flatten(targets).float() 189 | ones = torch.Tensor().new_ones((inputs.size(0),), dtype=torch.float, device=inputs.device) 190 | class_weights = ones / torch.pow(targets.sum(-1), 2).clamp(min=EPSILON) 191 | 192 | # Compute per channel Dice Coefficient 193 | intersect = (inputs * targets).sum(-1) * class_weights 194 | 195 | denominator = (inputs + targets).sum(-1) * class_weights 196 | 197 | dice = ones - (2.0 * intersect / denominator.clamp(min=EPSILON)) 198 | 199 | if self._ignore_index != -100: 200 | def ignore_index_fn(dice_vector): 201 | try: 202 | indices = list(range(len(dice_vector))) 203 | indices.remove(self._ignore_index) 204 | return dice_vector[indices] 205 | except ValueError as e: 206 | raise IndexError( 207 | "'ignore_index' must be non-negative, and lower than the number of classes in confusion matrix, but {} was given. ".format( 208 | self._ignore_index)) 209 | 210 | dice = MetricsLambda(ignore_index_fn, dice).compute() 211 | 212 | if self.reduction == "mean": 213 | dice = dice.mean() 214 | 215 | return dice 216 | 217 | 218 | class TverskyLoss(_WeightedLoss): 219 | SUPPORTED_REDUCTIONS = [None, "mean"] 220 | 221 | def __init__(self, reduction: Union[None, str] = "mean", ignore_index: int = -100, weight: torch.Tensor = None, 222 | alpha: float = 0.5, beta: float = 0.5): 223 | if reduction not in self.SUPPORTED_REDUCTIONS: 224 | raise NotImplementedError("Reduction type not supported.") 225 | super(TverskyLoss, self).__init__(weight=weight, reduction=reduction) 226 | self._ignore_index = ignore_index 227 | self._alpha = alpha 228 | self._beta = beta 229 | 230 | def forward(self, inputs: torch.Tensor, targets: torch.Tensor): 231 | """ 232 | Computes the Tversky loss based on https://arxiv.org/pdf/1706.05721.pdf 233 | Note that PyTorch optimizers minimize a loss. In this case, we would like to maximize the dice loss so we 234 | return the negated dice loss. 235 | Args: 236 | inputs (:obj:`torch.Tensor`) : A tensor of shape (B, C, ..). The model prediction on which the loss has to 237 | be computed. 238 | targets (:obj:`torch.Tensor`) : A tensor of shape (B, C, ..). The ground truth. 239 | Returns: 240 | :obj:`torch.Tensor`: The Tversky loss for each class or reduced according to reduction method. 241 | """ 242 | if not inputs.size() == targets.size(): 243 | raise ValueError("'Inputs' and 'Targets' must have the same shape.") 244 | 245 | inputs = flatten(inputs) 246 | targets = flatten(targets).float() 247 | ones = torch.Tensor().new_ones((inputs.size()), dtype=torch.float, device=inputs.device) 248 | 249 | P_G = (inputs * targets).sum(-1) 250 | if self.weight is not None: 251 | P_G = self.weight * P_G 252 | 253 | P_NG = (inputs * (ones - targets)).sum(-1) 254 | NP_G = ((ones - inputs) * targets).sum(-1) 255 | 256 | ones = torch.Tensor().new_ones((inputs.size(0),), dtype=torch.float, device=inputs.device) 257 | tversky = P_G / (P_G + self._alpha * P_NG + self._beta * NP_G + EPSILON) 258 | 259 | tversky_loss = ones - tversky 260 | 261 | if self._ignore_index != -100: 262 | def ignore_index_fn(tversky_vector): 263 | try: 264 | indices = list(range(len(tversky_vector))) 265 | indices.remove(self._ignore_index) 266 | return tversky_vector[indices] 267 | except ValueError as e: 268 | raise IndexError( 269 | "'ignore_index' must be non-negative, and lower than the number of classes in confusion matrix, but {} was given. ".format( 270 | self._ignore_index)) 271 | 272 | tversky_loss = MetricsLambda(ignore_index_fn, tversky_loss).compute() 273 | 274 | if self.reduction == "mean": 275 | tversky_loss = tversky_loss.mean() 276 | 277 | return tversky_loss 278 | 279 | 280 | class FocalTverskyLoss(_WeightedLoss): 281 | SUPPORTED_REDUCTIONS = [None, "mean"] 282 | 283 | def __init__(self, reduction: Union[None, str] = "mean", ignore_index: int = -100, weight: torch.Tensor = None, 284 | alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.0): 285 | if reduction not in self.SUPPORTED_REDUCTIONS: 286 | raise NotImplementedError("Reduction type not supported.") 287 | super(FocalTverskyLoss, self).__init__(weight=weight, reduction=reduction) 288 | self._ignore_index = ignore_index 289 | self._alpha = alpha 290 | self._beta = beta 291 | self._gamma = gamma 292 | 293 | def forward(self, inputs: torch.Tensor, targets: torch.Tensor): 294 | """ 295 | Computes the Focal Tversky Loss based on https://arxiv.org/pdf/1708.02002.pdf. 296 | Note that PyTorch optimizers minimize a loss. In this case, we would like to maximize the dice loss so we 297 | return the negated dice loss. 298 | Args: 299 | inputs (:obj:`torch.Tensor`) : A tensor of shape (B, C, ..). The model prediction on which the loss has to 300 | be computed. 301 | targets (:obj:`torch.Tensor`) : A tensor of shape (B, C, ..). The ground truth. 302 | Returns: 303 | :obj:`torch.Tensor`: The Sørensen–Dice loss for each class or reduced according to reduction method. 304 | """ 305 | if not inputs.size() == targets.size(): 306 | raise ValueError("'Inputs' and 'Targets' must have the same shape.") 307 | 308 | inputs = flatten(inputs) 309 | targets = flatten(targets).float() 310 | ones = torch.Tensor().new_ones((inputs.size()), dtype=torch.float, device=inputs.device) 311 | 312 | P_G = (inputs * targets).sum(-1) 313 | if self.weight is not None: 314 | P_G = self.weight * P_G 315 | 316 | P_NG = (inputs * (ones - targets)).sum(-1) 317 | NP_G = ((ones - inputs) * targets).sum(-1) 318 | 319 | ones = torch.Tensor().new_ones((inputs.size(0),), dtype=torch.float, device=inputs.device) 320 | tversky = P_G / (P_G + self._alpha * P_NG + self._beta * NP_G + EPSILON) 321 | 322 | tversky = ones - (torch.pow((ones - tversky), 1 / self._gamma)) 323 | 324 | if self._ignore_index != -100: 325 | def ignore_index_fn(tversky_vector): 326 | try: 327 | indices = list(range(len(tversky_vector))) 328 | indices.remove(self._ignore_index) 329 | return tversky_vector[indices] 330 | except ValueError as e: 331 | raise IndexError( 332 | "'ignore_index' must be non-negative, and lower than the number of classes in confusion matrix, but {} was given. ".format( 333 | self._ignore_index)) 334 | 335 | tversky = MetricsLambda(ignore_index_fn, tversky).compute() 336 | 337 | if self.reduction == "mean": 338 | tversky = tversky.mean() 339 | 340 | return tversky 341 | 342 | 343 | class WeightedCrossEntropyLoss(_Loss): 344 | 345 | def __init__(self, ignore_index: int = -100): 346 | super(WeightedCrossEntropyLoss, self).__init__() 347 | self._ignore_index = ignore_index 348 | 349 | def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> float: 350 | """ 351 | Computes the Weighted Cross Entropy Loss (WCE) as described in https://arxiv.org/pdf/1707.03237.pdf 352 | Args: 353 | inputs (:obj:`torch.Tensor`): A tensor of shape (B, C, ..). The model's prediction on which the loss has to be computed. 354 | targets (:obj:`torch.Tensor`): A tensor of shape (B, ..). The ground truth. 355 | Returns: 356 | float: the weighted Cross-Entropy loss value. 357 | """ 358 | class_weights = self.compute_class_weights(inputs) 359 | 360 | return torch.nn.functional.cross_entropy(inputs, targets, weight=class_weights, 361 | ignore_index=self._ignore_index) 362 | 363 | @staticmethod 364 | def compute_class_weights(inputs: torch.Tensor): 365 | """ 366 | Compute weights for each class as described in https://arxiv.org/pdf/1707.03237.pdf 367 | Args: 368 | inputs: (:obj:`torch.Tensor`): A tensor of shape (B, C, ..). The model's prediction on which the loss has to be computed. 369 | Returns: 370 | :obj:`torch.Tensor`: A tensor containing class weights. 371 | """ 372 | flattened_inputs = flatten(inputs) 373 | class_weights = (flattened_inputs.shape[1] - flattened_inputs.sum(-1)) / flattened_inputs.sum(-1) 374 | return class_weights 375 | --------------------------------------------------------------------------------