├── unet ├── __init__.py ├── decoder_skipless.py ├── decoder.py ├── decoder_edge.py └── model.py ├── ims ├── 1.png ├── 2.png ├── 3.png ├── 4.png ├── 5.png └── motiv.png ├── utils ├── __init__.py ├── meter.py ├── base.py ├── metrics.py ├── losses.py ├── train.py └── functional.py ├── base ├── __init__.py ├── initialization.py ├── heads.py ├── model.py └── modules.py ├── encoders ├── _preprocessing.py ├── _base.py ├── _utils.py ├── __init__.py └── resnet.py ├── params.py ├── README.md ├── evaluate.py ├── utilities.py └── dataset.py /unet/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import Unet -------------------------------------------------------------------------------- /ims/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burakekim/MTL_homoscedastic_SRB/HEAD/ims/1.png -------------------------------------------------------------------------------- /ims/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burakekim/MTL_homoscedastic_SRB/HEAD/ims/2.png -------------------------------------------------------------------------------- /ims/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burakekim/MTL_homoscedastic_SRB/HEAD/ims/3.png -------------------------------------------------------------------------------- /ims/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burakekim/MTL_homoscedastic_SRB/HEAD/ims/4.png -------------------------------------------------------------------------------- /ims/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burakekim/MTL_homoscedastic_SRB/HEAD/ims/5.png -------------------------------------------------------------------------------- /ims/motiv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burakekim/MTL_homoscedastic_SRB/HEAD/ims/motiv.png -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import train 2 | from . import losses 3 | from . import metrics -------------------------------------------------------------------------------- /base/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import SegmentationModel 2 | 3 | from .modules import ( 4 | Conv2dReLU, 5 | Attention, 6 | ) 7 | 8 | from .heads import ( 9 | SegmentationHead, 10 | ClassificationHead, 11 | AUX_edgehead, 12 | AUX_SegmentationHead 13 | ) -------------------------------------------------------------------------------- /encoders/_preprocessing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def preprocess_input( 5 | x, mean=None, std=None, input_space="RGB", input_range=None, **kwargs 6 | ): 7 | 8 | if input_space == "BGR": 9 | x = x[..., ::-1].copy() 10 | 11 | if input_range is not None: 12 | if x.max() > 1 and input_range[1] == 1: 13 | x = x / 255.0 14 | 15 | if mean is not None: 16 | mean = np.array(mean) 17 | x = x - mean 18 | 19 | if std is not None: 20 | std = np.array(std) 21 | x = x / std 22 | 23 | return x 24 | -------------------------------------------------------------------------------- /base/initialization.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def initialize_decoder(module): 5 | for m in module.modules(): 6 | 7 | if isinstance(m, nn.Conv2d): 8 | nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu") 9 | if m.bias is not None: 10 | nn.init.constant_(m.bias, 0) 11 | 12 | elif isinstance(m, nn.BatchNorm2d): 13 | nn.init.constant_(m.weight, 1) 14 | nn.init.constant_(m.bias, 0) 15 | 16 | elif isinstance(m, nn.Linear): 17 | nn.init.xavier_uniform_(m.weight) 18 | if m.bias is not None: 19 | nn.init.constant_(m.bias, 0) 20 | 21 | 22 | def initialize_head(module): 23 | for m in module.modules(): 24 | if isinstance(m, (nn.Linear, nn.Conv2d)): 25 | nn.init.xavier_uniform_(m.weight) 26 | if m.bias is not None: 27 | nn.init.constant_(m.bias, 0) 28 | -------------------------------------------------------------------------------- /params.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import os 4 | import torch 5 | 6 | def get_params(): 7 | 8 | params = {} 9 | 10 | params['experiment_log'] = "MTL_ser_r101u_0" # experiment ID. 11 | params['main_dir'] = r"C:\Users\burak\Desktop" 12 | params['log_path'] = os.path.join(params['main_dir'], 'mtl', 'logs') # path to save logs. 13 | 14 | params['main_data_dir'] = r"D:\Veri-Setleri\SN6\SN6_buildings_AOI_11_Rotterdam_train\train\AOI_11_Rotterdam" 15 | params['masks_dir'] = os.path.join(params['main_data_dir'], 'MASKS_binary') # dataset path. 16 | params['psp_dir'] = os.path.join(params['main_data_dir'], 'PS-RGB') # dataset path. 17 | 18 | params['crop_size'] = (480,480) 19 | params['encoder'] = 'resnet101' # backbone architecture, encoder. 20 | params['encoder_weights'] = 'imagenet' # pre-trained weights. 21 | params['classes'] = np.arange(0,2,1) # used for encoding-decoding the mask.['building','not building'] 22 | params['activation'] = 'sigmoid' # activation function. 23 | params['device'] = torch.device("cuda" if torch.cuda.is_available() else "cpu") # GPU or CPU. 24 | params['batch_size'] = 4 # batch size. 25 | params['lr'] = 0.0001 # learning rate. 26 | params['n_epoch'] = 50 # number of epochs. 27 | params['n_workers'] = 0 #number of workers for data loader, multi-process scheme. 28 | params['seed'] = 12 29 | 30 | return params 31 | 32 | -------------------------------------------------------------------------------- /encoders/_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import List 4 | from collections import OrderedDict 5 | 6 | from . import _utils as utils 7 | 8 | 9 | class EncoderMixin: 10 | """Add encoder functionality such as: 11 | - output channels specification of feature tensors (produced by encoder) 12 | - patching first convolution for arbitrary input channels 13 | """ 14 | 15 | @property 16 | def out_channels(self): 17 | """Return channels dimensions for each tensor of forward output of encoder""" 18 | return self._out_channels[: self._depth + 1] 19 | 20 | def set_in_channels(self, in_channels): 21 | """Change first convolution chennels""" 22 | if in_channels == 3: 23 | return 24 | 25 | self._in_channels = in_channels 26 | if self._out_channels[0] == 3: 27 | self._out_channels = tuple([in_channels] + list(self._out_channels)[1:]) 28 | 29 | utils.patch_first_conv(model=self, in_channels=in_channels) 30 | 31 | def get_stages(self): 32 | """Method should be overridden in encoder""" 33 | raise NotImplementedError 34 | 35 | def make_dilated(self, stage_list, dilation_list): 36 | stages = self.get_stages() 37 | for stage_indx, dilation_rate in zip(stage_list, dilation_list): 38 | utils.replace_strides_with_dilation( 39 | module=stages[stage_indx], 40 | dilation_rate=dilation_rate, 41 | ) 42 | -------------------------------------------------------------------------------- /encoders/_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def patch_first_conv(model, in_channels): 6 | """Change first convolution layer input channels. 7 | In case: 8 | in_channels == 1 or in_channels == 2 -> reuse original weights 9 | in_channels > 3 -> make random kaiming normal initialization 10 | """ 11 | 12 | # get first conv 13 | for module in model.modules(): 14 | if isinstance(module, nn.Conv2d): 15 | break 16 | 17 | # change input channels for first conv 18 | module.in_channels = in_channels 19 | weight = module.weight.detach() 20 | reset = False 21 | 22 | if in_channels == 1: 23 | weight = weight.sum(1, keepdim=True) 24 | elif in_channels == 2: 25 | weight = weight[:, :2] * (3.0 / 2.0) 26 | else: 27 | reset = True 28 | weight = torch.Tensor( 29 | module.out_channels, 30 | module.in_channels // module.groups, 31 | *module.kernel_size 32 | ) 33 | 34 | module.weight = nn.parameter.Parameter(weight) 35 | if reset: 36 | module.reset_parameters() 37 | 38 | 39 | def replace_strides_with_dilation(module, dilation_rate): 40 | """Patch Conv2d modules replacing strides with dilation""" 41 | for mod in module.modules(): 42 | if isinstance(mod, nn.Conv2d): 43 | mod.stride = (1, 1) 44 | mod.dilation = (dilation_rate, dilation_rate) 45 | kh, kw = mod.kernel_size 46 | mod.padding = ((kh // 2) * dilation_rate, (kh // 2) * dilation_rate) 47 | 48 | # Kostyl for EfficientNet 49 | if hasattr(mod, "static_padding"): 50 | mod.static_padding = nn.Identity() 51 | -------------------------------------------------------------------------------- /encoders/__init__.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch.utils.model_zoo as model_zoo 3 | 4 | from .resnet import resnet_encoders 5 | 6 | from ._preprocessing import preprocess_input 7 | 8 | encoders = {} 9 | encoders.update(resnet_encoders) 10 | 11 | def get_encoder(name, in_channels=3, depth=5, weights=None): 12 | Encoder = encoders[name]["encoder"] 13 | params = encoders[name]["params"] 14 | params.update(depth=depth) 15 | encoder = Encoder(**params) 16 | 17 | if weights is not None: 18 | print("Pretrained.") 19 | settings = encoders[name]["pretrained_settings"][weights] 20 | encoder.load_state_dict(model_zoo.load_url(settings["url"])) 21 | if weights is None: 22 | print("Not-pretrained.") 23 | encoder.set_in_channels(in_channels) 24 | 25 | return encoder 26 | 27 | 28 | def get_encoder_names(): 29 | return list(encoders.keys()) 30 | 31 | 32 | def get_preprocessing_params(encoder_name, pretrained="imagenet"): 33 | settings = encoders[encoder_name]["pretrained_settings"] 34 | 35 | if pretrained not in settings.keys(): 36 | raise ValueError("Avaliable pretrained options {}".format(settings.keys())) 37 | 38 | formatted_settings = {} 39 | formatted_settings["input_space"] = settings[pretrained].get("input_space") 40 | formatted_settings["input_range"] = settings[pretrained].get("input_range") 41 | formatted_settings["mean"] = settings[pretrained].get("mean") 42 | formatted_settings["std"] = settings[pretrained].get("std") 43 | return formatted_settings 44 | 45 | 46 | def get_preprocessing_fn(encoder_name, pretrained="imagenet"): 47 | params = get_preprocessing_params(encoder_name, pretrained=pretrained) 48 | return functools.partial(preprocess_input, **params) 49 | -------------------------------------------------------------------------------- /utils/meter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class Meter(object): 4 | '''Meters provide a way to keep track of important statistics in an online manner. 5 | This class is abstract, but provides a standard interface for all meters to follow. 6 | ''' 7 | 8 | def reset(self): 9 | '''Resets the meter to default settings.''' 10 | pass 11 | 12 | def add(self, value): 13 | '''Log a new value to the meter 14 | Args: 15 | value: Next restult to include. 16 | ''' 17 | pass 18 | 19 | def value(self): 20 | '''Get the value of the meter in the current state.''' 21 | pass 22 | 23 | 24 | class AverageValueMeter(Meter): 25 | def __init__(self): 26 | super(AverageValueMeter, self).__init__() 27 | self.reset() 28 | self.val = 0 29 | 30 | def add(self, value, n=1): 31 | self.val = value 32 | self.sum += value 33 | self.var += value * value 34 | self.n += n 35 | 36 | if self.n == 0: 37 | self.mean, self.std = np.nan, np.nan 38 | elif self.n == 1: 39 | self.mean = 0.0 + self.sum # This is to force a copy in torch/numpy 40 | self.std = np.inf 41 | self.mean_old = self.mean 42 | self.m_s = 0.0 43 | else: 44 | self.mean = self.mean_old + (value - n * self.mean_old) / float(self.n) 45 | self.m_s += (value - self.mean_old) * (value - self.mean) 46 | self.mean_old = self.mean 47 | self.std = np.sqrt(self.m_s / (self.n - 1.0)) 48 | 49 | def value(self): 50 | return self.mean, self.std 51 | 52 | def reset(self): 53 | self.n = 0 54 | self.sum = 0.0 55 | self.var = 0.0 56 | self.val = 0.0 57 | self.mean = np.nan 58 | self.mean_old = 0.0 59 | self.m_s = 0.0 60 | self.std = np.nan 61 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Land Use and Land Cover Mapping Using Deep Learning Based Segmentation Approaches and VHR Worldview-3 images 2 | This repository contains the code for the paper [A MULTI-TASK DEEP LEARNING FRAMEWORK FOR BUILDING FOOTPRINT SEGMENTATION](https://ieeexplore.ieee.org/document/9554766). 3 | 4 | 5 | Framework 6 | --------------------- 7 | - This work constructs a multi-task learning framework in which the main segmentation task is coupled with the auxilliary image reconstruction and edge extraction tasks. A homoscedastic uncertainty aware-objective function is formed where individual loss contributions are learned throughout the training procedure, along with the default weights. 8 | 9 | ![alt text](ims/motiv.png) 10 | 11 | 12 | Outputs 13 | --------------------- 14 | Sample outputs, from left to right: 15 | - Input Image, 16 | - Segmentation Annotation, 17 | - Predicted Segmentation Map, 18 | - Edge Annotation, 19 | - Predicted Edge Map, 20 | - Reconstructed Input Image. 21 | 22 | ![alt text](ims/1.png) 23 | ![alt text](ims/2.png) 24 | ![alt text](ims/3.png) 25 | ![alt text](ims/4.png) 26 | ![alt text](ims/5.png) 27 | 28 | 29 | How to use it? 30 | --------------------- 31 | 32 | Simply download the repository and follow the *main_notebook.ipynb* after modifying the paths and the parameters in the *params.py* script. 33 | 34 | The [Spacenet6](https://arxiv.org/abs/2004.06500) dataset needs to be downloaded prior to running the main notebook (or use your own custom Dataset instance). 35 | 36 | The code was implemented in Python(3.8) and PyTroch(1.14.0) on Windows OS. The *segmentation models pytorch* library is used as a baseline for implementation. Apart from main data science libraries, RS-specific libraries such as GDAL, rasterio, and tifffile are also required. 37 | 38 | Citation 39 | --------------------- 40 | 41 | B. Ekim and E. Sertel, "A Multi-Task Deep Learning Framework for Building Footprint Segmentation," 2021 IEEE International Geoscience and Remote Sensing Symposium IGARSS, 2021, pp. 2500-2503, doi: 10.1109/IGARSS47720.2021.9554766. 42 | -------------------------------------------------------------------------------- /base/heads.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .modules import Flatten, Activation 3 | 4 | class SegmentationHead(nn.Sequential): 5 | 6 | def __init__(self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1): 7 | conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2) 8 | upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity() 9 | activation = Activation(activation) 10 | super().__init__(conv2d, upsampling, activation) 11 | 12 | class AUX_SegmentationHead(nn.Sequential): 13 | 14 | def __init__(self, in_channels, out_channels, kernel_size=1, activation=None, upsampling=1): 15 | conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2) 16 | upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity() 17 | #activation = Activation(activation) 18 | super().__init__(conv2d, upsampling) 19 | 20 | class AUX_edgehead(nn.Sequential): 21 | 22 | def __init__(self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1): 23 | conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2) 24 | upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity() 25 | activation = Activation(activation) 26 | super().__init__(conv2d, upsampling, activation) 27 | 28 | class ClassificationHead(nn.Sequential): 29 | 30 | def __init__(self, in_channels, classes, pooling="avg", dropout=0.2, activation=None): 31 | if pooling not in ("max", "avg"): 32 | raise ValueError("Pooling should be one of ('max', 'avg'), got {}.".format(pooling)) 33 | pool = nn.AdaptiveAvgPool2d(1) if pooling == 'avg' else nn.AdaptiveMaxPool2d(1) 34 | flatten = Flatten() 35 | dropout = nn.Dropout(p=dropout, inplace=True) if dropout else nn.Identity() 36 | linear = nn.Linear(in_channels, classes, bias=True) 37 | activation = Activation(activation) 38 | super().__init__(pool, flatten, dropout, linear, activation) 39 | -------------------------------------------------------------------------------- /base/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from . import initialization as init 3 | 4 | class SegmentationModel(torch.nn.Module): 5 | 6 | def initialize(self): 7 | init.initialize_decoder(self.decoder) 8 | init.initialize_decoder(self.skipless_decoder) 9 | init.initialize_decoder(self.UnetDecoder_Edge) 10 | 11 | init.initialize_head(self.segmentation_head) 12 | init.initialize_head(self.edge_head) 13 | init.initialize_head(self.reconstruct_segmentation_head) 14 | if self.classification_head is not None: 15 | init.initialize_head(self.classification_head) 16 | 17 | def forward(self, x): 18 | """Sequentially pass `x` trough model`s encoder, decoder and heads""" 19 | features = self.encoder(x) # Extract features - common 20 | ## Segmentation 21 | decoder_output = self.decoder(*features) # Decoder output for segmentation task 22 | segmentation_mask = self.segmentation_head(decoder_output) # Feed the decoder output to the segmentation head 23 | ## Edge 24 | edge_decoder_output = self.decoder(*features) # Decoder output for edge detection head 25 | edge_mask = self.edge_head(edge_decoder_output) # Feed the decoder output to the edge head 26 | ## Reconstruction 27 | reconstruct_decoder_output = self.skipless_decoder(*features) # Decoder output for reconstruction task 28 | reconstruction_mask = self.reconstruct_segmentation_head(reconstruct_decoder_output) # Feed the decoder output to the reconstruction head 29 | 30 | return segmentation_mask, edge_mask, reconstruction_mask, self.sigma 31 | 32 | def predict(self, x): 33 | """Inference method. Switch model to `eval` mode, call `.forward(x)` with `torch.no_grad()` 34 | 35 | Args: 36 | x: 4D torch tensor with shape (batch_size, channels, height, width) 37 | 38 | Return: 39 | prediction: 4D torch tensor with shape (batch_size, classes, height, width) 40 | 41 | """ 42 | if self.training: 43 | self.eval() 44 | 45 | with torch.no_grad(): 46 | segmentation_mask, edge_mask, reconstruction_mask, self.sigma = self.forward(x) 47 | 48 | return segmentation_mask, edge_mask, reconstruction_mask -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import tqdm 4 | from sklearn.metrics import (accuracy_score, f1_score, jaccard_score, 5 | precision_score, recall_score) 6 | 7 | 8 | def evaluate(model, generator): 9 | """ Evaluation function. Calculates *y* and *y_hat* arrays. 10 | usage: 11 | from sklearn.metrics import classification_report (or any other evaluation metric) 12 | y_pred, y = evaluate(best_model, generator) 13 | y_pred = np.hstack(y_pred) 14 | y = np.hstack(y) 15 | targets = ['building', 'not_building'] 16 | print(classification_report(y, y_pred, target_names = targets))#, output_dict = True) 17 | """ 18 | 19 | best_model = model 20 | best_model.cuda() 21 | best_model.eval() 22 | 23 | y_preds = list() 24 | ys = list() 25 | with torch.no_grad(): 26 | for idx in tqdm.tqdm(range(len(generator))): 27 | X,y,z = generator[idx] #image, mask, edge_mask 28 | X = X.cuda() 29 | #X = X.detach().cpu().numpy() #no need to further increase the computational burden . 30 | y = y.detach().cpu().numpy() 31 | gt_max = np.argmax(y, axis=0) 32 | segmentation, edge, reconstruction, sigma = best_model.forward(X[None,:,:]) 33 | y_pred = segmentation.argmax(dim=1) 34 | y_pred = y_pred.squeeze() 35 | gt_color = gt_max.flatten() 36 | y= gt_color 37 | y_pred = y_pred.flatten().detach().cpu().numpy() 38 | 39 | ys.append(y) 40 | y_preds.append(y_pred) 41 | 42 | best_model.cpu() 43 | return y_preds, ys 44 | 45 | 46 | def calculate_metrics(y_preds,ys): 47 | """ 48 | Intented to complement the *evaluate* function. 49 | Calculates evaluation metrics for a given *y* and *y_hat* arrays. 50 | """ 51 | y_preds = np.asarray(y_preds) 52 | ys = np.asarray(ys) 53 | 54 | include_label = [0,1] # omit background during metric calculation 55 | 56 | F1 = f1_score(ys.flatten(), y_preds.flatten(), average=None, labels=include_label) 57 | Precision = precision_score(ys.flatten(), y_preds.flatten(), average=None,labels=include_label) 58 | Recall = recall_score(ys.flatten(), y_preds.flatten(), average=None, labels=include_label) 59 | Jaccard = jaccard_score(ys.flatten(), y_preds.flatten(), average=None,labels=include_label) 60 | acc = accuracy_score(ys.flatten(), y_preds.flatten()) 61 | 62 | f1 = np.asarray(F1) 63 | prec = np.asarray(Precision) 64 | rec = np.asarray(Recall) 65 | jacc = np.asarray(Jaccard) 66 | 67 | return f1, prec, rec, jacc, acc 68 | -------------------------------------------------------------------------------- /utils/base.py: -------------------------------------------------------------------------------- 1 | import re 2 | import functools 3 | import torch 4 | import torch.nn as nn 5 | 6 | class Activation(nn.Module): 7 | def __init__(self, activation): 8 | super().__init__() 9 | if activation == None or activation == 'identity': 10 | self.activation = nn.Identity() 11 | elif activation == 'sigmoid': 12 | self.activation = torch.sigmoid 13 | elif activation == 'softmax2d': 14 | self.activation = functools.partial(torch.softmax, dim=1) 15 | elif callable(activation): 16 | self.activation = activation 17 | else: 18 | raise ValueError 19 | 20 | def forward(self, x): 21 | return self.activation(x) 22 | 23 | class BaseObject(nn.Module): 24 | 25 | def __init__(self, name=None): 26 | super().__init__() 27 | self._name = name 28 | 29 | @property 30 | def __name__(self): 31 | if self._name is None: 32 | name = self.__class__.__name__ 33 | s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) 34 | return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() 35 | else: 36 | return self._name 37 | 38 | class Metric(BaseObject): 39 | pass 40 | 41 | class Loss(BaseObject): 42 | 43 | def __add__(self, other): 44 | if isinstance(other, Loss): 45 | return SumOfLosses(self, other) 46 | else: 47 | raise ValueError('Loss should be inherited from `Loss` class') 48 | 49 | def __radd__(self, other): 50 | return self.__add__(other) 51 | 52 | def __mul__(self, value): 53 | if isinstance(value, (int, float)): 54 | return MultipliedLoss(self, value) 55 | else: 56 | raise ValueError('Loss should be inherited from `BaseLoss` class') 57 | 58 | def __rmul__(self, other): 59 | return self.__mul__(other) 60 | 61 | class SumOfLosses(Loss): 62 | 63 | def __init__(self, l1, l2): 64 | name = '{} + {}'.format(l1.__name__, l2.__name__) 65 | super().__init__(name=name) 66 | self.l1 = l1 67 | self.l2 = l2 68 | 69 | def __call__(self, *inputs): 70 | return self.l1.forward(*inputs) + self.l2.forward(*inputs) 71 | 72 | class MultipliedLoss(Loss): 73 | 74 | def __init__(self, loss, multiplier): 75 | 76 | # resolve name 77 | if len(loss.__name__.split('+')) > 1: 78 | name = '{} * ({})'.format(multiplier, loss.__name__) 79 | else: 80 | name = '{} * {}'.format(multiplier, loss.__name__) 81 | super().__init__(name=name) 82 | self.loss = loss 83 | self.multiplier = multiplier 84 | 85 | def __call__(self, *inputs): 86 | return self.multiplier * self.loss.forward(*inputs) 87 | -------------------------------------------------------------------------------- /utilities.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import numpy as np 4 | from matplotlib import pyplot as plt 5 | 6 | palette = {0 : (255, 0, 0), # buildings -red 7 | 1 : (0, 0 , 0) # rest of the mask - black 8 | } 9 | 10 | invert_palette = {v: k for k, v in palette.items()} 11 | 12 | def convert_to_color(arr_2d, palette=palette): 13 | """ Numeric labels to RGB-color encoding """ 14 | arr_3d = np.zeros((arr_2d.shape[0], arr_2d.shape[1], 3), dtype=np.uint8) 15 | 16 | for c, i in palette.items(): 17 | m = arr_2d == c 18 | arr_3d[m] = i 19 | 20 | return arr_3d 21 | 22 | def convert_from_color(arr_3d, palette=invert_palette): 23 | """ RGB-color encoding to grayscale labels """ 24 | arr_2d = np.zeros((arr_3d.shape[0], arr_3d.shape[1]), dtype=np.uint8) 25 | 26 | for c, i in palette.items(): 27 | m = np.all(arr_3d == np.array(c).reshape(1, 1, 3), axis=2) 28 | arr_2d[m] = i 29 | 30 | return arr_2d 31 | 32 | def visualize(fig_name = None, **images): 33 | """PLot images in one row.""" 34 | n = len(images) 35 | plt.figure(figsize=(16, 5)) 36 | for i, (name, image) in enumerate(images.items()): 37 | plt.subplot(1, n, i + 1) 38 | plt.xticks([]) 39 | plt.yticks([]) 40 | plt.title(' '.join(name.split('_')).title()) 41 | plt.imshow(image) 42 | if fig_name is not None: 43 | plt.savefig(r'D:\burak\MULTI_TASK\TRIPLE_TASK\outputs\predictions\{}'.format(fig_name)) 44 | plt.show() 45 | 46 | 47 | def get_id(FOLDER, test_ratio = 0.1, verbose = True): 48 | """ 49 | Walk through dataset folder, extract all the IDs and divide them into train/val/test sets by using adjustable ratio parameter. 50 | Default: 51 | 0.7 Train 52 | 0.2 Validation 53 | 0.1 Test 54 | """ 55 | 56 | image_dir = FOLDER + '\*' 57 | all_files = sorted(glob.glob(image_dir)) 58 | ids = [] 59 | 60 | for i in range(len(all_files)): 61 | all_things = os.path.basename(all_files[i]).split('.') 62 | all_things = all_things[0].split('_') 63 | first_id, second_id, patch_id = all_things[-4], all_things[-3], all_things[-1] 64 | ID = first_id + '_' + second_id + '_' + 'tile' + '_' + patch_id + '.tif' 65 | ids.append(ID) 66 | 67 | np.random.seed(0) 68 | test_ids = np.random.choice(ids, size=round(len(ids) * test_ratio), replace = False) 69 | validation_ids = np.random.choice(ids, size=round(len(ids) * (2 * test_ratio)), replace = False) 70 | 71 | train_val = np.setdiff1d(ids,test_ids) 72 | train_ids = np.setdiff1d(train_val, validation_ids) 73 | 74 | if verbose is not False: 75 | print("len(train_ids): {}\nlen(validation_ids):{}\nlen(test_ids):{}\ntotal_#_of_ids:{}".format(len(train_ids),len(validation_ids),len(test_ids),len(ids))) 76 | return train_ids, validation_ids, test_ids 77 | else: 78 | return train_ids, validation_ids, test_ids 79 | 80 | 81 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | from . import base 2 | from . import functional as F 3 | from .base import Activation 4 | 5 | 6 | class IoU(base.Metric): 7 | __name__ = 'iou_score' 8 | 9 | def __init__(self, eps=1e-7, threshold=0.5, activation=None, ignore_channels=None, **kwargs): 10 | super().__init__(**kwargs) 11 | self.eps = eps 12 | self.threshold = threshold 13 | self.activation = Activation(activation) 14 | self.ignore_channels = ignore_channels 15 | 16 | def forward(self, y_pr, y_gt): 17 | y_pr = self.activation(y_pr) 18 | return F.iou( 19 | y_pr, y_gt, 20 | eps=self.eps, 21 | threshold=self.threshold, 22 | ignore_channels=self.ignore_channels, 23 | ) 24 | 25 | class Fscore(base.Metric): 26 | 27 | def __init__(self, beta=1, eps=1e-7, threshold=0.5, activation=None, ignore_channels=None, **kwargs): 28 | super().__init__(**kwargs) 29 | self.eps = eps 30 | self.beta = beta 31 | self.threshold = threshold 32 | self.activation = Activation(activation) 33 | self.ignore_channels = ignore_channels 34 | 35 | def forward(self, y_pr, y_gt): 36 | y_pr = self.activation(y_pr) 37 | return F.f_score( 38 | y_pr, y_gt, 39 | eps=self.eps, 40 | beta=self.beta, 41 | threshold=self.threshold, 42 | ignore_channels=self.ignore_channels, 43 | ) 44 | 45 | 46 | class Accuracy(base.Metric): 47 | 48 | def __init__(self, threshold=0.5, activation=None, ignore_channels=None, **kwargs): 49 | super().__init__(**kwargs) 50 | self.threshold = threshold 51 | self.activation = Activation(activation) 52 | self.ignore_channels = ignore_channels 53 | 54 | def forward(self, y_pr, y_gt): 55 | y_pr = self.activation(y_pr) 56 | return F.accuracy( 57 | y_pr, y_gt, 58 | threshold=self.threshold, 59 | ignore_channels=self.ignore_channels, 60 | ) 61 | 62 | 63 | class Recall(base.Metric): 64 | 65 | def __init__(self, eps=1e-7, threshold=0.5, activation=None, ignore_channels=None, **kwargs): 66 | super().__init__(**kwargs) 67 | self.eps = eps 68 | self.threshold = threshold 69 | self.activation = Activation(activation) 70 | self.ignore_channels = ignore_channels 71 | 72 | def forward(self, y_pr, y_gt): 73 | y_pr = self.activation(y_pr) 74 | return F.recall( 75 | y_pr, y_gt, 76 | eps=self.eps, 77 | threshold=self.threshold, 78 | ignore_channels=self.ignore_channels, 79 | ) 80 | 81 | 82 | class Precision(base.Metric): 83 | 84 | def __init__(self, eps=1e-7, threshold=0.5, activation=None, ignore_channels=None, **kwargs): 85 | super().__init__(**kwargs) 86 | self.eps = eps 87 | self.threshold = threshold 88 | self.activation = Activation(activation) 89 | self.ignore_channels = ignore_channels 90 | 91 | def forward(self, y_pr, y_gt): 92 | y_pr = self.activation(y_pr) 93 | return F.precision( 94 | y_pr, y_gt, 95 | eps=self.eps, 96 | threshold=self.threshold, 97 | ignore_channels=self.ignore_channels, 98 | ) 99 | 100 | # aliases 101 | iou_score = IoU() 102 | f1_score = Fscore(beta=1) 103 | f2_score = Fscore(beta=2) 104 | precision = Precision() 105 | recall = Recall() -------------------------------------------------------------------------------- /unet/decoder_skipless.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from base import modules as md 6 | 7 | class DecoderBlock(nn.Module): 8 | def __init__( 9 | self, 10 | in_channels, 11 | out_channels, 12 | use_batchnorm=True, 13 | attention_type=None, 14 | ): 15 | super().__init__() 16 | self.conv1 = md.Conv2dReLU( 17 | in_channels, 18 | out_channels, 19 | kernel_size=3, 20 | padding=1, 21 | use_batchnorm=use_batchnorm, 22 | ) 23 | self.attention1 = md.Attention(attention_type, in_channels=in_channels) 24 | self.conv2 = md.Conv2dReLU( 25 | out_channels, 26 | out_channels, 27 | kernel_size=3, 28 | padding=1, 29 | use_batchnorm=use_batchnorm, 30 | ) 31 | self.attention2 = md.Attention(attention_type, in_channels=out_channels) 32 | 33 | def forward(self, x): 34 | x = F.interpolate(x, scale_factor=2, mode="nearest") 35 | x = self.conv1(x) 36 | x = self.conv2(x) 37 | x = self.attention2(x) 38 | return x 39 | 40 | 41 | class CenterBlock(nn.Sequential): 42 | def __init__(self, in_channels, out_channels, use_batchnorm=True): 43 | conv1 = md.Conv2dReLU( 44 | in_channels, 45 | out_channels, 46 | kernel_size=3, 47 | padding=1, 48 | use_batchnorm=use_batchnorm, 49 | ) 50 | conv2 = md.Conv2dReLU( 51 | out_channels, 52 | out_channels, 53 | kernel_size=3, 54 | padding=1, 55 | use_batchnorm=use_batchnorm, 56 | ) 57 | super().__init__(conv1, conv2) 58 | 59 | class UnetDecoder_skipless(nn.Module): 60 | def __init__( 61 | self, 62 | encoder_channels, 63 | decoder_channels, 64 | n_blocks=5, 65 | use_batchnorm=True, 66 | attention_type=None, 67 | center=False, 68 | ): 69 | super().__init__() 70 | 71 | if n_blocks != len(decoder_channels): 72 | raise ValueError( 73 | "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format( 74 | n_blocks, len(decoder_channels) 75 | ) 76 | ) 77 | 78 | encoder_channels = encoder_channels[1:] # remove first skip with same spatial resolution 79 | encoder_channels = encoder_channels[::-1] # reverse channels to start from head of encoder 80 | 81 | # computing blocks input and output channels 82 | head_channels = encoder_channels[0] 83 | in_channels = [head_channels] + list(decoder_channels[:-1]) 84 | out_channels = decoder_channels 85 | 86 | if center: 87 | self.center = CenterBlock( 88 | head_channels, head_channels, use_batchnorm=use_batchnorm 89 | ) 90 | else: 91 | self.center = nn.Identity() 92 | 93 | # combine decoder keyword arguments 94 | kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type) 95 | blocks = [ 96 | DecoderBlock(in_ch, out_ch, **kwargs) 97 | for in_ch, out_ch in zip(in_channels, out_channels) 98 | ] 99 | self.blocks = nn.ModuleList(blocks) 100 | 101 | def forward(self, *features): 102 | 103 | features = features[1:] # remove first skip with same spatial resolution 104 | features = features[::-1] # reverse channels to start from head of encoder 105 | 106 | head = features[0] 107 | 108 | x = self.center(head) 109 | for i, decoder_block in enumerate(self.blocks): 110 | x = decoder_block(x) 111 | 112 | return x 113 | -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from . import base 4 | from . import functional as F 5 | from .base import Activation 6 | import torch 7 | 8 | class JaccardLoss(base.Loss): 9 | 10 | def __init__(self, eps=1., activation=None, ignore_channels=None, **kwargs): 11 | super().__init__(**kwargs) 12 | self.eps = eps 13 | self.activation = Activation(activation) 14 | self.ignore_channels = ignore_channels 15 | 16 | def forward(self, y_pr, y_gt): 17 | y_pr = self.activation(y_pr) 18 | return 1 - F.jaccard( 19 | y_pr, y_gt, 20 | eps=self.eps, 21 | threshold=None, 22 | ignore_channels=self.ignore_channels, 23 | ) 24 | 25 | class DiceLoss(base.Loss): 26 | 27 | def __init__(self, eps=1., beta=1., activation=None, ignore_channels=None, **kwargs): 28 | super().__init__(**kwargs) 29 | self.eps = eps 30 | self.beta = beta 31 | self.activation = Activation(activation) 32 | self.ignore_channels = ignore_channels 33 | 34 | def forward(self, y_pr, y_gt): 35 | y_pr = self.activation(y_pr) 36 | return 1 - F.f_score( 37 | y_pr, y_gt, 38 | beta=self.beta, 39 | eps=self.eps, 40 | threshold=None, 41 | ignore_channels=self.ignore_channels, 42 | ) 43 | 44 | 45 | ####### 46 | ## 47 | 48 | import torch 49 | import torch.nn as nn 50 | import torch.nn.functional as F 51 | from torch.autograd import Variable 52 | from torch.nn import Parameter, Module 53 | 54 | class _MSEloss(nn.MSELoss, base.Loss): 55 | def __init__(self, **kwargs): 56 | super().__init__(**kwargs) 57 | self.MSELoss = nn.MSELoss() 58 | 59 | def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 60 | return self.MSELoss(input, target) 61 | 62 | class _L1loss(nn.L1Loss, base.Loss): 63 | def __init__(self, **kwargs): 64 | super().__init__(**kwargs) 65 | self.L1Loss = nn.L1Loss() 66 | 67 | def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 68 | return self.L1Loss(input, target) 69 | 70 | class _BCrossEntropyloss(nn.BCELoss, base.Loss): 71 | def __init__(self, **kwargs): 72 | super().__init__(**kwargs) 73 | self.BCELoss = nn.BCELoss() 74 | 75 | def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 76 | return self.BCELoss(input, target) 77 | 78 | class MultiTaskLoss(base.Loss, nn.Module): 79 | """ 80 | https://github.com/oxcsaml2019/multitask-learning/blob/master/multitask-learning/losses.py 81 | """ 82 | def __init__(self, model, **kwargs): #task1 -> SEG, task2 -> BOUNDARY, task3 -> RECONSTRUCTION 83 | super().__init__(**kwargs) 84 | self.model = model 85 | self.BCE = _BCrossEntropyloss() 86 | self.L1 = _L1loss() 87 | print("MTL_learnable: Segmentation + Boundary + Reconstruction") 88 | 89 | def forward(self, targets): 90 | 91 | model = self.model 92 | segmentation_mask, edge_mask, reconstruction_mask, self.sigma = model.forward(targets[-1]) 93 | 94 | l1_bce = self.BCE(segmentation_mask, targets[0]) #*2 95 | l2_bce= self.BCE(edge_mask, targets[1]) #*2 96 | l3_l1 = self.L1(reconstruction_mask,targets[2]) 97 | 98 | precision1 = torch.exp(-self.sigma[0]) 99 | loss = torch.sum(precision1 * l1_bce + (self.sigma[0] * self.sigma[0]) , -1) 100 | 101 | precision2 = torch.exp( -self.sigma[1]) 102 | loss += torch.sum(precision2 * l2_bce + (self.sigma[1] * self.sigma[1]), -1) 103 | 104 | precision3 = torch.exp(-self.sigma[2]) 105 | loss += torch.sum(precision3 * l3_l1 + (self.sigma[2] * self.sigma[2]), -1) 106 | 107 | loss = torch.mean(loss) 108 | 109 | return loss, segmentation_mask, edge_mask, reconstruction_mask#, self.sigma.data.tolist() 110 | 111 | 112 | 113 | -------------------------------------------------------------------------------- /base/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | try: 5 | from inplace_abn import InPlaceABN 6 | except ImportError: 7 | InPlaceABN = None 8 | 9 | 10 | class Conv2dReLU(nn.Sequential): 11 | def __init__( 12 | self, 13 | in_channels, 14 | out_channels, 15 | kernel_size, 16 | padding=0, 17 | stride=1, 18 | use_batchnorm=True, 19 | ): 20 | 21 | if use_batchnorm == "inplace" and InPlaceABN is None: 22 | raise RuntimeError( 23 | "In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. " 24 | + "To install see: https://github.com/mapillary/inplace_abn" 25 | ) 26 | 27 | conv = nn.Conv2d( 28 | in_channels, 29 | out_channels, 30 | kernel_size, 31 | stride=stride, 32 | padding=padding, 33 | bias=not (use_batchnorm), 34 | ) 35 | relu = nn.ReLU(inplace=True) 36 | 37 | if use_batchnorm == "inplace": 38 | bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0) 39 | relu = nn.Identity() 40 | 41 | elif use_batchnorm and use_batchnorm != "inplace": 42 | bn = nn.BatchNorm2d(out_channels) 43 | 44 | else: 45 | bn = nn.Identity() 46 | 47 | super(Conv2dReLU, self).__init__(conv, bn, relu) 48 | 49 | 50 | class SCSEModule(nn.Module): 51 | def __init__(self, in_channels, reduction=16): 52 | super().__init__() 53 | self.cSE = nn.Sequential( 54 | nn.AdaptiveAvgPool2d(1), 55 | nn.Conv2d(in_channels, in_channels // reduction, 1), 56 | nn.ReLU(inplace=True), 57 | nn.Conv2d(in_channels // reduction, in_channels, 1), 58 | nn.Sigmoid(), 59 | ) 60 | self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid()) 61 | 62 | def forward(self, x): 63 | return x * self.cSE(x) + x * self.sSE(x) 64 | 65 | 66 | class ArgMax(nn.Module): 67 | 68 | def __init__(self, dim=None): 69 | super().__init__() 70 | self.dim = dim 71 | 72 | def forward(self, x): 73 | return torch.argmax(x, dim=self.dim) 74 | 75 | 76 | class Activation(nn.Module): 77 | 78 | def __init__(self, name, **params): 79 | 80 | super().__init__() 81 | 82 | if name is None or name == 'identity': 83 | self.activation = nn.Identity(**params) 84 | elif name == 'sigmoid': 85 | self.activation = nn.Sigmoid() 86 | elif name == 'softmax2d': 87 | self.activation = nn.Softmax(dim=1, **params) 88 | elif name == 'softmax': 89 | self.activation = nn.Softmax(**params) 90 | elif name == 'logsoftmax': 91 | self.activation = nn.LogSoftmax(**params) ######## 92 | elif name == 'argmax': 93 | self.activation = ArgMax(**params) 94 | elif name == 'argmax2d': 95 | self.activation = ArgMax(dim=1, **params) 96 | elif callable(name): 97 | self.activation = name(**params) 98 | else: 99 | raise ValueError('Activation should be callable/sigmoid/softmax/logsoftmax/None; got {}'.format(name)) 100 | 101 | def forward(self, x): 102 | return self.activation(x) 103 | 104 | 105 | class Attention(nn.Module): 106 | 107 | def __init__(self, name, **params): 108 | super().__init__() 109 | 110 | if name is None: 111 | self.attention = nn.Identity(**params) 112 | elif name == 'scse': 113 | self.attention = SCSEModule(**params) 114 | else: 115 | raise ValueError("Attention {} is not implemented".format(name)) 116 | 117 | def forward(self, x): 118 | return self.attention(x) 119 | 120 | 121 | class Flatten(nn.Module): 122 | def forward(self, x): 123 | return x.view(x.shape[0], -1) 124 | -------------------------------------------------------------------------------- /utils/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import torch 4 | import torch.nn as nn 5 | from tqdm import tqdm as tqdm 6 | 7 | from .meter import AverageValueMeter 8 | from .losses import MultiTaskLoss 9 | 10 | class Epoch: 11 | 12 | def __init__(self, model, loss, metrics, stage_name, device='cuda', verbose=True): 13 | self.model = model 14 | self.loss = loss 15 | self.metrics = metrics 16 | self.stage_name = stage_name 17 | self.verbose = verbose 18 | self.device = device 19 | 20 | self._to_device() 21 | 22 | def _to_device(self): 23 | self.model.to(self.device) 24 | self.loss.to(self.device) 25 | for metric in self.metrics: 26 | metric.to(self.device) 27 | 28 | def _format_logs(self, logs): 29 | str_logs = ['{} - {:.4}'.format(k, v) for k, v in logs.items()] 30 | s = ', '.join(str_logs) 31 | return s 32 | 33 | def batch_update(self, x, y, z): 34 | raise NotImplementedError 35 | 36 | def on_epoch_start(self): 37 | pass 38 | def on_epoch_end(self,): 39 | pass 40 | def run(self, dataloader): 41 | 42 | self.on_epoch_start() 43 | 44 | logs = {} 45 | loss_meter = AverageValueMeter() 46 | metrics_meters = {metric.__name__: AverageValueMeter() for metric in self.metrics} 47 | 48 | with tqdm(dataloader, desc=self.stage_name, file=sys.stdout, disable=not (self.verbose)) as iterator: 49 | for x, y, z in iterator: 50 | 51 | x, y, z= x.to(self.device), y.to(self.device), z.to(self.device) 52 | loss, y_pred, edge_mask, reconstruction_mask = self.batch_update(x, y, z) 53 | # update loss logs 54 | loss_value = loss.cpu().detach().numpy() 55 | loss_meter.add(loss_value) 56 | loss_logs = {self.loss.__name__: loss_meter.mean} 57 | logs.update(loss_logs) 58 | 59 | # update metrics logs 60 | for metric_fn in self.metrics: 61 | metric_value = metric_fn(y_pred, y).cpu().detach().numpy() 62 | metrics_meters[metric_fn.__name__].add(metric_value) 63 | metrics_logs = {k: v.mean for k, v in metrics_meters.items()} 64 | logs.update(metrics_logs) 65 | 66 | if self.verbose: 67 | s = self._format_logs(logs) 68 | iterator.set_postfix_str(s) 69 | 70 | return logs 71 | 72 | class TrainEpoch(Epoch): 73 | 74 | def __init__(self, model, loss, metrics, optimizer, device='cuda', verbose=True): 75 | super().__init__( 76 | model=model, 77 | loss=loss, 78 | stage_name='train', 79 | device=device, 80 | verbose=verbose, 81 | metrics = metrics 82 | ) 83 | self.optimizer = optimizer 84 | 85 | def on_epoch_start(self): 86 | self.model.train() 87 | 88 | def batch_update(self, x, y, z): #x-> input image, y-> GT #, z -> boundary_GT 89 | 90 | self.optimizer.zero_grad() 91 | 92 | targets = [y, z, x] 93 | 94 | MTLLoss, segmentation_mask, edge_mask, reconstruction_mask = self.loss(targets) 95 | total_loss = MTLLoss 96 | 97 | total_loss.backward() 98 | 99 | self.optimizer.step() 100 | 101 | return total_loss, segmentation_mask, edge_mask, reconstruction_mask#,log_vars 102 | 103 | class ValidEpoch(Epoch): 104 | 105 | def __init__(self, model, loss, metrics, device='cuda', verbose=True): 106 | super().__init__( 107 | model=model, 108 | loss=loss, 109 | metrics=metrics, 110 | stage_name='valid', 111 | device=device, 112 | verbose=verbose, 113 | ) 114 | 115 | def on_epoch_start(self): 116 | self.model.eval() 117 | 118 | def batch_update(self, x, y, z): 119 | with torch.no_grad(): 120 | 121 | targets = [y, z, x] 122 | 123 | MTLLoss, segmentation_mask, edge_mask, reconstruction_mask = self.loss(targets) 124 | total_loss = MTLLoss 125 | 126 | 127 | return total_loss, segmentation_mask, edge_mask, reconstruction_mask 128 | -------------------------------------------------------------------------------- /unet/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from base import modules as md 6 | 7 | class DecoderBlock(nn.Module): 8 | def __init__( 9 | self, 10 | in_channels, 11 | skip_channels, 12 | out_channels, 13 | use_batchnorm=True, 14 | attention_type=None, 15 | ): 16 | super().__init__() 17 | self.conv1 = md.Conv2dReLU( 18 | in_channels + skip_channels, 19 | out_channels, 20 | kernel_size=3, 21 | padding=1, 22 | use_batchnorm=use_batchnorm, 23 | ) 24 | self.attention1 = md.Attention(attention_type, in_channels=in_channels + skip_channels) 25 | self.conv2 = md.Conv2dReLU( 26 | out_channels, 27 | out_channels, 28 | kernel_size=3, 29 | padding=1, 30 | use_batchnorm=use_batchnorm, 31 | ) 32 | self.attention2 = md.Attention(attention_type, in_channels=out_channels) 33 | 34 | def forward(self, x, skip=None): 35 | x = F.interpolate(x, scale_factor=2, mode="nearest") 36 | if skip is not None: 37 | #print("X.shape", x.shape) 38 | #print("skip.shape", skip.shape) 39 | x = torch.cat([x, skip], dim=1) 40 | x = self.attention1(x) 41 | x = self.conv1(x) 42 | x = self.conv2(x) 43 | x = self.attention2(x) 44 | return x 45 | 46 | class CenterBlock(nn.Sequential): 47 | def __init__(self, in_channels, out_channels, use_batchnorm=True): 48 | conv1 = md.Conv2dReLU( 49 | in_channels, 50 | out_channels, 51 | kernel_size=3, 52 | padding=1, 53 | use_batchnorm=use_batchnorm, 54 | ) 55 | conv2 = md.Conv2dReLU( 56 | out_channels, 57 | out_channels, 58 | kernel_size=3, 59 | padding=1, 60 | use_batchnorm=use_batchnorm, 61 | ) 62 | super().__init__(conv1, conv2) 63 | 64 | 65 | class UnetDecoder(nn.Module): 66 | def __init__( 67 | self, 68 | encoder_channels, 69 | decoder_channels, 70 | n_blocks=5, 71 | use_batchnorm=True, 72 | attention_type=None, 73 | center=False, 74 | ): 75 | super().__init__() 76 | 77 | if n_blocks != len(decoder_channels): 78 | raise ValueError( 79 | "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format( 80 | n_blocks, len(decoder_channels) 81 | ) 82 | ) 83 | 84 | encoder_channels = encoder_channels[1:] # remove first skip with same spatial resolution 85 | encoder_channels = encoder_channels[::-1] # reverse channels to start from head of encoder 86 | 87 | # computing blocks input and output channels 88 | head_channels = encoder_channels[0] 89 | in_channels = [head_channels] + list(decoder_channels[:-1]) 90 | skip_channels = list(encoder_channels[1:]) + [0] 91 | out_channels = decoder_channels 92 | 93 | if center: 94 | self.center = CenterBlock( 95 | head_channels, head_channels, use_batchnorm=use_batchnorm 96 | ) 97 | else: 98 | self.center = nn.Identity() 99 | 100 | # combine decoder keyword arguments 101 | kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type) 102 | blocks = [ 103 | DecoderBlock(in_ch, skip_ch, out_ch, **kwargs) 104 | for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels) 105 | ] 106 | self.blocks = nn.ModuleList(blocks) 107 | 108 | def forward(self, *features): 109 | 110 | features = features[1:] # remove first skip with same spatial resolution 111 | features = features[::-1] # reverse channels to start from head of encoder 112 | 113 | head = features[0] 114 | skips = features[1:] 115 | 116 | x = self.center(head) 117 | for i, decoder_block in enumerate(self.blocks): 118 | skip = skips[i] if i < len(skips) else None 119 | x = decoder_block(x, skip) 120 | 121 | return x -------------------------------------------------------------------------------- /unet/decoder_edge.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from base import modules as md 6 | 7 | 8 | class DecoderBlock(nn.Module): 9 | def __init__( 10 | self, 11 | in_channels, 12 | skip_channels, 13 | out_channels, 14 | use_batchnorm=True, 15 | attention_type=None, 16 | ): 17 | super().__init__() 18 | self.conv1 = md.Conv2dReLU( 19 | in_channels + skip_channels, 20 | out_channels, 21 | kernel_size=3, 22 | padding=1, 23 | use_batchnorm=use_batchnorm, 24 | ) 25 | self.attention1 = md.Attention(attention_type, in_channels=in_channels + skip_channels) 26 | self.conv2 = md.Conv2dReLU( 27 | out_channels, 28 | out_channels, 29 | kernel_size=3, 30 | padding=1, 31 | use_batchnorm=use_batchnorm, 32 | ) 33 | self.attention2 = md.Attention(attention_type, in_channels=out_channels) 34 | 35 | def forward(self, x, skip=None): 36 | x = F.interpolate(x, scale_factor=2, mode="nearest") 37 | if skip is not None: 38 | #print("X.shape", x.shape) 39 | #print("skip.shape", skip.shape) 40 | x = torch.cat([x, skip], dim=1) 41 | x = self.attention1(x) 42 | x = self.conv1(x) 43 | x = self.conv2(x) 44 | x = self.attention2(x) 45 | return x 46 | 47 | 48 | class CenterBlock(nn.Sequential): 49 | def __init__(self, in_channels, out_channels, use_batchnorm=True): 50 | conv1 = md.Conv2dReLU( 51 | in_channels, 52 | out_channels, 53 | kernel_size=3, 54 | padding=1, 55 | use_batchnorm=use_batchnorm, 56 | ) 57 | conv2 = md.Conv2dReLU( 58 | out_channels, 59 | out_channels, 60 | kernel_size=3, 61 | padding=1, 62 | use_batchnorm=use_batchnorm, 63 | ) 64 | super().__init__(conv1, conv2) 65 | 66 | 67 | class UnetDecoder_Edge(nn.Module): 68 | def __init__( 69 | self, 70 | encoder_channels, 71 | decoder_channels, 72 | n_blocks=5, 73 | use_batchnorm=True, 74 | attention_type=None, 75 | center=False, 76 | ): 77 | super().__init__() 78 | 79 | if n_blocks != len(decoder_channels): 80 | raise ValueError( 81 | "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format( 82 | n_blocks, len(decoder_channels) 83 | ) 84 | ) 85 | 86 | encoder_channels = encoder_channels[1:] # remove first skip with same spatial resolution 87 | encoder_channels = encoder_channels[::-1] # reverse channels to start from head of encoder 88 | 89 | # computing blocks input and output channels 90 | head_channels = encoder_channels[0] 91 | in_channels = [head_channels] + list(decoder_channels[:-1]) 92 | skip_channels = list(encoder_channels[1:]) + [0] 93 | out_channels = decoder_channels 94 | 95 | if center: 96 | self.center = CenterBlock( 97 | head_channels, head_channels, use_batchnorm=use_batchnorm 98 | ) 99 | else: 100 | self.center = nn.Identity() 101 | 102 | # combine decoder keyword arguments 103 | kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type) 104 | blocks = [ 105 | DecoderBlock(in_ch, skip_ch, out_ch, **kwargs) 106 | for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels) 107 | ] 108 | self.blocks = nn.ModuleList(blocks) 109 | 110 | def forward(self, *features): 111 | 112 | features = features[1:] # remove first skip with same spatial resolution 113 | features = features[::-1] # reverse channels to start from head of encoder 114 | 115 | head = features[0] 116 | skips = features[1:] 117 | 118 | x = self.center(head) 119 | for i, decoder_block in enumerate(self.blocks): 120 | skip = skips[i] if i < len(skips) else None 121 | x = decoder_block(x, skip) 122 | 123 | return x -------------------------------------------------------------------------------- /utils/functional.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def _take_channels(*xs, ignore_channels=None): 4 | if ignore_channels is None: 5 | return xs 6 | else: 7 | channels = [channel for channel in range(xs[0].shape[1]) if channel not in ignore_channels] 8 | xs = [torch.index_select(x, dim=1, index=torch.tensor(channels).to(x.device)) for x in xs] 9 | return xs 10 | 11 | 12 | def _threshold(x, threshold=None): 13 | if threshold is not None: 14 | return (x > threshold).type(x.dtype) 15 | else: 16 | return x 17 | 18 | 19 | def iou(pr, gt, eps=1e-7, threshold=None, ignore_channels=None): 20 | """Calculate Intersection over Union between ground truth and prediction 21 | Args: 22 | pr (torch.Tensor): predicted tensor 23 | gt (torch.Tensor): ground truth tensor 24 | eps (float): epsilon to avoid zero division 25 | threshold: threshold for outputs binarization 26 | Returns: 27 | float: IoU (Jaccard) score 28 | """ 29 | 30 | pr = _threshold(pr, threshold=threshold) 31 | pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) 32 | 33 | intersection = torch.sum(gt * pr) 34 | union = torch.sum(gt) + torch.sum(pr) - intersection + eps 35 | return (intersection + eps) / union 36 | 37 | 38 | jaccard = iou 39 | 40 | 41 | def f_score(pr, gt, beta=1, eps=1e-7, threshold=None, ignore_channels=None): 42 | """Calculate F-score between ground truth and prediction 43 | Args: 44 | pr (torch.Tensor): predicted tensor 45 | gt (torch.Tensor): ground truth tensor 46 | beta (float): positive constant 47 | eps (float): epsilon to avoid zero division 48 | threshold: threshold for outputs binarization 49 | Returns: 50 | float: F score 51 | """ 52 | 53 | pr = _threshold(pr, threshold=threshold) 54 | pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) 55 | #print("pr.shape",pr.shape) 56 | #print("gt.shape",gt.shape) 57 | 58 | tp = torch.sum(gt * pr) 59 | fp = torch.sum(pr) - tp 60 | fn = torch.sum(gt) - tp 61 | 62 | score = ((1 + beta ** 2) * tp + eps) \ 63 | / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + eps) 64 | 65 | return score 66 | 67 | 68 | def accuracy(pr, gt, threshold=0.5, ignore_channels=None): 69 | """Calculate accuracy score between ground truth and prediction 70 | Args: 71 | pr (torch.Tensor): predicted tensor 72 | gt (torch.Tensor): ground truth tensor 73 | eps (float): epsilon to avoid zero division 74 | threshold: threshold for outputs binarization 75 | Returns: 76 | float: precision score 77 | """ 78 | pr = _threshold(pr, threshold=threshold) 79 | pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) 80 | 81 | tp = torch.sum(gt == pr, dtype=pr.dtype) 82 | score = tp / gt.view(-1).shape[0] 83 | return score 84 | 85 | 86 | def precision(pr, gt, eps=1e-7, threshold=None, ignore_channels=None): 87 | """Calculate precision score between ground truth and prediction 88 | Args: 89 | pr (torch.Tensor): predicted tensor 90 | gt (torch.Tensor): ground truth tensor 91 | eps (float): epsilon to avoid zero division 92 | threshold: threshold for outputs binarization 93 | Returns: 94 | float: precision score 95 | """ 96 | 97 | pr = _threshold(pr, threshold=threshold) 98 | pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) 99 | 100 | tp = torch.sum(gt * pr) 101 | fp = torch.sum(pr) - tp 102 | 103 | score = (tp + eps) / (tp + fp + eps) 104 | 105 | return score 106 | 107 | 108 | def recall(pr, gt, eps=1e-7, threshold=None, ignore_channels=None): 109 | """Calculate Recall between ground truth and prediction 110 | Args: 111 | pr (torch.Tensor): A list of predicted elements 112 | gt (torch.Tensor): A list of elements that are to be predicted 113 | eps (float): epsilon to avoid zero division 114 | threshold: threshold for outputs binarization 115 | Returns: 116 | float: recall score 117 | """ 118 | 119 | pr = _threshold(pr, threshold=threshold) 120 | pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) 121 | 122 | tp = torch.sum(gt * pr) 123 | fn = torch.sum(gt) - tp 124 | 125 | score = (tp + eps) / (tp + fn + eps) 126 | 127 | return score 128 | -------------------------------------------------------------------------------- /unet/model.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from base import SegmentationModel, heads 7 | from encoders import get_encoder 8 | from .decoder import UnetDecoder 9 | from .decoder_edge import UnetDecoder_Edge 10 | from .decoder_skipless import UnetDecoder_skipless 11 | 12 | class Unet(SegmentationModel): 13 | 14 | """Unet_ is a fully convolution neural network for image semantic segmentation 15 | Args: 16 | encoder_name: name of classification model (without last dense layers) used as feature 17 | extractor to build segmentation model. 18 | encoder_depth (int): number of stages used in decoder, larger depth - more features are generated. 19 | e.g. for depth=3 encoder will generate list of features with following spatial shapes 20 | [(H,W), (H/2, W/2), (H/4, W/4), (H/8, W/8)], so in general the deepest feature tensor will have 21 | spatial resolution (H/(2^depth), W/(2^depth)] 22 | encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet). 23 | decoder_channels: list of numbers of ``Conv2D`` layer filters in decoder blocks 24 | decoder_use_batchnorm: if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers 25 | is used. If 'inplace' InplaceABN will be used, allows to decrease memory consumption. 26 | One of [True, False, 'inplace'] 27 | decoder_attention_type: attention module used in decoder of the model 28 | One of [``None``, ``scse``] 29 | in_channels: number of input channels for model, default is 3. 30 | classes: a number of classes for output (output shape - ``(batch, classes, h, w)``). 31 | activation: activation function to apply after final convolution; 32 | One of [``sigmoid``, ``softmax``, ``logsoftmax``, ``identity``, callable, None] 33 | aux_params: if specified model will have additional classification auxiliary output 34 | build on top of encoder, supported params: 35 | - classes (int): number of classes 36 | - pooling (str): one of 'max', 'avg'. Default is 'avg'. 37 | - dropout (float): dropout factor in [0, 1) 38 | - activation (str): activation function to apply "sigmoid"/"softmax" (could be None to return logits) 39 | Returns: 40 | ``torch.nn.Module``: **Unet** 41 | .. _Unet: 42 | https://arxiv.org/pdf/1505.04597 43 | """ 44 | 45 | def __init__( 46 | self, 47 | encoder_name: str = "resnet34", 48 | encoder_depth: int = 5, 49 | encoder_weights: str = "imagenet", 50 | decoder_use_batchnorm: bool = True, 51 | decoder_channels: List[int] = (256, 128, 64, 32, 16), 52 | decoder_attention_type: Optional[str] = None, 53 | in_channels: int = 3, 54 | classes: int = 1, 55 | activation: Optional[Union[str, callable]] = None, 56 | aux_params: Optional[dict] = None, 57 | sigma = nn.Parameter(torch.tensor([0,0,0], requires_grad=True, dtype=torch.float32).cuda()) 58 | 59 | ): 60 | super().__init__() 61 | 62 | self.sigma = sigma 63 | self.encoder = get_encoder( 64 | encoder_name, 65 | in_channels=in_channels, 66 | depth=encoder_depth, 67 | weights= encoder_weights 68 | ) 69 | 70 | self.decoder = UnetDecoder( 71 | encoder_channels=self.encoder.out_channels, 72 | decoder_channels=decoder_channels, 73 | n_blocks=encoder_depth, 74 | use_batchnorm=decoder_use_batchnorm, 75 | center=True if encoder_name.startswith("vgg") else False, 76 | attention_type=decoder_attention_type, 77 | ) 78 | #for reconstruction -decoder 79 | self.skipless_decoder = UnetDecoder_skipless( 80 | encoder_channels=self.encoder.out_channels, 81 | decoder_channels=decoder_channels, 82 | n_blocks=encoder_depth, 83 | use_batchnorm=decoder_use_batchnorm, 84 | center=True if encoder_name.startswith("vgg") else False, 85 | attention_type=decoder_attention_type, 86 | ) 87 | 88 | self.UnetDecoder_Edge = UnetDecoder_Edge( 89 | encoder_channels=self.encoder.out_channels, 90 | decoder_channels=decoder_channels, 91 | n_blocks=encoder_depth, 92 | use_batchnorm=decoder_use_batchnorm, 93 | center=True if encoder_name.startswith("vgg") else False, 94 | attention_type=decoder_attention_type, 95 | ) 96 | 97 | #for reconstruction -head 98 | self.reconstruct_segmentation_head = heads.AUX_SegmentationHead( 99 | in_channels=decoder_channels[-1], 100 | out_channels= 3, #classes, channels 101 | activation= None, 102 | kernel_size=1, 103 | ) 104 | #for segmentation -head 105 | self.segmentation_head = heads.SegmentationHead( 106 | in_channels=decoder_channels[-1], 107 | out_channels=classes, 108 | activation=activation, 109 | kernel_size=3, 110 | ) 111 | #for edge -head 112 | self.edge_head = heads.AUX_edgehead( 113 | in_channels=decoder_channels[-1], 114 | out_channels=classes, 115 | activation= activation, 116 | kernel_size= 3, 117 | ) 118 | 119 | 120 | if aux_params is not None: 121 | self.classification_head = heads.ClassificationHead( 122 | in_channels=self.encoder.out_channels[-1], **aux_params 123 | ) 124 | else: 125 | self.classification_head = None 126 | 127 | self.name = "u-{}".format(encoder_name) 128 | self.initialize() 129 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import albumentations as A 2 | import cv2 3 | import numpy as np 4 | import skimage 5 | import torch 6 | from torch.utils.data import Dataset 7 | 8 | from utilities import convert_from_color, get_id 9 | 10 | 11 | class SemSeg_Custom(Dataset): 12 | """ 13 | Outputs image, mask, boun_mask with the shape of (C,H,W) where C equals to N (number of classes). 14 | On-the-fly cropping and boundary extraction operations. 15 | """ 16 | def __init__( 17 | self, 18 | psp_dir, 19 | masks_dir, 20 | mode='train', 21 | augmentation=None, 22 | preprocessing=None, 23 | image_mask = False, 24 | inference = False): 25 | 26 | self.masks_dir = masks_dir 27 | self.psp_dir = psp_dir 28 | self.augmentation = augmentation 29 | self.preprocessing = preprocessing 30 | self.class_values = [0,1] # class values to encode 31 | self.image_mask = image_mask 32 | self.inference = inference 33 | 34 | train_ids, validation_ids, test_ids = get_id(psp_dir, verbose=False) 35 | 36 | print("Len train_ids: {}, Len val_ids {}, Len test_ids {}".format(len(train_ids),len(validation_ids),len(test_ids))) 37 | masks_dir_ = masks_dir + '\{}' 38 | psp_dir_ = psp_dir + '\SN6_Train_AOI_11_Rotterdam_PS-RGB_{}' 39 | 40 | if mode == 'train': 41 | self.psp_list = [psp_dir_.format(id) for id in train_ids] 42 | self.mask_list = [masks_dir_.format(id) for id in train_ids] 43 | 44 | elif mode == 'validation': 45 | self.psp_list = [psp_dir_.format(id) for id in validation_ids] 46 | self.mask_list = [masks_dir_.format(id) for id in validation_ids] 47 | 48 | elif mode == 'test': 49 | self.psp_list = [psp_dir_.format(id) for id in test_ids] 50 | self.mask_list = [masks_dir_.format(id) for id in test_ids] 51 | 52 | @staticmethod 53 | def get_boundary(label, kernel_size = (3,3)): 54 | tlabel = label.astype(np.uint8) 55 | temp = cv2.Canny(tlabel,0,1) 56 | tlabel = cv2.dilate( 57 | temp, 58 | cv2.getStructuringElement( 59 | cv2.MORPH_CROSS, 60 | kernel_size), 61 | iterations = 2) 62 | tlabel = tlabel.astype(np.float32) 63 | tlabel /= 255. 64 | return tlabel 65 | 66 | @staticmethod 67 | def _read_img(image_path): 68 | img = skimage.io.imread(image_path, plugin='tifffile') 69 | return img 70 | 71 | def __len__(self): 72 | return len(self.mask_list) 73 | 74 | def __getitem__(self, idx): 75 | 76 | psp_filepath = self.psp_list[idx] 77 | mask_filepath = self.mask_list[idx] 78 | image = self._read_img(psp_filepath) 79 | mask = self._read_img(mask_filepath) 80 | 81 | mask_raw = convert_from_color(mask) 82 | masks = [(mask_raw == v) for v in [0,1]] 83 | mask = np.stack(masks, axis=-1).astype('uint8') #(480, 480, 2)) 84 | 85 | boun_mask = self.get_boundary(mask_raw) #(480, 480, 2)) 86 | boun_mask = [(boun_mask == v) for v in [1,0]] 87 | boun_mask = np.stack(boun_mask, axis=-1).astype('uint8') #(480, 480, 2)) 88 | 89 | if self.augmentation: 90 | transformed = A.Compose(self.augmentation, p=1)(image=image, masks=[mask, boun_mask]) 91 | 92 | image, mask, boun_mask = transformed['image'], transformed['masks'][0], transformed['masks'][1] 93 | 94 | if self.preprocessing: 95 | preprocessed = self.preprocessing(image=image, mask=mask, boundary_mask = boun_mask) 96 | image, mask, boun_mask = preprocessed['image'], preprocessed['mask'], preprocessed['boundary_mask'] 97 | 98 | image = image[...] / 255.0 99 | 100 | image = np.asarray(image).transpose(2,0,1) 101 | mask = np.asarray(mask).transpose(2,0,1) 102 | boun_mask = np.asarray(boun_mask).transpose(2,0,1) 103 | 104 | image = torch.as_tensor(image, dtype=torch.float32) 105 | mask = torch.as_tensor(mask, dtype=torch.float32) 106 | boun_mask = torch.as_tensor(boun_mask, dtype=torch.float32) 107 | 108 | if self.inference == True: 109 | return image,mask,boun_mask,self.psp_list 110 | if self.image_mask == True : 111 | return image, mask 112 | if self.image_mask == False: 113 | return image, mask, boun_mask 114 | 115 | def get_training_augmentation(crop_size): 116 | train_transform = [ 117 | 118 | A.OneOf([A.RandomCrop(crop_size[0], crop_size[1], p=1.0) 119 | ], p=1.0), 120 | 121 | A.HorizontalFlip(p=0.5), 122 | A.VerticalFlip(p=0.5), 123 | A.IAAAdditiveGaussianNoise(p=0.2), 124 | A.IAAPerspective(p=0.5), 125 | 126 | A.OneOf( 127 | [ 128 | A.CLAHE(p=1), 129 | A.RandomBrightness(p=1), 130 | A.RandomGamma(p=1), 131 | ], 132 | p=0.7, 133 | ), 134 | 135 | A.OneOf( 136 | [ 137 | A.IAASharpen(p=1), 138 | A.Blur(blur_limit=3, p=1), 139 | A.MotionBlur(blur_limit=3, p=1), 140 | ], 141 | p=0.7, 142 | ), 143 | 144 | A.OneOf( 145 | [ 146 | A.RandomContrast(p=1), 147 | A.HueSaturationValue(p=1), 148 | ], 149 | p=0.7, 150 | ), 151 | 152 | 153 | ] 154 | 155 | return A.Compose(train_transform) 156 | 157 | def get_val_augmentation(crop_size): 158 | val_transform = [ 159 | 160 | A.OneOf([A.RandomCrop(crop_size[0], crop_size[1], p=1.0) 161 | ], p=1.0),] 162 | 163 | return A.Compose(val_transform) 164 | 165 | def get_test_augmentation(crop_size): # ensure determinism. replace random crop with center crop. 166 | val_transform = [ 167 | 168 | A.OneOf([A.CenterCrop(crop_size[0], crop_size[1], p=1.0) 169 | ], p=1.0),] 170 | 171 | return A.Compose(val_transform) 172 | 173 | def to_tensor(x, **kwargs): 174 | return x.transpose(2,0,1).astype('float32') 175 | 176 | #preprocessing_fn = get_preprocessing_fn(params['encoder'], params['encoder_weights']) 177 | 178 | def get_preprocessing(preprocessing_fn): 179 | 180 | _transform = [ 181 | A.Lambda(image=preprocessing_fn), 182 | A.Lambda(image=to_tensor, mask= to_tensor), 183 | ] 184 | return A.Compose(_transform) 185 | -------------------------------------------------------------------------------- /encoders/resnet.py: -------------------------------------------------------------------------------- 1 | """ Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` 2 | 3 | Attributes: 4 | 5 | _out_channels (list of int): specify number of channels for each encoder feature tensor 6 | _depth (int): specify number of stages in decoder (in other words number of downsampling operations) 7 | _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) 8 | 9 | Methods: 10 | 11 | forward(self, x: torch.Tensor) 12 | produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of 13 | shape NCHW (features should be sorted in descending order according to spatial resolution, starting 14 | with resolution same as input `x` tensor). 15 | 16 | Input: `x` with shape (1, 3, 64, 64) 17 | Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes 18 | [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), 19 | (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) 20 | 21 | also should support number of features according to specified depth, e.g. if depth = 5, 22 | number of feature tensors = 6 (one with same resolution as input and 5 downsampled), 23 | depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). 24 | """ 25 | from copy import deepcopy 26 | 27 | import torch.nn as nn 28 | 29 | from torchvision.models.resnet import ResNet 30 | from torchvision.models.resnet import BasicBlock 31 | from torchvision.models.resnet import Bottleneck 32 | from pretrainedmodels.models.torchvision_models import pretrained_settings 33 | 34 | from ._base import EncoderMixin 35 | 36 | 37 | class ResNetEncoder(ResNet, EncoderMixin): 38 | def __init__(self, out_channels, depth=5, **kwargs): 39 | super().__init__(**kwargs) 40 | self._depth = depth 41 | self._out_channels = out_channels 42 | self._in_channels = 3 43 | 44 | del self.fc 45 | del self.avgpool 46 | 47 | def get_stages(self): 48 | return [ 49 | nn.Identity(), 50 | nn.Sequential(self.conv1, self.bn1, self.relu), 51 | nn.Sequential(self.maxpool, self.layer1), 52 | self.layer2, 53 | self.layer3, 54 | self.layer4, 55 | ] 56 | 57 | def forward(self, x): 58 | stages = self.get_stages() 59 | 60 | features = [] 61 | for i in range(self._depth + 1): 62 | x = stages[i](x) 63 | features.append(x) 64 | 65 | return features 66 | 67 | def load_state_dict(self, state_dict, **kwargs): 68 | state_dict.pop("fc.bias") 69 | state_dict.pop("fc.weight") 70 | super().load_state_dict(state_dict, **kwargs) 71 | 72 | 73 | new_settings = { 74 | "resnet18": { 75 | "ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet18-d92f0530.pth", 76 | "swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet18-118f1556.pth" 77 | }, 78 | "resnet50": { 79 | "ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet50-08389792.pth", 80 | "swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet50-16a12f1b.pth" 81 | }, 82 | "resnext50_32x4d": { 83 | "imagenet": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", 84 | "ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext50_32x4-ddb3e555.pth", 85 | "swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext50_32x4-72679e44.pth", 86 | }, 87 | "resnext101_32x4d": { 88 | "ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x4-dc43570a.pth", 89 | "swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x4-3f87e46b.pth" 90 | }, 91 | "resnext101_32x8d": { 92 | "imagenet": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", 93 | "instagram": "https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth", 94 | "ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x8-2cfe2f8b.pth", 95 | "swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x8-b4712904.pth", 96 | }, 97 | "resnext101_32x16d": { 98 | "instagram": "https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth", 99 | "ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x16-15fffa57.pth", 100 | "swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x16-f3559a9c.pth", 101 | }, 102 | "resnext101_32x32d": { 103 | "instagram": "https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth", 104 | }, 105 | "resnext101_32x48d": { 106 | "instagram": "https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth", 107 | } 108 | } 109 | 110 | pretrained_settings = deepcopy(pretrained_settings) 111 | for model_name, sources in new_settings.items(): 112 | if model_name not in pretrained_settings: 113 | pretrained_settings[model_name] = {} 114 | 115 | for source_name, source_url in sources.items(): 116 | pretrained_settings[model_name][source_name] = { 117 | "url": source_url, 118 | 'input_size': [3, 224, 224], 119 | 'input_range': [0, 1], 120 | 'mean': [0.485, 0.456, 0.406], 121 | 'std': [0.229, 0.224, 0.225], 122 | 'num_classes': 1000 123 | } 124 | 125 | 126 | resnet_encoders = { 127 | "resnet18": { 128 | "encoder": ResNetEncoder, 129 | "pretrained_settings": pretrained_settings["resnet18"], 130 | "params": { 131 | "out_channels": (3, 64, 64, 128, 256, 512), 132 | "block": BasicBlock, 133 | "layers": [2, 2, 2, 2], 134 | }, 135 | }, 136 | "resnet34": { 137 | "encoder": ResNetEncoder, 138 | "pretrained_settings": pretrained_settings["resnet34"], 139 | "params": { 140 | "out_channels": (3, 64, 64, 128, 256, 512), 141 | "block": BasicBlock, 142 | "layers": [3, 4, 6, 3], 143 | }, 144 | }, 145 | "resnet50": { 146 | "encoder": ResNetEncoder, 147 | "pretrained_settings": pretrained_settings["resnet50"], 148 | "params": { 149 | "out_channels": (3, 64, 256, 512, 1024, 2048), 150 | "block": Bottleneck, 151 | "layers": [3, 4, 6, 3], 152 | }, 153 | }, 154 | "resnet101": { 155 | "encoder": ResNetEncoder, 156 | "pretrained_settings": pretrained_settings["resnet101"], 157 | "params": { 158 | "out_channels": (3, 64, 256, 512, 1024, 2048), 159 | "block": Bottleneck, 160 | "layers": [3, 4, 23, 3], 161 | }, 162 | }, 163 | "resnet152": { 164 | "encoder": ResNetEncoder, 165 | "pretrained_settings": pretrained_settings["resnet152"], 166 | "params": { 167 | "out_channels": (3, 64, 256, 512, 1024, 2048), 168 | "block": Bottleneck, 169 | "layers": [3, 8, 36, 3], 170 | }, 171 | }, 172 | "resnext50_32x4d": { 173 | "encoder": ResNetEncoder, 174 | "pretrained_settings": pretrained_settings["resnext50_32x4d"], 175 | "params": { 176 | "out_channels": (3, 64, 256, 512, 1024, 2048), 177 | "block": Bottleneck, 178 | "layers": [3, 4, 6, 3], 179 | "groups": 32, 180 | "width_per_group": 4, 181 | }, 182 | }, 183 | "resnext101_32x4d": { 184 | "encoder": ResNetEncoder, 185 | "pretrained_settings": pretrained_settings["resnext101_32x4d"], 186 | "params": { 187 | "out_channels": (3, 64, 256, 512, 1024, 2048), 188 | "block": Bottleneck, 189 | "layers": [3, 4, 23, 3], 190 | "groups": 32, 191 | "width_per_group": 4, 192 | }, 193 | }, 194 | "resnext101_32x8d": { 195 | "encoder": ResNetEncoder, 196 | "pretrained_settings": pretrained_settings["resnext101_32x8d"], 197 | "params": { 198 | "out_channels": (3, 64, 256, 512, 1024, 2048), 199 | "block": Bottleneck, 200 | "layers": [3, 4, 23, 3], 201 | "groups": 32, 202 | "width_per_group": 8, 203 | }, 204 | }, 205 | "resnext101_32x16d": { 206 | "encoder": ResNetEncoder, 207 | "pretrained_settings": pretrained_settings["resnext101_32x16d"], 208 | "params": { 209 | "out_channels": (3, 64, 256, 512, 1024, 2048), 210 | "block": Bottleneck, 211 | "layers": [3, 4, 23, 3], 212 | "groups": 32, 213 | "width_per_group": 16, 214 | }, 215 | }, 216 | "resnext101_32x32d": { 217 | "encoder": ResNetEncoder, 218 | "pretrained_settings": pretrained_settings["resnext101_32x32d"], 219 | "params": { 220 | "out_channels": (3, 64, 256, 512, 1024, 2048), 221 | "block": Bottleneck, 222 | "layers": [3, 4, 23, 3], 223 | "groups": 32, 224 | "width_per_group": 32, 225 | }, 226 | }, 227 | "resnext101_32x48d": { 228 | "encoder": ResNetEncoder, 229 | "pretrained_settings": pretrained_settings["resnext101_32x48d"], 230 | "params": { 231 | "out_channels": (3, 64, 256, 512, 1024, 2048), 232 | "block": Bottleneck, 233 | "layers": [3, 4, 23, 3], 234 | "groups": 32, 235 | "width_per_group": 48, 236 | }, 237 | }, 238 | } 239 | --------------------------------------------------------------------------------