├── tests ├── __init__.py ├── conftest.py └── test_schedulers.py ├── albu_scheduler ├── __init__.py └── schedulers.py ├── Makefile ├── pyproject.toml ├── LICENSE └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | @pytest.fixture(scope="module") 5 | def image(): 6 | return "IMAGE" 7 | -------------------------------------------------------------------------------- /albu_scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | from albu_scheduler.schedulers import ( 2 | BaseTransformScheduler, 3 | TransformMultiStepScheduler, 4 | TransformSchedulerOnPlateau, 5 | ) 6 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | POETRY ?= $(HOME)/.poetry/bin/poetry 2 | 3 | .PHONY: install-poetry 4 | install-poetry: 5 | curl -sSL https://raw.githubusercontent.com/python-poetry/poetry/master/get-poetry.py | python 6 | 7 | .PHONY: install-packages 8 | install-packages: 9 | $(POETRY) install 10 | 11 | .PHONY: install 12 | install: install-poetry install-packages 13 | 14 | .PHONY: fmt 15 | fmt: 16 | $(POETRY) run isort . 17 | $(POETRY) run black . 18 | 19 | .PHONY: test 20 | test: 21 | $(POETRY) run pytest 22 | 23 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "albu_scheduler" 3 | version = "0.1.0" 4 | description = "Scheduler for albumentations transforms" 5 | authors = ["Kiryl Liaushun "] 6 | license = "MIT" 7 | 8 | [tool.poetry.dependencies] 9 | python = "^3.8" 10 | albumentations = "^0.5.2" 11 | pytest = "^6.2.3" 12 | 13 | [tool.poetry.dev-dependencies] 14 | black = "^20.8b1" 15 | isort = "^5.8.0" 16 | 17 | [build-system] 18 | requires = ["poetry-core>=1.0.0"] 19 | build-backend = "poetry.core.masonry.api" 20 | 21 | [tool.isort] 22 | profile = "black" 23 | skip = ".eggs,.pip-cache,.poetry,venv,.venv,libs" 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 KiriLev 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # albu_scheduler 2 | Scheduler for [albumentations](https://github.com/albumentations-team/albumentations) transforms based on [PyTorch schedulers](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate) interface 3 | 4 | # Usage 5 | ## TransformMultiStepScheduler 6 | ```python 7 | import albumentations as A 8 | 9 | from albu_scheduler import TransformMultiStepScheduler 10 | 11 | transform_1 = A.Compose([ 12 | A.RandomCrop(width=256, height=256), 13 | A.HorizontalFlip(p=0.5), 14 | A.RandomBrightnessContrast(p=0.2), 15 | ]) 16 | transform_2 = A.Compose([ 17 | A.RandomCrop(width=128, height=128), 18 | A.VerticalFlip(p=0.5), 19 | ]) 20 | 21 | scheduled_transform = TransformMultiStepScheduler(transforms=[transform_1, transform_2], 22 | milestones=[0, 10]) 23 | dataset = Dataset(transform=scheduled_transform) 24 | 25 | for epoch in range(100): 26 | train(...) 27 | validate(...) 28 | scheduled_transform.step() 29 | ``` 30 | ## TransformSchedulerOnPlateau 31 | ```python 32 | from albu_scheduler import TransformSchedulerOnPlateau 33 | 34 | scheduled_transform = TransformSchedulerOnPlateau(transforms=[transform_1, transform_2], 35 | mode="max", 36 | patience=5) 37 | 38 | dataset = Dataset(transform=scheduled_transform) 39 | for epoch in range(100): 40 | train(...) 41 | score = validate(...) 42 | scheduled_transform.step(score) 43 | ``` 44 | 45 | # Installation 46 | ```bash 47 | git clone https://github.com/KiriLev/albu_scheduler 48 | cd albu_scheduler 49 | make install 50 | ``` -------------------------------------------------------------------------------- /tests/test_schedulers.py: -------------------------------------------------------------------------------- 1 | from unittest import mock 2 | 3 | import pytest 4 | 5 | from albu_scheduler import TransformMultiStepScheduler, TransformSchedulerOnPlateau 6 | 7 | 8 | class TestTransformStepScheduler: 9 | def test_ok(self, image): 10 | transforms = [mock.MagicMock() for _ in range(4)] 11 | 12 | scheduled_transform = TransformMultiStepScheduler( 13 | transforms=transforms, milestones=[0, 5, 10] 14 | ) 15 | scheduled_transform(image=image) 16 | transforms[0].assert_called_with(image=image) 17 | 18 | for _ in range(5): 19 | scheduled_transform.step() 20 | 21 | scheduled_transform(image=image) 22 | transforms[1].assert_called_with(image=image) 23 | 24 | transforms[2].assert_not_called() 25 | transforms[3].assert_not_called() 26 | 27 | def test_no_zero_milestone(self): 28 | transforms = [mock.MagicMock() for _ in range(4)] 29 | 30 | scheduled_transform = TransformMultiStepScheduler( 31 | transforms=transforms, milestones=[5, 10] 32 | ) 33 | assert scheduled_transform.cur_transform.__class__.__name__ == "NoOp" 34 | 35 | def test_too_much_milestones_fails(self): 36 | transforms = [mock.MagicMock()] 37 | milestones = [i for i in range(100)] 38 | with pytest.raises(ValueError): 39 | TransformMultiStepScheduler(transforms=transforms, milestones=milestones) 40 | 41 | 42 | class TestTransformSchedulerOnPlateau: 43 | @pytest.mark.parametrize( 44 | "mode, metric_values", 45 | [("max", [1, 2, 3, 3, 3, 1]), ("min", [10, 9, 8, 8, 8, 100])], 46 | ) 47 | def test_ok(self, image, mode, metric_values): 48 | transforms = [mock.MagicMock() for _ in range(4)] 49 | 50 | scheduled_transform = TransformSchedulerOnPlateau( 51 | transforms=transforms, mode=mode, patience=2 52 | ) 53 | scheduled_transform(image=image) 54 | transforms[0].assert_called_with(image=image) 55 | 56 | for metric_value in metric_values[:-1]: 57 | scheduled_transform.step(metric_value) 58 | scheduled_transform(image=image) 59 | transforms[0].assert_called_with(image=image) 60 | transforms[0].reset_mock() 61 | 62 | scheduled_transform.step(metric_values[-1]) 63 | scheduled_transform(image=image) 64 | transforms[1].assert_called_with(image=image) 65 | 66 | for _ in range(100): 67 | scheduled_transform(image=image) 68 | transforms[2].assert_not_called() 69 | -------------------------------------------------------------------------------- /albu_scheduler/schedulers.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | from albumentations import BasicTransform, NoOp 4 | 5 | 6 | class BaseTransformScheduler: 7 | def __init__(self, **kwargs): 8 | pass 9 | 10 | def __call__(self, **kwargs): 11 | return self.cur_transform(**kwargs) 12 | 13 | def step(self, **kwargs): 14 | pass 15 | 16 | 17 | class TransformMultiStepScheduler(BaseTransformScheduler): 18 | """Selects matching transform once the number of epoch reaches 19 | one of the milestones. 20 | 21 | Args: 22 | transforms (list): Transforms to schedule. 23 | milestones (list): List of epoch indices. 24 | verbose (bool): If ``True``, prints a message to stdout for 25 | each update. Default: ``False``. 26 | Example: 27 | >>> # transform = A.NoOp() if 0 <= epoch < 5 28 | >>> # transform = transform_1 if 5 <= epoch < 30 29 | >>> # transform = transform_2 if 30 <= epoch < 80 30 | >>> # transform = transform_3 if epoch >= 80 31 | >>> 32 | >>> scheduled_transform = TransformMultiStepScheduler(transforms=[transform_1, transform_2, transform_3], 33 | >>> milestones=[5, 30, 80]) 34 | >>> train_dataset = Dataset(transform=scheduled_transform) 35 | >>> val_dataset = Dataset() 36 | >>> 37 | >>> for epoch in range(100): 38 | >>> train(train_dataset) 39 | >>> validate(val_dataset) 40 | >>> scheduled_transform.step() 41 | """ 42 | 43 | def __init__( 44 | self, 45 | transforms: List[BasicTransform], 46 | milestones: List[int], 47 | verbose: bool = False, 48 | ) -> None: 49 | super().__init__() 50 | if len(milestones) > len(transforms): 51 | raise ValueError( 52 | "Length of milestones can't be greater than number of transforms" 53 | ) 54 | self.epoch_to_transform: Dict[int, BasicTransform] = { 55 | epoch_num: aug for epoch_num, aug in zip(milestones, transforms) 56 | } 57 | if 0 not in self.epoch_to_transform: 58 | self.epoch_to_transform[0] = NoOp() 59 | self._step = 0 60 | self.cur_transform: BasicTransform = self.epoch_to_transform[0] 61 | self.verbose: bool = verbose 62 | 63 | def step(self, **kwargs) -> None: 64 | self._step += 1 65 | if self._step in self.epoch_to_transform: 66 | self.cur_transform = self.epoch_to_transform[self._step] 67 | if self.verbose: 68 | print(f"Changing aug at epoch={self._step}") 69 | 70 | 71 | class TransformSchedulerOnPlateau(BaseTransformScheduler): 72 | """Selects next transform when a metric has stopped improving. 73 | This scheduler reads a metrics quantity and if no improvement 74 | is seen for a 'patience' number of epochs, next transform in list is selected. 75 | 76 | Args: 77 | transforms (list): Transforms to schedule. 78 | patience (int): Number of epochs with no improvement after 79 | which next transform will be chosen (if there is). For example, if 80 | `patience = 2`, then we will ignore the first 2 epochs 81 | with no improvement, and will only switch transforms after the 82 | 3rd epoch if the loss still hasn't improved then. 83 | Default: 5. 84 | mode (str): One of `min`, `max`. In `min` mode, transform 85 | will be switched when the quantity monitored has stopped 86 | decreasing; in `max` mode it will be switched when the 87 | quantity monitored has stopped increasing. Default: 'min'. 88 | verbose (bool): If ``True``, prints a message to stdout for 89 | each update. Default: ``False``. 90 | Example: 91 | >>> 92 | >>> scheduled_transform = TransformSchedulerOnPlateau(transforms=[transform_1, transform_2, transform_3], 93 | >>> mode="max", 94 | >>> plateau=10) 95 | >>> train_dataset = Dataset(transform=scheduled_transform) 96 | >>> val_dataset = Dataset() 97 | >>> 98 | >>> for epoch in range(100): 99 | >>> train(dataset) 100 | >>> val_score = validate(val_dataset) 101 | >>> # Note that step should be called after validate() 102 | >>> scheduled_transform.step(val_score) 103 | """ 104 | 105 | def __init__( 106 | self, 107 | transforms: List[BasicTransform], 108 | patience: int, 109 | mode: str = "min", 110 | verbose: bool = False, 111 | ) -> None: 112 | super().__init__() 113 | self.mode = mode 114 | self.transforms = transforms 115 | self.patience = patience 116 | self.verbose = verbose 117 | 118 | self._step = 0 119 | self._cur_transform_ind = 0 120 | self.cur_transform = self.transforms[self._cur_transform_ind] 121 | 122 | self.best = 0.0 if self.mode == "max" else float("inf") 123 | self.num_bad_epochs = 0 124 | 125 | def is_better(self, left, right): 126 | if self.mode == "max": 127 | return left > right 128 | if self.mode == "min": 129 | return left < right 130 | 131 | def step(self, metric, **kwargs) -> None: 132 | current = float(metric) 133 | 134 | if self.is_better(current, self.best): 135 | self.best = current 136 | self.num_bad_epochs = 0 137 | else: 138 | self.num_bad_epochs += 1 139 | 140 | if ( 141 | self.num_bad_epochs > self.patience 142 | and self._cur_transform_ind < len(self.transforms) - 1 143 | ): 144 | self._cur_transform_ind += 1 145 | self.cur_transform = self.transforms[self._cur_transform_ind] 146 | self.num_bad_epochs = 0 147 | if self.verbose: 148 | print(f"Changing aug to transforms[{self._cur_transform_ind}]") 149 | --------------------------------------------------------------------------------