├── .gitignore ├── LICENSE ├── README.md ├── config_segmentation.json ├── datasets ├── __init__.py ├── base_dataset.py └── segmentation2d_dataset.py ├── losses └── __init__.py ├── models ├── __init__.py ├── base_model.py └── segmentation_model.py ├── optimizers └── __init__.py ├── requirements.txt ├── train.py ├── utils ├── __init__.py └── visualizer.py └── validate.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Visual Studio Code cache 2 | .vscode/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don’t work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # celery beat schedule file 96 | celerybeat-schedule 97 | 98 | # SageMath parsed files 99 | *.sage.py 100 | 101 | # Environments 102 | .env 103 | .venv 104 | env/ 105 | venv/ 106 | ENV/ 107 | env.bak/ 108 | venv.bak/ 109 | 110 | # Spyder project settings 111 | .spyderproject 112 | .spyproject 113 | 114 | # Rope project settings 115 | .ropeproject 116 | 117 | # mkdocs documentation 118 | /site 119 | 120 | # mypy 121 | .mypy_cache/ 122 | .dmypy.json 123 | dmypy.json 124 | 125 | # Pyre type checker 126 | .pyre/ 127 | 128 | # Pickle files 129 | .pickle -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Branislav Holländer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorchProjectFramework 2 | 3 | A basic framework for your PyTorch projects. -------------------------------------------------------------------------------- /config_segmentation.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_dataset_params": { 3 | "dataset_name": "segmentation2d", 4 | "dataset_path": "", 5 | "loader_params": { 6 | "batch_size": 24, 7 | "shuffle": true, 8 | "num_workers": 4, 9 | "pin_memory": true 10 | }, 11 | "input_size": [200, 200] 12 | }, 13 | "val_dataset_params": { 14 | "dataset_name": "segmentation2d", 15 | "dataset_path": "", 16 | "loader_params": { 17 | "batch_size": 24, 18 | "shuffle": false, 19 | "num_workers": 4, 20 | "pin_memory": true 21 | }, 22 | "input_size": [200, 200] 23 | }, 24 | "model_params": { 25 | "model_name": "segmentation2d", 26 | "is_train": true, 27 | "max_epochs": 40, 28 | "lr": 0.01, 29 | "export_path": "", 30 | "checkpoint_path": "", 31 | "load_checkpoint": -1, 32 | "lr_policy": "step", 33 | "lr_decay_iters": 10 34 | }, 35 | "visualization_params": { 36 | "name": "2d segmentation" 37 | }, 38 | "printout_freq": 10, 39 | "model_update_freq": 1 40 | } -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes all the modules related to data loading and preprocessing. 2 | 3 | To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. 4 | """ 5 | import importlib 6 | from torch.utils import data 7 | from datasets.base_dataset import BaseDataset 8 | 9 | 10 | def find_dataset_using_name(dataset_name): 11 | """Import the module "data/[dataset_name]_dataset.py". 12 | 13 | In the file, the class called DatasetNameDataset() will 14 | be instantiated. It has to be a subclass of BaseDataset, 15 | and it is case-insensitive. 16 | """ 17 | dataset_filename = "datasets." + dataset_name + "_dataset" 18 | datasetlib = importlib.import_module(dataset_filename) 19 | 20 | dataset = None 21 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 22 | for name, cls in datasetlib.__dict__.items(): 23 | if name.lower() == target_dataset_name.lower() \ 24 | and issubclass(cls, BaseDataset): 25 | dataset = cls 26 | 27 | if dataset is None: 28 | raise NotImplementedError('In {0}.py, there should be a subclass of BaseDataset with class name that matches {1} in lowercase.'.format(dataset_filename, target_dataset_name)) 29 | 30 | return dataset 31 | 32 | 33 | def create_dataset(configuration): 34 | """Create a dataset given the configuration (loaded from the json file). 35 | 36 | This function wraps the class CustomDatasetDataLoader. 37 | This is the main interface between this package and train.py/validate.py 38 | 39 | Example: 40 | from datasets import create_dataset 41 | dataset = create_dataset(configuration) 42 | """ 43 | data_loader = CustomDatasetDataLoader(configuration) 44 | dataset = data_loader.load_data() 45 | return dataset 46 | 47 | 48 | class CustomDatasetDataLoader(): 49 | """Wrapper class of Dataset class that performs multi-threaded data loading 50 | according to the configuration. 51 | """ 52 | def __init__(self, configuration): 53 | self.configuration = configuration 54 | dataset_class = find_dataset_using_name(configuration['dataset_name']) 55 | self.dataset = dataset_class(configuration) 56 | print("dataset [{0}] was created".format(type(self.dataset).__name__)) 57 | 58 | # if we use custom collation, define it as a staticmethod in the dataset class 59 | custom_collate_fn = getattr(self.dataset, "collate_fn", None) 60 | if callable(custom_collate_fn): 61 | self.dataloader = data.DataLoader(self.dataset, **configuration['loader_params'], collate_fn=custom_collate_fn) 62 | else: 63 | self.dataloader = data.DataLoader(self.dataset, **configuration['loader_params']) 64 | 65 | 66 | def load_data(self): 67 | return self 68 | 69 | 70 | def get_custom_dataloader(self, custom_configuration): 71 | """Get a custom dataloader (e.g. for exporting the model). 72 | This dataloader may use different configurations than the 73 | default train_dataloader and val_dataloader. 74 | """ 75 | custom_collate_fn = getattr(self.dataset, "collate_fn", None) 76 | if callable(custom_collate_fn): 77 | custom_dataloader = data.DataLoader(self.dataset, **self.configuration['loader_params'], collate_fn=custom_collate_fn) 78 | else: 79 | custom_dataloader = data.DataLoader(self.dataset, **self.configuration['loader_params']) 80 | return custom_dataloader 81 | 82 | 83 | def __len__(self): 84 | """Return the number of data in the dataset. 85 | """ 86 | return len(self.dataset) 87 | 88 | 89 | def __iter__(self): 90 | """Return a batch of data. 91 | """ 92 | for data in self.dataloader: 93 | yield data 94 | -------------------------------------------------------------------------------- /datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets. Also 2 | includes some transformation functions. 3 | """ 4 | from abc import ABC, abstractmethod 5 | import cv2 6 | import numpy as np 7 | import torch.utils.data as data 8 | from albumentations import Resize, Compose, ToFloat 9 | 10 | 11 | class BaseDataset(data.Dataset, ABC): 12 | """This class is an abstract base class (ABC) for datasets. 13 | """ 14 | 15 | def __init__(self, configuration): 16 | """Initialize the class; save the configuration in the class. 17 | """ 18 | self.configuration = configuration 19 | 20 | @abstractmethod 21 | def __len__(self): 22 | """Return the total number of images in the dataset.""" 23 | return 0 24 | 25 | @abstractmethod 26 | def __getitem__(self, index): 27 | """Return a data point (usually data and labels in 28 | a supervised setting). 29 | """ 30 | pass 31 | 32 | def pre_epoch_callback(self, epoch): 33 | """Callback to be called before every epoch. 34 | """ 35 | pass 36 | 37 | def post_epoch_callback(self, epoch): 38 | """Callback to be called after every epoch. 39 | """ 40 | pass 41 | 42 | 43 | def get_transform(opt, method=cv2.INTER_LINEAR): 44 | transform_list = [] 45 | if 'preprocess' in opt: 46 | if 'resize' in opt['preprocess']: 47 | transform_list.append(Resize(opt['input_size'][0], opt['input_size'][1], method)) 48 | 49 | if 'tofloat' in opt and opt['tofloat'] == True: 50 | transform_list.append(ToFloat()) 51 | 52 | return Compose(transform_list) -------------------------------------------------------------------------------- /datasets/segmentation2d_dataset.py: -------------------------------------------------------------------------------- 1 | from datasets.base_dataset import get_transform 2 | from datasets.base_dataset import BaseDataset 3 | import torch 4 | 5 | 6 | class Segmentation2DDataset(BaseDataset): 7 | """Represents a 2D segmentation dataset. 8 | 9 | Input params: 10 | configuration: Configuration dictionary. 11 | """ 12 | def __init__(self, configuration): 13 | super().__init__(configuration) 14 | 15 | 16 | def __getitem__(self, index): 17 | # get source image as x 18 | # get labels as y 19 | return (x, y) 20 | 21 | def __len__(self): 22 | # return the size of the dataset 23 | return 1 24 | -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/branislav1991/PyTorchProjectFramework/c2e2e9d391060a11f9151f021adc96c27ad8a894/losses/__init__.py -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | """This package contains modules related to objective functions, optimizations, and network architectures. 2 | 3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. 4 | You need to implement the following five functions: 5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, configuration). 6 | -- : unpack data from dataset and apply preprocessing. 7 | -- : produce intermediate results. 8 | -- : calculate loss, gradients, and update network weights. 9 | 10 | In the function <__init__>, you need to define four lists: 11 | -- self.network_names (str list): define networks used in our training. 12 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. 13 | """ 14 | 15 | import importlib 16 | from models.base_model import BaseModel 17 | from torch.optim import lr_scheduler 18 | 19 | 20 | def find_model_using_name(model_name): 21 | """Import the module "models/[model_name]_model.py". 22 | 23 | In the file, the class called DatasetNameModel() will 24 | be instantiated. It has to be a subclass of BaseModel, 25 | and it is case-insensitive. 26 | """ 27 | model_filename = "models." + model_name + "_model" 28 | modellib = importlib.import_module(model_filename) 29 | model = None 30 | target_model_name = model_name.replace('_', '') + 'model' 31 | for name, cls in modellib.__dict__.items(): 32 | if name.lower() == target_model_name.lower() \ 33 | and issubclass(cls, BaseModel): 34 | model = cls 35 | 36 | if model is None: 37 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 38 | exit(0) 39 | 40 | return model 41 | 42 | 43 | def create_model(configuration): 44 | """Create a model given the configuration. 45 | 46 | This is the main interface between this package and train.py/validate.py 47 | """ 48 | model = find_model_using_name(configuration['model_name']) 49 | instance = model(configuration) 50 | print("model [{0}] was created".format(type(instance).__name__)) 51 | return instance -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from utils import get_scheduler 4 | from utils import transfer_to_device 5 | from collections import OrderedDict 6 | from abc import ABC, abstractmethod 7 | 8 | 9 | class BaseModel(ABC): 10 | """This class is an abstract base class (ABC) for models. 11 | """ 12 | 13 | def __init__(self, configuration): 14 | """Initialize the BaseModel class. 15 | 16 | Parameters: 17 | configuration: Configuration dictionary. 18 | 19 | When creating your custom class, you need to implement your own initialization. 20 | In this fucntion, you should first call 21 | Then, you need to define these lists: 22 | -- self.network_names (str list): define networks used in our training. 23 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. 24 | """ 25 | self.configuration = configuration 26 | self.is_train = configuration['is_train'] 27 | self.use_cuda = torch.cuda.is_available() 28 | self.device = torch.device('cuda:0') if self.use_cuda else torch.device('cpu') 29 | torch.backends.cudnn.benchmark = True 30 | self.save_dir = configuration['checkpoint_path'] 31 | self.network_names = [] 32 | self.loss_names = [] 33 | self.optimizers = [] 34 | self.visual_names = [] 35 | 36 | 37 | def set_input(self, input): 38 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 39 | The implementation here is just a basic setting of input and label. You may implement 40 | other functionality in your own model. 41 | """ 42 | self.input = transfer_to_device(input[0], self.device) 43 | self.label = transfer_to_device(input[1], self.device) 44 | 45 | 46 | @abstractmethod 47 | def forward(self): 48 | """Run forward pass; called by both functions and .""" 49 | pass 50 | 51 | @abstractmethod 52 | def optimize_parameters(self): 53 | """Calculate losses, gradients, and update network weights; called in every training iteration""" 54 | pass 55 | 56 | def setup(self): 57 | """Load and print networks; create schedulers. 58 | """ 59 | if self.configuration['load_checkpoint'] >= 0: 60 | last_checkpoint = self.configuration['load_checkpoint'] 61 | else: 62 | last_checkpoint = -1 63 | 64 | if last_checkpoint >= 0: 65 | # enable restarting training 66 | self.load_networks(last_checkpoint) 67 | if self.is_train: 68 | self.load_optimizers(last_checkpoint) 69 | for o in self.optimizers: 70 | o.param_groups[0]['lr'] = o.param_groups[0]['initial_lr'] # reset learning rate 71 | 72 | self.schedulers = [get_scheduler(optimizer, self.configuration) for optimizer in self.optimizers] 73 | 74 | if last_checkpoint > 0: 75 | for s in self.schedulers: 76 | for _ in range(last_checkpoint): 77 | s.step() 78 | 79 | self.print_networks() 80 | 81 | def train(self): 82 | """Make models train mode during test time.""" 83 | for name in self.network_names: 84 | if isinstance(name, str): 85 | net = getattr(self, 'net' + name) 86 | net.train() 87 | 88 | def eval(self): 89 | """Make models eval mode during test time.""" 90 | for name in self.network_names: 91 | if isinstance(name, str): 92 | net = getattr(self, 'net' + name) 93 | net.eval() 94 | 95 | def test(self): 96 | """Forward function used in test time. 97 | 98 | This function wraps function in no_grad() so we don't save intermediate steps for backprop 99 | """ 100 | with torch.no_grad(): 101 | self.forward() 102 | 103 | 104 | def update_learning_rate(self): 105 | """Update learning rates for all the networks; called at the end of every epoch""" 106 | for scheduler in self.schedulers: 107 | scheduler.step() 108 | 109 | lr = self.optimizers[0].param_groups[0]['lr'] 110 | print('learning rate = {0:.7f}'.format(lr)) 111 | 112 | 113 | def save_networks(self, epoch): 114 | """Save all the networks to the disk. 115 | """ 116 | for name in self.network_names: 117 | if isinstance(name, str): 118 | save_filename = '{0}_net_{1}.pth'.format(epoch, name) 119 | save_path = os.path.join(self.save_dir, save_filename) 120 | net = getattr(self, 'net' + name) 121 | 122 | if self.use_cuda: 123 | torch.save(net.cpu().state_dict(), save_path) 124 | net.to(self.device) 125 | else: 126 | torch.save(net.cpu().state_dict(), save_path) 127 | 128 | 129 | def load_networks(self, epoch): 130 | """Load all the networks from the disk. 131 | """ 132 | for name in self.network_names: 133 | if isinstance(name, str): 134 | load_filename = '{0}_net_{1}.pth'.format(epoch, name) 135 | load_path = os.path.join(self.save_dir, load_filename) 136 | net = getattr(self, 'net' + name) 137 | if isinstance(net, torch.nn.DataParallel): 138 | net = net.module 139 | print('loading the model from {0}'.format(load_path)) 140 | state_dict = torch.load(load_path, map_location=self.device) 141 | if hasattr(state_dict, '_metadata'): 142 | del state_dict._metadata 143 | 144 | net.load_state_dict(state_dict) 145 | 146 | 147 | def save_optimizers(self, epoch): 148 | """Save all the optimizers to the disk for restarting training. 149 | """ 150 | for i, optimizer in enumerate(self.optimizers): 151 | save_filename = '{0}_optimizer_{1}.pth'.format(epoch, i) 152 | save_path = os.path.join(self.save_dir, save_filename) 153 | 154 | torch.save(optimizer.state_dict(), save_path) 155 | 156 | 157 | def load_optimizers(self, epoch): 158 | """Load all the optimizers from the disk. 159 | """ 160 | for i, optimizer in enumerate(self.optimizers): 161 | load_filename = '{0}_optimizer_{1}.pth'.format(epoch, i) 162 | load_path = os.path.join(self.save_dir, load_filename) 163 | print('loading the optimizer from {0}'.format(load_path)) 164 | state_dict = torch.load(load_path) 165 | if hasattr(state_dict, '_metadata'): 166 | del state_dict._metadata 167 | optimizer.load_state_dict(state_dict) 168 | 169 | 170 | def print_networks(self): 171 | """Print the total number of parameters in the network and network architecture. 172 | """ 173 | print('Networks initialized') 174 | for name in self.network_names: 175 | if isinstance(name, str): 176 | net = getattr(self, 'net' + name) 177 | num_params = 0 178 | for param in net.parameters(): 179 | num_params += param.numel() 180 | print(net) 181 | print('[Network {0}] Total number of parameters : {1:.3f} M'.format(name, num_params / 1e6)) 182 | 183 | 184 | def set_requires_grad(self, requires_grad=False): 185 | """Set requies_grad for all the networks to avoid unnecessary computations. 186 | """ 187 | for name in self.network_names: 188 | if isinstance(name, str): 189 | net = getattr(self, 'net' + name) 190 | for param in net.parameters(): 191 | param.requires_grad = requires_grad 192 | 193 | 194 | def get_current_losses(self): 195 | """Return traning losses / errors. train.py will print out these errors on console""" 196 | errors_ret = OrderedDict() 197 | for name in self.loss_names: 198 | if isinstance(name, str): 199 | errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number 200 | return errors_ret 201 | 202 | 203 | def pre_epoch_callback(self, epoch): 204 | pass 205 | 206 | 207 | def post_epoch_callback(self, epoch, visualizer): 208 | pass 209 | 210 | 211 | def get_hyperparam_result(self): 212 | """Returns the final training result for hyperparameter tuning (e.g. best 213 | validation loss). 214 | """ 215 | pass 216 | 217 | 218 | def export(self): 219 | """Exports all the networks of the model using JIT tracing. Requires that the 220 | input is set. 221 | """ 222 | for name in self.network_names: 223 | if isinstance(name, str): 224 | net = getattr(self, 'net' + name) 225 | export_path = os.path.join(self.configuration['export_path'], 'exported_net_{}.pth'.format(name)) 226 | if isinstance(self.input, list): # we have to modify the input for tracing 227 | self.input = [tuple(self.input)] 228 | traced_script_module = torch.jit.trace(net, self.input) 229 | traced_script_module.save(export_path) 230 | 231 | 232 | def get_current_visuals(self): 233 | """Return visualization images. train.py will display these images.""" 234 | visual_ret = OrderedDict() 235 | for name in self.visual_names: 236 | if isinstance(name, str): 237 | visual_ret[name] = getattr(self, name) 238 | return visual_ret -------------------------------------------------------------------------------- /models/segmentation_model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from models.base_model import BaseModel 3 | from optimizers.radam import RAdam 4 | import random 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from sklearn.metrics import accuracy_score 9 | import sys 10 | 11 | class double_conv(nn.Module): 12 | def __init__(self, in_ch, out_ch): 13 | super(double_conv, self).__init__() 14 | self.conv = nn.Sequential( 15 | nn.Conv2d(in_ch, out_ch, 3, padding=1), 16 | nn.BatchNorm2d(out_ch), 17 | nn.ReLU(inplace=True), 18 | nn.Conv2d(out_ch, out_ch, 3, padding=1), 19 | nn.BatchNorm2d(out_ch), 20 | nn.ReLU(inplace=True) 21 | ) 22 | 23 | def forward(self, x): 24 | x = self.conv(x) 25 | return x 26 | 27 | class inconv(nn.Module): 28 | def __init__(self, in_ch, out_ch): 29 | super(inconv, self).__init__() 30 | self.conv = double_conv(in_ch, out_ch) 31 | 32 | def forward(self, x): 33 | x = self.conv(x) 34 | return x 35 | 36 | class down(nn.Module): 37 | def __init__(self, in_ch, out_ch): 38 | super(down, self).__init__() 39 | self.mpconv = nn.Sequential( 40 | nn.MaxPool2d(2), 41 | double_conv(in_ch, out_ch) 42 | ) 43 | 44 | def forward(self, x): 45 | x = self.mpconv(x) 46 | return x 47 | 48 | class up(nn.Module): 49 | def __init__(self, in_ch, out_ch, bilinear=True): 50 | super(up, self).__init__() 51 | 52 | self.convtrans = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2) 53 | self.bilinear = bilinear 54 | 55 | self.conv = double_conv(in_ch, out_ch) 56 | 57 | def forward(self, x1, x2): 58 | if self.bilinear: 59 | x1 = nn.functional.interpolate(x1, scale_factor=2, mode='bilinear', align_corners=True) 60 | else: 61 | x1 = self.convtrans(x1) 62 | 63 | # input is CHW 64 | diffY = x2.size()[2] - x1.size()[2] 65 | diffX = x2.size()[3] - x1.size()[3] 66 | 67 | x1 = F.pad(x1, (diffX // 2, diffX - diffX//2, 68 | diffY // 2, diffY - diffY//2)) 69 | 70 | x = torch.cat([x2, x1], dim=1) 71 | x = self.conv(x) 72 | return x 73 | 74 | class outconv(nn.Module): 75 | def __init__(self, in_ch, out_ch): 76 | super(outconv, self).__init__() 77 | self.conv = nn.Conv2d(in_ch, out_ch, 1) 78 | 79 | def forward(self, x): 80 | x = self.conv(x) 81 | return x 82 | 83 | 84 | 85 | class UNet(nn.Module): 86 | """Standard U-Net architecture network. 87 | 88 | Input params: 89 | n_channels: Number of input channels (usually 1 for a grayscale image). 90 | n_classes: Number of output channels (2 for binary segmentation). 91 | """ 92 | def __init__(self, n_channels, n_classes): 93 | super().__init__() 94 | self.inc = inconv(n_channels, 64) 95 | self.down1 = down(64, 128) 96 | self.down2 = down(128, 256) 97 | self.down3 = down(256, 512) 98 | self.down4 = down(512, 512) 99 | self.up1 = up(1024, 256) 100 | self.up2 = up(512, 128) 101 | self.up3 = up(256, 64) 102 | self.up4 = up(128, 64) 103 | self.outc = outconv(64, n_classes) 104 | 105 | def forward(self, x): 106 | x1 = self.inc(x) 107 | x2 = self.down1(x1) 108 | x3 = self.down2(x2) 109 | x4 = self.down3(x3) 110 | x5 = self.down4(x4) 111 | x = self.up1(x5, x4) 112 | x = self.up2(x, x3) 113 | x = self.up3(x, x2) 114 | x = self.up4(x, x1) 115 | x = self.outc(x) 116 | return x 117 | 118 | 119 | class Segmentation2DModel(BaseModel): 120 | def __init__(self, configuration): 121 | """Initialize the model. 122 | """ 123 | super().__init__(configuration) 124 | 125 | self.loss_names = ['segmentation'] 126 | self.network_names = ['unet'] 127 | 128 | self.netunet = UNet(1, 2) 129 | self.netunet = self.netunet.to(self.device) 130 | if self.is_train: # only defined during training time 131 | self.criterion_loss = torch.nn.CrossEntropyLoss() 132 | self.optimizer = torch.optim.Adam(self.netunet.parameters(), lr=configuration['lr']) 133 | self.optimizers = [self.optimizer] 134 | 135 | # storing predictions and labels for validation 136 | self.val_predictions = [] 137 | self.val_labels = [] 138 | self.val_images = [] 139 | 140 | 141 | def forward(self): 142 | """Run forward pass. 143 | """ 144 | self.output = self.netunet(self.input) 145 | 146 | 147 | def backward(self): 148 | """Calculate losses; called in every training iteration. 149 | """ 150 | self.loss_segmentation = self.criterion_loss(self.output, self.label) 151 | 152 | 153 | def optimize_parameters(self): 154 | """Calculate gradients and update network weights. 155 | """ 156 | self.loss_segmentation.backward() # calculate gradients 157 | self.optimizer.step() 158 | self.optimizer.zero_grad() 159 | torch.cuda.empty_cache() 160 | 161 | 162 | def test(self): 163 | super().test() # run the forward pass 164 | 165 | # save predictions and labels as flat tensors 166 | self.val_images.append(self.input) 167 | self.val_predictions.append(self.output) 168 | self.val_labels.append(self.label) 169 | 170 | 171 | def post_epoch_callback(self, epoch, visualizer): 172 | self.val_predictions = torch.cat(self.val_predictions, dim=0) 173 | predictions = torch.argmax(self.val_predictions, dim=1) 174 | predictions = torch.flatten(predictions).cpu() 175 | 176 | self.val_labels = torch.cat(self.val_labels, dim=0) 177 | labels = torch.flatten(self.val_labels).cpu() 178 | 179 | self.val_images = torch.squeeze(torch.cat(self.val_images, dim=0)).cpu() 180 | 181 | # Calculate and show accuracy 182 | val_accuracy = accuracy_score(labels, predictions) 183 | 184 | metrics = OrderedDict() 185 | metrics['accuracy'] = val_accuracy 186 | 187 | visualizer.plot_current_validation_metrics(epoch, metrics) 188 | print('Validation accuracy: {0:.3f}'.format(val_accuracy)) 189 | 190 | # Here you may do something else with the validation data such as 191 | # displaying the validation images or calculating the ROC curve 192 | 193 | self.val_images = [] 194 | self.val_predictions = [] 195 | self.val_labels = [] -------------------------------------------------------------------------------- /optimizers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/branislav1991/PyTorchProjectFramework/c2e2e9d391060a11f9151f021adc96c27ad8a894/optimizers/__init__.py -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python 2 | torch 3 | albumentations 4 | visdom 5 | numpy 6 | matplotlib 7 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from datasets import create_dataset 3 | from utils import parse_configuration 4 | import math 5 | from models import create_model 6 | import time 7 | from utils.visualizer import Visualizer 8 | 9 | """Performs training of a specified model. 10 | 11 | Input params: 12 | config_file: Either a string with the path to the JSON 13 | system-specific config file or a dictionary containing 14 | the system-specific, dataset-specific and 15 | model-specific settings. 16 | export: Whether to export the final model (default=True). 17 | """ 18 | def train(config_file, export=True): 19 | print('Reading config file...') 20 | configuration = parse_configuration(config_file) 21 | 22 | print('Initializing dataset...') 23 | train_dataset = create_dataset(configuration['train_dataset_params']) 24 | train_dataset_size = len(train_dataset) 25 | print('The number of training samples = {0}'.format(train_dataset_size)) 26 | 27 | val_dataset = create_dataset(configuration['val_dataset_params']) 28 | val_dataset_size = len(val_dataset) 29 | print('The number of validation samples = {0}'.format(val_dataset_size)) 30 | 31 | print('Initializing model...') 32 | model = create_model(configuration['model_params']) 33 | model.setup() 34 | 35 | print('Initializing visualization...') 36 | visualizer = Visualizer(configuration['visualization_params']) # create a visualizer that displays images and plots 37 | 38 | starting_epoch = configuration['model_params']['load_checkpoint'] + 1 39 | num_epochs = configuration['model_params']['max_epochs'] 40 | 41 | for epoch in range(starting_epoch, num_epochs): 42 | epoch_start_time = time.time() # timer for entire epoch 43 | train_dataset.dataset.pre_epoch_callback(epoch) 44 | model.pre_epoch_callback(epoch) 45 | 46 | train_iterations = len(train_dataset) 47 | train_batch_size = configuration['train_dataset_params']['loader_params']['batch_size'] 48 | 49 | model.train() 50 | for i, data in enumerate(train_dataset): # inner loop within one epoch 51 | visualizer.reset() 52 | 53 | model.set_input(data) # unpack data from dataset and apply preprocessing 54 | model.forward() 55 | model.backward() 56 | 57 | if i % configuration['model_update_freq'] == 0: 58 | model.optimize_parameters() # calculate loss functions, get gradients, update network weights 59 | 60 | if i % configuration['printout_freq'] == 0: 61 | losses = model.get_current_losses() 62 | visualizer.print_current_losses(epoch, num_epochs, i, math.floor(train_iterations / train_batch_size), losses) 63 | visualizer.plot_current_losses(epoch, float(i) / math.floor(train_iterations / train_batch_size), losses) 64 | 65 | model.eval() 66 | for i, data in enumerate(val_dataset): 67 | model.set_input(data) 68 | model.test() 69 | 70 | model.post_epoch_callback(epoch, visualizer) 71 | train_dataset.dataset.post_epoch_callback(epoch) 72 | 73 | print('Saving model at the end of epoch {0}'.format(epoch)) 74 | model.save_networks(epoch) 75 | model.save_optimizers(epoch) 76 | 77 | print('End of epoch {0} / {1} \t Time Taken: {2} sec'.format(epoch, num_epochs, time.time() - epoch_start_time)) 78 | 79 | model.update_learning_rate() # update learning rates every epoch 80 | 81 | if export: 82 | print('Exporting model') 83 | model.eval() 84 | custom_configuration = configuration['train_dataset_params'] 85 | custom_configuration['loader_params']['batch_size'] = 1 # set batch size to 1 for tracing 86 | dl = train_dataset.get_custom_dataloader(custom_configuration) 87 | sample_input = next(iter(dl)) # sample input from the training dataset 88 | model.set_input(sample_input) 89 | model.export() 90 | 91 | return model.get_hyperparam_result() 92 | 93 | if __name__ == '__main__': 94 | import multiprocessing 95 | multiprocessing.set_start_method('spawn', True) 96 | 97 | parser = argparse.ArgumentParser(description='Perform model training.') 98 | parser.add_argument('configfile', help='path to the configfile') 99 | 100 | args = parser.parse_args() 101 | train(args.configfile) -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import numpy as np 4 | import os 5 | from pathlib import Path 6 | import torch 7 | from torch.optim import lr_scheduler 8 | 9 | 10 | def transfer_to_device(x, device): 11 | """Transfers pytorch tensors or lists of tensors to GPU. This 12 | function is recursive to be able to deal with lists of lists. 13 | """ 14 | if isinstance(x, list): 15 | for i in range(len(x)): 16 | x[i] = transfer_to_device(x[i], device) 17 | else: 18 | x = x.to(device) 19 | return x 20 | 21 | 22 | def parse_configuration(config_file): 23 | """Loads config file if a string was passed 24 | and returns the input if a dictionary was passed. 25 | """ 26 | if isinstance(config_file, str): 27 | with open(config_file) as json_file: 28 | return json.load(json_file) 29 | else: 30 | return config_file 31 | 32 | 33 | def get_scheduler(optimizer, configuration, last_epoch=-1): 34 | """Return a learning rate scheduler. 35 | """ 36 | if configuration['lr_policy'] == 'step': 37 | scheduler = lr_scheduler.StepLR(optimizer, step_size=configuration['lr_decay_iters'], gamma=0.3, last_epoch=last_epoch) 38 | else: 39 | return NotImplementedError('learning rate policy [{0}] is not implemented'.format(configuration['lr_policy'])) 40 | return scheduler 41 | 42 | 43 | def stack_all(list, dim=0): 44 | """Stack all iterables of torch tensors in a list (i.e. [[(tensor), (tensor)], [(tensor), (tensor)]]) 45 | """ 46 | return [torch.stack(s, dim) for s in list] -------------------------------------------------------------------------------- /utils/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | from subprocess import Popen, PIPE 4 | import utils 5 | import visdom 6 | 7 | 8 | class Visualizer(): 9 | """This class includes several functions that can display images and print logging information. 10 | """ 11 | 12 | def __init__(self, configuration): 13 | """Initialize the Visualizer class. 14 | 15 | Input params: 16 | configuration -- stores all the configurations 17 | """ 18 | self.configuration = configuration # cache the option 19 | self.display_id = 0 20 | self.name = configuration['name'] 21 | 22 | self.ncols = 0 23 | self.vis = visdom.Visdom() 24 | if not self.vis.check_connection(): 25 | self.create_visdom_connections() 26 | 27 | 28 | def reset(self): 29 | """Reset the visualization. 30 | """ 31 | pass 32 | 33 | 34 | def create_visdom_connections(self): 35 | """If the program could not connect to Visdom server, this function will start a new server at the default port. 36 | """ 37 | cmd = sys.executable + ' -m visdom.server' 38 | print('\n\nCould not connect to Visdom server. \n Trying to start a server....') 39 | print('Command: %s' % cmd) 40 | Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE) 41 | 42 | 43 | def plot_current_losses(self, epoch, counter_ratio, losses): 44 | """Display the current losses on visdom display: dictionary of error labels and values. 45 | 46 | Input params: 47 | epoch: Current epoch. 48 | counter_ratio: Progress (percentage) in the current epoch, between 0 to 1. 49 | losses: Training losses stored in the format of (name, float) pairs. 50 | """ 51 | if not hasattr(self, 'loss_plot_data'): 52 | self.loss_plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())} 53 | self.loss_plot_data['X'].append(epoch + counter_ratio) 54 | self.loss_plot_data['Y'].append([losses[k] for k in self.loss_plot_data['legend']]) 55 | x = np.squeeze(np.stack([np.array(self.loss_plot_data['X'])] * len(self.loss_plot_data['legend']), 1), axis=1) 56 | y = np.squeeze(np.array(self.loss_plot_data['Y']), axis=1) 57 | try: 58 | self.vis.line( 59 | X=x, 60 | Y=y, 61 | opts={ 62 | 'title': self.name + ' loss over time', 63 | 'legend': self.loss_plot_data['legend'], 64 | 'xlabel': 'epoch', 65 | 'ylabel': 'loss'}, 66 | win=self.display_id) 67 | except ConnectionError: 68 | self.create_visdom_connections() 69 | 70 | 71 | def plot_current_validation_metrics(self, epoch, metrics): 72 | """Display the current validation metrics on visdom display: dictionary of error labels and values. 73 | 74 | Input params: 75 | epoch: Current epoch. 76 | losses: Validation metrics stored in the format of (name, float) pairs. 77 | """ 78 | if not hasattr(self, 'val_plot_data'): 79 | self.val_plot_data = {'X': [], 'Y': [], 'legend': list(metrics.keys())} 80 | self.val_plot_data['X'].append(epoch) 81 | self.val_plot_data['Y'].append([metrics[k] for k in self.val_plot_data['legend']]) 82 | x = np.squeeze(np.stack([np.array(self.val_plot_data['X'])] * len(self.val_plot_data['legend']), 1), axis=1) 83 | y = np.squeeze(np.array(self.val_plot_data['Y']), axis=1) 84 | try: 85 | self.vis.line( 86 | X=x, 87 | Y=y, 88 | opts={ 89 | 'title': self.name + ' over time', 90 | 'legend': self.val_plot_data['legend'], 91 | 'xlabel': 'epoch', 92 | 'ylabel': 'metric'}, 93 | win=self.display_id+1) 94 | except ConnectionError: 95 | self.create_visdom_connections() 96 | 97 | 98 | def plot_roc_curve(self, fpr, tpr, thresholds): 99 | """Display the ROC curve. 100 | 101 | Input params: 102 | fpr: False positive rate (1 - specificity). 103 | tpr: True positive rate (sensitivity). 104 | thresholds: Thresholds for the curve. 105 | """ 106 | try: 107 | self.vis.line( 108 | X=fpr, 109 | Y=tpr, 110 | opts={ 111 | 'title': 'ROC Curve', 112 | 'xlabel': '1 - specificity', 113 | 'ylabel': 'sensitivity', 114 | 'fillarea': True}, 115 | win=self.display_id+2) 116 | except ConnectionError: 117 | self.create_visdom_connections() 118 | 119 | 120 | def show_validation_images(self, images): 121 | """Display validation images. The images have to be in the form of a tensor with 122 | [(image, label, prediction), (image, label, prediction), ...] in the 0-th dimension. 123 | """ 124 | # zip the images together so that always the image is followed by label is followed by prediction 125 | images = images.permute(1,0,2,3) 126 | images = images.reshape((images.shape[0]*images.shape[1],images.shape[2],images.shape[3])) 127 | 128 | # add a channel dimension to the tensor since the excepted format by visdom is (B,C,H,W) 129 | images = images[:,None,:,:] 130 | 131 | try: 132 | self.vis.images(images, win=self.display_id+3, nrow=3) 133 | except ConnectionError: 134 | self.create_visdom_connections() 135 | 136 | 137 | def print_current_losses(self, epoch, max_epochs, iter, max_iters, losses): 138 | """Print current losses on console. 139 | 140 | Input params: 141 | epoch: Current epoch. 142 | max_epochs: Maximum number of epochs. 143 | iter: Iteration in epoch. 144 | max_iters: Number of iterations in epoch. 145 | losses: Training losses stored in the format of (name, float) pairs 146 | """ 147 | message = '[epoch: {}/{}, iter: {}/{}] '.format(epoch, max_epochs, iter, max_iters) 148 | for k, v in losses.items(): 149 | message += '{0}: {1:.6f} '.format(k, v) 150 | 151 | print(message) # print the message 152 | -------------------------------------------------------------------------------- /validate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from datasets import create_dataset 3 | from utils import parse_configuration 4 | from models import create_model 5 | import os 6 | 7 | """Performs validation of a specified model. 8 | 9 | Input params: 10 | config_file: Either a string with the path to the JSON 11 | system-specific config file or a dictionary containing 12 | the system-specific, dataset-specific and 13 | model-specific settings. 14 | """ 15 | def validate(config_file): 16 | print('Reading config file...') 17 | configuration = parse_configuration(config_file) 18 | 19 | print('Initializing dataset...') 20 | val_dataset = create_dataset(configuration['val_dataset_params']) 21 | val_dataset_size = len(val_dataset) 22 | print('The number of validation samples = {0}'.format(val_dataset_size)) 23 | 24 | print('Initializing model...') 25 | model = create_model(configuration['model_params']) 26 | model.setup() 27 | model.eval() 28 | 29 | model.pre_epoch_callback(configuration['model_params']['load_checkpoint']) 30 | 31 | for i, data in enumerate(val_dataset): 32 | model.set_input(data) # unpack data from data loader 33 | model.test() # run inference 34 | 35 | model.post_epoch_callback(configuration['model_params']['load_checkpoint']) 36 | 37 | if __name__ == '__main__': 38 | parser = argparse.ArgumentParser(description='Perform model validation.') 39 | parser.add_argument('configfile', help='path to the configfile') 40 | 41 | args = parser.parse_args() 42 | validate(args.configfile) 43 | --------------------------------------------------------------------------------