├── .gitignore ├── LICENSE ├── README.md ├── inria_submit.py ├── lib ├── augmentations.py ├── common.py ├── datasets │ ├── Inria.py │ ├── dsb2018.py │ └── shapes.py ├── losses.py ├── metrics.py ├── models │ ├── afterburner.py │ ├── dilated_linknet.py │ ├── dilated_resnet.py │ ├── duc_hdc.py │ ├── gcn152.py │ ├── linknet.py │ ├── linknext.py │ ├── psp_net.py │ ├── squeezenet.py │ ├── tiramisu.py │ ├── unet.py │ ├── unet11.py │ ├── unet16.py │ ├── unet_abn.py │ ├── wider_resnet.py │ └── zf_unet.py ├── modules │ └── abn │ │ ├── __init__.py │ │ ├── bn.py │ │ ├── dense.py │ │ ├── functions.py │ │ ├── misc.py │ │ └── residual.py ├── numpy_losses.py ├── tiles.py └── train_utils.py ├── loss_plot.png ├── plot.py ├── plot_loss.py ├── results └── dsb2018_bce_all.png ├── run_all.cmd ├── test.py ├── torch_train.py ├── torch_train_ab.py └── torch_train_reg.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *.cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # Jupyter Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # SageMath parsed files 79 | *.sage.py 80 | 81 | # dotenv 82 | .env 83 | 84 | # virtualenv 85 | .venv 86 | venv/ 87 | ENV/ 88 | 89 | # Spyder project settings 90 | .spyderproject 91 | .spyproject 92 | 93 | # Rope project settings 94 | .ropeproject 95 | 96 | # mkdocs documentation 97 | /site 98 | 99 | # mypy 100 | .mypy_cache/ 101 | .idea/ 102 | experiments/ 103 | runs/ 104 | submits/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 BloodAxe 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Segmentation networks benchmark 2 | 3 | Evaluation framework for testing segmentation networks in PyTorch. 4 | What segmentation network to choose for next Kaggle competition? This benchmark knows the answer! 5 | 6 | # Deprecation notice 7 | 8 | *This repository is not maintained. Please refer to https://github.com/BloodAxe/pytorch-toolbelt instead.* 9 | 10 | ## What all this code is about? 11 | 12 | It tries to show pros & cons of many existing segmentation networks implemented in Keras and PyTorch for different applications (biomed, sattelite, autonomous driving, etc). 13 | Briefly, it does the following: 14 | 15 | ``` 16 | for model in [Unet, Tiramisu, DenseNet, ...]: 17 | for dataset in [COCO, LUNA, STARE, ...]: 18 | for optimizer in [SGD, Adam]: 19 | history = train(model, dataset, optimizer) 20 | results.append(history) 21 | 22 | summarize(results) 23 | ``` 24 | 25 | ## Roadmap 26 | 27 | - [x] Write Keras train pipeline 28 | - [x] Write Pytorch train pipeline 29 | 30 | ### Models 31 | 32 | - [x] Add ZF_UNET model (https://github.com/ZFTurbo/ZF_UNET_224_Pretrained_Model) 33 | - [x] Add LinkNet model 34 | - [x] Add Tiramisu model (https://github.com/0bserver07/One-Hundred-Layers-Tiramisu) 35 | - [ ] Add SegCaps model 36 | - [x] Add VGG11,VGG16,AlbuNet models (https://github.com/ternaus/TernausNet) 37 | - [x] Add FCDenseNet model (https://github.com/bfortuner/pytorch_tiramisu) 38 | 39 | ### Datasets 40 | 41 | - [x] Add DSB2018 (stage1) dataset 42 | - [ ] Add COCO dataset 43 | - [ ] Add STARE dataset 44 | - [ ] Add LUNA16 dataset 45 | - [ ] Add Inria dataset 46 | - [ ] Add Cityscapes dataset 47 | - [ ] Add PASCAL VOC2012 dataset 48 | 49 | ### Reporting 50 | 51 | - [ ] Add fancy plots 52 | 53 | 54 | # Credits 55 | 56 | * https://github.com/ZFTurbo/ZF_UNET_224_Pretrained_Model 57 | * https://github.com/ternaus/TernausNet 58 | * https://github.com/0bserver07/One-Hundred-Layers-Tiramisu 59 | * https://github.com/bfortuner/pytorch_tiramisu 60 | * https://raw.githubusercontent.com/ZijunDeng/pytorch-semantic-segmentation 61 | * https://github.com/mapillary/inplace_abn 62 | -------------------------------------------------------------------------------- /inria_submit.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from multiprocessing.pool import Pool 3 | 4 | import cv2 5 | import os.path 6 | import sys 7 | 8 | from torch.backends import cudnn 9 | 10 | from lib import augmentations as aug 11 | import numpy as np 12 | import pandas as pd 13 | import torch 14 | from torch import nn 15 | from tensorboardX import SummaryWriter 16 | from torch.optim import Optimizer 17 | from torch.utils.data import DataLoader 18 | from torchvision.utils import make_grid 19 | from tqdm import tqdm 20 | import torch_train as TT 21 | from lib.common import find_in_dir, read_rgb, InMemoryDataset 22 | from lib.datasets.Inria import INRIA, INRIA_MEAN, INRIA_STD 23 | from lib.datasets.dsb2018 import DSB2018Sliced 24 | from lib.losses import JaccardLoss, FocalLossBinary, BCEWithLogitsLossAndSmoothJaccard, BCEWithSigmoidLoss 25 | from lib.metrics import JaccardScore, PixelAccuracy 26 | from lib.models import linknet, unet16, unet11 27 | from lib.models.duc_hdc import ResNetDUCHDC, ResNetDUC 28 | from lib.models.gcn152 import GCN152, GCN34 29 | from lib.models.psp_net import PSPNet 30 | from lib.models.tiramisu import FCDenseNet67 31 | from lib.models.unet import UNet 32 | from lib.models.zf_unet import ZF_UNET 33 | from lib.tiles import ImageSlicer 34 | from lib.train_utils import AverageMeter, auto_file 35 | 36 | tqdm.monitor_interval = 0 # Workaround for https://github.com/tqdm/tqdm/issues/481 37 | 38 | 39 | def get_dataset(dataset_name, dataset_dir, grayscale, patch_size, keep_in_mem=False): 40 | dataset_name = dataset_name.lower() 41 | 42 | if dataset_name == 'inria': 43 | return INRIA(dataset_dir, grayscale, patch_size, keep_in_mem) 44 | 45 | if dataset_name == 'dsb2018': 46 | return DSB2018Sliced(dataset_dir, grayscale, patch_size) 47 | 48 | raise ValueError(dataset_name) 49 | 50 | 51 | def preduct(model, loss, optimizer, dataloader, epoch: int, metrics={}, summary_writer=None): 52 | losses = AverageMeter() 53 | 54 | train_scores = {} 55 | for key, _ in metrics.items(): 56 | train_scores[key] = AverageMeter() 57 | 58 | with torch.set_grad_enabled(True): 59 | model.train() 60 | n_batches = len(dataloader) 61 | with tqdm(total=n_batches) as tq: 62 | tq.set_description('Train') 63 | x = None 64 | y = None 65 | outputs = None 66 | batch_loss = None 67 | 68 | for batch_index, (x, y) in enumerate(dataloader): 69 | x, y = x.cuda(non_blocking=True), y.cuda(non_blocking=True) 70 | 71 | # zero the parameter gradients 72 | optimizer.zero_grad() 73 | 74 | # forward + backward + optimize 75 | outputs = model(x) 76 | 77 | batch_loss = loss(outputs, y) 78 | 79 | batch_size = x.size(0) 80 | (batch_size * batch_loss).backward() 81 | 82 | optimizer.step() 83 | 84 | # Batch train end 85 | # Log train progress 86 | 87 | batch_loss_val = batch_loss.cpu().item() 88 | if summary_writer is not None: 89 | summary_writer.add_scalar('train/batch/loss', batch_loss_val, epoch * n_batches + batch_index) 90 | 91 | # Plot gradient absmax to see if there are any gradient explosions 92 | grad_max = 0 93 | for name, param in model.named_parameters(): 94 | if param.grad is not None: 95 | grad_max = max(grad_max, param.grad.abs().max().cpu().item()) 96 | summary_writer.add_scalar('train/grad/global_max', grad_max, epoch * n_batches + batch_index) 97 | 98 | losses.update(batch_loss_val) 99 | 100 | for key, metric in metrics.items(): 101 | score = metric(outputs, y).cpu().item() 102 | train_scores[key].update(score) 103 | 104 | if summary_writer is not None: 105 | summary_writer.add_scalar('train/batch/' + key, score, epoch * n_batches + batch_index) 106 | 107 | tq.set_postfix(loss='{:.3f}'.format(losses.avg), **train_scores) 108 | tq.update() 109 | 110 | # End of train epoch 111 | if summary_writer is not None: 112 | summary_writer.add_image('train/image', make_grid(x.cpu(), normalize=True), epoch) 113 | summary_writer.add_image('train/y_true', make_grid(y.cpu(), normalize=True), epoch) 114 | summary_writer.add_image('train/y_pred', make_grid(outputs.sigmoid().cpu(), normalize=True), epoch) 115 | summary_writer.add_scalar('train/epoch/loss', losses.avg, epoch) 116 | for key, value in train_scores.items(): 117 | summary_writer.add_scalar('train/epoch/' + key, value.avg, epoch) 118 | 119 | # Plot histogram of parameters after each epoch 120 | for name, param in model.named_parameters(): 121 | if param.grad is not None: 122 | # Plot weighs 123 | param_data = param.data.cpu().numpy() 124 | summary_writer.add_histogram('model/' + name, param_data, epoch, bins='doane') 125 | 126 | # for m in model.modules(): 127 | # if isinstance(m, nn.Conv2d): 128 | # weights = m.weights.data.numpy() 129 | 130 | del x, y, outputs, batch_loss 131 | 132 | return losses, train_scores 133 | 134 | 135 | def validate(model, loss, dataloader, epoch: int, metrics=dict(), summary_writer: SummaryWriter = None): 136 | losses = AverageMeter() 137 | 138 | valid_scores = {} 139 | for key, _ in metrics.items(): 140 | valid_scores[key] = AverageMeter() 141 | 142 | with torch.set_grad_enabled(False): 143 | model.eval() 144 | 145 | n_batches = len(dataloader) 146 | with tqdm(total=len(dataloader)) as tq: 147 | tq.set_description('Validation') 148 | 149 | x = None 150 | y = None 151 | outputs = None 152 | batch_loss = None 153 | 154 | for batch_index, (x, y) in enumerate(dataloader): 155 | x, y = x.cuda(non_blocking=True), y.cuda(non_blocking=True) 156 | 157 | # forward + backward + optimize 158 | outputs = model(x) 159 | batch_loss = loss(outputs, y) 160 | 161 | # Log train progress 162 | 163 | batch_loss_val = batch_loss.cpu().item() 164 | if summary_writer is not None: 165 | summary_writer.add_scalar('val/batch/loss', batch_loss_val, epoch * n_batches + batch_index) 166 | 167 | losses.update(batch_loss_val) 168 | 169 | for key, metric in metrics.items(): 170 | score = metric(outputs, y).cpu().item() 171 | valid_scores[key].update(score) 172 | 173 | if summary_writer is not None: 174 | summary_writer.add_scalar('val/batch/' + key, score, epoch * n_batches + batch_index) 175 | 176 | tq.set_postfix(loss='{:.3f}'.format(losses.avg), **valid_scores) 177 | tq.update() 178 | 179 | if summary_writer is not None: 180 | summary_writer.add_image('val/image', make_grid(x.cpu(), normalize=True), epoch) 181 | summary_writer.add_image('val/y_true', make_grid(y.cpu(), normalize=True), epoch) 182 | summary_writer.add_image('val/y_pred', make_grid(outputs.sigmoid().cpu(), normalize=True), epoch) 183 | summary_writer.add_scalar('val/epoch/loss', losses.avg, epoch) 184 | for key, value in valid_scores.items(): 185 | summary_writer.add_scalar('val/epoch/' + key, value.avg, epoch) 186 | 187 | del x, y, outputs, batch_loss 188 | 189 | return losses, valid_scores 190 | 191 | 192 | def save_snapshot(model: nn.Module, optimizer: Optimizer, loss: float, epoch: int, train_history: pd.DataFrame, snapshot_file: str): 193 | torch.save({ 194 | 'model': model.state_dict(), 195 | 'optimizer': optimizer.state_dict(), 196 | 'epoch': epoch, 197 | 'loss': loss, 198 | 'train_history': train_history.to_dict(), 199 | 'args': ' '.join(sys.argv[1:]) 200 | }, snapshot_file) 201 | 202 | 203 | def restore_snapshot(model: nn.Module, optimizer: Optimizer, snapshot_file: str): 204 | checkpoint = torch.load(snapshot_file) 205 | start_epoch = checkpoint['epoch'] + 1 206 | best_loss = checkpoint['loss'] 207 | model.load_state_dict(checkpoint['model']) 208 | 209 | if optimizer is not None: 210 | optimizer.load_state_dict(checkpoint['optimizer']) 211 | 212 | train_history = pd.DataFrame.from_dict(checkpoint['train_history']) 213 | 214 | return start_epoch, train_history, best_loss 215 | 216 | 217 | def predict_full(image, model, test_transform): 218 | image, pad = aug.pad(image, 32, borderType=cv2.BORDER_REPLICATE) 219 | image, _ = test_transform(image) 220 | images = list(aug.tta_d4_aug([image])) 221 | predicts = [] 222 | 223 | for image in images: 224 | image = torch.from_numpy(np.moveaxis(image, -1, 0)).float().unsqueeze(dim=0) 225 | image = image.cuda(non_blocking=True) 226 | y = model(image) 227 | y = torch.sigmoid(y).cpu().numpy() 228 | y = np.moveaxis(y, 1, -1) 229 | y = np.squeeze(y) 230 | predicts.append(y) 231 | 232 | mask = next(aug.tta_d4_deaug(predicts)) 233 | mask = aug.unpad(mask, pad) 234 | return mask 235 | 236 | 237 | def predict_tiled(image, model, test_transform, patch_size, batch_size): 238 | image, _ = test_transform(image) 239 | 240 | slicer = ImageSlicer(image.shape, patch_size, patch_size // 2, weight='pyramid') 241 | patches = slicer.split(image) 242 | 243 | patches = aug.tta_d4_aug(patches) 244 | testset = InMemoryDataset(patches, None) 245 | trainloader = DataLoader(testset, batch_size=batch_size, shuffle=False, pin_memory=True, drop_last=False) 246 | 247 | patches_pred = [] 248 | for batch_index, x in enumerate(trainloader): 249 | x = x.cuda(non_blocking=True) 250 | y = model(x) 251 | y = torch.sigmoid(y).cpu().numpy() 252 | y = np.moveaxis(y, 1, -1) 253 | patches_pred.extend(y) 254 | 255 | patches_pred = aug.tta_d4_deaug(patches_pred) 256 | mask = slicer.merge(patches_pred, dtype=np.float32) 257 | return mask 258 | 259 | 260 | def main(): 261 | cudnn.benchmark = True 262 | 263 | parser = argparse.ArgumentParser() 264 | 265 | parser.add_argument('-g', '--grayscale', action='store_true', help='Whether to use grayscale image instead of RGB') 266 | parser.add_argument('-m', '--model', required=True, type=str, help='Name of the model') 267 | parser.add_argument('-c', '--checkpoint', required=True, type=str, help='Name of the model checkpoint') 268 | parser.add_argument('-p', '--patch-size', type=int, default=224) 269 | parser.add_argument('-b', '--batch-size', type=int, default=1, help='Batch Size during training, e.g. -b 64') 270 | parser.add_argument('-dd', '--data-dir', type=str, default='data', help='Root directory where datasets are located.') 271 | parser.add_argument('-x', '--experiment', type=str, help='Name of the experiment') 272 | parser.add_argument('-f', '--full', action='store_true') 273 | 274 | args = parser.parse_args() 275 | 276 | if args.experiment is None: 277 | args.experiment = 'inria_%s_%d_%s' % (args.model, args.patch_size, 'gray' if args.grayscale else 'rgb') 278 | 279 | experiment_dir = os.path.join('submits', args.experiment) 280 | os.makedirs(experiment_dir, exist_ok=True) 281 | 282 | model = TT.get_model(args.model, patch_size=args.patch_size, num_channels=1 if args.grayscale else 3).cuda() 283 | start_epoch, train_history, best_loss = TT.restore_snapshot(model, None, auto_file(args.checkpoint)) 284 | print('Using weights from epoch', start_epoch - 1, best_loss) 285 | 286 | test_transform = aug.Sequential([ 287 | aug.ImageOnly(aug.NormalizeImage(mean=INRIA_MEAN, std=INRIA_STD)), 288 | ]) 289 | 290 | x = sorted(find_in_dir(os.path.join(args.data_dir, 'images'))) 291 | # x = x[:10] 292 | 293 | model.eval() 294 | with torch.no_grad(): 295 | 296 | for test_fname in tqdm(x, total=len(x)): 297 | image = read_rgb(test_fname) 298 | basename = os.path.splitext(os.path.basename(test_fname))[0] 299 | 300 | if args.full: 301 | mask = predict_full(image, model, test_transform) 302 | else: 303 | mask = predict_tiled(image, model, test_transform, args.patch_size, args.batch_size) 304 | 305 | mask = ((mask > 0.5) * 255).astype(np.uint8) 306 | cv2.imwrite(os.path.join(experiment_dir, basename + '.tif'), mask) 307 | 308 | 309 | if __name__ == '__main__': 310 | main() 311 | -------------------------------------------------------------------------------- /lib/augmentations.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import random 3 | import torch 4 | 5 | import cv2 6 | import numpy as np 7 | import math 8 | 9 | 10 | class Sequential: 11 | def __init__(self, transforms): 12 | self.transforms = transforms 13 | 14 | def __call__(self, x, mask=None): 15 | for t in self.transforms: 16 | x, mask = t(x, mask) 17 | return x, mask 18 | 19 | 20 | class OneOf: 21 | def __init__(self, transforms, prob=.5): 22 | self.transforms = transforms 23 | self.prob = prob 24 | 25 | def __call__(self, x, mask=None): 26 | if random.random() < self.prob: 27 | t = random.choice(self.transforms) 28 | t.prob = 1. 29 | x, mask = t(x, mask) 30 | return x, mask 31 | 32 | 33 | class OneOrOther: 34 | def __init__(self, first, second, prob=.5): 35 | self.first = first 36 | first.prob = 1. 37 | self.second = second 38 | second.prob = 1. 39 | self.prob = prob 40 | 41 | def __call__(self, x, mask=None): 42 | if random.random() < self.prob: 43 | x, mask = self.first(x, mask) 44 | else: 45 | x, mask = self.second(x, mask) 46 | return x, mask 47 | 48 | 49 | class ImageOnly: 50 | def __init__(self, trans): 51 | self.trans = trans 52 | 53 | def __call__(self, x, mask=None): 54 | return self.trans(x), mask 55 | 56 | 57 | class MaskOnly: 58 | def __init__(self, trans): 59 | self.trans = trans 60 | 61 | def __call__(self, x, mask): 62 | return x, self.trans(mask) 63 | 64 | 65 | class RandomGrayscale(): 66 | def __init__(self, prob=0.5): 67 | self.prob = prob 68 | 69 | def __call__(self, img): 70 | if random.random() < self.prob: 71 | img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) 72 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) 73 | return img 74 | 75 | 76 | class RandomInvert: 77 | def __init__(self, prob=0.5): 78 | self.prob = prob 79 | 80 | def __call__(self, img): 81 | if random.random() < self.prob: 82 | img = img.max() - img 83 | return img 84 | 85 | 86 | class MakeBinary: 87 | def __call__(self, x): 88 | dt = x.dtype 89 | x = x > 0 90 | return x.astype(dt) 91 | 92 | 93 | class VerticalFlip: 94 | def __init__(self, prob=.5): 95 | self.prob = prob 96 | 97 | def __call__(self, img, mask=None): 98 | if random.random() < self.prob: 99 | img = np.flipud(img).copy() 100 | if mask is not None: 101 | mask = np.flipud(mask).copy() 102 | return img, mask 103 | 104 | 105 | class HorizontalFlip: 106 | def __init__(self, prob=.5): 107 | self.prob = prob 108 | 109 | def __call__(self, img, mask=None): 110 | if random.random() < self.prob: 111 | img = np.fliplr(img).copy() 112 | if mask is not None: 113 | mask = np.fliplr(mask).copy() 114 | return img, mask 115 | 116 | 117 | class Transpose: 118 | def __init__(self, prob=.5): 119 | self.prob = prob 120 | 121 | def __call__(self, img, mask=None): 122 | if random.random() < self.prob: 123 | img = img.transpose(1, 0, 2).copy() 124 | if mask is not None: 125 | mask = mask.transpose(1, 0).copy() 126 | return img, mask 127 | 128 | 129 | class RandomRotate90: 130 | def __init__(self, prob=0.5): 131 | self.prob = prob 132 | 133 | def __call__(self, img, mask=None): 134 | if random.random() < self.prob: 135 | factor = random.randint(0, 4) 136 | img = np.rot90(img, factor).copy() 137 | if mask is not None: 138 | mask = np.rot90(mask, factor).copy() 139 | return img, mask 140 | 141 | 142 | class Rotate: 143 | def __init__(self, limit=90, prob=.5): 144 | self.prob = prob 145 | self.limit = limit 146 | 147 | def __call__(self, img, mask=None): 148 | if random.random() < self.prob: 149 | angle = random.uniform(-self.limit, self.limit) 150 | 151 | height, width = img.shape[0:2] 152 | mat = cv2.getRotationMatrix2D((width / 2, height / 2), angle, 1.0) 153 | img = cv2.warpAffine(img, mat, (width, height), 154 | flags=cv2.INTER_LINEAR, 155 | borderMode=cv2.BORDER_REFLECT_101) 156 | if mask is not None: 157 | mask = cv2.warpAffine(mask, mat, (width, height), 158 | flags=cv2.INTER_LINEAR, 159 | borderMode=cv2.BORDER_REFLECT_101) 160 | 161 | return img, mask 162 | 163 | 164 | class Shift: 165 | def __init__(self, limit=4, prob=.5): 166 | self.limit = limit 167 | self.prob = prob 168 | 169 | def __call__(self, img, mask=None): 170 | if random.random() < self.prob: 171 | limit = self.limit 172 | dx = round(random.uniform(-limit, limit)) 173 | dy = round(random.uniform(-limit, limit)) 174 | 175 | height, width, channel = img.shape 176 | y1 = limit + 1 + dy 177 | y2 = y1 + height 178 | x1 = limit + 1 + dx 179 | x2 = x1 + width 180 | 181 | img1 = cv2.copyMakeBorder(img, limit + 1, limit + 1, limit + 1, limit + 1, borderType=cv2.BORDER_REFLECT_101) 182 | img = img1[y1:y2, x1:x2, :].copy() 183 | if mask is not None: 184 | msk1 = cv2.copyMakeBorder(mask, limit + 1, limit + 1, limit + 1, limit + 1, borderType=cv2.BORDER_REFLECT_101) 185 | mask = msk1[y1:y2, x1:x2, :].copy() 186 | 187 | return img, mask 188 | 189 | 190 | class ShiftScale: 191 | def __init__(self, limit=4, prob=.25): 192 | self.limit = limit 193 | self.prob = prob 194 | 195 | def __call__(self, img, mask=None): 196 | limit = self.limit 197 | if random.random() < self.prob: 198 | height, width, channel = img.shape 199 | assert (width == height) 200 | size0 = width 201 | size1 = width + 2 * limit 202 | size = round(random.uniform(size0, size1)) 203 | 204 | dx = round(random.uniform(0, size1 - size)) 205 | dy = round(random.uniform(0, size1 - size)) 206 | 207 | y1 = dy 208 | y2 = y1 + size 209 | x1 = dx 210 | x2 = x1 + size 211 | 212 | img1 = cv2.copyMakeBorder(img, limit, limit, limit, limit, borderType=cv2.BORDER_REFLECT_101) 213 | img = (img1[y1:y2, x1:x2, :] if size == size0 214 | else cv2.resize(img1[y1:y2, x1:x2, :], (size0, size0), interpolation=cv2.INTER_LINEAR)) 215 | 216 | if mask is not None: 217 | msk1 = cv2.copyMakeBorder(mask, limit, limit, limit, limit, borderType=cv2.BORDER_REFLECT_101) 218 | mask = (msk1[y1:y2, x1:x2, :] if size == size0 219 | else cv2.resize(msk1[y1:y2, x1:x2, :], (size0, size0), interpolation=cv2.INTER_LINEAR)) 220 | 221 | return img, mask 222 | 223 | 224 | class ShiftScaleRotate: 225 | def __init__(self, shift_limit=0.0625, scale_limit=0.1, rotate_limit=45, prob=0.5): 226 | self.shift_limit = shift_limit 227 | self.scale_limit = scale_limit 228 | self.rotate_limit = rotate_limit 229 | self.prob = prob 230 | 231 | def __call__(self, img, mask=None): 232 | if random.random() < self.prob: 233 | height, width, channel = img.shape 234 | 235 | angle = random.uniform(-self.rotate_limit, self.rotate_limit) 236 | scale = random.uniform(1 - self.scale_limit, 1 + self.scale_limit) 237 | dx = round(random.uniform(-self.shift_limit, self.shift_limit)) * width 238 | dy = round(random.uniform(-self.shift_limit, self.shift_limit)) * height 239 | 240 | cc = math.cos(angle / 180 * math.pi) * scale 241 | ss = math.sin(angle / 180 * math.pi) * scale 242 | rotate_matrix = np.array([[cc, -ss], [ss, cc]]) 243 | 244 | box0 = np.array([[0, 0], [width, 0], [width, height], [0, height], ]) 245 | box1 = box0 - np.array([width / 2, height / 2]) 246 | box1 = np.dot(box1, rotate_matrix.T) + np.array([width / 2 + dx, height / 2 + dy]) 247 | 248 | box0 = box0.astype(np.float32) 249 | box1 = box1.astype(np.float32) 250 | mat = cv2.getPerspectiveTransform(box0, box1) 251 | img = cv2.warpPerspective(img, mat, (width, height), 252 | flags=cv2.INTER_LINEAR, 253 | borderMode=cv2.BORDER_REFLECT_101) 254 | if mask is not None: 255 | mask = cv2.warpPerspective(mask, mat, (width, height), 256 | flags=cv2.INTER_LINEAR, 257 | borderMode=cv2.BORDER_REFLECT_101) 258 | 259 | return img, mask 260 | 261 | 262 | class CenterCrop: 263 | def __init__(self, height, width): 264 | self.height = height 265 | self.width = width 266 | 267 | def __call__(self, img, mask=None): 268 | h, w, c = img.shape 269 | dy = (h - self.height) // 2 270 | dx = (w - self.width) // 2 271 | 272 | y1 = dy 273 | y2 = y1 + self.height 274 | x1 = dx 275 | x2 = x1 + self.width 276 | img = img[y1:y2, x1:x2].copy() 277 | if mask is not None: 278 | mask = mask[y1:y2, x1:x2].copy() 279 | 280 | return img, mask 281 | 282 | 283 | class RandomCrop(object): 284 | """Crop the given Image at a random location. 285 | 286 | Args: 287 | size (sequence or int): Desired output size of the crop. If size is an 288 | int instead of sequence like (h, w), a square crop (size, size) is 289 | made. 290 | padding (int or sequence, optional): Optional padding on each border 291 | of the image. Default is 0, i.e no padding. If a sequence of length 292 | 4 is provided, it is used to pad left, top, right, bottom borders 293 | respectively. 294 | """ 295 | 296 | def __init__(self, size, padding=0): 297 | if isinstance(size, numbers.Number): 298 | self.size = (int(size), int(size)) 299 | else: 300 | self.size = size 301 | 302 | self.padding = padding 303 | 304 | @staticmethod 305 | def get_params(img, output_size): 306 | """Get parameters for ``crop`` for a random crop. 307 | 308 | Args: 309 | img (PIL Image): Image to be cropped. 310 | output_size (tuple): Expected output size of the crop. 311 | 312 | Returns: 313 | tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. 314 | """ 315 | h, w = img.shape[:2] 316 | th, tw = output_size 317 | if w == tw and h == th: 318 | return 0, 0, h, w 319 | 320 | i = random.randint(0, h - th) 321 | j = random.randint(0, w - tw) 322 | return i, j, th, tw 323 | 324 | def __call__(self, x, mask=None): 325 | """ 326 | Args: 327 | img: Image to be cropped. 328 | 329 | Returns: 330 | : Cropped image. 331 | """ 332 | if self.padding > 0: 333 | x = np.pad(x, self.padding, 'constant') 334 | 335 | i, j, h, w = self.get_params(x, self.size) 336 | 337 | x = x[i:i + h, j:j + w].copy() 338 | 339 | if mask is not None: 340 | if self.padding > 0: 341 | mask = np.pad(mask, self.padding, 'constant') 342 | mask = mask[i:i + h, j:j + w].copy() 343 | 344 | return x, mask 345 | 346 | 347 | def clip(img, dtype, maxval): 348 | return np.clip(img, 0, maxval).astype(dtype) 349 | 350 | 351 | class RandomFilter: 352 | """ 353 | blur sharpen, etc 354 | """ 355 | 356 | def __init__(self, limit=.5, prob=.5): 357 | self.limit = limit 358 | self.prob = prob 359 | 360 | def __call__(self, img): 361 | if random.random() < self.prob: 362 | alpha = self.limit * random.uniform(0, 1) 363 | kernel = np.ones((3, 3), np.float32) / 9 * 0.2 364 | 365 | colored = img[..., :3] 366 | colored = alpha * cv2.filter2D(colored, -1, kernel) + (1 - alpha) * colored 367 | maxval = np.max(img[..., :3]) 368 | dtype = img.dtype 369 | img[..., :3] = clip(colored, dtype, maxval) 370 | 371 | return img 372 | 373 | 374 | # https://github.com/pytorch/vision/pull/27/commits/659c854c6971ecc5b94dca3f4459ef2b7e42fb70 375 | # color augmentation 376 | 377 | # brightness, contrast, saturation------------- 378 | # from mxnet code, see: https://github.com/dmlc/mxnet/blob/master/python/mxnet/image.py 379 | 380 | class RandomBrightness: 381 | def __init__(self, limit=0.1, prob=0.5): 382 | self.limit = limit 383 | self.prob = prob 384 | 385 | def __call__(self, img): 386 | if random.random() < self.prob: 387 | alpha = 1.0 + self.limit * random.uniform(-1, 1) 388 | 389 | maxval = np.max(img[..., :3]) 390 | dtype = img.dtype 391 | img[..., :3] = clip(alpha * img[..., :3], dtype, maxval) 392 | return img 393 | 394 | 395 | class RandomContrast: 396 | def __init__(self, limit=.1, prob=.5): 397 | self.limit = limit 398 | self.prob = prob 399 | 400 | def __call__(self, img): 401 | if random.random() < self.prob: 402 | alpha = 1.0 + self.limit * random.uniform(-1, 1) 403 | 404 | gray = cv2.cvtColor(img[:, :, :3], cv2.COLOR_BGR2GRAY) 405 | gray = (3.0 * (1.0 - alpha) / gray.size) * np.sum(gray) 406 | maxval = np.max(img[..., :3]) 407 | dtype = img.dtype 408 | img[:, :, :3] = clip(alpha * img[:, :, :3] + gray, dtype, maxval) 409 | return img 410 | 411 | 412 | class RandomSaturation: 413 | def __init__(self, limit=0.3, prob=0.5): 414 | self.limit = limit 415 | self.prob = prob 416 | 417 | def __call__(self, img): 418 | # dont work :( 419 | if random.random() < self.prob: 420 | maxval = np.max(img[..., :3]) 421 | dtype = img.dtype 422 | alpha = 1.0 + random.uniform(-self.limit, self.limit) 423 | gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 424 | gray = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR) 425 | img[..., :3] = alpha * img[..., :3] + (1.0 - alpha) * gray 426 | img[..., :3] = clip(img[..., :3], dtype, maxval) 427 | return img 428 | 429 | 430 | class RandomHueSaturationValue: 431 | def __init__(self, hue_shift_limit=(-10, 10), sat_shift_limit=(-25, 25), val_shift_limit=(-25, 25), prob=0.5): 432 | self.hue_shift_limit = hue_shift_limit 433 | self.sat_shift_limit = sat_shift_limit 434 | self.val_shift_limit = val_shift_limit 435 | self.prob = prob 436 | 437 | def __call__(self, image): 438 | if random.random() < self.prob: 439 | image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) 440 | h, s, v = cv2.split(image) 441 | hue_shift = np.random.uniform(self.hue_shift_limit[0], self.hue_shift_limit[1]) 442 | h = cv2.add(h, hue_shift) 443 | sat_shift = np.random.uniform(self.sat_shift_limit[0], self.sat_shift_limit[1]) 444 | s = cv2.add(s, sat_shift) 445 | val_shift = np.random.uniform(self.val_shift_limit[0], self.val_shift_limit[1]) 446 | v = cv2.add(v, val_shift) 447 | image = cv2.merge((h, s, v)) 448 | image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) 449 | return image 450 | 451 | 452 | class NormalizeImage: 453 | def __init__(self, scale=1. / 255., mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): 454 | self.scale = float(scale) 455 | self.mean = np.array(mean, dtype=np.float32) 456 | self.std = np.array(std, dtype=np.float32) 457 | 458 | def __call__(self, x): 459 | x = (x * self.scale - self.mean) / self.std 460 | return x 461 | 462 | 463 | class CLAHE: 464 | def __init__(self, clipLimit=2.0, tileGridSize=(8, 8)): 465 | self.clipLimit = clipLimit 466 | self.tileGridSize = tileGridSize 467 | 468 | def __call__(self, im): 469 | img_yuv = cv2.cvtColor(im, cv2.COLOR_BGR2YUV) 470 | clahe = cv2.createCLAHE(clipLimit=self.clipLimit, tileGridSize=self.tileGridSize) 471 | img_yuv[:, :, 0] = clahe.apply(img_yuv[:, :, 0]) 472 | img_output = cv2.cvtColor(img_yuv, cv2.COLOR_YUV2BGR) 473 | return img_output 474 | 475 | 476 | def tta_d4_aug(images): 477 | res = [] 478 | for image in images: 479 | res.extend([ 480 | image, 481 | np.rot90(image, 1), 482 | np.rot90(image, 2), 483 | np.rot90(image, 3), 484 | 485 | np.fliplr(image), 486 | np.fliplr(np.rot90(image, 1)), 487 | np.fliplr(np.rot90(image, 2)), 488 | np.fliplr(np.rot90(image, 3)) 489 | ]) 490 | 491 | return res 492 | 493 | 494 | def tta_d4_deaug(image_list): 495 | assert len(image_list) % 8 == 0 496 | res = [] 497 | one_over_8 = float(1./8.) 498 | 499 | for i in range(0, len(image_list), 8): 500 | img = (image_list[i + 0] + 501 | np.rot90(image_list[i + 1], -1)+ 502 | np.rot90(image_list[i + 2], -2)+ 503 | np.rot90(image_list[i + 3], -3)+ 504 | np.fliplr(image_list[i + 4])+ 505 | np.rot90(np.fliplr(image_list[i + 5]), -1)+ 506 | np.rot90(np.fliplr(image_list[i + 6]), -2)+ 507 | np.rot90(np.fliplr(image_list[i + 7]), -3)) * one_over_8 508 | 509 | res.append(img) 510 | 511 | return res 512 | 513 | def pad(image, pad_size: int, **kwargs): 514 | rows, cols = image.shape[:2] 515 | 516 | pad_rows = rows % pad_size 517 | pad_cols = cols % pad_size 518 | 519 | if pad_rows == 0 and pad_cols == 0: 520 | return image, (0, 0, 0, 0) 521 | 522 | pad_rows = pad_size - pad_rows 523 | pad_cols = pad_size - pad_cols 524 | 525 | pad_top = pad_rows // 2 526 | pad_btm = pad_rows - pad_top 527 | 528 | pad_left = pad_cols // 2 529 | pad_right = pad_cols - pad_left 530 | 531 | image = cv2.copyMakeBorder(image, pad_top, pad_btm, pad_left, pad_right, **kwargs) 532 | return image, (pad_top, pad_btm, pad_left, pad_right) 533 | 534 | 535 | def unpad(image, pad): 536 | pad_top, pad_btm, pad_left, pad_right = pad 537 | rows, cols = image.shape[:2] 538 | return image[pad_top:rows - pad_btm, pad_left: cols - pad_right] 539 | -------------------------------------------------------------------------------- /lib/common.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import Dataset, ConcatDataset 8 | from torchvision.utils import make_grid 9 | 10 | from lib.tiles import ImageSlicer 11 | 12 | cuda_is_available = torch.cuda.is_available() 13 | 14 | 15 | def maybe_cuda(x): 16 | return x.cuda() if cuda_is_available else x 17 | 18 | 19 | def count_parameters(model): 20 | total = sum(p.numel() for p in model.parameters()) 21 | trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) 22 | return total, trainable 23 | 24 | 25 | def show_landmarks_batch(data): 26 | x, y = data 27 | 28 | grid_x = make_grid(x, normalize=True, scale_each=True) 29 | grid_y = make_grid(y, normalize=True, scale_each=True) 30 | f, (ax1, ax2) = plt.subplots(2, 1) 31 | 32 | ax1.imshow(grid_x.numpy().transpose((1, 2, 0))) 33 | ax2.imshow(grid_y.numpy().transpose((1, 2, 0))) 34 | 35 | plt.title('Batch from dataloader') 36 | plt.show() 37 | 38 | 39 | def find_in_dir(dirname): 40 | return [os.path.join(dirname, fname) for fname in os.listdir(dirname)] 41 | 42 | 43 | def read_rgb(fname): 44 | x = cv2.imread(fname, cv2.IMREAD_COLOR) 45 | return x 46 | 47 | 48 | def read_mask(fname): 49 | x = cv2.imread(fname, cv2.IMREAD_GRAYSCALE) 50 | return x 51 | 52 | 53 | class InMemoryDataset(Dataset): 54 | def __init__(self, images, masks, transform=None): 55 | self.images = images 56 | self.masks = masks 57 | self.transform = transform 58 | 59 | def __getitem__(self, index): 60 | i = self.images[index].copy() 61 | 62 | if self.masks is not None: 63 | m = self.masks[index].copy() 64 | else: 65 | m = None 66 | 67 | if self.transform is not None: 68 | i, m = self.transform(i, m) 69 | 70 | i = torch.from_numpy(np.moveaxis(i, -1, 0)).float() 71 | 72 | if self.masks is not None: 73 | m = torch.from_numpy(np.expand_dims(m, 0)).long() 74 | return i, m 75 | else: 76 | return i 77 | 78 | def __len__(self): 79 | return len(self.images) 80 | 81 | 82 | class ImageMaskDataset(Dataset): 83 | def __init__(self, image_filenames, target_filenames, image_loader, target_loader, transform=None, 84 | load_in_ram=False): 85 | if len(image_filenames) != len(target_filenames): 86 | raise ValueError('Number of images does not corresponds to number of targets') 87 | 88 | if load_in_ram: 89 | self.image_filenames = [image_loader(fname) for fname in image_filenames] 90 | self.target_filenames = [target_loader(fname) for fname in target_filenames] 91 | self.image_loader = lambda x: x 92 | self.target_loader = lambda x: x 93 | else: 94 | self.image_filenames = image_filenames 95 | self.target_filenames = target_filenames 96 | self.image_loader = image_loader 97 | self.target_loader = target_loader 98 | 99 | self.transform = transform 100 | 101 | def __len__(self): 102 | return len(self.image_filenames) 103 | 104 | def __getitem__(self, index): 105 | image = self.image_loader(self.image_filenames[index]) 106 | mask = self.target_loader(self.target_filenames[index]) 107 | 108 | if self.transform is not None: 109 | image, mask = self.transform(image, mask) 110 | 111 | image = torch.from_numpy(np.moveaxis(image, -1, 0).copy()).float() 112 | mask = torch.from_numpy(np.expand_dims(mask, 0)).long() 113 | return image, mask 114 | 115 | 116 | class TiledImageDataset(Dataset): 117 | def __init__(self, image_fname, mask_fname, tile_size, tile_step=0, image_margin=0, transform=None, 118 | target_shape=None, 119 | keep_in_mem=False): 120 | self.image_fname = image_fname 121 | self.mask_fname = mask_fname 122 | 123 | self.image = None 124 | self.mask = None 125 | 126 | if target_shape is None or keep_in_mem: 127 | image = read_rgb(image_fname) 128 | mask = read_mask(mask_fname) 129 | if image.shape[0] != mask.shape[0] or image.shape[1] != mask.shape[1]: 130 | raise ValueError() 131 | 132 | target_shape = image.shape 133 | 134 | if keep_in_mem: 135 | self.image = image 136 | self.mask = mask 137 | 138 | if tile_step <= 0: 139 | tile_step = tile_size // 2 140 | 141 | self.slicer = ImageSlicer(target_shape, tile_size, tile_step, image_margin) 142 | self.transform = transform 143 | 144 | def __len__(self): 145 | return len(self.slicer.crops) 146 | 147 | def __getitem__(self, index): 148 | image = self.image if self.image is not None else read_rgb(self.image_fname) 149 | mask = self.mask if self.mask is not None else read_mask(self.mask_fname) 150 | 151 | image = self.slicer.cut_patch(image, index).copy() 152 | mask = self.slicer.cut_patch(mask, index).copy() 153 | 154 | if self.transform is not None: 155 | image, mask = self.transform(image, mask) 156 | 157 | image = torch.from_numpy(np.moveaxis(image, -1, 0).copy()).float() 158 | mask = torch.from_numpy(np.expand_dims(mask, 0)).long() 159 | return image, mask 160 | 161 | 162 | class TiledImagesDataset(ConcatDataset): 163 | def __init__(self, image_filenames, target_filenames, tile_size, tile_step=0, image_margin=0, 164 | target_shape=None, 165 | transform=None, 166 | keep_in_mem=False): 167 | if len(image_filenames) != len(target_filenames): 168 | raise ValueError('Number of images does not corresponds to number of targets') 169 | 170 | datasets = [ 171 | TiledImageDataset(image, mask, tile_size, tile_step, image_margin, transform, target_shape=target_shape, keep_in_mem=keep_in_mem) for 172 | image, mask in zip(image_filenames, target_filenames)] 173 | super().__init__(datasets) 174 | -------------------------------------------------------------------------------- /lib/datasets/Inria.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | 5 | from sklearn.model_selection import train_test_split 6 | from tqdm import tqdm 7 | 8 | from lib import augmentations as aug 9 | from lib.common import find_in_dir, TiledImagesDataset, read_rgb, read_mask, ImageMaskDataset 10 | from lib.tiles import ImageSlicer 11 | 12 | 13 | def compute_mean_std(dataset): 14 | """ 15 | https://stats.stackexchange.com/questions/25848/how-to-sum-a-standard-deviation 16 | """ 17 | one_over_255 = float(1. / 255.) 18 | 19 | global_mean = np.zeros(3, dtype=np.float64) 20 | global_var = np.zeros(3, dtype=np.float64) 21 | 22 | n_items = len(dataset) 23 | 24 | for image_fname in dataset: 25 | x = read_rgb(image_fname) * one_over_255 26 | mean, stddev = cv2.meanStdDev(x) 27 | 28 | global_mean += np.squeeze(mean) 29 | global_var += np.squeeze(stddev) ** 2 30 | 31 | return global_mean / n_items, np.sqrt(global_var) 32 | 33 | 34 | INRIA_MEAN = [0.40273115, 0.45046371, 0.42960134] 35 | INRIA_STD = [3.15086464, 3.29831641, 3.63201004] 36 | 37 | 38 | def INRIA(dataset_dir, grayscale, patch_size, keep_in_mem, small=False): 39 | x = sorted(find_in_dir(os.path.join(dataset_dir, 'images'))) 40 | y = sorted(find_in_dir(os.path.join(dataset_dir, 'gt'))) 41 | 42 | if small: 43 | x = x[:4] 44 | y = y[:4] 45 | 46 | x_train, x_test, y_train, y_test = train_test_split(x, y, random_state=1234, test_size=0.1) 47 | 48 | train_transform = aug.Sequential([ 49 | aug.ImageOnly(aug.RandomGrayscale(1.0 if grayscale else 0.5)), 50 | aug.ImageOnly(aug.RandomBrightness()), 51 | aug.ImageOnly(aug.RandomContrast()), 52 | aug.VerticalFlip(), 53 | aug.HorizontalFlip(), 54 | aug.ShiftScaleRotate(rotate_limit=15), 55 | aug.ImageOnly(aug.NormalizeImage(mean=INRIA_MEAN, std=INRIA_STD)), 56 | aug.MaskOnly(aug.MakeBinary()) 57 | ]) 58 | 59 | test_transform = aug.Sequential([ 60 | aug.ImageOnly(aug.NormalizeImage(mean=INRIA_MEAN, std=INRIA_STD)), 61 | aug.MaskOnly(aug.MakeBinary()) 62 | ]) 63 | 64 | train = TiledImagesDataset(x_train, y_train, patch_size, target_shape=(5000, 5000), transform=train_transform, keep_in_mem=keep_in_mem) 65 | test = TiledImagesDataset(x_test, y_test, patch_size, target_shape=(5000, 5000), transform=test_transform, keep_in_mem=keep_in_mem) 66 | num_classes = 1 67 | return train, test, num_classes 68 | 69 | 70 | def INRIASliced(dataset_dir, grayscale): 71 | x = sorted(find_in_dir(os.path.join(dataset_dir, 'images'))) 72 | y = sorted(find_in_dir(os.path.join(dataset_dir, 'gt'))) 73 | image_id = [os.path.basename(fname).split('_')[0] for fname in x] 74 | 75 | unique_image_id = np.unique(image_id) 76 | location = [basename[:6] for basename in unique_image_id] # Geocode is first 6 characters 77 | train_id, test_id = train_test_split(unique_image_id, random_state=1234, test_size=0.1, stratify=location) 78 | 79 | xy_train = [(image_fname, mask_fname) for image_fname, mask_fname, image_id in zip(x, y, image_id) if image_id in train_id] 80 | xy_test = [(image_fname, mask_fname) for image_fname, mask_fname, image_id in zip(x, y, image_id) if image_id in test_id] 81 | 82 | x_train, y_train = zip(*xy_train) 83 | x_test, y_test = zip(*xy_test) 84 | 85 | train_transform = aug.Sequential([ 86 | aug.ImageOnly(aug.RandomGrayscale(1.0 if grayscale else 0.5)), 87 | aug.ImageOnly(aug.RandomBrightness()), 88 | aug.ImageOnly(aug.RandomContrast()), 89 | aug.VerticalFlip(), 90 | aug.HorizontalFlip(), 91 | aug.ShiftScaleRotate(rotate_limit=15), 92 | aug.ImageOnly(aug.NormalizeImage(mean=INRIA_MEAN, std=INRIA_STD)), 93 | aug.MaskOnly(aug.MakeBinary()) 94 | ]) 95 | 96 | test_transform = aug.Sequential([ 97 | aug.ImageOnly(aug.NormalizeImage(mean=INRIA_MEAN, std=INRIA_STD)), 98 | aug.MaskOnly(aug.MakeBinary()) 99 | ]) 100 | 101 | train = ImageMaskDataset(x_train, y_train, image_loader=read_rgb, target_loader=read_mask, transform=train_transform, load_in_ram=False) 102 | test = ImageMaskDataset(x_test, y_test, image_loader=read_rgb, target_loader=read_mask, transform=test_transform, load_in_ram=False) 103 | 104 | num_classes = 1 105 | return train, test, num_classes 106 | 107 | 108 | def cut_dataset_in_patches(data_dir, output_dir, patch_size): 109 | x = sorted(find_in_dir(os.path.join(data_dir, 'images'))) 110 | y = sorted(find_in_dir(os.path.join(data_dir, 'gt'))) 111 | 112 | out_img = os.path.join(output_dir, 'images') 113 | out_msk = os.path.join(output_dir, 'gt') 114 | os.makedirs(out_img, exist_ok=True) 115 | os.makedirs(out_msk, exist_ok=True) 116 | 117 | slicer = ImageSlicer((5000, 5000), patch_size, patch_size // 2) 118 | 119 | for image_fname, mask_fname in tqdm(zip(x, y), total=len(x)): 120 | image = read_rgb(image_fname) 121 | mask = read_mask(mask_fname) 122 | 123 | basename = os.path.basename(image_fname) 124 | basename = os.path.splitext(basename)[0] 125 | 126 | for index, patch in enumerate(slicer.split(image)): 127 | cv2.imwrite(os.path.join(out_img, '%s_%d.tif' % (basename, index)), patch) 128 | 129 | for index, patch in enumerate(slicer.split(mask)): 130 | cv2.imwrite(os.path.join(out_msk, '%s_%d.tif' % (basename, index)), patch) 131 | 132 | 133 | if __name__ == '__main__': 134 | cut_dataset_in_patches('d:/datasets/inria/train', 'd:/datasets/inria-train-1024', 1024) 135 | # dataset_dir = 'e:/datasets/inria' 136 | # train = sorted(find_in_dir(os.path.join(dataset_dir, 'train', 'images'))) 137 | # test = sorted(find_in_dir(os.path.join(dataset_dir, 'test', 'images'))) 138 | # print('train', compute_mean_std(train)) 139 | # print('test', compute_mean_std(test)) 140 | # print('both', compute_mean_std(train + test)) 141 | -------------------------------------------------------------------------------- /lib/datasets/dsb2018.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from sklearn.model_selection import train_test_split 5 | 6 | from lib.tiles import ImageSlicer 7 | from lib.common import find_in_dir, ImageMaskDataset, read_rgb, read_mask, InMemoryDataset 8 | from lib import augmentations as aug 9 | 10 | 11 | def DSB2018(dataset_dir, grayscale, patch_size): 12 | """ 13 | Returns train & test dataset or DSB2018 14 | :param dataset_dir: 15 | :param grayscale: 16 | :param patch_size: 17 | :return: 18 | """ 19 | 20 | images = find_in_dir(os.path.join(dataset_dir, 'images')) 21 | masks = find_in_dir(os.path.join(dataset_dir, 'masks')) 22 | 23 | x_train, x_test, y_train, y_test = train_test_split(images, masks, random_state=1234, test_size=0.1) 24 | 25 | train_transform = aug.Sequential([ 26 | aug.RandomCrop(patch_size), 27 | # aug.ImageOnly(aug.RandomGrayscale()), 28 | # aug.ImageOnly(aug.RandomInvert()), 29 | aug.ImageOnly(aug.NormalizeImage()), 30 | # aug.ImageOnly(aug.RandomBrightness()), 31 | # aug.ImageOnly(aug.RandomContrast()), 32 | # aug.RandomRotate90(), 33 | # aug.VerticalFlip(), 34 | # aug.HorizontalFlip(), 35 | # aug.ShiftScaleRotate(), 36 | aug.MaskOnly(aug.MakeBinary()) 37 | ]) 38 | 39 | test_transform = aug.Sequential([ 40 | aug.CenterCrop(patch_size, patch_size), 41 | # aug.ImageOnly(aug.RandomGrayscale(1.0 if grayscale else 0.0)), 42 | aug.ImageOnly(aug.NormalizeImage()), 43 | aug.MaskOnly(aug.MakeBinary()), 44 | ]) 45 | 46 | train = ImageMaskDataset(x_train, y_train, image_loader=read_rgb, target_loader=read_mask, transform=train_transform, load_in_ram=False) 47 | test = ImageMaskDataset(x_test, y_test, image_loader=read_rgb, target_loader=read_mask, transform=test_transform, load_in_ram=False) 48 | num_classes = 1 49 | return train, test, num_classes 50 | 51 | 52 | def DSB2018Sliced(dataset_dir, grayscale, patch_size): 53 | """ 54 | Returns train & test dataset or DSB2018 55 | :param dataset_dir: 56 | :param grayscale: 57 | :param patch_size: 58 | :return: 59 | """ 60 | 61 | images = [read_rgb(x) for x in find_in_dir(os.path.join(dataset_dir, 'images'))] 62 | masks = [read_mask(x) for x in find_in_dir(os.path.join(dataset_dir, 'masks'))] 63 | 64 | image_ids = [] 65 | patch_images = [] 66 | patch_masks = [] 67 | 68 | for image_id, (image, mask) in enumerate(zip(images, masks)): 69 | slicer = ImageSlicer(image.shape, patch_size, patch_size // 2) 70 | 71 | patch_images.extend(slicer.split(image)) 72 | patch_masks.extend(slicer.split(mask)) 73 | image_ids.extend([image_id] * len(slicer.crops)) 74 | 75 | x_train, x_test, y_train, y_test = train_test_split(patch_images, patch_masks, random_state=1234, test_size=0.1, stratify=image_ids) 76 | 77 | train_transform = aug.Sequential([ 78 | # aug.ImageOnly(aug.RandomGrayscale()), 79 | # aug.ImageOnly(aug.RandomInvert()), 80 | aug.ImageOnly(aug.NormalizeImage()), 81 | # aug.ImageOnly(aug.RandomBrightness()), 82 | # aug.ImageOnly(aug.RandomContrast()), 83 | aug.RandomRotate90(), 84 | aug.VerticalFlip(), 85 | aug.HorizontalFlip(), 86 | aug.ShiftScaleRotate(rotate_limit=15), 87 | aug.MaskOnly(aug.MakeBinary()) 88 | ]) 89 | 90 | test_transform = aug.Sequential([ 91 | aug.ImageOnly(aug.NormalizeImage()), 92 | aug.MaskOnly(aug.MakeBinary()) 93 | ]) 94 | 95 | train = InMemoryDataset(x_train, y_train, transform=train_transform) 96 | test = InMemoryDataset(x_test, y_test, transform=test_transform) 97 | num_classes = 1 98 | return train, test, num_classes 99 | -------------------------------------------------------------------------------- /lib/datasets/shapes.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import Dataset 7 | import lib.augmentations as aug 8 | 9 | 10 | def gen_random_image(patch_size): 11 | img = np.zeros((patch_size, patch_size, 3), dtype=np.uint8) 12 | mask = np.zeros((patch_size, patch_size), dtype=np.uint8) 13 | 14 | # Background 15 | dark_color0 = random.randint(0, 100) 16 | dark_color1 = random.randint(0, 100) 17 | dark_color2 = random.randint(0, 100) 18 | img[:, :, 0] = dark_color0 19 | img[:, :, 1] = dark_color1 20 | img[:, :, 2] = dark_color2 21 | 22 | # Object 23 | light_color0 = random.randint(dark_color0 + 1, 255) 24 | light_color1 = random.randint(dark_color1 + 1, 255) 25 | light_color2 = random.randint(dark_color2 + 1, 255) 26 | center_0 = random.randint(0, patch_size) 27 | center_1 = random.randint(0, patch_size) 28 | r1 = random.randint(10, 56) 29 | r2 = random.randint(10, 56) 30 | cv2.ellipse(img, (center_0, center_1), (r1, r2), 0, 0, 360, (light_color0, light_color1, light_color2), -1) 31 | cv2.ellipse(mask, (center_0, center_1), (r1, r2), 0, 0, 360, 1, -1) 32 | 33 | # White noise 34 | density = random.uniform(0, 0.1) 35 | for i in range(patch_size): 36 | for j in range(patch_size): 37 | if random.random() < density: 38 | img[i, j, 0] = random.randint(0, 255) 39 | img[i, j, 1] = random.randint(0, 255) 40 | img[i, j, 2] = random.randint(0, 255) 41 | 42 | return img, mask 43 | 44 | 45 | class ShapesDataset(Dataset): 46 | def __init__(self, steps, patch_size, transform=aug.ImageOnly(aug.NormalizeImage())): 47 | self.transform = transform 48 | self.patch_size = patch_size 49 | self.steps = steps 50 | 51 | def __len__(self): 52 | return self.steps 53 | 54 | def __getitem__(self, item): 55 | image, mask = gen_random_image(self.patch_size) 56 | image, mask = self.transform(image, mask) 57 | 58 | image = torch.from_numpy(np.moveaxis(image, -1, 0).copy()).float() 59 | mask = torch.from_numpy(np.expand_dims(mask, 0)).long() 60 | return image, mask 61 | 62 | 63 | def SHAPES(patch_size): 64 | """ 65 | https://github.com/ZFTurbo/ZF_UNET_patch_size_Pretrained_Model/blob/master/train_infinite_generator.py 66 | :param patch_size: 67 | :return: 68 | """ 69 | return ShapesDataset(1024, patch_size), ShapesDataset(128, patch_size), 1 70 | -------------------------------------------------------------------------------- /lib/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from torch.nn import functional as F, BCEWithLogitsLoss 4 | from torch.nn.modules.loss import _Loss 5 | 6 | 7 | class DiceLoss(nn.Module): 8 | def __init__(self): 9 | super(DiceLoss, self).__init__() 10 | 11 | def forward(self, output, target): 12 | prediction = F.sigmoid(output) 13 | intersection = torch.sum(prediction * target) 14 | union = torch.sum(prediction) + torch.sum(target) + 1e-7 15 | return 1 - 2 * intersection / union 16 | 17 | 18 | class JaccardLoss(_Loss): 19 | def __init__(self): 20 | super(JaccardLoss, self).__init__() 21 | 22 | def forward(self, output, target): 23 | output = F.sigmoid(output) 24 | intersection = torch.sum(output * target) 25 | union = torch.sum(output) + torch.sum(target) 26 | 27 | jac = intersection / (union - intersection + 1e-7) 28 | return 1 - jac 29 | 30 | 31 | class SmoothJaccardLoss(_Loss): 32 | def __init__(self, smooth=100): 33 | super(SmoothJaccardLoss, self).__init__() 34 | self.smooth = smooth 35 | 36 | def forward(self, output, target): 37 | output = F.sigmoid(output) 38 | target = target.float() 39 | intersection = torch.sum(output * target) 40 | union = torch.sum(output) + torch.sum(target) 41 | 42 | jac = (intersection + self.smooth) / (union - intersection + self.smooth) 43 | return 1 - jac 44 | 45 | 46 | class BCEWithSigmoidLoss(_Loss): 47 | def __init__(self, size_average=True, reduce=True): 48 | super().__init__(size_average=size_average, reduce=reduce) 49 | 50 | def forward(self, outputs, targets): 51 | outputs = F.logsigmoid(outputs) 52 | targets = targets.float() 53 | return F.binary_cross_entropy_with_logits(outputs, targets, size_average=self.size_average, reduce=self.reduce) 54 | 55 | 56 | class BCEWithLogitsLossAndSmoothJaccard(_Loss): 57 | """ 58 | Loss defined as BCE + SmoothJaccardLoss 59 | Vladimir Iglovikov, Sergey Mushinskiy, Vladimir Osin, 60 | Satellite Imagery Feature Detection using Deep Convolutional Neural Network: A Kaggle Competition 61 | arXiv:1706.06169 62 | """ 63 | 64 | def __init__(self, bce_weight=1, jaccard_weight=0.5): 65 | super(BCEWithLogitsLossAndSmoothJaccard, self).__init__() 66 | self.bce_loss = BCEWithSigmoidLoss() 67 | self.jac_loss = SmoothJaccardLoss() 68 | 69 | self.bce_weight = bce_weight 70 | self.jaccard_weight = jaccard_weight 71 | 72 | def forward(self, outputs, targets): 73 | loss1 = self.bce_loss(outputs, targets) * self.bce_weight 74 | loss2 = self.jac_loss(outputs, targets) * self.jaccard_weight 75 | return (loss1 + loss2) / (self.bce_weight + self.jaccard_weight) 76 | 77 | 78 | class FocalLossBinary(_Loss): 79 | """Focal loss puts more weight on more complicated examples. 80 | https://github.com/warmspringwinds/pytorch-segmentation-detection/blob/master/pytorch_segmentation_detection/losses.py 81 | output is log_softmax 82 | """ 83 | 84 | def __init__(self, gamma=2, size_average=True, reduce=True): 85 | super(FocalLossBinary, self).__init__(size_average=size_average, reduce=reduce) 86 | self.gamma = gamma 87 | 88 | def forward(self, outputs: Tensor, targets: Tensor): 89 | 90 | outputs = F.logsigmoid(outputs) 91 | logpt = -F.binary_cross_entropy_with_logits(outputs, targets.float(), reduce=False) 92 | pt = torch.exp(logpt) 93 | 94 | # compute the loss 95 | loss = -((1 - pt).pow(self.gamma)) * logpt 96 | 97 | # averaging (or not) loss 98 | if self.size_average: 99 | return loss.mean() 100 | else: 101 | return loss.sum() 102 | 103 | 104 | 105 | class FocalLossMulti(_Loss): 106 | """Focal loss puts more weight on more complicated examples. 107 | https://github.com/warmspringwinds/pytorch-segmentation-detection/blob/master/pytorch_segmentation_detection/losses.py 108 | output is log_softmax 109 | """ 110 | 111 | def __init__(self, gamma=2, size_average=True, reduce=True, ignore_index=-100, from_logits=False): 112 | super(FocalLossMulti, self).__init__(size_average=size_average, reduce=reduce) 113 | self.gamma = gamma 114 | self.ignore_index = ignore_index 115 | self.from_logits = from_logits 116 | 117 | def forward(self, outputs: Tensor, targets: Tensor): 118 | 119 | if not self.from_logits: 120 | outputs = F.log_softmax(outputs, dim=1) 121 | 122 | logpt = -F.nll_loss(outputs, targets, ignore_index=self.ignore_index, reduce=False) 123 | pt = torch.exp(logpt) 124 | 125 | # compute the loss 126 | loss = -((1 - pt).pow(self.gamma)) * logpt 127 | 128 | # averaging (or not) loss 129 | if self.size_average: 130 | return loss.mean() 131 | else: 132 | return loss.sum() 133 | 134 | 135 | class JaccardLossMulti(_Loss): 136 | """ 137 | Multiclass jaccard loss 138 | """ 139 | 140 | def __init__(self, ignore_index=-100, from_logits=False, weight: Tensor = None, reduce=True): 141 | super(JaccardLossMulti, self).__init__(reduce=reduce) 142 | self.ignore_index = ignore_index 143 | self.from_logits = from_logits 144 | 145 | if weight is None: 146 | self.class_weights = None 147 | else: 148 | self.class_weights = weight / weight.sum() 149 | 150 | def forward(self, outputs: Tensor, targets: Tensor): 151 | """ 152 | 153 | :param outputs: NxCxHxW 154 | :param targets: NxHxW 155 | :return: scalar 156 | """ 157 | if self.from_logits: 158 | outputs = outputs.exp() 159 | else: 160 | outputs = F.softmax(outputs, dim=1) 161 | 162 | n_classes = outputs.size(1) 163 | mask = (targets != self.ignore_index) 164 | smooth = 100 165 | 166 | loss = torch.zeros(n_classes, dtype=torch.float).to(outputs.device) 167 | 168 | for cls_indx in range(0, outputs.size(1)): 169 | jaccard_target = (targets == cls_indx) 170 | jaccard_output = outputs[:, cls_indx] 171 | 172 | jaccard_target = torch.masked_select(jaccard_target, mask) 173 | jaccard_output = torch.masked_select(jaccard_output, mask) 174 | 175 | num_preds = jaccard_target.long().sum() 176 | 177 | if num_preds == 0: 178 | loss[cls_indx] = 0 179 | else: 180 | jaccard_target = jaccard_target.float() 181 | intersection = (jaccard_output * jaccard_target).sum() 182 | union = jaccard_output.sum() + jaccard_target.sum() 183 | jac = (intersection + smooth) / (union - intersection + smooth) 184 | loss[cls_indx] = 1 - jac 185 | 186 | if self.class_weights is not None: 187 | loss = loss * self.class_weights.to(outputs.device) 188 | 189 | if self.reduce: 190 | return loss.sum() 191 | 192 | return loss 193 | 194 | 195 | class FocalAndJaccardLossMulti(_Loss): 196 | def __init__(self, jaccard_weight=1, class_weights=None, ignore_index=-1): 197 | super(FocalAndJaccardLossMulti, self).__init__() 198 | 199 | if class_weights is not None: 200 | nll_weight = torch.from_numpy(class_weights).float() 201 | else: 202 | nll_weight = None 203 | 204 | self.focal_loss = FocalLossMulti(ignore_index=ignore_index, from_logits=True) 205 | self.jaccard_loss = JaccardLossMulti(ignore_index=ignore_index, from_logits=True, weight=nll_weight) 206 | self.jaccard_weight = jaccard_weight 207 | 208 | def forward(self, outputs, targets): 209 | outputs = F.log_softmax(outputs, dim=1) 210 | focal_loss = self.focal_loss(outputs, targets) 211 | jac_loss = self.jaccard_loss(outputs, targets) 212 | return (focal_loss + jac_loss) / (1 + self.jaccard_weight) 213 | 214 | 215 | class NLLLAndJaccardLossMulti(_Loss): 216 | def __init__(self, jaccard_weight=1, class_weights=None, ignore_index=-1): 217 | super(NLLLAndJaccardLossMulti, self).__init__() 218 | 219 | if class_weights is not None: 220 | nll_weight = torch.from_numpy(class_weights).float() 221 | else: 222 | nll_weight = None 223 | 224 | self.nll_loss = NLLLoss(weight=nll_weight, ignore_index=ignore_index) 225 | self.jaccard_loss = JaccardLossMulti(ignore_index=ignore_index, from_logits=True, weight=nll_weight) 226 | self.jaccard_weight = jaccard_weight 227 | 228 | def forward(self, outputs, targets): 229 | outputs = F.log_softmax(outputs, dim=1) 230 | nll_loss = self.nll_loss(outputs, targets) 231 | jac_loss = self.jaccard_loss(outputs, targets) 232 | return (nll_loss + jac_loss) / (1 + self.jaccard_weight) 233 | -------------------------------------------------------------------------------- /lib/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | from torch.nn.modules.loss import _Loss 5 | import numpy as np 6 | import torch.nn.functional as F 7 | 8 | 9 | class JaccardScore(_Loss): 10 | def __init__(self): 11 | super(JaccardScore, self).__init__() 12 | 13 | def forward(self, output, target): 14 | output = F.sigmoid(output) 15 | target = target.float() 16 | 17 | intersection = (output * target).sum() 18 | union = output.sum() + target.sum() 19 | jac = intersection / (union - intersection + 1e-7) 20 | return jac 21 | 22 | def __str__(self): 23 | return 'JaccardScore' 24 | 25 | 26 | class PixelAccuracy(_Loss): 27 | def __init__(self): 28 | super(PixelAccuracy, self).__init__() 29 | 30 | def forward(self, output, target): 31 | output = F.sigmoid(output) > 0.5 32 | target = target.byte() 33 | 34 | n_true = torch.eq(output, target) 35 | n_all = torch.numel(target) 36 | n_true = n_true.sum() 37 | if n_true == 0: 38 | return n_true 39 | 40 | return n_true.float() / n_all 41 | 42 | def __str__(self): 43 | return 'PixelAccuracy' 44 | -------------------------------------------------------------------------------- /lib/models/afterburner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from lib.models.unet import UNet 6 | 7 | 8 | class Afterburner(nn.Module): 9 | def __init__(self,n_channels=1): 10 | super().__init__() 11 | self.unet = UNet(n_channels=n_channels, n_classes=1) 12 | 13 | def forward(self, x): 14 | return self.unet(x) 15 | 16 | -------------------------------------------------------------------------------- /lib/models/dilated_linknet.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from torchvision import models 4 | import torchvision 5 | from torch.nn import functional as F 6 | 7 | from lib.models.dilated_resnet import dilated_resnet34 8 | 9 | 10 | class DecoderBlockLinkNet(nn.Module): 11 | def __init__(self, in_channels, n_filters): 12 | super().__init__() 13 | 14 | self.relu = nn.ReLU(inplace=True) 15 | 16 | # B, C, H, W -> B, C/4, H, W 17 | self.conv1 = nn.Conv2d(in_channels, in_channels // 4, 1) 18 | self.norm1 = nn.BatchNorm2d(in_channels // 4) 19 | 20 | # B, C/4, H, W -> B, C/4, 2 * H, 2 * W 21 | self.deconv2 = nn.ConvTranspose2d(in_channels // 4, in_channels // 4, kernel_size=4, 22 | stride=2, padding=1, output_padding=0) 23 | self.norm2 = nn.BatchNorm2d(in_channels // 4) 24 | 25 | # B, C/4, H, W -> B, C, H, W 26 | self.conv3 = nn.Conv2d(in_channels // 4, n_filters, 1) 27 | self.norm3 = nn.BatchNorm2d(n_filters) 28 | 29 | def forward(self, x): 30 | x = self.conv1(x) 31 | x = self.norm1(x) 32 | x = self.relu(x) 33 | x = self.deconv2(x) 34 | x = self.norm2(x) 35 | x = self.relu(x) 36 | x = self.conv3(x) 37 | x = self.norm3(x) 38 | x = self.relu(x) 39 | return x 40 | 41 | 42 | class DilatedLinkNet34(nn.Module): 43 | def __init__(self, num_classes=1, num_channels=3, pretrained=True): 44 | super().__init__() 45 | assert num_channels == 3 46 | self.num_classes = num_classes 47 | filters = [64, 128, 256, 512] 48 | resnet = dilated_resnet34(pretrained=pretrained) 49 | 50 | self.firstconv = resnet.conv1 51 | self.firstbn = resnet.bn1 52 | self.firstrelu = resnet.relu 53 | self.firstmaxpool = resnet.maxpool 54 | self.encoder1 = resnet.layer1 55 | self.encoder2 = resnet.layer2 56 | self.encoder3 = resnet.layer3 57 | self.encoder4 = resnet.layer4 58 | 59 | # Decoder 60 | self.decoder4 = DecoderBlockLinkNet(filters[3], filters[2]) 61 | self.decoder3 = DecoderBlockLinkNet(filters[2], filters[1]) 62 | self.decoder2 = DecoderBlockLinkNet(filters[1], filters[0]) 63 | self.decoder1 = DecoderBlockLinkNet(filters[0], filters[0]) 64 | 65 | # Final Classifier 66 | self.finaldrop1 = nn.Dropout2d(p=0.5) 67 | self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 3, stride=2) 68 | self.finalrelu1 = nn.ReLU(inplace=True) 69 | self.finalconv2 = nn.Conv2d(32, 32, 3) 70 | self.finalrelu2 = nn.ReLU(inplace=True) 71 | self.finalconv3 = nn.Conv2d(32, num_classes, 2, padding=1) 72 | 73 | # noinspection PyCallingNonCallable 74 | def forward(self, x): 75 | # Encoder 76 | x = self.firstconv(x) 77 | x = self.firstbn(x) 78 | x = self.firstrelu(x) 79 | x = self.firstmaxpool(x) 80 | e1 = self.encoder1(x) 81 | e2 = self.encoder2(e1) 82 | e3 = self.encoder3(e2) 83 | e4 = self.encoder4(e3) 84 | 85 | # Decoder with Skip Connections 86 | d4 = self.decoder4(e4) + e3 87 | d3 = self.decoder3(d4) + e2 88 | d2 = self.decoder2(d3) + e1 89 | d1 = self.decoder1(d2) 90 | 91 | # Final Classification 92 | # d1 = self.finaldrop1(d1) # Added dropout 93 | f1 = self.finaldeconv1(d1) 94 | f2 = self.finalrelu1(f1) 95 | f3 = self.finalconv2(f2) 96 | f4 = self.finalrelu2(f3) 97 | f5 = self.finalconv3(f4) 98 | 99 | return f5 100 | -------------------------------------------------------------------------------- /lib/models/dilated_resnet.py: -------------------------------------------------------------------------------- 1 | """Dilated ResNet""" 2 | import math 3 | import torch 4 | import torch.utils.model_zoo as model_zoo 5 | import torch.nn as nn 6 | 7 | 8 | model_urls = { 9 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 10 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 11 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 12 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 13 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 14 | } 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1): 18 | "3x3 convolution with padding" 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 20 | padding=1, bias=False) 21 | 22 | 23 | class BasicBlock(nn.Module): 24 | """ResNet BasicBlock 25 | """ 26 | expansion = 1 27 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, previous_dilation=1, 28 | norm_layer=None): 29 | super(BasicBlock, self).__init__() 30 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, 31 | padding=dilation, dilation=dilation, bias=False) 32 | self.bn1 = norm_layer(planes) 33 | self.relu = nn.ReLU(inplace=True) 34 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, 35 | padding=previous_dilation, dilation=previous_dilation, bias=False) 36 | self.bn2 = norm_layer(planes) 37 | self.downsample = downsample 38 | self.stride = stride 39 | 40 | def forward(self, x): 41 | residual = x 42 | 43 | out = self.conv1(x) 44 | out = self.bn1(out) 45 | out = self.relu(out) 46 | 47 | out = self.conv2(out) 48 | out = self.bn2(out) 49 | 50 | if self.downsample is not None: 51 | residual = self.downsample(x) 52 | 53 | out += residual 54 | out = self.relu(out) 55 | 56 | return out 57 | 58 | 59 | class Bottleneck(nn.Module): 60 | """ResNet Bottleneck 61 | """ 62 | # pylint: disable=unused-argument 63 | expansion = 4 64 | def __init__(self, inplanes, planes, stride=1, dilation=1, 65 | downsample=None, previous_dilation=1, norm_layer=None): 66 | super(Bottleneck, self).__init__() 67 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 68 | self.bn1 = norm_layer(planes) 69 | self.conv2 = nn.Conv2d( 70 | planes, planes, kernel_size=3, stride=stride, 71 | padding=dilation, dilation=dilation, bias=False) 72 | self.bn2 = norm_layer(planes) 73 | self.conv3 = nn.Conv2d( 74 | planes, planes * 4, kernel_size=1, bias=False) 75 | self.bn3 = norm_layer(planes * 4) 76 | self.relu = nn.ReLU(inplace=True) 77 | self.downsample = downsample 78 | self.dilation = dilation 79 | self.stride = stride 80 | 81 | def _sum_each(self, x, y): 82 | assert(len(x) == len(y)) 83 | z = [] 84 | for i in range(len(x)): 85 | z.append(x[i]+y[i]) 86 | return z 87 | 88 | def forward(self, x): 89 | residual = x 90 | 91 | out = self.conv1(x) 92 | out = self.bn1(out) 93 | out = self.relu(out) 94 | 95 | out = self.conv2(out) 96 | out = self.bn2(out) 97 | out = self.relu(out) 98 | 99 | out = self.conv3(out) 100 | out = self.bn3(out) 101 | 102 | if self.downsample is not None: 103 | residual = self.downsample(x) 104 | 105 | out += residual 106 | out = self.relu(out) 107 | 108 | return out 109 | 110 | 111 | class DilatedResNet(nn.Module): 112 | """Dilated Pre-trained ResNet Model, which preduces the stride of 8 featuremaps at conv5. 113 | 114 | Parameters 115 | ---------- 116 | block : Block 117 | Class for the residual block. Options are BasicBlockV1, BottleneckV1. 118 | layers : list of int 119 | Numbers of layers in each block 120 | classes : int, default 1000 121 | Number of classification classes. 122 | dilated : bool, default False 123 | Applying dilation strategy to pretrained ResNet yielding a stride-8 model, 124 | typically used in Semantic Segmentation. 125 | norm_layer : object 126 | Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`; 127 | for Synchronized Cross-GPU BachNormalization). 128 | 129 | Reference: 130 | 131 | - He, Kaiming, et al. "Deep residual learning for image recognition." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016. 132 | 133 | - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions." 134 | """ 135 | # pylint: disable=unused-variable 136 | def __init__(self, block, layers, num_classes=1000, dilated=True, norm_layer=nn.BatchNorm2d): 137 | self.inplanes = 64 138 | super(DilatedResNet, self).__init__() 139 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 140 | bias=False) 141 | self.bn1 = norm_layer(64) 142 | self.relu = nn.ReLU(inplace=True) 143 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 144 | self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer) 145 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer) 146 | if dilated: 147 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, 148 | dilation=2, norm_layer=norm_layer) 149 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, 150 | dilation=4, norm_layer=norm_layer) 151 | else: 152 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 153 | norm_layer=norm_layer) 154 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 155 | norm_layer=norm_layer) 156 | self.avgpool = nn.AvgPool2d(7) 157 | self.fc = nn.Linear(512 * block.expansion, num_classes) 158 | 159 | for m in self.modules(): 160 | if isinstance(m, nn.Conv2d): 161 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 162 | m.weight.data.normal_(0, math.sqrt(2. / n)) 163 | elif isinstance(m, norm_layer): 164 | m.weight.data.fill_(1) 165 | m.bias.data.zero_() 166 | 167 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=None): 168 | downsample = None 169 | if stride != 1 or self.inplanes != planes * block.expansion: 170 | downsample = nn.Sequential( 171 | nn.Conv2d(self.inplanes, planes * block.expansion, 172 | kernel_size=1, stride=stride, bias=False), 173 | norm_layer(planes * block.expansion), 174 | ) 175 | 176 | layers = [] 177 | if dilation == 1 or dilation == 2: 178 | layers.append(block(self.inplanes, planes, stride, dilation=1, 179 | downsample=downsample, previous_dilation=dilation, norm_layer=norm_layer)) 180 | elif dilation == 4: 181 | layers.append(block(self.inplanes, planes, stride, dilation=2, 182 | downsample=downsample, previous_dilation=dilation, norm_layer=norm_layer)) 183 | else: 184 | raise RuntimeError("=> unknown dilation size: {}".format(dilation)) 185 | 186 | self.inplanes = planes * block.expansion 187 | for i in range(1, blocks): 188 | layers.append(block(self.inplanes, planes, dilation=dilation, previous_dilation=dilation, 189 | norm_layer=norm_layer)) 190 | 191 | return nn.Sequential(*layers) 192 | 193 | def forward(self, x): 194 | x = self.conv1(x) 195 | x = self.bn1(x) 196 | x = self.relu(x) 197 | x = self.maxpool(x) 198 | 199 | x = self.layer1(x) 200 | x = self.layer2(x) 201 | x = self.layer3(x) 202 | x = self.layer4(x) 203 | 204 | x = self.avgpool(x) 205 | x = x.view(x.size(0), -1) 206 | x = self.fc(x) 207 | 208 | return x 209 | 210 | 211 | def resnet18(pretrained=False, **kwargs): 212 | """Constructs a ResNet-18 model. 213 | 214 | Args: 215 | pretrained (bool): If True, returns a model pre-trained on ImageNet 216 | """ 217 | model = DilatedResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 218 | if pretrained: 219 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 220 | return model 221 | 222 | 223 | def dilated_resnet34(pretrained=False, **kwargs): 224 | """Constructs a ResNet-34 model. 225 | 226 | Args: 227 | pretrained (bool): If True, returns a model pre-trained on ImageNet 228 | """ 229 | model = DilatedResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 230 | if pretrained: 231 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 232 | return model 233 | 234 | 235 | def resnet50(pretrained=False, root='~/.encoding/models', **kwargs): 236 | """Constructs a ResNet-50 model. 237 | 238 | Args: 239 | pretrained (bool): If True, returns a model pre-trained on ImageNet 240 | """ 241 | model = DilatedResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 242 | if pretrained: 243 | from ..models.model_store import get_model_file 244 | model.load_state_dict(torch.load( 245 | get_model_file('resnet50', root=root)), strict=False) 246 | return model 247 | 248 | 249 | def resnet101(pretrained=False, root='~/.encoding/models', **kwargs): 250 | """Constructs a ResNet-101 model. 251 | 252 | Args: 253 | pretrained (bool): If True, returns a model pre-trained on ImageNet 254 | """ 255 | model = DilatedResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 256 | if pretrained: 257 | from ..models.model_store import get_model_file 258 | model.load_state_dict(torch.load( 259 | get_model_file('resnet101', root=root)), strict=False) 260 | return model 261 | 262 | 263 | def resnet152(pretrained=False, root='~/.encoding/models', **kwargs): 264 | """Constructs a ResNet-152 model. 265 | 266 | Args: 267 | pretrained (bool): If True, returns a model pre-trained on ImageNet 268 | """ 269 | model = DilatedResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 270 | if pretrained: 271 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 272 | return model -------------------------------------------------------------------------------- /lib/models/duc_hdc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torchvision import models 4 | from torch.nn import functional as F 5 | 6 | 7 | class _DenseUpsamplingConvModule(nn.Module): 8 | def __init__(self, down_factor, in_dim, num_classes): 9 | super(_DenseUpsamplingConvModule, self).__init__() 10 | upsample_dim = (down_factor ** 2) * num_classes 11 | self.conv = nn.Conv2d(in_dim, upsample_dim, kernel_size=3, padding=1) 12 | self.bn = nn.BatchNorm2d(upsample_dim) 13 | self.relu = nn.ReLU(inplace=True) 14 | self.pixel_shuffle = nn.PixelShuffle(down_factor) 15 | 16 | def forward(self, x): 17 | x = self.conv(x) 18 | x = self.bn(x) 19 | x = self.relu(x) 20 | x = self.pixel_shuffle(x) 21 | return x 22 | 23 | 24 | class ResNetDUC(nn.Module): 25 | # the size of image should be multiple of 8 26 | def __init__(self, num_classes, pretrained=True): 27 | super(ResNetDUC, self).__init__() 28 | resnet = models.resnet152(pretrained=pretrained) 29 | 30 | self.num_classes = num_classes 31 | self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) 32 | self.layer1 = resnet.layer1 33 | self.layer2 = resnet.layer2 34 | self.layer3 = resnet.layer3 35 | self.layer4 = resnet.layer4 36 | 37 | for n, m in self.layer3.named_modules(): 38 | if 'conv2' in n: 39 | m.dilation = (2, 2) 40 | m.padding = (2, 2) 41 | m.stride = (1, 1) 42 | elif 'downsample.0' in n: 43 | m.stride = (1, 1) 44 | for n, m in self.layer4.named_modules(): 45 | if 'conv2' in n: 46 | m.dilation = (4, 4) 47 | m.padding = (4, 4) 48 | m.stride = (1, 1) 49 | elif 'downsample.0' in n: 50 | m.stride = (1, 1) 51 | 52 | self.duc = _DenseUpsamplingConvModule(8, 2048, num_classes) 53 | 54 | def forward(self, x): 55 | x = self.layer0(x) 56 | x = self.layer1(x) 57 | x = self.layer2(x) 58 | x = self.layer3(x) 59 | x = self.layer4(x) 60 | x = self.duc(x) 61 | return x 62 | 63 | 64 | class ResNetDUCHDC(nn.Module): 65 | # the size of image should be multiple of 8 66 | def __init__(self, num_classes, pretrained=True): 67 | super(ResNetDUCHDC, self).__init__() 68 | resnet = models.resnet152(pretrained=pretrained) 69 | 70 | self.num_classes = num_classes 71 | self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) 72 | self.layer1 = resnet.layer1 73 | self.layer2 = resnet.layer2 74 | self.layer3 = resnet.layer3 75 | self.layer4 = resnet.layer4 76 | 77 | for n, m in self.layer3.named_modules(): 78 | if 'conv2' in n or 'downsample.0' in n: 79 | m.stride = (1, 1) 80 | for n, m in self.layer4.named_modules(): 81 | if 'conv2' in n or 'downsample.0' in n: 82 | m.stride = (1, 1) 83 | layer3_group_config = [1, 2, 5, 9] 84 | for idx in range(len(self.layer3)): 85 | self.layer3[idx].conv2.dilation = (layer3_group_config[idx % 4], layer3_group_config[idx % 4]) 86 | self.layer3[idx].conv2.padding = (layer3_group_config[idx % 4], layer3_group_config[idx % 4]) 87 | layer4_group_config = [5, 9, 17] 88 | for idx in range(len(self.layer4)): 89 | self.layer4[idx].conv2.dilation = (layer4_group_config[idx], layer4_group_config[idx]) 90 | self.layer4[idx].conv2.padding = (layer4_group_config[idx], layer4_group_config[idx]) 91 | 92 | self.duc = _DenseUpsamplingConvModule(8, 2048, num_classes) 93 | 94 | def forward(self, x): 95 | x = self.layer0(x) 96 | x = self.layer1(x) 97 | x = self.layer2(x) 98 | x = self.layer3(x) 99 | x = self.layer4(x) 100 | x = self.duc(x) 101 | 102 | return x -------------------------------------------------------------------------------- /lib/models/gcn152.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torchvision import models 5 | 6 | 7 | # many are borrowed from https://github.com/ycszen/pytorch-ss/blob/master/gcn.py 8 | # https://raw.githubusercontent.com/ZijunDeng/pytorch-semantic-segmentation/master/models/gcn.py 9 | class _GlobalConvModule(nn.Module): 10 | def __init__(self, in_dim, out_dim, kernel_size): 11 | super(_GlobalConvModule, self).__init__() 12 | pad0 = (kernel_size[0] - 1) // 2 13 | pad1 = (kernel_size[1] - 1) // 2 14 | # kernel size had better be odd number so as to avoid alignment error 15 | super(_GlobalConvModule, self).__init__() 16 | 17 | self.pre_drop = nn.Dropout2d(p=0.1) 18 | self.conv_l1 = nn.Conv2d(in_dim, out_dim, kernel_size=(kernel_size[0], 1), 19 | padding=(pad0, 0)) 20 | self.conv_l2 = nn.Conv2d(out_dim, out_dim, kernel_size=(1, kernel_size[1]), 21 | padding=(0, pad1)) 22 | self.conv_r1 = nn.Conv2d(in_dim, out_dim, kernel_size=(1, kernel_size[1]), 23 | padding=(0, pad1)) 24 | self.conv_r2 = nn.Conv2d(out_dim, out_dim, kernel_size=(kernel_size[0], 1), 25 | padding=(pad0, 0)) 26 | 27 | def forward(self, x): 28 | x = self.pre_drop(x) 29 | x_l = self.conv_l1(x) 30 | x_l = self.conv_l2(x_l) 31 | x_r = self.conv_r1(x) 32 | x_r = self.conv_r2(x_r) 33 | x = x_l + x_r 34 | return x 35 | 36 | 37 | class _BoundaryRefineModule(nn.Module): 38 | def __init__(self, dim): 39 | super(_BoundaryRefineModule, self).__init__() 40 | self.relu = nn.ReLU(inplace=True) 41 | self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1) 42 | self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1) 43 | 44 | def forward(self, x): 45 | residual = self.conv1(x) 46 | residual = self.relu(residual) 47 | residual = self.conv2(residual) 48 | out = x + residual 49 | return out 50 | 51 | 52 | def initialize_weights(*models): 53 | for model in models: 54 | for module in model.modules(): 55 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear): 56 | nn.init.kaiming_normal_(module.weight) 57 | if module.bias is not None: 58 | module.bias.data.zero_() 59 | elif isinstance(module, nn.BatchNorm2d): 60 | module.weight.data.fill_(1) 61 | module.bias.data.zero_() 62 | 63 | 64 | class GCN34(nn.Module): 65 | def __init__(self, num_classes, input_size, pretrained=True): 66 | super(GCN34, self).__init__() 67 | self.input_size = input_size 68 | self.num_classes = num_classes 69 | resnet = models.resnet34(pretrained) 70 | 71 | self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu) 72 | self.layer1 = nn.Sequential(resnet.maxpool, resnet.layer1) 73 | self.layer2 = resnet.layer2 74 | self.layer3 = resnet.layer3 75 | self.layer4 = resnet.layer4 76 | 77 | self.gcm1 = _GlobalConvModule(512, num_classes, (7, 7)) 78 | self.gcm2 = _GlobalConvModule(256, num_classes, (7, 7)) 79 | self.gcm3 = _GlobalConvModule(128, num_classes, (7, 7)) 80 | self.gcm4 = _GlobalConvModule(64, num_classes, (7, 7)) 81 | 82 | self.brm1 = _BoundaryRefineModule(num_classes) 83 | self.brm2 = _BoundaryRefineModule(num_classes) 84 | self.brm3 = _BoundaryRefineModule(num_classes) 85 | self.brm4 = _BoundaryRefineModule(num_classes) 86 | self.brm5 = _BoundaryRefineModule(num_classes) 87 | self.brm6 = _BoundaryRefineModule(num_classes) 88 | self.brm7 = _BoundaryRefineModule(num_classes) 89 | self.brm8 = _BoundaryRefineModule(num_classes) 90 | self.brm9 = _BoundaryRefineModule(num_classes) 91 | 92 | initialize_weights(self.gcm1, self.gcm2, self.gcm3, self.gcm4, self.brm1, self.brm2, self.brm3, 93 | self.brm4, self.brm5, self.brm6, self.brm7, self.brm8, self.brm9) 94 | 95 | def forward(self, x): 96 | # if x: 512 97 | fm0 = self.layer0(x) # 256 98 | fm1 = self.layer1(fm0) # 128 99 | fm2 = self.layer2(fm1) # 64 100 | fm3 = self.layer3(fm2) # 32 101 | fm4 = self.layer4(fm3) # 16 102 | 103 | gcfm1 = self.brm1(self.gcm1(fm4)) # 16 104 | gcfm2 = self.brm2(self.gcm2(fm3)) # 32 105 | gcfm3 = self.brm3(self.gcm3(fm2)) # 64 106 | gcfm4 = self.brm4(self.gcm4(fm1)) # 128 107 | 108 | fs1 = self.brm5(F.upsample(gcfm1, fm3.size()[2:], mode='bilinear', align_corners=True) + gcfm2) # 32 109 | fs2 = self.brm6(F.upsample(fs1, fm2.size()[2:], mode='bilinear', align_corners=True) + gcfm3) # 64 110 | fs3 = self.brm7(F.upsample(fs2, fm1.size()[2:], mode='bilinear', align_corners=True) + gcfm4) # 128 111 | fs4 = self.brm8(F.upsample(fs3, fm0.size()[2:], mode='bilinear', align_corners=True)) # 256 112 | out = self.brm9(F.upsample(fs4, self.input_size, mode='bilinear', align_corners=True)) # 512 113 | 114 | return out 115 | 116 | 117 | class GCN152(nn.Module): 118 | def __init__(self, num_classes, input_size, pretrained=True): 119 | super(GCN152, self).__init__() 120 | self.input_size = input_size 121 | self.num_classes = num_classes 122 | resnet = models.resnet152(pretrained) 123 | 124 | self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu) 125 | self.layer1 = nn.Sequential(resnet.maxpool, resnet.layer1) 126 | self.layer2 = resnet.layer2 127 | self.layer3 = resnet.layer3 128 | self.layer4 = resnet.layer4 129 | 130 | self.gcm1 = _GlobalConvModule(2048, num_classes, (7, 7)) 131 | self.gcm2 = _GlobalConvModule(1024, num_classes, (7, 7)) 132 | self.gcm3 = _GlobalConvModule(512, num_classes, (7, 7)) 133 | self.gcm4 = _GlobalConvModule(256, num_classes, (7, 7)) 134 | 135 | self.brm1 = _BoundaryRefineModule(num_classes) 136 | self.brm2 = _BoundaryRefineModule(num_classes) 137 | self.brm3 = _BoundaryRefineModule(num_classes) 138 | self.brm4 = _BoundaryRefineModule(num_classes) 139 | self.brm5 = _BoundaryRefineModule(num_classes) 140 | self.brm6 = _BoundaryRefineModule(num_classes) 141 | self.brm7 = _BoundaryRefineModule(num_classes) 142 | self.brm8 = _BoundaryRefineModule(num_classes) 143 | self.brm9 = _BoundaryRefineModule(num_classes) 144 | 145 | initialize_weights(self.gcm1, self.gcm2, self.gcm3, self.gcm4, self.brm1, self.brm2, self.brm3, 146 | self.brm4, self.brm5, self.brm6, self.brm7, self.brm8, self.brm9) 147 | 148 | def forward(self, x): 149 | # if x: 512 150 | fm0 = self.layer0(x) # 256 151 | fm1 = self.layer1(fm0) # 128 152 | fm2 = self.layer2(fm1) # 64 153 | fm3 = self.layer3(fm2) # 32 154 | fm4 = self.layer4(fm3) # 16 155 | 156 | gcfm1 = self.brm1(self.gcm1(fm4)) # 16 157 | gcfm2 = self.brm2(self.gcm2(fm3)) # 32 158 | gcfm3 = self.brm3(self.gcm3(fm2)) # 64 159 | gcfm4 = self.brm4(self.gcm4(fm1)) # 128 160 | 161 | fs1 = self.brm5(F.upsample(gcfm1, fm3.size()[2:], mode='bilinear', align_corners=True) + gcfm2) # 32 162 | fs2 = self.brm6(F.upsample(fs1, fm2.size()[2:], mode='bilinear', align_corners=True) + gcfm3) # 64 163 | fs3 = self.brm7(F.upsample(fs2, fm1.size()[2:], mode='bilinear', align_corners=True) + gcfm4) # 128 164 | fs4 = self.brm8(F.upsample(fs3, fm0.size()[2:], mode='bilinear', align_corners=True)) # 256 165 | out = self.brm9(F.upsample(fs4, self.input_size, mode='bilinear', align_corners=True)) # 512 166 | 167 | return out 168 | -------------------------------------------------------------------------------- /lib/models/linknet.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torchvision import models 3 | 4 | 5 | class DecoderBlockLinkNet(nn.Module): 6 | def __init__(self, in_channels, n_filters): 7 | super().__init__() 8 | 9 | from lib.modules.abn import InPlaceABN 10 | 11 | # B, C, H, W -> B, C/4, H, W 12 | self.conv1 = nn.Conv2d(in_channels, in_channels // 4, 1) 13 | self.abn1 = InPlaceABN(in_channels // 4) 14 | 15 | # B, C/4, H, W -> B, C/4, 2 * H, 2 * W 16 | self.deconv2 = nn.ConvTranspose2d(in_channels // 4, in_channels // 4, kernel_size=4, stride=2, padding=1, output_padding=0) 17 | self.abn2 = InPlaceABN(in_channels // 4) 18 | 19 | # B, C/4, H, W -> B, C, H, W 20 | self.conv3 = nn.Conv2d(in_channels // 4, n_filters, 1) 21 | self.abn3 = InPlaceABN(n_filters) 22 | 23 | def forward(self, x): 24 | x = self.conv1(x) 25 | x = self.abn1(x) 26 | x = self.deconv2(x) 27 | x = self.abn2(x) 28 | x = self.conv3(x) 29 | x = self.abn3(x) 30 | return x 31 | 32 | 33 | class LinkNet34(nn.Module): 34 | def __init__(self, num_classes=1, num_channels=3, pretrained=True): 35 | super().__init__() 36 | assert num_channels == 3 37 | self.num_classes = num_classes 38 | filters = [64, 128, 256, 512] 39 | resnet = models.resnet34(pretrained=pretrained) 40 | 41 | self.firstconv = resnet.conv1 42 | self.firstbn = resnet.bn1 43 | self.firstrelu = resnet.relu 44 | self.firstmaxpool = resnet.maxpool 45 | self.encoder1 = resnet.layer1 46 | self.encoder2 = resnet.layer2 47 | self.encoder3 = resnet.layer3 48 | self.encoder4 = resnet.layer4 49 | 50 | # Decoder 51 | self.decoder4 = DecoderBlockLinkNet(filters[3], filters[2]) 52 | self.decoder3 = DecoderBlockLinkNet(filters[2], filters[1]) 53 | self.decoder2 = DecoderBlockLinkNet(filters[1], filters[0]) 54 | self.decoder1 = DecoderBlockLinkNet(filters[0], filters[0]) 55 | 56 | # Final Classifier 57 | self.finaldrop1 = nn.Dropout2d(p=0.5) 58 | self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 3, stride=2) 59 | self.finalrelu1 = nn.LeakyReLU(inplace=True) 60 | self.finalconv2 = nn.Conv2d(32, 32, 3) 61 | self.finalrelu2 = nn.LeakyReLU(inplace=True) 62 | self.finalconv3 = nn.Conv2d(32, num_classes, 2, padding=1) 63 | 64 | # noinspection PyCallingNonCallable 65 | def forward(self, x): 66 | # Encoder 67 | x = self.firstconv(x) 68 | x = self.firstbn(x) 69 | x = self.firstrelu(x) 70 | x = self.firstmaxpool(x) 71 | e1 = self.encoder1(x) 72 | e2 = self.encoder2(e1) 73 | e3 = self.encoder3(e2) 74 | e4 = self.encoder4(e3) 75 | 76 | # Decoder with Skip Connections 77 | d4 = self.decoder4(e4) + e3 78 | d3 = self.decoder3(d4) + e2 79 | d2 = self.decoder2(d3) + e1 80 | d1 = self.decoder1(d2) 81 | 82 | # Final Classification 83 | d1 = self.finaldrop1(d1) # Added dropout 84 | f1 = self.finaldeconv1(d1) 85 | f2 = self.finalrelu1(f1) 86 | f3 = self.finalconv2(f2) 87 | f4 = self.finalrelu2(f3) 88 | f5 = self.finalconv3(f4) 89 | 90 | return f5 91 | -------------------------------------------------------------------------------- /lib/models/linknext.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://arxiv.org/abs/1611.05431 3 | official code: 4 | https://github.com/facebookresearch/ResNeXt 5 | """ 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from collections import OrderedDict 10 | 11 | from torch.autograd import Variable 12 | 13 | """ 14 | NOTICE: 15 | BasicBlock_B is not implemented 16 | BasicBlock_C is recommendation 17 | The full architecture consist of BasicBlock_A is not implemented. 18 | """ 19 | 20 | 21 | class ResBottleBlock(nn.Module): 22 | 23 | def __init__(self, in_planes, bottleneck_width=4, stride=1, expansion=1): 24 | super(ResBottleBlock, self).__init__() 25 | self.conv0 = nn.Conv2d(in_planes, bottleneck_width, 1, stride=1, bias=False) 26 | self.bn0 = nn.BatchNorm2d(bottleneck_width) 27 | self.conv1 = nn.Conv2d(bottleneck_width, bottleneck_width, 3, stride=stride, padding=1, bias=False) 28 | self.bn1 = nn.BatchNorm2d(bottleneck_width) 29 | self.conv2 = nn.Conv2d(bottleneck_width, expansion * in_planes, 1, bias=False) 30 | self.bn2 = nn.BatchNorm2d(expansion * in_planes) 31 | 32 | self.shortcut = nn.Sequential() 33 | if stride != 1 or expansion != 1: 34 | self.shortcut = nn.Sequential( 35 | nn.Conv2d(in_planes, in_planes * expansion, 1, stride=stride, bias=False) 36 | ) 37 | 38 | def forward(self, x): 39 | out = F.relu(self.bn0(self.conv0(x))) 40 | out = F.relu(self.bn1(self.conv1(out))) 41 | out = self.bn2(self.conv2(out)) 42 | out += self.shortcut(x) 43 | out = F.relu(out) 44 | return out 45 | 46 | 47 | class BasicBlock_A(nn.Module): 48 | def __init__(self, in_planes, num_paths=32, bottleneck_width=4, expansion=1, stride=1): 49 | super(BasicBlock_A, self).__init__() 50 | self.num_paths = num_paths 51 | for i in range(num_paths): 52 | setattr(self, 'path' + str(i), self._make_path(in_planes, bottleneck_width, stride, expansion)) 53 | 54 | # self.paths=self._make_path(in_planes,bottleneck_width,stride,expansion) 55 | self.conv0 = nn.Conv2d(in_planes * expansion, expansion * in_planes, 1, stride=1, bias=False) 56 | self.bn0 = nn.BatchNorm2d(in_planes * expansion) 57 | 58 | self.shortcut = nn.Sequential() 59 | if stride != 1 or expansion != 1: 60 | self.shortcut = nn.Sequential( 61 | nn.Conv2d(in_planes, in_planes * expansion, 1, stride=stride, bias=False) 62 | ) 63 | 64 | def forward(self, x): 65 | out = self.path0(x) 66 | for i in range(1, self.num_paths): 67 | if hasattr(self, 'path' + str(i)): 68 | out + getattr(self, 'path' + str(i))(x) 69 | # out+=self.paths(x) 70 | # getattr 71 | # out = torch.sum(out, dim=1) 72 | out = self.bn0(out) 73 | out += self.shortcut(x) 74 | out = F.relu(out) 75 | return out 76 | 77 | def _make_path(self, in_planes, bottleneck_width, stride, expansion): 78 | layers = [] 79 | layers.append(ResBottleBlock( 80 | in_planes, bottleneck_width, stride, expansion)) 81 | return nn.Sequential(*layers) 82 | 83 | 84 | class BasicBlock_C(nn.Module): 85 | """ 86 | increasing cardinality is a more effective way of 87 | gaining accuracy than going deeper or wider 88 | """ 89 | 90 | def __init__(self, in_planes, bottleneck_width=4, cardinality=32, stride=1, expansion=2): 91 | super(BasicBlock_C, self).__init__() 92 | inner_width = cardinality * bottleneck_width 93 | self.expansion = expansion 94 | self.basic = nn.Sequential(OrderedDict( 95 | [ 96 | ('conv1_0', nn.Conv2d(in_planes, inner_width, 1, stride=1, bias=False)), 97 | ('bn1', nn.BatchNorm2d(inner_width)), 98 | ('act0', nn.ReLU()), 99 | ('conv3_0', 100 | nn.Conv2d(inner_width, inner_width, 3, stride=stride, padding=1, groups=cardinality, bias=False)), 101 | ('bn2', nn.BatchNorm2d(inner_width)), 102 | ('act1', nn.ReLU()), 103 | ('conv1_1', nn.Conv2d(inner_width, inner_width * self.expansion, 1, stride=1, bias=False)), 104 | ('bn3', nn.BatchNorm2d(inner_width * self.expansion)) 105 | ] 106 | )) 107 | self.shortcut = nn.Sequential() 108 | if stride != 1 or in_planes != inner_width * self.expansion: 109 | self.shortcut = nn.Sequential( 110 | nn.Conv2d(in_planes, inner_width * self.expansion, 1, stride=stride, bias=False) 111 | ) 112 | self.bn0 = nn.BatchNorm2d(self.expansion * inner_width) 113 | 114 | def forward(self, x): 115 | out = self.basic(x) 116 | out += self.shortcut(x) 117 | out = F.relu(self.bn0(out)) 118 | return out 119 | 120 | 121 | class ResNeXt(nn.Module): 122 | def __init__(self, num_blocks, cardinality, bottleneck_width, expansion=2, num_classes=10): 123 | super(ResNeXt, self).__init__() 124 | self.cardinality = cardinality 125 | self.bottleneck_width = bottleneck_width 126 | self.in_planes = 64 127 | self.expansion = expansion 128 | self.n_out_filters = [0, 0, 0, 0] 129 | 130 | self.conv0 = nn.Conv2d(3, self.in_planes, kernel_size=3, stride=1, padding=1) 131 | self.bn0 = nn.BatchNorm2d(self.in_planes) 132 | self.pool0 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 133 | self.layer1, self.n_out_filters[0] = self._make_layer(num_blocks[0], 1) 134 | self.layer2, self.n_out_filters[1] = self._make_layer(num_blocks[1], 2) 135 | self.layer3, self.n_out_filters[2] = self._make_layer(num_blocks[2], 2) 136 | self.layer4, self.n_out_filters[3] = self._make_layer(num_blocks[3], 2) 137 | self.linear = nn.Linear(self.cardinality * self.bottleneck_width, num_classes) 138 | 139 | def forward(self, x): 140 | out = F.relu(self.bn0(self.conv0(x))) 141 | # out = self.pool0(out) 142 | out = self.layer1(out) 143 | out = self.layer2(out) 144 | out = self.layer3(out) 145 | out = self.layer4(out) 146 | out = F.avg_pool2d(out, 4) 147 | out = out.view(out.size(0), -1) 148 | out = self.linear(out) 149 | return out 150 | 151 | def _make_layer(self, num_blocks, stride): 152 | strides = [stride] + [1] * (num_blocks - 1) 153 | layers = [] 154 | for stride in strides: 155 | layers.append(BasicBlock_C(self.in_planes, self.bottleneck_width, self.cardinality, stride, self.expansion)) 156 | self.in_planes = self.expansion * self.bottleneck_width * self.cardinality 157 | self.bottleneck_width *= 2 158 | return nn.Sequential(*layers), self.in_planes 159 | 160 | 161 | def resnext26_2x64d(): 162 | return ResNeXt(num_blocks=[2, 2, 2, 2], cardinality=2, bottleneck_width=64) 163 | 164 | 165 | def resnext26_4x32d(): 166 | return ResNeXt(num_blocks=[2, 2, 2, 2], cardinality=4, bottleneck_width=32) 167 | 168 | 169 | def resnext26_8x16d(): 170 | return ResNeXt(num_blocks=[2, 2, 2, 2], cardinality=8, bottleneck_width=16) 171 | 172 | 173 | def resnext26_16x8d(): 174 | return ResNeXt(num_blocks=[2, 2, 2, 2], cardinality=16, bottleneck_width=8) 175 | 176 | 177 | def resnext26_32x4d(): 178 | return ResNeXt(num_blocks=[2, 2, 2, 2], cardinality=32, bottleneck_width=4) 179 | 180 | 181 | def resnext26_64x2d(): 182 | return ResNeXt(num_blocks=[2, 2, 2, 2], cardinality=32, bottleneck_width=4) 183 | 184 | 185 | def resnext50_2x64d(): 186 | return ResNeXt(num_blocks=[3, 4, 6, 3], cardinality=2, bottleneck_width=64) 187 | 188 | 189 | def resnext50_32x4d(): 190 | return ResNeXt(num_blocks=[3, 4, 6, 3], cardinality=32, bottleneck_width=4) 191 | 192 | 193 | class DecoderBlockLinkNet(nn.Module): 194 | def __init__(self, in_channels, n_filters, drop_rate=0.): 195 | super().__init__() 196 | 197 | self.relu = nn.ReLU(inplace=True) 198 | 199 | # B, C, H, W -> B, C/4, H, W 200 | self.conv1 = nn.Conv2d(in_channels, in_channels // 4, 1) 201 | self.norm1 = nn.BatchNorm2d(in_channels // 4) 202 | 203 | # B, C/4, H, W -> B, C/4, 2 * H, 2 * W 204 | self.deconv2 = nn.ConvTranspose2d(in_channels // 4, in_channels // 4, kernel_size=4, 205 | stride=2, padding=1, output_padding=0) 206 | self.norm2 = nn.BatchNorm2d(in_channels // 4) 207 | 208 | # B, C/4, H, W -> B, C, H, W 209 | self.conv3 = nn.Conv2d(in_channels // 4, n_filters, 1) 210 | self.norm3 = nn.BatchNorm2d(n_filters) 211 | self.drop = nn.Dropout2d(drop_rate) 212 | 213 | def forward(self, x): 214 | x = self.conv1(x) 215 | x = self.norm1(x) 216 | x = self.relu(x) 217 | x = self.deconv2(x) 218 | x = self.norm2(x) 219 | x = self.relu(x) 220 | x = self.conv3(x) 221 | x = self.norm3(x) 222 | x = self.relu(x) 223 | x = self.drop(x) 224 | return x 225 | 226 | 227 | class LinkNext(nn.Module): 228 | def __init__(self, num_classes=1, num_channels=3): 229 | super().__init__() 230 | assert num_channels == 3 231 | self.num_classes = num_classes 232 | resnext = resnext26_2x64d() 233 | filters = resnext.n_out_filters 234 | 235 | self.encoder0 = nn.Sequential(resnext.conv0, resnext.bn0, nn.ReLU()) 236 | 237 | # self.firstmaxpool = resnext.maxpool 238 | self.encoder1 = resnext.layer1 239 | self.encoder2 = resnext.layer2 240 | self.encoder3 = resnext.layer3 241 | self.encoder4 = resnext.layer4 242 | 243 | # Decoder 244 | self.decoder4 = DecoderBlockLinkNet(filters[3], filters[2], drop_rate=0.1) 245 | self.decoder3 = DecoderBlockLinkNet(filters[2], filters[1], drop_rate=0.2) 246 | self.decoder2 = DecoderBlockLinkNet(filters[1], filters[0], drop_rate=0.3) 247 | self.decoder1 = DecoderBlockLinkNet(filters[0], filters[0], drop_rate=0.4) 248 | 249 | # Final Classifier 250 | self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 3, stride=2) 251 | self.finalrelu1 = nn.ReLU(inplace=True) 252 | self.finalconv2 = nn.Conv2d(32, 32, 3) 253 | self.finalrelu2 = nn.ReLU(inplace=True) 254 | self.finalconv3 = nn.Conv2d(32, num_classes, 2, padding=1) 255 | 256 | def forward(self, x): 257 | # Encoder 258 | x = self.encoder0(x) 259 | 260 | e1 = self.encoder1(x) 261 | e2 = self.encoder2(e1) 262 | e3 = self.encoder3(e2) 263 | e4 = self.encoder4(e3) 264 | 265 | # Decoder with Skip Connections 266 | d4 = self.decoder4(e4) + e3 267 | d3 = self.decoder3(d4) + e2 268 | d2 = self.decoder2(d3) + e1 269 | d1 = self.decoder1(d2) 270 | 271 | # Final Classification 272 | f1 = self.finaldeconv1(d1) 273 | f2 = self.finalrelu1(f1) 274 | f3 = self.finalconv2(f2) 275 | f4 = self.finalrelu2(f3) 276 | f5 = self.finalconv3(f4) 277 | 278 | return f5 279 | -------------------------------------------------------------------------------- /lib/models/psp_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torchvision import models 5 | import numpy as np 6 | 7 | 8 | def initialize_weights(*models): 9 | for model in models: 10 | for module in model.modules(): 11 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear): 12 | nn.init.kaiming_normal_(module.weight) 13 | if module.bias is not None: 14 | module.bias.data.zero_() 15 | elif isinstance(module, nn.BatchNorm2d): 16 | module.weight.data.fill_(1) 17 | module.bias.data.zero_() 18 | 19 | 20 | class _PyramidPoolingModule(nn.Module): 21 | def __init__(self, in_dim, reduction_dim, setting): 22 | super(_PyramidPoolingModule, self).__init__() 23 | self.features = [] 24 | for s in setting: 25 | self.features.append(nn.Sequential( 26 | nn.AdaptiveAvgPool2d(s), 27 | nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False), 28 | nn.BatchNorm2d(reduction_dim, momentum=.95), 29 | nn.ReLU(inplace=True) 30 | )) 31 | self.features = nn.ModuleList(self.features) 32 | 33 | def forward(self, x): 34 | x_size = x.size() 35 | out = [x] 36 | for f in self.features: 37 | out.append(F.upsample(f(x), x_size[2:], mode='bilinear')) 38 | out = torch.cat(out, 1) 39 | return out 40 | 41 | 42 | class PSPNet(nn.Module): 43 | def __init__(self, num_classes, pretrained=True, use_aux=True): 44 | super(PSPNet, self).__init__() 45 | self.use_aux = use_aux 46 | self.num_classes = num_classes 47 | 48 | resnet = models.resnet101(pretrained) 49 | 50 | self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) 51 | self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4 52 | 53 | for n, m in self.layer3.named_modules(): 54 | if 'conv2' in n: 55 | m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) 56 | elif 'downsample.0' in n: 57 | m.stride = (1, 1) 58 | for n, m in self.layer4.named_modules(): 59 | if 'conv2' in n: 60 | m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) 61 | elif 'downsample.0' in n: 62 | m.stride = (1, 1) 63 | 64 | self.ppm = _PyramidPoolingModule(2048, 512, (1, 2, 3, 6)) 65 | self.final = nn.Sequential( 66 | nn.Conv2d(4096, 512, kernel_size=3, padding=1, bias=False), 67 | nn.BatchNorm2d(512, momentum=.95), 68 | nn.ReLU(inplace=True), 69 | nn.Dropout(0.1), 70 | nn.Conv2d(512, num_classes, kernel_size=1) 71 | ) 72 | 73 | if use_aux: 74 | self.aux_logits = nn.Conv2d(1024, num_classes, kernel_size=1) 75 | initialize_weights(self.aux_logits) 76 | 77 | initialize_weights(self.ppm, self.final) 78 | 79 | def forward(self, x): 80 | x_size = x.size() 81 | x = self.layer0(x) 82 | x = self.layer1(x) 83 | x = self.layer2(x) 84 | x = self.layer3(x) 85 | if self.training and self.use_aux: 86 | aux = self.aux_logits(x) 87 | x = self.layer4(x) 88 | x = self.ppm(x) 89 | x = self.final(x) 90 | 91 | if self.training and self.use_aux: 92 | out = F.upsample(x, x_size[2:], mode='bilinear'), F.upsample(aux, x_size[2:], mode='bilinear') 93 | else: 94 | out = F.upsample(x, x_size[2:], mode='bilinear') 95 | 96 | return out 97 | -------------------------------------------------------------------------------- /lib/models/squeezenet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | import torch.nn.functional as F 6 | 7 | class Fire(nn.Module): 8 | 9 | def __init__(self, inplanes, squeeze_planes, 10 | expand1x1_planes, expand3x3_planes): 11 | super(Fire, self).__init__() 12 | self.inplanes = inplanes 13 | self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1) 14 | self.squeeze_activation = nn.ELU(inplace=True) 15 | self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, 16 | kernel_size=1) 17 | self.expand1x1_activation = nn.ELU(inplace=True) 18 | self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, 19 | kernel_size=3, padding=1) 20 | self.expand3x3_activation = nn.ELU(inplace=True) 21 | 22 | def forward(self, x): 23 | x = self.squeeze_activation(self.squeeze(x)) 24 | return torch.cat([ 25 | self.expand1x1_activation(self.expand1x1(x)), 26 | self.expand3x3_activation(self.expand3x3(x)) 27 | ], 1) 28 | 29 | class DFire(nn.Module): 30 | 31 | def __init__(self, inplanes, squeeze_planes, 32 | expand1x1_planes, expand3x3_planes): 33 | super(DFire, self).__init__() 34 | self.inplanes = inplanes 35 | 36 | self.expand1x1 = nn.Conv2d(inplanes, expand1x1_planes, 37 | kernel_size=1) 38 | self.expand1x1_activation = nn.ELU(inplace=True) 39 | self.expand3x3 = nn.Conv2d(inplanes, expand3x3_planes, 40 | kernel_size=3, padding=1) 41 | self.expand3x3_activation = nn.ELU(inplace=True) 42 | 43 | self.squeeze = nn.Conv2d(expand3x3_planes + expand1x1_planes, squeeze_planes, kernel_size=1) 44 | self.squeeze_activation = nn.ELU(inplace=True) 45 | 46 | def forward(self, x): 47 | x = torch.cat([ 48 | self.expand1x1_activation(self.expand1x1(x)), 49 | self.expand3x3_activation(self.expand3x3(x)) 50 | ], 1) 51 | x = self.squeeze_activation(self.squeeze(x)) 52 | return x 53 | 54 | class SharpMaskBypass(nn.Module): 55 | def __init__(self, enc_features, dec_features, num_classes): 56 | super(SharpMaskBypass, self).__init__() 57 | self.conv1 = nn.Sequential(nn.Conv2d(enc_features, 32, kernel_size=3, padding=1), nn.ELU(inplace=True)) 58 | self.conv2 = nn.Sequential(nn.Conv2d(32 + dec_features, num_classes, kernel_size=3, padding=1), nn.ELU(inplace=True)) 59 | # TODO: Init weights stddev = 0.0001 60 | 61 | def forward(self, from_enc, from_dec): 62 | x = self.conv1(from_enc) 63 | x = torch.cat((x, from_dec), dim=1) 64 | x = self.conv2(x) 65 | return x 66 | 67 | 68 | class SqueezeNet(nn.Module): 69 | def __init__(self, in_channels=3, num_classes=1): 70 | super(SqueezeNet, self).__init__() 71 | 72 | self.conv1 = nn.Conv2d(in_channels, 96, kernel_size=3, padding=1) 73 | self.pool1 = nn.MaxPool2d(2, 2) 74 | 75 | self.fire2 = Fire(96, 16, 64, 64) 76 | self.fire3 = Fire(128, 16, 64, 64) 77 | self.pool3 = nn.MaxPool2d(2, 2) 78 | 79 | self.fire4 = Fire(128, 48, 128, 128) 80 | self.fire5 = Fire(256, 48, 128, 128) 81 | self.pool5 = nn.MaxPool2d(2, 2) 82 | 83 | self.fire6 = Fire(256, 48, 192, 192) 84 | self.fire7 = Fire(384, 48, 192, 192) 85 | 86 | self.fire8 = Fire(384, 64, 256, 256) 87 | self.fire9 = Fire(512, 64, 256, 256) 88 | 89 | self.conv10 = nn.Sequential(nn.Conv2d(512, 1024, kernel_size=1), nn.ELU(inplace=True)) 90 | self.dconv10 = nn.Sequential(nn.Conv2d(1024, 512, kernel_size=1), nn.ELU(inplace=True)) 91 | 92 | # Decoder 93 | self.dfire9 = DFire(512, 512, 256, 256) 94 | self.dfire8 = DFire(512, 384, 256, 256) 95 | self.dfire7 = DFire(384, 384, 192, 192) 96 | self.dfire6 = DFire(384, 256, 192, 192) 97 | self.dfire5 = DFire(256, 256, 128, 128) 98 | self.dfire4 = DFire(256, 128, 128, 128) 99 | self.dfire3 = DFire(128, 128, 64, 64) 100 | self.dfire2 = DFire(128, 96, 48, 48) 101 | 102 | self.dconv1 = nn.Conv2d(96, num_classes, kernel_size=1) 103 | 104 | # 105 | # self.upscore2 = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=4, stride=2, padding=1) 106 | # 107 | # self.sharpmask3 = SharpMaskBypass(256, num_classes, num_classes) 108 | # 109 | # self.upscore4 = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=4, stride=2, padding=1) 110 | # 111 | # self.sharpmask2 = SharpMaskBypass(128, num_classes, num_classes) 112 | # 113 | # self.upscore8 = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=4, stride=2, padding=1) 114 | # 115 | # self.sharpmask1 = SharpMaskBypass(64, num_classes, num_classes) 116 | 117 | def forward(self, x): 118 | conv1 = self.conv1(x) 119 | pool1 = self.pool1(conv1) 120 | 121 | fire2 = self.fire2(pool1) 122 | fire3 = self.fire3(fire2) 123 | fire4 = self.fire4(fire3) 124 | pool4 = self.pool3(fire4) 125 | 126 | fire5 = self.fire5(pool4) 127 | fire6 = self.fire6(fire5) 128 | fire7 = self.fire7(fire6) 129 | fire8 = self.fire8(fire7) 130 | 131 | pool8 = self.pool5(fire8) 132 | 133 | fire9 = self.fire9(pool8) 134 | center = self.dconv10(self.conv10(fire9)) 135 | dfire9 = self.dfire9(center) 136 | 137 | dfire9 = F.upsample(dfire9, scale_factor=2, mode='nearest') 138 | dfire8 = self.dfire8(dfire9 + fire8) 139 | dfire7 = self.dfire7(dfire8) 140 | dfire6 = self.dfire6(dfire7) 141 | dfire5 = self.dfire5(dfire6) 142 | 143 | dfire5 = F.upsample(dfire5, scale_factor=2, mode='nearest') 144 | dfire4 = self.dfire4(dfire5 + fire4) 145 | dfire3 = self.dfire3(dfire4) 146 | dfire2 = self.dfire2(dfire3) 147 | 148 | dfire2 = F.upsample(dfire2, scale_factor=2, mode='nearest') 149 | dconv1 = self.dconv1(dfire2 + conv1) 150 | 151 | return dconv1 152 | 153 | # drop9 = self.drop9(fire9) 154 | # score_fr = self.score_fr(drop9) 155 | # 156 | # upscore2 = self.upscore2(score_fr) 157 | # sharpmask3 = self.sharpmask3(fire5, upscore2) 158 | # 159 | # upscore4 = self.upscore4(sharpmask3) 160 | # sharpmask2 = self.sharpmask2(fire3, upscore4) 161 | # 162 | # upscore8 = self.upscore8(sharpmask2) 163 | # sharpmask1 = self.sharpmask1(conv1, upscore8) 164 | # 165 | # return sharpmask1 166 | 167 | 168 | if __name__ == '__main__': 169 | arch = SqueezeNet(in_channels=3, num_classes=1) 170 | 171 | x = torch.rand((4,3,512,512)) 172 | y = arch(x) -------------------------------------------------------------------------------- /lib/models/tiramisu.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn import functional as F 7 | 8 | 9 | class DenseLayer(nn.Sequential): 10 | def __init__(self, in_channels, growth_rate): 11 | super().__init__() 12 | self.add_module('norm', nn.BatchNorm2d(in_channels)) 13 | self.add_module('relu', nn.ReLU(True)) 14 | self.add_module('conv', nn.Conv2d(in_channels, growth_rate, kernel_size=3, 15 | stride=1, padding=1, bias=True)) 16 | self.add_module('drop', nn.Dropout2d(0.2)) 17 | 18 | def forward(self, x): 19 | return super().forward(x) 20 | 21 | 22 | class DenseBlock(nn.Module): 23 | def __init__(self, in_channels, growth_rate, n_layers, upsample=False): 24 | super().__init__() 25 | self.upsample = upsample 26 | self.layers = nn.ModuleList([DenseLayer( 27 | in_channels + i * growth_rate, growth_rate) 28 | for i in range(n_layers)]) 29 | 30 | def forward(self, x): 31 | if self.upsample: 32 | new_features = [] 33 | # we pass all previous activations into each dense layer normally 34 | # But we only store each dense layer's output in the new_features array 35 | for layer in self.layers: 36 | out = layer(x) 37 | x = torch.cat([x, out], 1) 38 | new_features.append(out) 39 | return torch.cat(new_features, 1) 40 | else: 41 | for layer in self.layers: 42 | out = layer(x) 43 | x = torch.cat([x, out], 1) # 1 = channel axis 44 | return x 45 | 46 | 47 | class TransitionDown(nn.Sequential): 48 | def __init__(self, in_channels): 49 | super().__init__() 50 | self.add_module('norm', nn.BatchNorm2d(num_features=in_channels)) 51 | self.add_module('relu', nn.ReLU(inplace=True)) 52 | self.add_module('conv', nn.Conv2d(in_channels, in_channels, 53 | kernel_size=1, stride=1, 54 | padding=0, bias=True)) 55 | self.add_module('drop', nn.Dropout2d(0.2)) 56 | self.add_module('maxpool', nn.MaxPool2d(2)) 57 | 58 | def forward(self, x): 59 | return super().forward(x) 60 | 61 | 62 | class TransitionUp(nn.Module): 63 | def __init__(self, in_channels, out_channels): 64 | super().__init__() 65 | self.convTrans = nn.ConvTranspose2d( 66 | in_channels=in_channels, out_channels=out_channels, 67 | kernel_size=3, stride=2, padding=0, bias=True) 68 | 69 | def forward(self, x, skip): 70 | out = self.convTrans(x) 71 | out = center_crop(out, skip.size(2), skip.size(3)) 72 | out = torch.cat([out, skip], 1) 73 | return out 74 | 75 | 76 | class Bottleneck(nn.Sequential): 77 | def __init__(self, in_channels, growth_rate, n_layers): 78 | super().__init__() 79 | self.add_module('bottleneck', DenseBlock( 80 | in_channels, growth_rate, n_layers, upsample=True)) 81 | 82 | def forward(self, x): 83 | return super().forward(x) 84 | 85 | 86 | def center_crop(layer, max_height, max_width): 87 | _, _, h, w = layer.size() 88 | xy1 = (w - max_width) // 2 89 | xy2 = (h - max_height) // 2 90 | return layer[:, :, xy2:(xy2 + max_height), xy1:(xy1 + max_width)] 91 | 92 | 93 | class FCDenseNet(nn.Module): 94 | def __init__(self, in_channels=3, down_blocks=(5, 5, 5, 5, 5), 95 | up_blocks=(5, 5, 5, 5, 5), bottleneck_layers=5, 96 | growth_rate=16, out_chans_first_conv=48, n_classes=12): 97 | super().__init__() 98 | self.num_classes = n_classes 99 | self.down_blocks = down_blocks 100 | self.up_blocks = up_blocks 101 | cur_channels_count = 0 102 | skip_connection_channel_counts = [] 103 | 104 | ## First Convolution ## 105 | 106 | self.add_module('firstconv', nn.Conv2d(in_channels=in_channels, 107 | out_channels=out_chans_first_conv, kernel_size=3, 108 | stride=1, padding=1, bias=True)) 109 | cur_channels_count = out_chans_first_conv 110 | 111 | ##################### 112 | # Downsampling path # 113 | ##################### 114 | 115 | self.denseBlocksDown = nn.ModuleList([]) 116 | self.transDownBlocks = nn.ModuleList([]) 117 | for i in range(len(down_blocks)): 118 | self.denseBlocksDown.append( 119 | DenseBlock(cur_channels_count, growth_rate, down_blocks[i])) 120 | cur_channels_count += (growth_rate * down_blocks[i]) 121 | skip_connection_channel_counts.insert(0, cur_channels_count) 122 | self.transDownBlocks.append(TransitionDown(cur_channels_count)) 123 | 124 | ##################### 125 | # Bottleneck # 126 | ##################### 127 | 128 | self.add_module('bottleneck', Bottleneck(cur_channels_count, 129 | growth_rate, bottleneck_layers)) 130 | prev_block_channels = growth_rate * bottleneck_layers 131 | cur_channels_count += prev_block_channels 132 | 133 | ####################### 134 | # Upsampling path # 135 | ####################### 136 | 137 | self.transUpBlocks = nn.ModuleList([]) 138 | self.denseBlocksUp = nn.ModuleList([]) 139 | for i in range(len(up_blocks) - 1): 140 | self.transUpBlocks.append(TransitionUp(prev_block_channels, prev_block_channels)) 141 | cur_channels_count = prev_block_channels + skip_connection_channel_counts[i] 142 | 143 | self.denseBlocksUp.append(DenseBlock( 144 | cur_channels_count, growth_rate, up_blocks[i], 145 | upsample=True)) 146 | prev_block_channels = growth_rate * up_blocks[i] 147 | cur_channels_count += prev_block_channels 148 | 149 | ## Final DenseBlock ## 150 | 151 | self.transUpBlocks.append(TransitionUp( 152 | prev_block_channels, prev_block_channels)) 153 | cur_channels_count = prev_block_channels + skip_connection_channel_counts[-1] 154 | 155 | self.denseBlocksUp.append(DenseBlock( 156 | cur_channels_count, growth_rate, up_blocks[-1], 157 | upsample=False)) 158 | cur_channels_count += growth_rate * up_blocks[-1] 159 | 160 | ## Softmax ## 161 | 162 | self.finalConv = nn.Conv2d(in_channels=cur_channels_count, 163 | out_channels=n_classes, kernel_size=1, stride=1, 164 | padding=0, bias=True) 165 | 166 | self.softmax = nn.LogSoftmax(dim=1) 167 | 168 | def forward(self, x): 169 | out = self.firstconv(x) 170 | 171 | skip_connections = [] 172 | for i in range(len(self.down_blocks)): 173 | out = self.denseBlocksDown[i](out) 174 | skip_connections.append(out) 175 | out = self.transDownBlocks[i](out) 176 | 177 | out = self.bottleneck(out) 178 | for i in range(len(self.up_blocks)): 179 | skip = skip_connections.pop() 180 | out = self.transUpBlocks[i](out, skip) 181 | out = self.denseBlocksUp[i](out) 182 | 183 | out = self.finalConv(out) 184 | return out 185 | 186 | 187 | def FCDenseNet57(n_classes): 188 | return FCDenseNet( 189 | in_channels=3, down_blocks=(4, 4, 4, 4, 4), 190 | up_blocks=(4, 4, 4, 4, 4), bottleneck_layers=4, 191 | growth_rate=12, out_chans_first_conv=48, n_classes=n_classes) 192 | 193 | 194 | def FCDenseNet67(n_classes): 195 | return FCDenseNet( 196 | in_channels=3, down_blocks=(5, 5, 5, 5, 5), 197 | up_blocks=(5, 5, 5, 5, 5), bottleneck_layers=5, 198 | growth_rate=16, out_chans_first_conv=48, n_classes=n_classes) 199 | 200 | 201 | def FCDenseNet103(n_classes): 202 | return FCDenseNet( 203 | in_channels=3, down_blocks=(4, 5, 7, 10, 12), 204 | up_blocks=(12, 10, 7, 5, 4), bottleneck_layers=15, 205 | growth_rate=16, out_chans_first_conv=48, n_classes=n_classes) 206 | -------------------------------------------------------------------------------- /lib/models/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class double_conv(nn.Module): 7 | '''(conv => BN => ReLU) * 2''' 8 | 9 | def __init__(self, in_ch, out_ch): 10 | super(double_conv, self).__init__() 11 | self.conv = nn.Sequential( 12 | nn.Conv2d(in_ch, out_ch, 3, padding=1), 13 | nn.BatchNorm2d(out_ch), 14 | nn.ReLU(inplace=True), 15 | nn.Conv2d(out_ch, out_ch, 3, padding=1), 16 | nn.BatchNorm2d(out_ch), 17 | nn.ReLU(inplace=True) 18 | ) 19 | 20 | def forward(self, x): 21 | x = self.conv(x) 22 | return x 23 | 24 | 25 | class inconv(nn.Module): 26 | def __init__(self, in_ch, out_ch): 27 | super(inconv, self).__init__() 28 | self.conv = double_conv(in_ch, out_ch) 29 | 30 | def forward(self, x): 31 | x = self.conv(x) 32 | return x 33 | 34 | 35 | class down(nn.Module): 36 | def __init__(self, in_ch, out_ch): 37 | super(down, self).__init__() 38 | self.mpconv = nn.Sequential( 39 | nn.MaxPool2d(2), 40 | double_conv(in_ch, out_ch) 41 | ) 42 | 43 | def forward(self, x): 44 | x = self.mpconv(x) 45 | return x 46 | 47 | 48 | class up(nn.Module): 49 | def __init__(self, in_ch, out_ch, upsample=True): 50 | super(up, self).__init__() 51 | 52 | if upsample: 53 | self.up = nn.Upsample(scale_factor=2, mode='nearest') 54 | else: 55 | self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2) 56 | 57 | self.conv = double_conv(in_ch, out_ch) 58 | 59 | def forward(self, x1, x2): 60 | x1 = self.up(x1) 61 | diffX = x1.size()[2] - x2.size()[2] 62 | diffY = x1.size()[3] - x2.size()[3] 63 | x2 = F.pad(x2, (diffX // 2, int(diffX / 2), 64 | diffY // 2, int(diffY / 2))) 65 | x = torch.cat([x2, x1], dim=1) 66 | x = self.conv(x) 67 | return x 68 | 69 | 70 | class outconv(nn.Module): 71 | def __init__(self, in_ch, out_ch): 72 | super(outconv, self).__init__() 73 | self.conv = nn.Conv2d(in_ch, out_ch, 1) 74 | 75 | def forward(self, x): 76 | x = self.conv(x) 77 | return x 78 | 79 | 80 | class UNet(nn.Module): 81 | def __init__(self, n_channels=3, n_classes=1, n_filters=32, upsample=True): 82 | super(UNet, self).__init__() 83 | self.inc = inconv(n_channels, n_filters) 84 | self.down1 = down(n_filters, n_filters*2) 85 | self.down2 = down(n_filters*2, n_filters*4) 86 | self.down3 = down(n_filters*4, n_filters*8) 87 | self.down4 = down(n_filters*8, n_filters*8) 88 | self.up1 = up(n_filters*16, n_filters*4, upsample=upsample) 89 | self.up2 = up(n_filters*8, n_filters*2, upsample=upsample) 90 | self.up3 = up(n_filters*4, n_filters, upsample=upsample) 91 | self.up4 = up(n_filters*2, n_filters, upsample=upsample) 92 | self.finaldrop = nn.Dropout2d(p=0.5) 93 | self.outc = outconv(n_filters, n_classes) 94 | 95 | def forward(self, x): 96 | x1 = self.inc(x) 97 | x2 = self.down1(x1) 98 | x3 = self.down2(x2) 99 | x4 = self.down3(x3) 100 | x5 = self.down4(x4) 101 | x = self.up1(x5, x4) 102 | x = self.up2(x, x3) 103 | x = self.up3(x, x2) 104 | x = self.up4(x, x1) 105 | x = self.finaldrop(x) 106 | x = self.outc(x) 107 | return x 108 | -------------------------------------------------------------------------------- /lib/models/unet11.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from torchvision import models 4 | import torchvision 5 | from torch.nn import functional as F 6 | 7 | 8 | def conv3x3(in_, out): 9 | return nn.Conv2d(in_, out, 3, padding=1) 10 | 11 | 12 | class ConvRelu(nn.Module): 13 | def __init__(self, in_: int, out: int): 14 | super().__init__() 15 | self.conv = conv3x3(in_, out) 16 | self.activation = nn.ReLU(inplace=True) 17 | 18 | def forward(self, x): 19 | x = self.conv(x) 20 | x = self.activation(x) 21 | return x 22 | 23 | class DecoderBlock(nn.Module): 24 | """ 25 | Paramaters for Deconvolution were chosen to avoid artifacts, following 26 | link https://distill.pub/2016/deconv-checkerboard/ 27 | """ 28 | 29 | def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True): 30 | super(DecoderBlock, self).__init__() 31 | self.in_channels = in_channels 32 | 33 | if is_deconv: 34 | self.block = nn.Sequential( 35 | ConvRelu(in_channels, middle_channels), 36 | nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2, 37 | padding=1), 38 | nn.ReLU(inplace=True) 39 | ) 40 | else: 41 | self.block = nn.Sequential( 42 | nn.Upsample(scale_factor=2, mode='bilinear'), 43 | ConvRelu(in_channels, middle_channels), 44 | ConvRelu(middle_channels, out_channels), 45 | ) 46 | 47 | def forward(self, x): 48 | return self.block(x) 49 | 50 | 51 | class UNet11(nn.Module): 52 | def __init__(self, num_classes=1, num_filters=32, pretrained=False): 53 | """ 54 | :param num_classes: 55 | :param num_filters: 56 | :param pretrained: 57 | False - no pre-trained network used 58 | vgg - encoder pre-trained with VGG11 59 | """ 60 | super().__init__() 61 | self.pool = nn.MaxPool2d(2, 2) 62 | 63 | self.num_classes = num_classes 64 | 65 | if pretrained == 'vgg': 66 | self.encoder = models.vgg11(pretrained=True).features 67 | else: 68 | self.encoder = models.vgg11(pretrained=False).features 69 | 70 | self.relu = nn.ReLU(inplace=True) 71 | self.conv1 = nn.Sequential(self.encoder[0], 72 | self.relu) 73 | 74 | self.conv2 = nn.Sequential(self.encoder[3], 75 | self.relu) 76 | 77 | self.conv3 = nn.Sequential( 78 | self.encoder[6], 79 | self.relu, 80 | self.encoder[8], 81 | self.relu, 82 | ) 83 | self.conv4 = nn.Sequential( 84 | self.encoder[11], 85 | self.relu, 86 | self.encoder[13], 87 | self.relu, 88 | ) 89 | 90 | self.conv5 = nn.Sequential( 91 | self.encoder[16], 92 | self.relu, 93 | self.encoder[18], 94 | self.relu, 95 | ) 96 | 97 | self.center = DecoderBlock(256 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv=True) 98 | self.dec5 = DecoderBlock(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv=True) 99 | self.dec4 = DecoderBlock(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 4, is_deconv=True) 100 | self.dec3 = DecoderBlock(256 + num_filters * 4, num_filters * 4 * 2, num_filters * 2, is_deconv=True) 101 | self.dec2 = DecoderBlock(128 + num_filters * 2, num_filters * 2 * 2, num_filters, is_deconv=True) 102 | self.dec1 = ConvRelu(64 + num_filters, num_filters) 103 | 104 | self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1) 105 | 106 | def forward(self, x): 107 | conv1 = self.conv1(x) 108 | conv2 = self.conv2(self.pool(conv1)) 109 | conv3 = self.conv3(self.pool(conv2)) 110 | conv4 = self.conv4(self.pool(conv3)) 111 | conv5 = self.conv5(self.pool(conv4)) 112 | center = self.center(self.pool(conv5)) 113 | 114 | dec5 = self.dec5(torch.cat([center, conv5], 1)) 115 | dec4 = self.dec4(torch.cat([dec5, conv4], 1)) 116 | dec3 = self.dec3(torch.cat([dec4, conv3], 1)) 117 | dec2 = self.dec2(torch.cat([dec3, conv2], 1)) 118 | dec1 = self.dec1(torch.cat([dec2, conv1], 1)) 119 | 120 | x_out = self.final(dec1) 121 | 122 | return x_out 123 | -------------------------------------------------------------------------------- /lib/models/unet16.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from torchvision import models 4 | import torchvision 5 | from torch.nn import functional as F 6 | 7 | 8 | def conv3x3(in_, out): 9 | return nn.Conv2d(in_, out, 3, padding=1) 10 | 11 | 12 | class ConvRelu(nn.Module): 13 | def __init__(self, in_: int, out: int): 14 | super().__init__() 15 | self.conv = conv3x3(in_, out) 16 | self.activation = nn.ReLU(inplace=True) 17 | 18 | def forward(self, x): 19 | x = self.conv(x) 20 | x = self.activation(x) 21 | return x 22 | 23 | 24 | class DecoderBlock(nn.Module): 25 | """ 26 | Paramaters for Deconvolution were chosen to avoid artifacts, following 27 | link https://distill.pub/2016/deconv-checkerboard/ 28 | """ 29 | 30 | def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True): 31 | super(DecoderBlock, self).__init__() 32 | self.in_channels = in_channels 33 | 34 | if is_deconv: 35 | self.block = nn.Sequential( 36 | ConvRelu(in_channels, middle_channels), 37 | nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2, 38 | padding=1), 39 | nn.ReLU(inplace=True) 40 | ) 41 | else: 42 | self.block = nn.Sequential( 43 | nn.Upsample(scale_factor=2, mode='bilinear'), 44 | ConvRelu(in_channels, middle_channels), 45 | ConvRelu(middle_channels, out_channels), 46 | ) 47 | 48 | def forward(self, x): 49 | return self.block(x) 50 | 51 | 52 | class UNet16(nn.Module): 53 | def __init__(self, num_classes=1, num_filters=32, pretrained=False): 54 | """ 55 | :param num_classes: 56 | :param num_filters: 57 | :param pretrained: 58 | False - no pre-trained network used 59 | vgg - encoder pre-trained with VGG11 60 | """ 61 | super().__init__() 62 | self.num_classes = num_classes 63 | 64 | self.pool = nn.MaxPool2d(2, 2) 65 | 66 | if pretrained == 'vgg': 67 | self.encoder = torchvision.models.vgg16(pretrained=True).features 68 | else: 69 | self.encoder = torchvision.models.vgg16(pretrained=False).features 70 | 71 | self.relu = nn.ReLU(inplace=True) 72 | 73 | self.conv1 = nn.Sequential(self.encoder[0], 74 | self.relu, 75 | self.encoder[2], 76 | self.relu) 77 | 78 | self.conv2 = nn.Sequential(self.encoder[5], 79 | self.relu, 80 | self.encoder[7], 81 | self.relu) 82 | 83 | self.conv3 = nn.Sequential(self.encoder[10], 84 | self.relu, 85 | self.encoder[12], 86 | self.relu, 87 | self.encoder[14], 88 | self.relu) 89 | 90 | self.conv4 = nn.Sequential(self.encoder[17], 91 | self.relu, 92 | self.encoder[19], 93 | self.relu, 94 | self.encoder[21], 95 | self.relu) 96 | 97 | self.conv5 = nn.Sequential(self.encoder[24], 98 | self.relu, 99 | self.encoder[26], 100 | self.relu, 101 | self.encoder[28], 102 | self.relu) 103 | 104 | self.center = DecoderBlock(512, num_filters * 8 * 2, num_filters * 8) 105 | 106 | self.dec5 = DecoderBlock(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8) 107 | self.dec4 = DecoderBlock(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8) 108 | self.dec3 = DecoderBlock(256 + num_filters * 8, num_filters * 4 * 2, num_filters * 2) 109 | self.dec2 = DecoderBlock(128 + num_filters * 2, num_filters * 2 * 2, num_filters) 110 | self.dec1 = ConvRelu(64 + num_filters, num_filters) 111 | self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1) 112 | 113 | def forward(self, x): 114 | conv1 = self.conv1(x) 115 | conv2 = self.conv2(self.pool(conv1)) 116 | conv3 = self.conv3(self.pool(conv2)) 117 | conv4 = self.conv4(self.pool(conv3)) 118 | conv5 = self.conv5(self.pool(conv4)) 119 | 120 | center = self.center(self.pool(conv5)) 121 | 122 | dec5 = self.dec5(torch.cat([center, conv5], 1)) 123 | 124 | dec4 = self.dec4(torch.cat([dec5, conv4], 1)) 125 | dec3 = self.dec3(torch.cat([dec4, conv3], 1)) 126 | dec2 = self.dec2(torch.cat([dec3, conv2], 1)) 127 | dec1 = self.dec1(torch.cat([dec2, conv1], 1)) 128 | 129 | x_out = self.final(dec1) 130 | 131 | return x_out 132 | -------------------------------------------------------------------------------- /lib/models/unet_abn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class double_conv(nn.Module): 7 | '''(conv => BN => ReLU) * 2''' 8 | 9 | def __init__(self, in_ch, out_ch): 10 | from lib.modules.abn import InPlaceABN 11 | 12 | super(double_conv, self).__init__() 13 | self.conv = nn.Sequential( 14 | nn.Conv2d(in_ch, out_ch, 3, padding=1), 15 | InPlaceABN(out_ch), 16 | nn.Conv2d(out_ch, out_ch, 3, padding=1), 17 | InPlaceABN(out_ch), 18 | ) 19 | 20 | def forward(self, x): 21 | x = self.conv(x) 22 | return x 23 | 24 | 25 | class inconv(nn.Module): 26 | def __init__(self, in_ch, out_ch): 27 | super(inconv, self).__init__() 28 | self.conv = double_conv(in_ch, out_ch) 29 | 30 | def forward(self, x): 31 | x = self.conv(x) 32 | return x 33 | 34 | 35 | class down(nn.Module): 36 | def __init__(self, in_ch, out_ch): 37 | super(down, self).__init__() 38 | self.mpconv = nn.Sequential( 39 | nn.MaxPool2d(2), 40 | double_conv(in_ch, out_ch) 41 | ) 42 | 43 | def forward(self, x): 44 | x = self.mpconv(x) 45 | return x 46 | 47 | 48 | class up(nn.Module): 49 | def __init__(self, in_ch, out_ch, upsample=True): 50 | super(up, self).__init__() 51 | 52 | if upsample: 53 | self.up = nn.Upsample(scale_factor=2, mode='nearest') 54 | else: 55 | self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2) 56 | 57 | self.conv = double_conv(in_ch, out_ch) 58 | 59 | def forward(self, x1, x2): 60 | x1 = self.up(x1) 61 | diffX = x1.size()[2] - x2.size()[2] 62 | diffY = x1.size()[3] - x2.size()[3] 63 | x2 = F.pad(x2, (diffX // 2, int(diffX / 2), 64 | diffY // 2, int(diffY / 2))) 65 | x = torch.cat([x2, x1], dim=1) 66 | x = self.conv(x) 67 | return x 68 | 69 | 70 | class outconv(nn.Module): 71 | def __init__(self, in_ch, out_ch): 72 | super(outconv, self).__init__() 73 | self.conv = nn.Conv2d(in_ch, out_ch, 1) 74 | 75 | def forward(self, x): 76 | x = self.conv(x) 77 | return x 78 | 79 | 80 | class UNetABN(nn.Module): 81 | def __init__(self, n_channels=3, n_classes=1, n_filters=32, upsample=True): 82 | super(UNetABN, self).__init__() 83 | self.inc = inconv(n_channels, n_filters) 84 | self.down1 = down(n_filters, n_filters * 2) 85 | self.down2 = down(n_filters * 2, n_filters * 4) 86 | self.down3 = down(n_filters * 4, n_filters * 8) 87 | self.down4 = down(n_filters * 8, n_filters * 8) 88 | self.up1 = up(n_filters * 16, n_filters * 4, upsample=upsample) 89 | self.up2 = up(n_filters * 8, n_filters * 2, upsample=upsample) 90 | self.up3 = up(n_filters * 4, n_filters, upsample=upsample) 91 | self.up4 = up(n_filters * 2, n_filters, upsample=upsample) 92 | self.finaldrop = nn.Dropout2d(p=0.5) 93 | self.outc = outconv(n_filters, n_classes) 94 | 95 | def forward(self, x): 96 | x1 = self.inc(x) 97 | x2 = self.down1(x1) 98 | x3 = self.down2(x2) 99 | x4 = self.down3(x3) 100 | x5 = self.down4(x4) 101 | x = self.up1(x5, x4) 102 | x = self.up2(x, x3) 103 | x = self.up3(x, x2) 104 | x = self.up4(x, x1) 105 | x = self.finaldrop(x) 106 | x = self.outc(x) 107 | return x 108 | -------------------------------------------------------------------------------- /lib/models/wider_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | import sys 8 | import numpy as np 9 | 10 | 11 | def conv3x3(in_planes, out_planes, stride=1): 12 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 13 | 14 | 15 | def conv_init(m): 16 | classname = m.__class__.__name__ 17 | if classname.find('Conv') != -1: 18 | init.xavier_uniform(m.weight, gain=np.sqrt(2)) 19 | init.constant(m.bias, 0) 20 | elif classname.find('BatchNorm') != -1: 21 | init.constant(m.weight, 1) 22 | init.constant(m.bias, 0) 23 | 24 | 25 | class wide_basic(nn.Module): 26 | def __init__(self, in_planes, planes, dropout_rate, stride=1): 27 | super(wide_basic, self).__init__() 28 | self.bn1 = nn.BatchNorm2d(in_planes) 29 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True) 30 | self.dropout = nn.Dropout(p=dropout_rate) 31 | self.bn2 = nn.BatchNorm2d(planes) 32 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 33 | 34 | self.shortcut = nn.Sequential() 35 | if stride != 1 or in_planes != planes: 36 | self.shortcut = nn.Sequential( 37 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), 38 | ) 39 | 40 | def forward(self, x): 41 | out = self.dropout(self.conv1(F.relu(self.bn1(x)))) 42 | out = self.conv2(F.relu(self.bn2(out))) 43 | out += self.shortcut(x) 44 | 45 | return out 46 | 47 | 48 | class Wide_ResNet(nn.Module): 49 | def __init__(self, depth, widen_factor, dropout_rate, num_classes): 50 | super(Wide_ResNet, self).__init__() 51 | self.in_planes = 16 52 | 53 | assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4' 54 | n = (depth - 4) / 6 55 | k = widen_factor 56 | 57 | nStages = [16, 16 * k, 32 * k, 64 * k] 58 | 59 | self.conv1 = conv3x3(3, nStages[0]) 60 | self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1) 61 | self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2) 62 | self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2) 63 | 64 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): 65 | strides = [stride] + [1] * (num_blocks - 1) 66 | layers = [] 67 | 68 | for stride in strides: 69 | layers.append(block(self.in_planes, planes, dropout_rate, stride)) 70 | self.in_planes = planes 71 | 72 | return nn.Sequential(*layers) 73 | 74 | def forward(self, x): 75 | out = self.conv1(x) 76 | out = self.layer1(out) 77 | out = self.layer2(out) 78 | out = self.layer3(out) 79 | out = F.relu(self.bn1(out)) 80 | out = F.avg_pool2d(out, 8) 81 | out = out.view(out.size(0), -1) 82 | out = self.linear(out) 83 | 84 | return out 85 | -------------------------------------------------------------------------------- /lib/models/zf_unet.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | class _Conv3BN(nn.Module): 6 | def __init__(self, in_: int, out: int, bn=False): 7 | super().__init__() 8 | self.conv = nn.Conv2d(in_, out, 3, padding=1) 9 | self.bn = nn.BatchNorm2d(out) if bn else None 10 | self.activation = nn.ReLU(inplace=True) 11 | 12 | def forward(self, x): 13 | x = self.conv(x) 14 | if self.bn is not None: 15 | x = self.bn(x) 16 | x = self.activation(x) 17 | return x 18 | 19 | 20 | class _DoubleConvModule(nn.Module): 21 | def __init__(self, in_: int, out: int, dropout_val, batch_norm): 22 | super().__init__() 23 | self.l1 = _Conv3BN(in_, out, batch_norm) 24 | self.l2 = _Conv3BN(out, out, batch_norm) 25 | self.dropout = nn.Dropout2d(p=dropout_val) 26 | 27 | def forward(self, x): 28 | x = self.l1(x) 29 | x = self.l2(x) 30 | if self.dropout is not None: 31 | x = self.dropout(x) 32 | return x 33 | 34 | 35 | class ZF_UNET(nn.Module): 36 | def __init__(self, dropout_val=0.2, batch_norm=True, input_channels=3, num_classes=1, filters=32): 37 | super(ZF_UNET, self).__init__() 38 | 39 | self.num_classes = num_classes 40 | 41 | self.pool = nn.MaxPool2d(2) 42 | self.unpool = nn.Upsample(scale_factor=2) 43 | 44 | self.conv_224 = _DoubleConvModule(input_channels, filters, dropout_val, batch_norm) 45 | self.conv_112 = _DoubleConvModule(filters, 2 * filters, dropout_val, batch_norm) 46 | self.conv_56 = _DoubleConvModule(2 * filters, 4 * filters, dropout_val, batch_norm) 47 | self.conv_28 = _DoubleConvModule(4 * filters, 8 * filters, dropout_val, batch_norm) 48 | self.conv_14 = _DoubleConvModule(8 * filters, 16 * filters, dropout_val, batch_norm) 49 | 50 | self.conv_7 = _DoubleConvModule(16 * filters, 32 * filters, dropout_val, batch_norm) 51 | 52 | self.up_conv_14 = _DoubleConvModule(32 * filters + 16 * filters, 16 * filters, dropout_val, batch_norm) 53 | self.up_conv_28 = _DoubleConvModule(16 * filters + 8 * filters, 8 * filters, dropout_val, batch_norm) 54 | self.up_conv_56 = _DoubleConvModule(8 * filters + 4 * filters, 4 * filters, dropout_val, batch_norm) 55 | self.up_conv_112 = _DoubleConvModule(4 * filters + 2 * filters, 2 * filters, dropout_val, batch_norm) 56 | self.up_conv_224 = _DoubleConvModule(2 * filters + filters, filters, dropout_val, batch_norm) 57 | 58 | self.conv_final = nn.Conv2d(filters, num_classes, 1) 59 | 60 | def forward(self, x): 61 | conv_224 = self.conv_224(x) 62 | pool_112 = self.pool(conv_224) 63 | 64 | conv_112 = self.conv_112(pool_112) 65 | pool_56 = self.pool(conv_112) 66 | 67 | conv_56 = self.conv_56(pool_56) 68 | pool_28 = self.pool(conv_56) 69 | 70 | conv_28 = self.conv_28(pool_28) 71 | pool_14 = self.pool(conv_28) 72 | 73 | conv_14 = self.conv_14(pool_14) 74 | pool_7 = self.pool(conv_14) 75 | 76 | conv_7 = self.conv_7(pool_7) 77 | 78 | up_14 = torch.cat([self.unpool(conv_7), conv_14], dim=1) 79 | up_conv_14 = self.up_conv_14(up_14) 80 | 81 | up_28 = torch.cat([self.unpool(up_conv_14), conv_28], dim=1) 82 | up_conv_28 = self.up_conv_28(up_28) 83 | 84 | up_56 = torch.cat([self.unpool(up_conv_28), conv_56], dim=1) 85 | up_conv_56 = self.up_conv_56(up_56) 86 | 87 | up_112 = torch.cat([self.unpool(up_conv_56), conv_112], dim=1) 88 | up_conv_112 = self.up_conv_112(up_112) 89 | 90 | up_224 = torch.cat([self.unpool(up_conv_112), conv_224],dim=1) 91 | up_conv_224 = self.up_conv_224(up_224) 92 | 93 | out = self.conv_final(up_conv_224) 94 | 95 | return out 96 | -------------------------------------------------------------------------------- /lib/modules/abn/__init__.py: -------------------------------------------------------------------------------- 1 | from .bn import ABN, InPlaceABN, InPlaceABNWrapper, InPlaceABNSync, InPlaceABNSyncWrapper 2 | from .misc import GlobalAvgPool2d 3 | from .residual import IdentityResidualBlock 4 | from .dense import DenseModule 5 | -------------------------------------------------------------------------------- /lib/modules/abn/bn.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict, Iterable 2 | from itertools import repeat 3 | 4 | try: 5 | # python 3 6 | from queue import Queue 7 | except ImportError: 8 | # python 2 9 | from Queue import Queue 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | from .functions import inplace_abn, inplace_abn_sync 15 | 16 | 17 | def _pair(x): 18 | if isinstance(x, Iterable): 19 | return x 20 | return tuple(repeat(x, 2)) 21 | 22 | 23 | class ABN(nn.Sequential): 24 | """Activated Batch Normalization 25 | 26 | This gathers a `BatchNorm2d` and an activation function in a single module 27 | """ 28 | 29 | def __init__(self, num_features, activation=nn.ReLU(inplace=True), **kwargs): 30 | """Creates an Activated Batch Normalization module 31 | 32 | Parameters 33 | ---------- 34 | num_features : int 35 | Number of feature channels in the input and output. 36 | activation : nn.Module 37 | Module used as an activation function. 38 | kwargs 39 | All other arguments are forwarded to the `BatchNorm2d` constructor. 40 | """ 41 | super(ABN, self).__init__(OrderedDict([ 42 | ("bn", nn.BatchNorm2d(num_features, **kwargs)), 43 | ("act", activation) 44 | ])) 45 | 46 | 47 | class InPlaceABN(nn.Module): 48 | """InPlace Activated Batch Normalization""" 49 | 50 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01): 51 | """Creates an InPlace Activated Batch Normalization module 52 | 53 | Parameters 54 | ---------- 55 | num_features : int 56 | Number of feature channels in the input and output. 57 | eps : float 58 | Small constant to prevent numerical issues. 59 | momentum : float 60 | Momentum factor applied to compute running statistics as. 61 | affine : bool 62 | If `True` apply learned scale and shift transformation after normalization. 63 | activation : str 64 | Name of the activation functions, one of: `leaky_relu`, `elu` or `none`. 65 | slope : float 66 | Negative slope for the `leaky_relu` activation. 67 | """ 68 | super(InPlaceABN, self).__init__() 69 | self.num_features = num_features 70 | self.affine = affine 71 | self.eps = eps 72 | self.momentum = momentum 73 | self.activation = activation 74 | self.slope = slope 75 | if self.affine: 76 | self.weight = nn.Parameter(torch.ones(num_features)) 77 | self.bias = nn.Parameter(torch.zeros(num_features)) 78 | else: 79 | self.register_parameter('weight', None) 80 | self.register_parameter('bias', None) 81 | self.register_buffer('running_mean', torch.zeros(num_features)) 82 | self.register_buffer('running_var', torch.ones(num_features)) 83 | self.reset_parameters() 84 | 85 | def reset_parameters(self): 86 | nn.init.constant_(self.running_mean, 0) 87 | nn.init.constant_(self.running_var, 1) 88 | if self.affine: 89 | nn.init.constant_(self.weight, 1) 90 | nn.init.constant_(self.bias, 0) 91 | 92 | def forward(self, x): 93 | return inplace_abn(x, self.weight, self.bias, self.running_mean, self.running_var, 94 | self.training, self.momentum, self.eps, self.activation, self.slope) 95 | 96 | def __repr__(self): 97 | rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \ 98 | ' affine={affine}, activation={activation}' 99 | if self.activation == "leaky_relu": 100 | rep += ' slope={slope})' 101 | else: 102 | rep += ')' 103 | return rep.format(name=self.__class__.__name__, **self.__dict__) 104 | 105 | 106 | class InPlaceABNSync(nn.Module): 107 | """InPlace Activated Batch Normalization with cross-GPU synchronization 108 | 109 | This assumes that it will be replicated across GPUs using the same mechanism as in `nn.DataParallel`. 110 | """ 111 | 112 | def __init__(self, num_features, devices=None, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", 113 | slope=0.01): 114 | """Creates a synchronized, InPlace Activated Batch Normalization module 115 | 116 | Parameters 117 | ---------- 118 | num_features : int 119 | Number of feature channels in the input and output. 120 | devices : list of int or None 121 | IDs of the GPUs that will run the replicas of this module. 122 | eps : float 123 | Small constant to prevent numerical issues. 124 | momentum : float 125 | Momentum factor applied to compute running statistics as. 126 | affine : bool 127 | If `True` apply learned scale and shift transformation after normalization. 128 | activation : str 129 | Name of the activation functions, one of: `leaky_relu`, `elu` or `none`. 130 | slope : float 131 | Negative slope for the `leaky_relu` activation. 132 | """ 133 | super(InPlaceABNSync, self).__init__() 134 | self.num_features = num_features 135 | self.devices = devices if devices else list(range(torch.cuda.device_count())) 136 | self.affine = affine 137 | self.eps = eps 138 | self.momentum = momentum 139 | self.activation = activation 140 | self.slope = slope 141 | if self.affine: 142 | self.weight = nn.Parameter(torch.ones(num_features)) 143 | self.bias = nn.Parameter(torch.zeros(num_features)) 144 | else: 145 | self.register_parameter('weight', None) 146 | self.register_parameter('bias', None) 147 | self.register_buffer('running_mean', torch.zeros(num_features)) 148 | self.register_buffer('running_var', torch.ones(num_features)) 149 | self.reset_parameters() 150 | 151 | # Initialize queues 152 | self.worker_ids = self.devices[1:] 153 | self.master_queue = Queue(len(self.worker_ids)) 154 | self.worker_queues = [Queue(1) for _ in self.worker_ids] 155 | 156 | def reset_parameters(self): 157 | nn.init.constant_(self.running_mean, 0) 158 | nn.init.constant_(self.running_var, 1) 159 | if self.affine: 160 | nn.init.constant_(self.weight, 1) 161 | nn.init.constant_(self.bias, 0) 162 | 163 | def forward(self, x): 164 | if x.get_device() == self.devices[0]: 165 | # Master mode 166 | extra = { 167 | "is_master": True, 168 | "master_queue": self.master_queue, 169 | "worker_queues": self.worker_queues, 170 | "worker_ids": self.worker_ids 171 | } 172 | else: 173 | # Worker mode 174 | extra = { 175 | "is_master": False, 176 | "master_queue": self.master_queue, 177 | "worker_queue": self.worker_queues[self.worker_ids.index(x.get_device())] 178 | } 179 | 180 | return inplace_abn_sync(x, self.weight, self.bias, self.running_mean, self.running_var, 181 | extra, self.training, self.momentum, self.eps, self.activation, self.slope) 182 | 183 | def __repr__(self): 184 | rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \ 185 | ' affine={affine}, devices={devices}, activation={activation}' 186 | if self.activation == "leaky_relu": 187 | rep += ' slope={slope})' 188 | else: 189 | rep += ')' 190 | return rep.format(name=self.__class__.__name__, **self.__dict__) 191 | 192 | 193 | class InPlaceABNWrapper(nn.Module): 194 | """Wrapper module to make `InPlaceABN` compatible with `ABN`""" 195 | 196 | def __init__(self, *args, **kwargs): 197 | super(InPlaceABNWrapper, self).__init__() 198 | self.bn = InPlaceABN(*args, **kwargs) 199 | 200 | def forward(self, input): 201 | return self.bn(input) 202 | 203 | 204 | class InPlaceABNSyncWrapper(nn.Module): 205 | """Wrapper module to make `InPlaceABNSync` compatible with `ABN`""" 206 | 207 | def __init__(self, *args, **kwargs): 208 | super(InPlaceABNSyncWrapper, self).__init__() 209 | self.bn = InPlaceABNSync(*args, **kwargs) 210 | 211 | def forward(self, input): 212 | return self.bn(input) 213 | -------------------------------------------------------------------------------- /lib/modules/abn/dense.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .bn import ABN 7 | 8 | 9 | class DenseModule(nn.Module): 10 | def __init__(self, in_channels, growth, layers, bottleneck_factor=4, norm_act=ABN, dilation=1): 11 | super(DenseModule, self).__init__() 12 | self.in_channels = in_channels 13 | self.growth = growth 14 | self.layers = layers 15 | 16 | self.convs1 = nn.ModuleList() 17 | self.convs3 = nn.ModuleList() 18 | for i in range(self.layers): 19 | self.convs1.append(nn.Sequential(OrderedDict([ 20 | ("bn", norm_act(in_channels)), 21 | ("conv", nn.Conv2d(in_channels, self.growth * bottleneck_factor, 1, bias=False)) 22 | ]))) 23 | self.convs3.append(nn.Sequential(OrderedDict([ 24 | ("bn", norm_act(self.growth * bottleneck_factor)), 25 | ("conv", nn.Conv2d(self.growth * bottleneck_factor, self.growth, 3, padding=dilation, bias=False, 26 | dilation=dilation)) 27 | ]))) 28 | in_channels += self.growth 29 | 30 | @property 31 | def out_channels(self): 32 | return self.in_channels + self.growth * self.layers 33 | 34 | def forward(self, x): 35 | inputs = [x] 36 | for i in range(self.layers): 37 | x = torch.cat(inputs, dim=1) 38 | x = self.convs1[i](x) 39 | x = self.convs3[i](x) 40 | inputs += [x] 41 | 42 | return torch.cat(inputs, dim=1) 43 | -------------------------------------------------------------------------------- /lib/modules/abn/functions.py: -------------------------------------------------------------------------------- 1 | import inplace_abn as backend 2 | import torch.autograd as autograd 3 | import torch.cuda.comm as comm 4 | from torch.autograd.function import once_differentiable 5 | 6 | # Activation names 7 | ACT_LEAKY_RELU = "leaky_relu" 8 | ACT_ELU = "elu" 9 | ACT_NONE = "none" 10 | 11 | 12 | def _check(fn, *args, **kwargs): 13 | success = fn(*args, **kwargs) 14 | if not success: 15 | raise RuntimeError("CUDA Error encountered in {}".format(fn)) 16 | 17 | 18 | def _broadcast_shape(x): 19 | out_size = [] 20 | for i, s in enumerate(x.size()): 21 | if i != 1: 22 | out_size.append(1) 23 | else: 24 | out_size.append(s) 25 | return out_size 26 | 27 | 28 | def _reduce(x): 29 | if len(x.size()) == 2: 30 | return x.sum(dim=0) 31 | else: 32 | n, c = x.size()[0:2] 33 | return x.contiguous().view((n, c, -1)).sum(2).sum(0) 34 | 35 | 36 | def _count_samples(x): 37 | count = 1 38 | for i, s in enumerate(x.size()): 39 | if i != 1: 40 | count *= s 41 | return count 42 | 43 | 44 | def _act_forward(ctx, x): 45 | if ctx.activation == ACT_LEAKY_RELU: 46 | backend.leaky_relu_forward(x, ctx.slope) 47 | elif ctx.activation == ACT_ELU: 48 | backend.elu_forward(x) 49 | elif ctx.activation == ACT_NONE: 50 | pass 51 | 52 | 53 | def _act_backward(ctx, x, dx): 54 | if ctx.activation == ACT_LEAKY_RELU: 55 | backend.leaky_relu_backward(x, dx, ctx.slope) 56 | elif ctx.activation == ACT_ELU: 57 | backend.elu_backward(x, dx) 58 | elif ctx.activation == ACT_NONE: 59 | pass 60 | 61 | 62 | class InPlaceABN(autograd.Function): 63 | @staticmethod 64 | def forward(ctx, x, weight, bias, running_mean, running_var, 65 | training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01): 66 | # Save context 67 | ctx.training = training 68 | ctx.momentum = momentum 69 | ctx.eps = eps 70 | ctx.activation = activation 71 | ctx.slope = slope 72 | ctx.affine = weight is not None and bias is not None 73 | 74 | # Prepare inputs 75 | count = _count_samples(x) 76 | x = x.contiguous() 77 | weight = weight.contiguous() if ctx.affine else x.new_empty(0) 78 | bias = bias.contiguous() if ctx.affine else x.new_empty(0) 79 | 80 | if ctx.training: 81 | mean, var = backend.mean_var(x) 82 | 83 | # Update running stats 84 | running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean) 85 | running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * count / (count - 1)) 86 | 87 | # Mark in-place modified tensors 88 | ctx.mark_dirty(x, running_mean, running_var) 89 | else: 90 | mean, var = running_mean.contiguous(), running_var.contiguous() 91 | ctx.mark_dirty(x) 92 | 93 | # BN forward + activation 94 | backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps) 95 | _act_forward(ctx, x) 96 | 97 | # Output 98 | ctx.var = var 99 | ctx.save_for_backward(x, var, weight, bias) 100 | return x 101 | 102 | @staticmethod 103 | @once_differentiable 104 | def backward(ctx, dz): 105 | z, var, weight, bias = ctx.saved_tensors 106 | dz = dz.contiguous() 107 | 108 | # Undo activation 109 | _act_backward(ctx, z, dz) 110 | 111 | if ctx.training: 112 | edz, eydz = backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps) 113 | else: 114 | # TODO: implement simplified CUDA backward for inference mode 115 | edz = dz.new_zeros(dz.size(1)) 116 | eydz = dz.new_zeros(dz.size(1)) 117 | 118 | dx, dweight, dbias = backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps) 119 | dweight = dweight if ctx.affine else None 120 | dbias = dbias if ctx.affine else None 121 | 122 | return dx, dweight, dbias, None, None, None, None, None, None, None 123 | 124 | 125 | class InPlaceABNSync(autograd.Function): 126 | @classmethod 127 | def forward(cls, ctx, x, weight, bias, running_mean, running_var, 128 | extra, training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01): 129 | # Save context 130 | cls._parse_extra(ctx, extra) 131 | ctx.training = training 132 | ctx.momentum = momentum 133 | ctx.eps = eps 134 | ctx.activation = activation 135 | ctx.slope = slope 136 | ctx.affine = weight is not None and bias is not None 137 | 138 | # Prepare inputs 139 | count = _count_samples(x) * (ctx.master_queue.maxsize + 1) 140 | x = x.contiguous() 141 | weight = weight.contiguous() if ctx.affine else x.new_empty(0) 142 | bias = bias.contiguous() if ctx.affine else x.new_empty(0) 143 | 144 | if ctx.training: 145 | mean, var = backend.mean_var(x) 146 | 147 | if ctx.is_master: 148 | means, vars = [mean.unsqueeze(0)], [var.unsqueeze(0)] 149 | for _ in range(ctx.master_queue.maxsize): 150 | mean_w, var_w = ctx.master_queue.get() 151 | ctx.master_queue.task_done() 152 | means.append(mean_w.unsqueeze(0)) 153 | vars.append(var_w.unsqueeze(0)) 154 | 155 | means = comm.gather(means) 156 | vars = comm.gather(vars) 157 | 158 | mean = means.mean(0) 159 | var = (vars + (mean - means) ** 2).mean(0) 160 | 161 | tensors = comm.broadcast_coalesced((mean, var), [mean.get_device()] + ctx.worker_ids) 162 | for ts, queue in zip(tensors[1:], ctx.worker_queues): 163 | queue.put(ts) 164 | else: 165 | ctx.master_queue.put((mean, var)) 166 | mean, var = ctx.worker_queue.get() 167 | ctx.worker_queue.task_done() 168 | 169 | # Update running stats 170 | running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean) 171 | running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * count / (count - 1)) 172 | 173 | # Mark in-place modified tensors 174 | ctx.mark_dirty(x, running_mean, running_var) 175 | else: 176 | mean, var = running_mean.contiguous(), running_var.contiguous() 177 | ctx.mark_dirty(x) 178 | 179 | # BN forward + activation 180 | backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps) 181 | _act_forward(ctx, x) 182 | 183 | # Output 184 | ctx.var = var 185 | ctx.save_for_backward(x, var, weight, bias) 186 | return x 187 | 188 | @staticmethod 189 | @once_differentiable 190 | def backward(ctx, dz): 191 | z, var, weight, bias = ctx.saved_tensors 192 | dz = dz.contiguous() 193 | 194 | # Undo activation 195 | _act_backward(ctx, z, dz) 196 | 197 | if ctx.training: 198 | edz, eydz = backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps) 199 | 200 | if ctx.is_master: 201 | edzs, eydzs = [edz], [eydz] 202 | for _ in range(len(ctx.worker_queues)): 203 | edz_w, eydz_w = ctx.master_queue.get() 204 | ctx.master_queue.task_done() 205 | edzs.append(edz_w) 206 | eydzs.append(eydz_w) 207 | 208 | edz = comm.reduce_add(edzs) / (ctx.master_queue.maxsize + 1) 209 | eydz = comm.reduce_add(eydzs) / (ctx.master_queue.maxsize + 1) 210 | 211 | tensors = comm.broadcast_coalesced((edz, eydz), [edz.get_device()] + ctx.worker_ids) 212 | for ts, queue in zip(tensors[1:], ctx.worker_queues): 213 | queue.put(ts) 214 | else: 215 | ctx.master_queue.put((edz, eydz)) 216 | edz, eydz = ctx.worker_queue.get() 217 | ctx.worker_queue.task_done() 218 | else: 219 | edz = dz.new_zeros(dz.size(1)) 220 | eydz = dz.new_zeros(dz.size(1)) 221 | 222 | dx, dweight, dbias = backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps) 223 | dweight = dweight if ctx.affine else None 224 | dbias = dbias if ctx.affine else None 225 | 226 | return dx, dweight, dbias, None, None, None, None, None, None, None, None 227 | 228 | @staticmethod 229 | def _parse_extra(ctx, extra): 230 | ctx.is_master = extra["is_master"] 231 | if ctx.is_master: 232 | ctx.master_queue = extra["master_queue"] 233 | ctx.worker_queues = extra["worker_queues"] 234 | ctx.worker_ids = extra["worker_ids"] 235 | else: 236 | ctx.master_queue = extra["master_queue"] 237 | ctx.worker_queue = extra["worker_queue"] 238 | 239 | 240 | inplace_abn = InPlaceABN.apply 241 | inplace_abn_sync = InPlaceABNSync.apply 242 | 243 | __all__ = ["inplace_abn", "inplace_abn_sync"] 244 | -------------------------------------------------------------------------------- /lib/modules/abn/misc.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class GlobalAvgPool2d(nn.Module): 5 | def __init__(self): 6 | """Global average pooling over the input's spatial dimensions""" 7 | super(GlobalAvgPool2d, self).__init__() 8 | 9 | def forward(self, inputs): 10 | in_size = inputs.size() 11 | return inputs.view((in_size[0], in_size[1], -1)).mean(dim=2) 12 | -------------------------------------------------------------------------------- /lib/modules/abn/residual.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch.nn as nn 4 | 5 | from .bn import ABN 6 | 7 | 8 | class IdentityResidualBlock(nn.Module): 9 | def __init__(self, 10 | in_channels, 11 | channels, 12 | stride=1, 13 | dilation=1, 14 | groups=1, 15 | norm_act=ABN, 16 | dropout=None): 17 | """Configurable identity-mapping residual block 18 | 19 | Parameters 20 | ---------- 21 | in_channels : int 22 | Number of input channels. 23 | channels : list of int 24 | Number of channels in the internal feature maps. Can either have two or three elements: if three construct 25 | a residual block with two `3 x 3` convolutions, otherwise construct a bottleneck block with `1 x 1`, then 26 | `3 x 3` then `1 x 1` convolutions. 27 | stride : int 28 | Stride of the first `3 x 3` convolution 29 | dilation : int 30 | Dilation to apply to the `3 x 3` convolutions. 31 | groups : int 32 | Number of convolution groups. This is used to create ResNeXt-style blocks and is only compatible with 33 | bottleneck blocks. 34 | norm_act : callable 35 | Function to create normalization / activation Module. 36 | dropout: callable 37 | Function to create Dropout Module. 38 | """ 39 | super(IdentityResidualBlock, self).__init__() 40 | 41 | # Check parameters for inconsistencies 42 | if len(channels) != 2 and len(channels) != 3: 43 | raise ValueError("channels must contain either two or three values") 44 | if len(channels) == 2 and groups != 1: 45 | raise ValueError("groups > 1 are only valid if len(channels) == 3") 46 | 47 | is_bottleneck = len(channels) == 3 48 | need_proj_conv = stride != 1 or in_channels != channels[-1] 49 | 50 | self.bn1 = norm_act(in_channels) 51 | if not is_bottleneck: 52 | layers = [ 53 | ("conv1", nn.Conv2d(in_channels, channels[0], 3, stride=stride, padding=dilation, bias=False, 54 | dilation=dilation)), 55 | ("bn2", norm_act(channels[0])), 56 | ("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False, 57 | dilation=dilation)) 58 | ] 59 | if dropout is not None: 60 | layers = layers[0:2] + [("dropout", dropout())] + layers[2:] 61 | else: 62 | layers = [ 63 | ("conv1", nn.Conv2d(in_channels, channels[0], 1, stride=stride, padding=0, bias=False)), 64 | ("bn2", norm_act(channels[0])), 65 | ("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False, 66 | groups=groups, dilation=dilation)), 67 | ("bn3", norm_act(channels[1])), 68 | ("conv3", nn.Conv2d(channels[1], channels[2], 1, stride=1, padding=0, bias=False)) 69 | ] 70 | if dropout is not None: 71 | layers = layers[0:4] + [("dropout", dropout())] + layers[4:] 72 | self.convs = nn.Sequential(OrderedDict(layers)) 73 | 74 | if need_proj_conv: 75 | self.proj_conv = nn.Conv2d(in_channels, channels[-1], 1, stride=stride, padding=0, bias=False) 76 | 77 | def forward(self, x): 78 | if hasattr(self, "proj_conv"): 79 | bn1 = self.bn1(x) 80 | shortcut = self.proj_conv(bn1) 81 | else: 82 | shortcut = x.clone() 83 | bn1 = self.bn1(x) 84 | 85 | out = self.convs(bn1) 86 | out.add_(shortcut) 87 | 88 | return out 89 | -------------------------------------------------------------------------------- /lib/numpy_losses.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def binary_crossentropy(y_true, y_pred): 5 | y_true = np.reshape(y_true, (-1, 1)) 6 | y_pred = np.reshape(y_pred, (-1, 1)) 7 | eps = 1e-7 8 | y_pred = np.clip(y_pred, eps, 1 - eps) 9 | loss = - np.sum(np.log(y_pred) * y_true) 10 | return np.mean(loss) 11 | 12 | 13 | def jaccard_coef(y_true, y_pred): 14 | y_true = np.reshape(y_true, (-1, 1)) 15 | y_pred = np.reshape(y_pred, (-1, 1)) 16 | eps = 1e-7 17 | intersection = np.sum(y_true * y_pred) 18 | union = np.sum(y_true) + np.sum(y_pred) + eps 19 | return intersection / (union - intersection) 20 | 21 | 22 | def jaccard_loss(y_true, y_pred): 23 | return 1. - jaccard_coef(y_true, y_pred) 24 | 25 | 26 | def smooth_jaccard_loss(y_true, y_pred): 27 | """Jaccard distance for semantic segmentation, also known as the intersection-over-union loss. 28 | This loss is useful when you have unbalanced numbers of pixels within an image 29 | because it gives all classes equal weight. However, it is not the defacto 30 | standard for image segmentation. 31 | For example, assume you are trying to predict if each pixel is cat, dog, or background. 32 | You have 80% background pixels, 10% dog, and 10% cat. If the model predicts 100% background 33 | should it be be 80% right (as with categorical cross entropy) or 30% (with this loss)? 34 | The loss has been modified to have a smooth gradient as it converges on zero. 35 | This has been shifted so it converges on 0 and is smoothed to avoid exploding 36 | or disappearing gradient. 37 | Jaccard = (|X & Y|)/ (|X|+ |Y| - |X & Y|) 38 | = sum(|A*B|)/(sum(|A|)+sum(|B|)-sum(|A*B|)) 39 | # References 40 | Csurka, Gabriela & Larlus, Diane & Perronnin, Florent. (2013). 41 | What is a good evaluation measure for semantic segmentation?. 42 | IEEE Trans. Pattern Anal. Mach. Intell.. 26. . 10.5244/C.27.32. 43 | https://en.wikipedia.org/wiki/Jaccard_index 44 | """ 45 | y_true = np.reshape(y_true, (-1, 1)) 46 | y_pred = np.reshape(y_pred, (-1, 1)) 47 | smooth = 100 48 | 49 | intersection = np.sum(y_true * y_pred) 50 | union = np.sum(y_true) + np.sum(y_pred) 51 | jac = (intersection + smooth) / (union - intersection + smooth) 52 | return (1 - jac) * smooth 53 | 54 | def bce_jaccard_loss(y_true, y_pred): 55 | return binary_crossentropy(y_true, y_pred) + jaccard_loss(y_true, y_pred) 56 | 57 | 58 | def bce_smooth_jaccard_loss(y_true, y_pred): 59 | return binary_crossentropy(y_true, y_pred) + smooth_jaccard_loss(y_true, y_pred) 60 | -------------------------------------------------------------------------------- /lib/tiles.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import math 4 | 5 | 6 | def compute_patch_weight_loss(width, height): 7 | xc = width * 0.5 8 | yc = height * 0.5 9 | xl = 0 10 | xr = width 11 | yb = 0 12 | yt = height 13 | Dc = np.zeros((width, height)) 14 | De = np.zeros((width, height)) 15 | 16 | for i in range(width): 17 | for j in range(height): 18 | Dc[i, j] = np.sqrt(np.square(i - xc+0.5) + np.square(j - yc+0.5)) 19 | De_l = np.sqrt(np.square(i - xl+0.5) + np.square(j - j+0.5)) 20 | De_r = np.sqrt(np.square(i - xr+0.5) + np.square(j - j+0.5)) 21 | De_b = np.sqrt(np.square(i - i+0.5) + np.square(j - yb+0.5)) 22 | De_t = np.sqrt(np.square(i - i+0.5) + np.square(j - yt+0.5)) 23 | De[i, j] = np.min([De_l, De_r, De_b, De_t]) 24 | 25 | alpha = (width * height) / np.sum(np.divide(De, np.add(Dc, De))) 26 | W = alpha * np.divide(De, np.add(Dc, De)) 27 | return W, Dc, De 28 | 29 | 30 | class ImageSlicer: 31 | """ 32 | Helper class to slice image into tiles and merge them back with fusion 33 | """ 34 | 35 | def __init__(self, image_shape, tile_size, tile_step=0, image_margin=0, weight='mean'): 36 | """ 37 | 38 | :param image_shape: Shape of the source image 39 | :param tile_size: Tile size 40 | :param tile_step: Step in pixels between tiles 41 | :param image_margin: 42 | :param weight: Fusion algorithm. 'mean' - avegaing 43 | """ 44 | self.image_height = image_shape[0] 45 | self.image_width = image_shape[1] 46 | self.tile_size = tile_size 47 | self.tile_step = tile_step 48 | 49 | weights = { 50 | 'mean': self._mean, 51 | 'pyramid': self._pyramid 52 | } 53 | 54 | self.compute_weight = weights[weight] 55 | 56 | if tile_step < 1 or tile_step > tile_size: 57 | raise ValueError() 58 | 59 | overlap = tile_size - tile_step 60 | 61 | self.margin_left = 0 62 | self.margin_right = 0 63 | self.margin_top = 0 64 | self.margin_bottom = 0 65 | 66 | if image_margin == 0: 67 | # In case margin is not set, we compute it manually 68 | 69 | nw = max(1, math.ceil((self.image_width - overlap) / tile_step)) 70 | nh = max(1, math.ceil((self.image_height - overlap) / tile_step)) 71 | 72 | extra_w = self.tile_step * nw - (self.image_width - overlap) 73 | extra_h = self.tile_step * nh - (self.image_height - overlap) 74 | 75 | self.margin_left = extra_w // 2 76 | self.margin_right = extra_w - self.margin_left 77 | self.margin_top = extra_h // 2 78 | self.margin_bottom = extra_h - self.margin_top 79 | 80 | else: 81 | if (self.image_width - overlap + 2 * image_margin) % tile_step != 0: 82 | raise ValueError() 83 | 84 | if (self.image_height - overlap + 2 * image_margin) % tile_step != 0: 85 | raise ValueError() 86 | 87 | self.margin_left = image_margin 88 | self.margin_right = image_margin 89 | self.margin_top = image_margin 90 | self.margin_bottom = image_margin 91 | 92 | self.crops = [] 93 | 94 | for y in range(0, self.image_height + self.margin_top + self.margin_bottom - tile_size + 1, tile_step): 95 | for x in range(0, self.image_width + self.margin_left + self.margin_right - tile_size + 1, tile_step): 96 | self.crops.append((x, y, tile_size, tile_size)) 97 | 98 | def split(self, image, borderType=cv2.BORDER_REFLECT101, value=0): 99 | assert image.shape[0] == self.image_height 100 | assert image.shape[1] == self.image_width 101 | 102 | orig_shape_len = len(image.shape) 103 | image = cv2.copyMakeBorder(image, self.margin_top, self.margin_bottom, self.margin_left, self.margin_right, borderType=borderType, value=value) 104 | 105 | # This check recovers possible lack of last dummy dimension for single-channel images 106 | if len(image.shape) != orig_shape_len: 107 | image = np.expand_dims(image,axis=-1) 108 | 109 | tiles = [] 110 | for x, y, tile_width, tile_height in self.crops: 111 | tile = image[y:y + tile_height, x:x + tile_width].copy() 112 | assert tile.shape[0] == self.tile_size 113 | assert tile.shape[1] == self.tile_size 114 | 115 | tiles.append(tile) 116 | 117 | return tiles 118 | 119 | def cut_patch(self, image, slice_index, borderType=cv2.BORDER_REFLECT101, value=0): 120 | assert image.shape[0] == self.image_height 121 | assert image.shape[1] == self.image_width 122 | 123 | orig_shape_len = len(image.shape) 124 | image = cv2.copyMakeBorder(image, self.margin_top, self.margin_bottom, self.margin_left, self.margin_right, borderType=borderType, value=value) 125 | 126 | # This check recovers possible lack of last dummy dimension for single-channel images 127 | if len(image.shape) != orig_shape_len: 128 | image = np.expand_dims(image,axis=-1) 129 | 130 | x, y, tile_width, tile_height = self.crops[slice_index] 131 | 132 | tile = image[y:y + tile_height, x:x + tile_width].copy() 133 | assert tile.shape[0] == self.tile_size 134 | assert tile.shape[1] == self.tile_size 135 | return tile 136 | 137 | def merge(self, tiles, dtype=np.float32): 138 | if len(tiles) != len(self.crops): 139 | raise ValueError 140 | 141 | channels = 1 if len(tiles[0].shape) == 2 else tiles[0].shape[2] 142 | target_shape = self.image_height + self.margin_bottom + self.margin_top, self.image_width + self.margin_right + self.margin_left, channels 143 | 144 | image = np.zeros(target_shape, dtype=np.float64) 145 | norm_mask = np.zeros(target_shape, dtype=np.float64) 146 | 147 | weight = self.compute_weight(self.tile_size) 148 | w = np.dstack([weight] * channels) 149 | 150 | for tile, (x, y, tile_width, tile_height) in zip(tiles, self.crops): 151 | # print(x, y, tile_width, tile_height, image.shape) 152 | image[y:y + tile_height, x:x + tile_width] += tile * w 153 | norm_mask[y:y + tile_height, x:x + tile_width] += w 154 | 155 | # print(norm_mask.min(), norm_mask.max()) 156 | norm_mask = np.clip(norm_mask, a_min=np.finfo(norm_mask.dtype).eps, a_max=None) 157 | normalized = np.divide(image, norm_mask).astype(dtype) 158 | crop = normalized[self.margin_top:self.image_height + self.margin_top, self.margin_left:self.image_width + self.margin_left] 159 | assert crop.shape[0] == self.image_height 160 | assert crop.shape[1] == self.image_width 161 | return crop 162 | 163 | def _mean(self, tile_size): 164 | return np.ones((tile_size, tile_size), dtype=np.float32) 165 | 166 | def _pyramid(self, tile_size): 167 | w, _, _ = compute_patch_weight_loss(tile_size, tile_size) 168 | return w 169 | -------------------------------------------------------------------------------- /lib/train_utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | import torch 5 | from sklearn.metrics import confusion_matrix 6 | 7 | from torch import nn 8 | from torch.optim import Optimizer, SGD 9 | import numpy as np 10 | from torch.optim.lr_scheduler import LambdaLR 11 | from tqdm import tqdm 12 | 13 | 14 | class AverageMeter(object): 15 | """Computes and stores the average and current value""" 16 | 17 | def __init__(self): 18 | self.reset() 19 | 20 | def reset(self): 21 | self.val = 0 22 | self.avg = 0 23 | self.sum = 0 24 | self.count = 0 25 | 26 | def update(self, val, n=1): 27 | self.val = val 28 | self.sum += val * n 29 | self.count += n 30 | self.avg = self.sum / self.count 31 | 32 | def __str__(self): 33 | return '%.3f' % self.avg 34 | 35 | 36 | def find_optimal_lr(model: nn.Module, criterion, optimizer: Optimizer, dataloader): 37 | min_lr = 1e-8 38 | lrs = [] 39 | lr = min_lr 40 | for i in range(30): 41 | lrs.append(lr) 42 | lr *= 2. 43 | 44 | lrs = np.array(lrs, dtype=np.float32) 45 | print(lrs) 46 | 47 | loss = np.zeros_like(lrs) 48 | 49 | scheduler = LambdaLR(optimizer, lr_lambda=lambda x: lrs[x]) 50 | 51 | with torch.set_grad_enabled(True): 52 | model.train() 53 | dataiter = iter(dataloader) 54 | for i, lr in enumerate(tqdm(lrs, total=len(lrs))): 55 | scheduler.step() 56 | x, y = next(dataiter) 57 | x, y = x.cuda(non_blocking=True), y.cuda(non_blocking=True) 58 | 59 | y_pred = model(x) 60 | batch_loss = criterion(y_pred, y) 61 | 62 | batch_size = x.size(0) 63 | (batch_size * batch_loss).backward() 64 | 65 | optimizer.step() 66 | 67 | loss[i] = batch_loss.cpu().item() 68 | 69 | return lrs, loss 70 | 71 | 72 | def auto_file(filename, where='.') -> str: 73 | """ 74 | Helper function to find a unique filename in subdirectory without specifying fill path to it 75 | :param filename: 76 | :return: 77 | """ 78 | prob = os.path.join(where, filename) 79 | if os.path.exists(prob) and os.path.isfile(prob): 80 | return filename 81 | 82 | files = list(glob.iglob(os.path.join(where, '**', filename), recursive=True)) 83 | if len(files) == 0: 84 | raise FileNotFoundError('Given file could not be found with recursive search:' + filename) 85 | 86 | if len(files) > 1: 87 | raise FileNotFoundError('More than one file matches given filename. Please specify it explicitly' + filename) 88 | 89 | return files[0] 90 | 91 | 92 | class PRCurveMeter(object): 93 | 94 | def __init__(self, n_thresholds=127): 95 | self.n_thresholds = n_thresholds 96 | self.k = 2 97 | self.thresholds = np.arange(0., 1., 1. / n_thresholds, dtype=np.float32) 98 | self.tp = np.zeros(n_thresholds, dtype=np.uint64) 99 | self.tn = np.zeros(n_thresholds, dtype=np.uint64) 100 | self.fp = np.zeros(n_thresholds, dtype=np.uint64) 101 | self.fn = np.zeros(n_thresholds, dtype=np.uint64) 102 | 103 | def reset(self): 104 | self.tp.fill(0) 105 | self.tn.fill(0) 106 | self.fp.fill(0) 107 | self.fn.fill(0) 108 | 109 | def update(self, y_pred, y_true): 110 | y_pred = torch.sigmoid(y_pred.detach()).cpu().numpy().reshape(-1) 111 | y_true = y_true.cpu().numpy().astype(np.int32).reshape(-1) 112 | 113 | for i, value in enumerate(self.thresholds): 114 | y_pred_i = (y_pred > value).astype(np.int32) 115 | 116 | # hack for bincounting 2 arrays together 117 | x = y_pred_i + self.k * y_true 118 | bincount_2d = np.bincount(x, minlength=self.k ** 2) 119 | assert bincount_2d.size == self.k ** 2 120 | conf = bincount_2d.reshape((self.k, self.k)) 121 | 122 | self.tp[i] += conf[1, 1] 123 | self.tn[i] += conf[0, 0] 124 | self.fp[i] += conf[0, 1] 125 | self.fn[i] += conf[1, 0] 126 | 127 | def precision(self): 128 | return np.divide(self.tp, self.tp + self.fp) 129 | 130 | def recall(self): 131 | return np.divide(self.tp, self.tp + self.fn) 132 | -------------------------------------------------------------------------------- /loss_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BloodAxe/segmentation-networks-benchmark/2e3feb560102230be9369ab442b4a59cc86dff61/loss_plot.png -------------------------------------------------------------------------------- /plot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import matplotlib.pyplot as plt 4 | import seaborn as sns 5 | import numpy as np 6 | 7 | sns.set() 8 | 9 | 10 | def plot_train_history(names, loss, val_loss, title=None, legend_loc='upper right'): 11 | fig = plt.figure(figsize=(15, 8)) 12 | 13 | if title is not None: 14 | fig.suptitle(title) 15 | 16 | ax1, ax2 = fig.subplots(1, 2) 17 | 18 | # ax1.set_color_cycle([cm(1.*i/NUM_COLORS) for i in range(NUM_COLORS)]) 19 | for m in loss: 20 | ax1.plot(m) 21 | 22 | ax1.set_ylabel('Value') 23 | ax1.set_xlabel('Epoch') 24 | ax1.set_title('Train') 25 | ax1.legend(names, loc=legend_loc) 26 | 27 | for m in val_loss: 28 | ax2.plot(m) 29 | 30 | ax2.set_ylabel('Value') 31 | ax2.set_xlabel('Epoch') 32 | ax2.set_title('Test') 33 | ax2.legend(names, loc=legend_loc) 34 | 35 | plt.show() 36 | 37 | 38 | def plot_experiment_train_history(name, loss, val_loss, metric, val_metric): 39 | fig = plt.figure(figsize=(15, 8)) 40 | 41 | fig.suptitle(name) 42 | 43 | ax1, ax2 = fig.subplots(1, 2) 44 | 45 | ax1.plot(loss) 46 | ax1.plot(val_loss) 47 | 48 | ax1.set_ylabel('Value') 49 | ax1.set_xlabel('Epoch') 50 | ax1.set_title('Loss') 51 | ax1.legend(['Train', 'Test'], loc='upper right') 52 | 53 | ax2.plot(metric) 54 | ax2.plot(val_metric) 55 | 56 | ax2.set_ylabel('Value') 57 | ax2.set_xlabel('Epoch') 58 | ax2.set_title('Score') 59 | ax2.legend(['Train', 'Test'], loc='upper left') 60 | 61 | plt.show() 62 | 63 | 64 | def main(): 65 | experiments = { 66 | 'ZF_UNET': pd.read_csv(os.path.join('experiments', 'dsb2018', 'bce', 'torch_dsb2018_zf_unet_224_rgb_bce', 'torch_dsb2018_zf_unet_224_rgb_bce.csv')), 67 | 'Linknet (Resnet34)': pd.read_csv(os.path.join('experiments', 'dsb2018', 'bce', 'torch_dsb2018_linknet34_224_rgb_bce', 'torch_dsb2018_linknet34_224_rgb_bce.csv')), 68 | 'Unet (VGG16)': pd.read_csv(os.path.join('experiments', 'dsb2018', 'bce', 'torch_dsb2018_unet16_224_rgb_bce', 'torch_dsb2018_unet16_224_rgb_bce.csv')), 69 | 'Unet (VGG11)': pd.read_csv(os.path.join('experiments', 'dsb2018', 'bce', 'torch_dsb2018_unet11_224_rgb_bce', 'torch_dsb2018_unet11_224_rgb_bce.csv')), 70 | 'GCN': pd.read_csv(os.path.join('experiments', 'dsb2018', 'bce', 'torch_dsb2018_gcn_224_rgb_bce', 'torch_dsb2018_gcn_224_rgb_bce.csv')), 71 | } 72 | 73 | names = [] 74 | loss = [] 75 | val_loss = [] 76 | metric = [] 77 | val_metric = [] 78 | for key, item in experiments.items(): 79 | names.append(key) 80 | loss.append(item[['loss']]) 81 | val_loss.append(item[['val_loss']]) 82 | 83 | metric.append(item[['JaccardScore']]) 84 | val_metric.append(item[['val_JaccardScore']]) 85 | 86 | plot_experiment_train_history(key, item[['loss']], item[['val_loss']], item[['JaccardScore']], item[['val_JaccardScore']]) 87 | 88 | plot_train_history(names, loss, val_loss, 'DSB2018, BCE loss', legend_loc='upper right') 89 | plot_train_history(names, metric, val_metric, 'DSB2018, Jaccard score', legend_loc='lower right') 90 | 91 | 92 | if __name__ == '__main__': 93 | main() 94 | -------------------------------------------------------------------------------- /plot_loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import seaborn as sns 4 | from lib import numpy_losses 5 | 6 | sns.set() 7 | 8 | 9 | def main(): 10 | steps = 10 11 | loss_functions = [numpy_losses.binary_crossentropy, 12 | numpy_losses.jaccard_loss, 13 | numpy_losses.smooth_jaccard_loss, 14 | numpy_losses.bce_smooth_jaccard_loss] 15 | 16 | for loss_fn in loss_functions: 17 | y_true = np.ones((224, 224), dtype=np.float32) 18 | y_pred = y_true.copy() 19 | 20 | losses = [loss_fn(y_true, y_pred)] 21 | 22 | for pred_val in range(0, 1000): 23 | y_pred[...] = 1 - pred_val / 1000 24 | loss_val = loss_fn(y_true, y_pred) 25 | losses.append(loss_val) 26 | 27 | # 28 | # for row in range(0, 224): 29 | # for col in range(0, 224): 30 | # y_pred[row, col] = 0.5 31 | # loss_val = loss_fn(y_true, y_pred) 32 | # losses.append(loss_val) 33 | 34 | plt.figure() 35 | plt.title(loss_fn.__name__) 36 | plt.plot(losses) 37 | plt.ylabel('Loss value') 38 | plt.xlabel('Wrong pixels') 39 | plt.tight_layout() 40 | plt.show() 41 | 42 | 43 | if __name__ == "__main__": 44 | main() 45 | -------------------------------------------------------------------------------- /results/dsb2018_bce_all.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BloodAxe/segmentation-networks-benchmark/2e3feb560102230be9369ab442b4a59cc86dff61/results/dsb2018_bce_all.png -------------------------------------------------------------------------------- /run_all.cmd: -------------------------------------------------------------------------------- 1 | REM python compare.py -m zf_unet -d dsb2018 -e 100 -p 512 -b 4 -dd e:/datasets 2 | REM python compare.py -m selunet -d dsb2018 -e 100 -p 512 -b 4 -dd e:/datasets 3 | REM python compare.py -m linknet -d dsb2018 -e 100 -p 512 -b 4 -dd e:/datasets 4 | REM python compare.py -m dilated_unet -d dsb2018 -e 100 -p 512 -b 4 -dd e:/datasets 5 | python compare.py -m tiramisu -d dsb2018 -e 100 -p 512 -b 1 -dd e:/datasets 6 | python compare.py -m dilated_resnet -d dsb2018 -e 100 -p 512 -b 4 -dd e:/datasets 7 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from torch.utils.data import DataLoader 4 | 5 | from lib.train_utils import find_optimal_lr, auto_file 6 | import torch_train as TT 7 | 8 | if __name__ == '__main__': 9 | dd = 'e:/datasets/inria/train' 10 | 11 | model = TT.get_model('linknet34', patch_size=512, num_channels=3).cuda() 12 | loss = TT.get_loss('bce').cuda() 13 | optimizer = TT.get_optimizer('sgd', model.parameters(), 1e-4) 14 | trainset, validset, num_classes = TT.get_dataset('inria', dd, grayscale=False, patch_size=512) 15 | 16 | TT.restore_snapshot(model, None, auto_file('linknet34_checkpoint.pth')) 17 | trainloader = DataLoader(trainset, batch_size=8, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) 18 | 19 | lr, loss = find_optimal_lr(model, loss, optimizer, trainloader) 20 | 21 | loss = np.convolve(loss, np.ones((4,)) / 4, mode='same') 22 | 23 | fig, ax = plt.subplots(figsize=(16, 12)) 24 | ax.plot(lr, loss) 25 | ax.set(xlabel='lr', ylabel='loss', title='LR') 26 | ax.set_xscale("log", nonposx='clip') 27 | ax.grid() 28 | fig.show() 29 | 30 | plt.savefig('loss_plot.png') 31 | print(lr, loss) 32 | print('A') 33 | -------------------------------------------------------------------------------- /torch_train_ab.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path 3 | import sys 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | from torch import nn 9 | from tensorboardX import SummaryWriter 10 | from torch.backends import cudnn 11 | from torch.optim import Optimizer 12 | from torch.utils.data import DataLoader 13 | from torchvision.utils import make_grid 14 | from tqdm import tqdm 15 | 16 | from lib.common import count_parameters 17 | from lib.datasets.Inria import INRIA 18 | from lib.datasets.dsb2018 import DSB2018Sliced 19 | from lib.losses import JaccardLoss, FocalLossBinary, BCEWithLogitsLossAndSmoothJaccard, BCEWithSigmoidLoss 20 | from lib.metrics import JaccardScore, PixelAccuracy 21 | from lib.models import linknet, unet16, unet11 22 | from lib.models.afterburner import Afterburner 23 | from lib.models.duc_hdc import ResNetDUCHDC, ResNetDUC 24 | from lib.models.gcn152 import GCN152, GCN34 25 | from lib.models.psp_net import PSPNet 26 | from lib.models.tiramisu import FCDenseNet67 27 | from lib.models.unet import UNet 28 | from lib.models.zf_unet import ZF_UNET 29 | from lib.train_utils import AverageMeter, auto_file 30 | import torch_train as TT 31 | 32 | tqdm.monitor_interval = 0 # Workaround for https://github.com/tqdm/tqdm/issues/481 33 | 34 | 35 | def train(model, loss, optimizer, dataloader, epoch: int, metrics={}, summary_writer=None): 36 | losses = AverageMeter() 37 | 38 | train_scores = {} 39 | for key, _ in metrics.items(): 40 | train_scores[key] = AverageMeter() 41 | 42 | with torch.set_grad_enabled(True): 43 | model.train() 44 | n_batches = len(dataloader) 45 | with tqdm(total=n_batches) as tq: 46 | tq.set_description('Train') 47 | x = None 48 | y = None 49 | outputs = None 50 | batch_loss = None 51 | 52 | for batch_index, (x, y) in enumerate(dataloader): 53 | x, y = x.cuda(non_blocking=True), y.cuda(non_blocking=True) 54 | 55 | # zero the parameter gradients 56 | optimizer.zero_grad() 57 | 58 | # forward + backward + optimize 59 | outputs = model(x) 60 | 61 | batch_loss = loss(outputs, y) 62 | 63 | batch_size = x.size(0) 64 | (batch_size * batch_loss).backward() 65 | 66 | optimizer.step() 67 | 68 | # Batch train end 69 | # Log train progress 70 | 71 | batch_loss_val = batch_loss.cpu().item() 72 | if summary_writer is not None: 73 | summary_writer.add_scalar('train/batch/loss', batch_loss_val, epoch * n_batches + batch_index) 74 | 75 | # Plot gradient absmax to see if there are any gradient explosions 76 | grad_max = 0 77 | for name, param in model.named_parameters(): 78 | if param.grad is not None: 79 | grad_max = max(grad_max, param.grad.abs().max().cpu().item()) 80 | summary_writer.add_scalar('train/grad/global_max', grad_max, epoch * n_batches + batch_index) 81 | 82 | losses.update(batch_loss_val) 83 | 84 | for key, metric in metrics.items(): 85 | score = metric(outputs, y).cpu().item() 86 | train_scores[key].update(score) 87 | 88 | if summary_writer is not None: 89 | summary_writer.add_scalar('train/batch/' + key, score, epoch * n_batches + batch_index) 90 | 91 | tq.set_postfix(loss='{:.3f}'.format(losses.avg), **train_scores) 92 | tq.update() 93 | 94 | # End of train epoch 95 | if summary_writer is not None: 96 | summary_writer.add_image('train/image', make_grid(x.cpu(), normalize=True), epoch) 97 | summary_writer.add_image('train/y_true', make_grid(y.cpu(), normalize=True), epoch) 98 | summary_writer.add_image('train/y_pred', make_grid(outputs.sigmoid().cpu(), normalize=True), epoch) 99 | summary_writer.add_scalar('train/epoch/loss', losses.avg, epoch) 100 | for key, value in train_scores.items(): 101 | summary_writer.add_scalar('train/epoch/' + key, value.avg, epoch) 102 | 103 | # Plot histogram of parameters after each epoch 104 | for name, param in model.named_parameters(): 105 | if param.grad is not None: 106 | # Plot weighs 107 | param_data = param.data.cpu().numpy() 108 | summary_writer.add_histogram('model/' + name, param_data, epoch, bins='doane') 109 | 110 | # for m in model.modules(): 111 | # if isinstance(m, nn.Conv2d): 112 | # weights = m.weights.data.numpy() 113 | 114 | del x, y, outputs, batch_loss 115 | 116 | return losses, train_scores 117 | 118 | 119 | def validate(model, loss, dataloader, epoch: int, metrics=dict(), summary_writer: SummaryWriter = None): 120 | losses = AverageMeter() 121 | 122 | valid_scores = {} 123 | for key, _ in metrics.items(): 124 | valid_scores[key] = AverageMeter() 125 | 126 | with torch.set_grad_enabled(False): 127 | model.eval() 128 | 129 | n_batches = len(dataloader) 130 | with tqdm(total=len(dataloader)) as tq: 131 | tq.set_description('Validation') 132 | 133 | x = None 134 | y = None 135 | outputs = None 136 | batch_loss = None 137 | 138 | for batch_index, (x, y) in enumerate(dataloader): 139 | x, y = x.cuda(non_blocking=True), y.cuda(non_blocking=True) 140 | 141 | # forward + backward + optimize 142 | outputs = model(x) 143 | batch_loss = loss(outputs, y) 144 | 145 | # Log train progress 146 | 147 | batch_loss_val = batch_loss.cpu().item() 148 | if summary_writer is not None: 149 | summary_writer.add_scalar('val/batch/loss', batch_loss_val, epoch * n_batches + batch_index) 150 | 151 | losses.update(batch_loss_val) 152 | 153 | for key, metric in metrics.items(): 154 | score = metric(outputs, y).cpu().item() 155 | valid_scores[key].update(score) 156 | 157 | if summary_writer is not None: 158 | summary_writer.add_scalar('val/batch/' + key, score, epoch * n_batches + batch_index) 159 | 160 | tq.set_postfix(loss='{:.3f}'.format(losses.avg), **valid_scores) 161 | tq.update() 162 | 163 | if summary_writer is not None: 164 | summary_writer.add_image('val/image', make_grid(x.cpu(), normalize=True), epoch) 165 | summary_writer.add_image('val/y_true', make_grid(y.cpu(), normalize=True), epoch) 166 | summary_writer.add_image('val/y_pred', make_grid(outputs.sigmoid().cpu(), normalize=True), epoch) 167 | summary_writer.add_scalar('val/epoch/loss', losses.avg, epoch) 168 | for key, value in valid_scores.items(): 169 | summary_writer.add_scalar('val/epoch/' + key, value.avg, epoch) 170 | 171 | del x, y, outputs, batch_loss 172 | 173 | return losses, valid_scores 174 | 175 | 176 | def save_snapshot(model: nn.Module, optimizer: Optimizer, loss: float, epoch: int, train_history: pd.DataFrame, snapshot_file: str): 177 | torch.save({ 178 | 'model': model.state_dict(), 179 | 'optimizer': optimizer.state_dict(), 180 | 'epoch': epoch, 181 | 'loss': loss, 182 | 'train_history': train_history.to_dict(), 183 | 'args': ' '.join(sys.argv[1:]) 184 | }, snapshot_file) 185 | 186 | 187 | def restore_snapshot(model: nn.Module, optimizer: Optimizer, snapshot_file: str): 188 | checkpoint = torch.load(snapshot_file) 189 | start_epoch = checkpoint['epoch'] + 1 190 | best_loss = checkpoint['loss'] 191 | model.load_state_dict(checkpoint['model']) 192 | 193 | if optimizer is not None: 194 | optimizer.load_state_dict(checkpoint['optimizer']) 195 | 196 | train_history = pd.DataFrame.from_dict(checkpoint['train_history']) 197 | 198 | return start_epoch, train_history, best_loss 199 | 200 | 201 | def main(): 202 | parser = argparse.ArgumentParser() 203 | 204 | parser.add_argument('-g', '--grayscale', action='store_true', help='Whether to use grayscale image instead of RGB') 205 | parser.add_argument('-m', '--model', required=True, type=str, help='Name of the model') 206 | parser.add_argument('-p', '--patch-size', type=int, default=224) 207 | parser.add_argument('-b', '--batch-size', type=int, default=1, help='Batch Size during training, e.g. -b 64') 208 | parser.add_argument('-lr', '--learning-rate', type=float, default=1e-3, help='Initial learning rate') 209 | parser.add_argument('-l', '--loss', type=str, default='bce', help='Target loss') 210 | parser.add_argument('-o', '--optimizer', default='SGD', help='Name of the optimizer') 211 | parser.add_argument('-e', '--epochs', type=int, default=100, help='Epoch to run') 212 | parser.add_argument('-d', '--dataset', type=str, help='Name of the dataset to use for training.') 213 | parser.add_argument('-dd', '--data-dir', type=str, default='data', help='Root directory where datasets are located.') 214 | parser.add_argument('-s', '--steps', type=int, default=128, help='Steps per epoch') 215 | parser.add_argument('-x', '--experiment', type=str, help='Name of the experiment') 216 | parser.add_argument('-w', '--workers', default=0, type=int, help='Num workers') 217 | parser.add_argument('-r', '--resume', action='store_true') 218 | parser.add_argument('-mem', '--memory', action='store_true') 219 | 220 | args = parser.parse_args() 221 | cudnn.benchmark = True 222 | 223 | if args.experiment is None: 224 | args.experiment = 'torch_%s_%s_afterburn_%d_%s_%s' % (args.dataset, args.model, args.patch_size, 'gray' if args.grayscale else 'rgb', args.loss) 225 | 226 | experiment_dir = os.path.join('experiments', args.dataset, args.loss, args.experiment) 227 | os.makedirs(experiment_dir, exist_ok=True) 228 | 229 | writer = SummaryWriter(comment=args.experiment) 230 | 231 | with open(os.path.join(experiment_dir, 'arguments.txt'), 'w') as f: 232 | f.write(' '.join(sys.argv[1:])) 233 | 234 | trainset, validset, num_classes = TT.get_dataset(args.dataset, args.data_dir, grayscale=args.grayscale, patch_size=args.patch_size, keep_in_mem=args.memory) 235 | print('Train set size', len(trainset)) 236 | print('Valid set size', len(validset)) 237 | 238 | trainloader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) 239 | validloader = DataLoader(validset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=True) 240 | 241 | head_model = TT.get_model(args.model, patch_size=args.patch_size, num_channels=1 if args.grayscale else 3).cuda() 242 | TT.restore_snapshot(head_model, None, auto_file('linknet34_checkpoint.pth')) 243 | 244 | # Freeze model training 245 | for param in head_model.parameters(): 246 | param.requires_grad = False 247 | 248 | afterburner = Afterburner() 249 | model = nn.Sequential(head_model, nn.Sigmoid(), afterburner).cuda() 250 | optimizer = TT.get_optimizer(args.optimizer, afterburner.parameters(), args.learning_rate) 251 | 252 | loss = TT.get_loss(args.loss).cuda() 253 | metrics = {'iou': JaccardScore().cuda(), 'accuracy': PixelAccuracy().cuda()} 254 | 255 | start_epoch = 0 256 | best_loss = np.inf 257 | train_history = pd.DataFrame() 258 | 259 | checkpoint_filename = os.path.join(experiment_dir, f'{args.model}_checkpoint.pth') 260 | if args.resume: 261 | start_epoch, train_history, best_loss = restore_snapshot(model, optimizer, checkpoint_filename) 262 | print('Resuming training from epoch', start_epoch, ' and loss', best_loss) 263 | print(train_history) 264 | 265 | print('Head :', count_parameters(head_model)) 266 | print('Afterburner:', count_parameters(afterburner)) 267 | 268 | for epoch in range(start_epoch, args.epochs): 269 | train_loss, train_scores = train(model, loss, optimizer, trainloader, epoch, metrics, summary_writer=writer) 270 | valid_loss, valid_scores = validate(model, loss, validloader, epoch, metrics, summary_writer=writer) 271 | 272 | summary = { 273 | 'epoch': [epoch], 274 | 'loss': [train_loss.avg], 275 | 'val_loss': [valid_loss.avg] 276 | } 277 | 278 | for key, value in train_scores.items(): 279 | summary[key] = [value.avg] 280 | 281 | for key, value in valid_scores.items(): 282 | summary['val_' + key] = [value.avg] 283 | 284 | train_history = train_history.append(pd.DataFrame.from_dict(summary), ignore_index=True) 285 | 286 | print(epoch, summary) 287 | 288 | if valid_loss.avg < best_loss: 289 | save_snapshot(model, optimizer, valid_loss.avg, epoch, train_history, checkpoint_filename) 290 | best_loss = valid_loss.avg 291 | print('Checkpoint saved', epoch, best_loss) 292 | 293 | print('Training is finished...') 294 | 295 | train_history.to_csv(os.path.join(experiment_dir, args.experiment + '.csv'), 296 | index=False, 297 | mode='a' if args.resume else 'w', 298 | header=not args.resume) 299 | 300 | 301 | if __name__ == '__main__': 302 | main() 303 | -------------------------------------------------------------------------------- /torch_train_reg.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path 3 | import sys 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | from torch import nn 9 | from tensorboardX import SummaryWriter 10 | from torch.backends import cudnn 11 | from torch.optim import Optimizer 12 | from torch.utils.data import DataLoader 13 | from torchvision.utils import make_grid 14 | from tqdm import tqdm 15 | 16 | from lib.datasets.Inria import INRIA 17 | from lib.datasets.dsb2018 import DSB2018Sliced 18 | from lib.datasets.shapes import SHAPES 19 | from lib.losses import JaccardLoss, FocalLossBinary, BCEWithLogitsLossAndSmoothJaccard, BCEWithSigmoidLoss 20 | from lib.metrics import JaccardScore, PixelAccuracy 21 | from lib.models import linknet, unet16, unet11 22 | from lib.models.dilated_linknet import DilatedLinkNet34 23 | from lib.models.duc_hdc import ResNetDUCHDC, ResNetDUC 24 | from lib.models.gcn152 import GCN152, GCN34 25 | from lib.models.linknext import LinkNext 26 | from lib.models.psp_net import PSPNet 27 | from lib.models.tiramisu import FCDenseNet67 28 | from lib.models.unet import UNet 29 | from lib.models.unet_abn import UNetABN 30 | from lib.models.zf_unet import ZF_UNET 31 | from lib.train_utils import AverageMeter, PRCurveMeter 32 | from lib.common import count_parameters 33 | from torch.nn.modules.loss import _Loss 34 | 35 | from torch_train import get_model, get_loss, get_optimizer, get_dataset 36 | 37 | tqdm.monitor_interval = 0 # Workaround for https://github.com/tqdm/tqdm/issues/481 38 | 39 | 40 | class Conv2dRegularization(_Loss): 41 | def __init__(self, l1_factor=0.0005, l2_factor=0.0005): 42 | super(Conv2dRegularization, self).__init__() 43 | self.l1_factor = l1_factor 44 | self.l2_factor = l2_factor 45 | self.l1_crit = nn.L1Loss(size_average=False) 46 | self.l2_crit = nn.MSELoss(size_average=False) 47 | 48 | def forward(self, model): 49 | reg_loss_l1 = 0 50 | reg_loss_l2 = 0 51 | 52 | for module in model.modules(): 53 | if isinstance(module, nn.Conv2d): 54 | if module.weight.requires_grad: 55 | # We apply L1 norml to weights in order to make kernel sparse 56 | reg_loss_l1 += self.l1_crit(module.weight, target=torch.zeros_like(module.weight)) 57 | 58 | # We apply L2 norm to bias in order to make them zero-mean 59 | if module.bias is not None: 60 | reg_loss_l2 += self.l2_crit(module.bias, target=torch.zeros_like(module.bias)) 61 | 62 | return self.l1_factor * reg_loss_l1, self.l2_factor * reg_loss_l2 63 | 64 | 65 | def train(model, loss, optimizer, dataloader, epoch: int, metrics={}, summary_writer=None): 66 | losses = AverageMeter() 67 | 68 | train_scores = {} 69 | for key, _ in metrics.items(): 70 | train_scores[key] = AverageMeter() 71 | 72 | conv2d_reg = Conv2dRegularization().cuda() 73 | 74 | with torch.set_grad_enabled(True): 75 | model.train() 76 | n_batches = len(dataloader) 77 | with tqdm(total=n_batches) as tq: 78 | tq.set_description('Train') 79 | x = None 80 | y = None 81 | outputs = None 82 | batch_loss = None 83 | 84 | for batch_index, (x, y) in enumerate(dataloader): 85 | x, y = x.cuda(non_blocking=True), y.cuda(non_blocking=True) 86 | 87 | # zero the parameter gradients 88 | optimizer.zero_grad() 89 | 90 | # forward + backward + optimize 91 | outputs = model(x) 92 | l1_penalty, l2_penalty = conv2d_reg(model) 93 | 94 | batch_loss = loss(outputs, y) 95 | 96 | batch_size = x.size(0) 97 | (batch_size * batch_loss + l1_penalty + l2_penalty).backward() 98 | 99 | optimizer.step() 100 | 101 | # Batch train end 102 | # Log train progress 103 | 104 | l1_penalty = l1_penalty.cpu().item() 105 | l2_penalty = l2_penalty.cpu().item() 106 | 107 | 108 | batch_loss_val = batch_loss.cpu().item() 109 | if summary_writer is not None: 110 | summary_writer.add_scalar('train/batch/loss', batch_loss_val, epoch * n_batches + batch_index) 111 | summary_writer.add_scalar('train/batch/l1_penalty', l1_penalty, epoch * n_batches + batch_index) 112 | summary_writer.add_scalar('train/batch/l2_penalty', l2_penalty, epoch * n_batches + batch_index) 113 | 114 | # Plot gradient absmax and absmin to see if there are any gradient explosions 115 | grad_max = 0 116 | for name, param in model.named_parameters(): 117 | if param.grad is not None: 118 | grad_max = max(grad_max, param.grad.abs().max().cpu().item()) 119 | 120 | summary_writer.add_scalar('train/grad/global_abs_max', grad_max, epoch * n_batches + batch_index) 121 | 122 | losses.update(batch_loss_val) 123 | 124 | for key, metric in metrics.items(): 125 | score = metric(outputs, y).cpu().item() 126 | train_scores[key].update(score) 127 | 128 | if summary_writer is not None: 129 | summary_writer.add_scalar('train/batch/' + key, score, epoch * n_batches + batch_index) 130 | 131 | tq.set_postfix(loss='{:.3f}'.format(losses.avg), 132 | l1_penalty='{:.3f}'.format(l1_penalty), 133 | l2_penalty='{:.3f}'.format(l2_penalty), **train_scores) 134 | tq.update() 135 | 136 | # End of train epoch 137 | if summary_writer is not None: 138 | summary_writer.add_image('train/image', make_grid(x.cpu(), normalize=True), epoch) 139 | summary_writer.add_image('train/y_true', make_grid(y.cpu(), normalize=True), epoch) 140 | summary_writer.add_image('train/y_pred', make_grid(outputs.sigmoid().cpu(), normalize=True), epoch) 141 | summary_writer.add_scalar('train/epoch/loss', losses.avg, epoch) 142 | for key, value in train_scores.items(): 143 | summary_writer.add_scalar('train/epoch/' + key, value.avg, epoch) 144 | 145 | # Plot histogram of parameters after each epoch 146 | for name, param in model.named_parameters(): 147 | if param.grad is not None: 148 | # Plot weighs 149 | param_data = param.data.cpu().numpy() 150 | summary_writer.add_histogram('model/' + name, param_data, epoch, bins='doane') 151 | 152 | del x, y, outputs, batch_loss 153 | 154 | return losses, train_scores 155 | 156 | 157 | def validate(model, loss, dataloader, epoch: int, metrics=dict(), summary_writer: SummaryWriter = None): 158 | losses = AverageMeter() 159 | pr_meter = PRCurveMeter() 160 | 161 | valid_scores = {} 162 | for key, _ in metrics.items(): 163 | valid_scores[key] = AverageMeter() 164 | 165 | with torch.set_grad_enabled(False): 166 | model.eval() 167 | 168 | n_batches = len(dataloader) 169 | with tqdm(total=len(dataloader)) as tq: 170 | tq.set_description('Validation') 171 | 172 | x = None 173 | y = None 174 | outputs = None 175 | batch_loss = None 176 | 177 | for batch_index, (x, y) in enumerate(dataloader): 178 | x, y = x.cuda(non_blocking=True), y.cuda(non_blocking=True) 179 | 180 | # forward + backward + optimize 181 | outputs = model(x) 182 | batch_loss = loss(outputs, y) 183 | 184 | # Log train progress 185 | 186 | batch_loss_val = batch_loss.cpu().item() 187 | if summary_writer is not None: 188 | summary_writer.add_scalar('val/batch/loss', batch_loss_val, epoch * n_batches + batch_index) 189 | 190 | losses.update(batch_loss_val) 191 | 192 | for key, metric in metrics.items(): 193 | score = metric(outputs, y).cpu().item() 194 | valid_scores[key].update(score) 195 | 196 | if summary_writer is not None: 197 | summary_writer.add_scalar('val/batch/' + key, score, epoch * n_batches + batch_index) 198 | 199 | tq.set_postfix(loss='{:.3f}'.format(losses.avg), **valid_scores) 200 | tq.update() 201 | 202 | if summary_writer is not None: 203 | summary_writer.add_image('val/image', make_grid(x.cpu(), normalize=True), epoch) 204 | summary_writer.add_image('val/y_true', make_grid(y.cpu(), normalize=True), epoch) 205 | summary_writer.add_image('val/y_pred', make_grid(outputs.sigmoid().cpu(), normalize=True), epoch) 206 | summary_writer.add_scalar('val/epoch/loss', losses.avg, epoch) 207 | for key, value in valid_scores.items(): 208 | summary_writer.add_scalar('val/epoch/' + key, value.avg, epoch) 209 | 210 | # Compute PR curve only for last batch, because computing it for entire validation set is costly 211 | pr_meter.update(outputs, y) 212 | summary_writer.add_pr_curve_raw('val/pr_curve', 213 | true_positive_counts=pr_meter.tp, 214 | true_negative_counts=pr_meter.tn, 215 | false_negative_counts=pr_meter.fn, 216 | false_positive_counts=pr_meter.fp, 217 | precision=pr_meter.precision(), 218 | recall=pr_meter.recall(), 219 | global_step=epoch) 220 | del x, y, outputs, batch_loss 221 | 222 | return losses, valid_scores 223 | 224 | 225 | def save_snapshot(model: nn.Module, optimizer: Optimizer, loss: float, epoch: int, train_history: pd.DataFrame, snapshot_file: str): 226 | torch.save({ 227 | 'model': model.state_dict(), 228 | 'optimizer': optimizer.state_dict(), 229 | 'epoch': epoch, 230 | 'loss': loss, 231 | 'train_history': train_history.to_dict(), 232 | 'args': ' '.join(sys.argv[1:]) 233 | }, snapshot_file) 234 | 235 | 236 | def restore_snapshot(model: nn.Module, optimizer: Optimizer, snapshot_file: str): 237 | checkpoint = torch.load(snapshot_file) 238 | start_epoch = checkpoint['epoch'] + 1 239 | best_loss = checkpoint['loss'] 240 | model.load_state_dict(checkpoint['model']) 241 | 242 | if optimizer is not None: 243 | optimizer.load_state_dict(checkpoint['optimizer']) 244 | 245 | train_history = pd.DataFrame.from_dict(checkpoint['train_history']) 246 | 247 | return start_epoch, train_history, best_loss 248 | 249 | 250 | def main(): 251 | parser = argparse.ArgumentParser() 252 | 253 | parser.add_argument('-g', '--grayscale', action='store_true', help='Whether to use grayscale image instead of RGB') 254 | parser.add_argument('-m', '--model', required=True, type=str, help='Name of the model') 255 | parser.add_argument('-p', '--patch-size', type=int, default=224) 256 | parser.add_argument('-b', '--batch-size', type=int, default=1, help='Batch Size during training, e.g. -b 64') 257 | parser.add_argument('-lr', '--learning-rate', type=float, default=1e-3, help='Initial learning rate') 258 | parser.add_argument('-l', '--loss', type=str, default='bce', help='Target loss') 259 | parser.add_argument('-o', '--optimizer', default='SGD', help='Name of the optimizer') 260 | parser.add_argument('-e', '--epochs', type=int, default=100, help='Epoch to run') 261 | parser.add_argument('-d', '--dataset', type=str, help='Name of the dataset to use for training.') 262 | parser.add_argument('-dd', '--data-dir', type=str, default='data', help='Root directory where datasets are located.') 263 | parser.add_argument('-s', '--steps', type=int, default=128, help='Steps per epoch') 264 | parser.add_argument('-x', '--experiment', type=str, help='Name of the experiment') 265 | parser.add_argument('-w', '--workers', default=0, type=int, help='Num workers') 266 | parser.add_argument('-r', '--resume', action='store_true') 267 | parser.add_argument('-mem', '--memory', action='store_true') 268 | 269 | args = parser.parse_args() 270 | cudnn.benchmark = True 271 | 272 | if args.experiment is None: 273 | args.experiment = '%s_%s_reg_%d_%s_%s' % (args.dataset, args.model, args.patch_size, 'gray' if args.grayscale else 'rgb', args.loss) 274 | 275 | experiment_dir = os.path.join('experiments', args.dataset, args.loss, args.experiment) 276 | os.makedirs(experiment_dir, exist_ok=True) 277 | 278 | writer = SummaryWriter(comment='_' + args.experiment) 279 | 280 | with open(os.path.join(experiment_dir, 'arguments.txt'), 'w') as f: 281 | f.write(' '.join(sys.argv[1:])) 282 | 283 | model = get_model(args.model, patch_size=args.patch_size, num_channels=1 if args.grayscale else 3) 284 | 285 | # Write model graph 286 | dummy_input = torch.autograd.Variable(torch.rand((args.batch_size, 1 if args.grayscale else 3, args.patch_size, args.patch_size))) 287 | writer.add_graph(model, dummy_input) 288 | 289 | model = model.cuda() 290 | loss = get_loss(args.loss).cuda() 291 | optimizer = get_optimizer(args.optimizer, model.parameters(), args.learning_rate) 292 | metrics = {'iou': JaccardScore().cuda(), 'accuracy': PixelAccuracy().cuda()} 293 | 294 | trainset, validset, num_classes = get_dataset(args.dataset, args.data_dir, grayscale=args.grayscale, patch_size=args.patch_size, keep_in_mem=args.memory) 295 | print('Train set size', len(trainset)) 296 | print('Valid set size', len(validset)) 297 | print('Model ', model) 298 | print('Parameters ', count_parameters(model)) 299 | 300 | trainloader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) 301 | validloader = DataLoader(validset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=True) 302 | 303 | start_epoch = 0 304 | best_loss = np.inf 305 | train_history = pd.DataFrame() 306 | 307 | # Checkpoint is train result of epoch with best loss 308 | checkpoint_filename = os.path.join(experiment_dir, f'{args.model}_checkpoint.pth') 309 | 310 | # Snapshot is train result of last epoch 311 | snapshot_filename = os.path.join(experiment_dir, f'{args.model}_snapshot.pth') 312 | 313 | if args.resume: 314 | start_epoch, train_history, best_loss = restore_snapshot(model, optimizer, checkpoint_filename) 315 | print('Resuming training from epoch', start_epoch, ' and loss', best_loss) 316 | print(train_history) 317 | 318 | for epoch in range(start_epoch, args.epochs): 319 | train_loss, train_scores = train(model, loss, optimizer, trainloader, epoch, metrics, summary_writer=writer) 320 | valid_loss, valid_scores = validate(model, loss, validloader, epoch, metrics, summary_writer=writer) 321 | 322 | summary = { 323 | 'epoch': [epoch], 324 | 'loss': [train_loss.avg], 325 | 'val_loss': [valid_loss.avg] 326 | } 327 | 328 | for key, value in train_scores.items(): 329 | summary[key] = [value.avg] 330 | 331 | for key, value in valid_scores.items(): 332 | summary['val_' + key] = [value.avg] 333 | 334 | train_history = train_history.append(pd.DataFrame.from_dict(summary), ignore_index=True) 335 | 336 | print(epoch, summary) 337 | 338 | if valid_loss.avg < best_loss: 339 | save_snapshot(model, optimizer, valid_loss.avg, epoch, train_history, checkpoint_filename) 340 | best_loss = valid_loss.avg 341 | print('Checkpoint saved', epoch, best_loss) 342 | 343 | save_snapshot(model, optimizer, valid_loss.avg, epoch, train_history, snapshot_filename) 344 | 345 | print('Training is finished...') 346 | 347 | train_history.to_csv(os.path.join(experiment_dir, args.experiment + '.csv'), 348 | index=False, 349 | mode='a' if args.resume else 'w', 350 | header=not args.resume) 351 | 352 | 353 | if __name__ == '__main__': 354 | main() 355 | --------------------------------------------------------------------------------