├── change_detection_pytorch ├── V2 │ └── __init__.py ├── datasets │ ├── transforms │ │ └── __init__.py │ ├── __init__.py │ ├── SVCD.py │ └── LEVIR_CD.py ├── fpn │ ├── __init__.py │ ├── decoder.py │ └── model.py ├── pan │ ├── __init__.py │ └── model.py ├── pspnet │ ├── __init__.py │ ├── decoder.py │ └── model.py ├── stanet │ ├── __init__.py │ ├── BAM.py │ └── model.py ├── unet │ ├── __init__.py │ ├── decoder.py │ └── model.py ├── linknet │ ├── __init__.py │ ├── decoder.py │ └── model.py ├── manet │ ├── __init__.py │ └── model.py ├── upernet │ ├── __init__.py │ ├── model.py │ └── decoder.py ├── unetplusplus │ ├── __init__.py │ ├── model.py │ └── decoder.py ├── deeplabv3 │ └── __init__.py ├── __version__.py ├── utils │ ├── __init__.py │ ├── losses.py │ ├── meter.py │ ├── base.py │ ├── utils.py │ ├── metrics.py │ └── functional.py ├── base │ ├── __init__.py │ ├── initialization.py │ ├── decoder.py │ ├── heads.py │ └── model.py ├── losses │ ├── __init__.py │ ├── bcl.py │ ├── constants.py │ ├── soft_ce.py │ ├── hybrid_loss.py │ ├── tversky.py │ ├── soft_bce.py │ ├── focal.py │ ├── jaccard.py │ └── dice.py ├── encoders │ ├── _preprocessing.py │ ├── timm_universal.py │ ├── _base.py │ ├── xception.py │ ├── _utils.py │ ├── mobilenet.py │ ├── inceptionresnetv2.py │ ├── inceptionv4.py │ ├── timm_sknet.py │ ├── __init__.py │ ├── timm_gernet.py │ ├── densenet.py │ ├── vgg.py │ ├── timm_res2net.py │ ├── senet.py │ └── dpn.py └── __init__.py ├── __init__.py ├── resources ├── wechat.jpg └── model architecture.png ├── requirements.txt ├── lino_test.py ├── COMPETITIONS.md ├── misc ├── generate_table.py └── generate_table_timm.py ├── LICENSE ├── tests ├── test_preprocessing.py └── test_models.py ├── .gitignore ├── local_test.py └── setup.py /change_detection_pytorch/V2/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from change_detection_pytorch import * -------------------------------------------------------------------------------- /change_detection_pytorch/datasets/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /change_detection_pytorch/fpn/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import FPN -------------------------------------------------------------------------------- /change_detection_pytorch/pan/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import PAN -------------------------------------------------------------------------------- /change_detection_pytorch/pspnet/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import PSPNet -------------------------------------------------------------------------------- /change_detection_pytorch/stanet/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import STANet -------------------------------------------------------------------------------- /change_detection_pytorch/unet/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import Unet -------------------------------------------------------------------------------- /change_detection_pytorch/linknet/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import Linknet -------------------------------------------------------------------------------- /change_detection_pytorch/manet/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import MAnet 2 | -------------------------------------------------------------------------------- /change_detection_pytorch/upernet/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import UPerNet 2 | -------------------------------------------------------------------------------- /change_detection_pytorch/unetplusplus/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import UnetPlusPlus 2 | -------------------------------------------------------------------------------- /change_detection_pytorch/deeplabv3/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import DeepLabV3, DeepLabV3Plus -------------------------------------------------------------------------------- /resources/wechat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/likyoo/change_detection.pytorch/HEAD/resources/wechat.jpg -------------------------------------------------------------------------------- /change_detection_pytorch/__version__.py: -------------------------------------------------------------------------------- 1 | VERSION = (0, 1, 4) 2 | 3 | __version__ = '.'.join(map(str, VERSION)) 4 | -------------------------------------------------------------------------------- /change_detection_pytorch/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .custom import * 2 | from .LEVIR_CD import * 3 | from .SVCD import * 4 | -------------------------------------------------------------------------------- /resources/model architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/likyoo/change_detection.pytorch/HEAD/resources/model architecture.png -------------------------------------------------------------------------------- /change_detection_pytorch/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import train 2 | from . import losses 3 | from . import metrics 4 | from . import lr_scheduler 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torchvision>=0.5.0 2 | pretrainedmodels==0.7.4 3 | efficientnet-pytorch==0.6.3 4 | timm==0.4.12 5 | albumentations>=1.0.0,<=1.0.3 -------------------------------------------------------------------------------- /change_detection_pytorch/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import SegmentationModel 2 | from .decoder import Decoder 3 | 4 | from .modules import ( 5 | Conv2dReLU, 6 | Attention, 7 | ) 8 | 9 | from .heads import ( 10 | SegmentationHead, 11 | ClassificationHead, 12 | ) 13 | -------------------------------------------------------------------------------- /lino_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from change_detection_pytorch.encoders import get_encoder 3 | 4 | if __name__ == '__main__': 5 | sample = torch.randn(1, 3, 256, 256) 6 | model = get_encoder('mit-b0', img_size=256) 7 | res = model(sample) 8 | for x in res: 9 | print(x.size()) 10 | -------------------------------------------------------------------------------- /change_detection_pytorch/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE 2 | 3 | from .jaccard import JaccardLoss 4 | from .dice import DiceLoss 5 | from .focal import FocalLoss 6 | from .lovasz import LovaszLoss 7 | from .soft_bce import SoftBCEWithLogitsLoss 8 | from .soft_ce import SoftCrossEntropyLoss 9 | from .tversky import TverskyLoss 10 | from .hybrid_loss import HybridLoss 11 | from .bcl import BCLLoss 12 | -------------------------------------------------------------------------------- /change_detection_pytorch/encoders/_preprocessing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def preprocess_input( 5 | x, mean=None, std=None, input_space="RGB", input_range=None, **kwargs 6 | ): 7 | 8 | if input_space == "BGR": 9 | x = x[..., ::-1].copy() 10 | 11 | if input_range is not None: 12 | if x.max() > 1 and input_range[1] == 1: 13 | x = x / 255.0 14 | 15 | if mean is not None: 16 | mean = np.array(mean) 17 | x = x - mean 18 | 19 | if std is not None: 20 | std = np.array(std) 21 | x = x / std 22 | 23 | return x 24 | -------------------------------------------------------------------------------- /COMPETITIONS.md: -------------------------------------------------------------------------------- 1 | # Change Detection Competitions with cdp 2 | 3 | `change_detection.pytorch` has competitiveness and potential in the change detection competitions. 4 | Here you can find competitions, names of the winners and links to their solutions. 5 | 6 | 7 | 8 | ------ 9 | 10 | 11 | 12 | ### [PRCV2021 Change Detection Competition](https://captain-whu.github.io/PRCV2021_RS/index.html) 13 | 14 | - 3rd place. 15 | 16 | [Kaiyu Li](https://github.com/likyoo), 17 | [Fulin Sun](https://github.com/LinoSun), 18 | Xudong Liu, 19 | Guoqiang Liu, 20 | [[description](https://github.com/likyoo/PRCV2021_ChangeDetection_Top3)] -------------------------------------------------------------------------------- /change_detection_pytorch/base/initialization.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def initialize_decoder(module): 5 | for m in module.modules(): 6 | 7 | if isinstance(m, nn.Conv2d): 8 | nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu") 9 | if m.bias is not None: 10 | nn.init.constant_(m.bias, 0) 11 | 12 | elif isinstance(m, nn.BatchNorm2d): 13 | nn.init.constant_(m.weight, 1) 14 | nn.init.constant_(m.bias, 0) 15 | 16 | elif isinstance(m, nn.Linear): 17 | nn.init.xavier_uniform_(m.weight) 18 | if m.bias is not None: 19 | nn.init.constant_(m.bias, 0) 20 | 21 | 22 | def initialize_head(module): 23 | for m in module.modules(): 24 | if isinstance(m, (nn.Linear, nn.Conv2d)): 25 | nn.init.xavier_uniform_(m.weight) 26 | if m.bias is not None: 27 | nn.init.constant_(m.bias, 0) 28 | -------------------------------------------------------------------------------- /misc/generate_table.py: -------------------------------------------------------------------------------- 1 | import change_detection_pytorch as smp 2 | 3 | encoders = smp.encoders.encoders 4 | 5 | 6 | WIDTH = 32 7 | COLUMNS = [ 8 | "Encoder", 9 | "Weights", 10 | "Params, M", 11 | ] 12 | 13 | def wrap_row(r): 14 | return "|{}|".format(r) 15 | 16 | header = "|".join([column.ljust(WIDTH, ' ') for column in COLUMNS]) 17 | separator = "|".join(["-" * WIDTH] + [":" + "-" * (WIDTH - 2) + ":"] * (len(COLUMNS) - 1)) 18 | 19 | print(wrap_row(header)) 20 | print(wrap_row(separator)) 21 | 22 | for encoder_name, encoder in encoders.items(): 23 | weights = "
".join(encoder["pretrained_settings"].keys()) 24 | encoder_name = encoder_name.ljust(WIDTH, " ") 25 | weights = weights.ljust(WIDTH, " ") 26 | 27 | model = encoder["encoder"](**encoder["params"], depth=5) 28 | params = sum(p.numel() for p in model.parameters()) 29 | params = str(params // 1000000) + "M" 30 | params = params.ljust(WIDTH, " ") 31 | 32 | row = "|".join([encoder_name, weights, params]) 33 | print(wrap_row(row)) 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Kaiyu Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /change_detection_pytorch/encoders/timm_universal.py: -------------------------------------------------------------------------------- 1 | import timm 2 | import torch.nn as nn 3 | 4 | 5 | class TimmUniversalEncoder(nn.Module): 6 | 7 | def __init__(self, name, pretrained=True, in_channels=3, depth=5, output_stride=32): 8 | super().__init__() 9 | kwargs = dict( 10 | in_chans=in_channels, 11 | features_only=True, 12 | output_stride=output_stride, 13 | pretrained=pretrained, 14 | out_indices=tuple(range(depth)), 15 | ) 16 | 17 | # not all models support output stride argument, drop it by default 18 | if output_stride == 32: 19 | kwargs.pop("output_stride") 20 | 21 | self.model = timm.create_model(name, **kwargs) 22 | 23 | self._in_channels = in_channels 24 | self._out_channels = [3, ] + self.model.feature_info.channels() 25 | self._depth = depth 26 | 27 | def forward(self, x): 28 | features = self.model(x) 29 | features = [x,] + features 30 | return features 31 | 32 | @property 33 | def out_channels(self): 34 | return self._out_channels 35 | -------------------------------------------------------------------------------- /change_detection_pytorch/losses/bcl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.modules.loss import _Loss 3 | 4 | 5 | class BCLLoss(_Loss): 6 | """loss function of stanet""" 7 | def __init__( 8 | self, 9 | label_value: int = 1, 10 | margin: int = 2 11 | ): 12 | super(BCLLoss, self).__init__() 13 | self.margin = margin 14 | self.label_value = label_value 15 | 16 | def forward(self, distance, label): 17 | label = label.float() 18 | label[label == 255] = 1 19 | label[label == self.label_value] = -1 20 | label[label == 0] = 1 21 | mask = (label != 255).float() 22 | distance = distance * mask 23 | pos_num = torch.sum((label == 1).float()) + 0.0001 24 | neg_num = torch.sum((label == -1).float()) + 0.0001 25 | 26 | loss_1 = torch.sum((1 + label) / 2 * torch.pow(distance, 2)) / pos_num 27 | loss_2 = torch.sum((1 - label) / 2 * mask * 28 | torch.pow(torch.clamp(self.margin - distance, min=0.0), 2) 29 | ) / neg_num 30 | loss = loss_1 + loss_2 31 | return loss 32 | -------------------------------------------------------------------------------- /change_detection_pytorch/losses/constants.py: -------------------------------------------------------------------------------- 1 | #: Loss binary mode suppose you are solving binary segmentation task. 2 | #: That mean yor have only one class which pixels are labled as **1**, 3 | #: the rest pixels are background and labeled as **0**. 4 | #: Target mask shape - (N, H, W), model output mask shape (N, 1, H, W). 5 | BINARY_MODE: str = "binary" 6 | 7 | #: Loss multiclass mode suppose you are solving multi-**class** segmentation task. 8 | #: That mean you have *C = 1..N* classes which have unique label values, 9 | #: classes are mutually exclusive and all pixels are labeled with theese values. 10 | #: Target mask shape - (N, H, W), model output mask shape (N, C, H, W). 11 | MULTICLASS_MODE: str = "multiclass" 12 | 13 | #: Loss multilabel mode suppose you are solving multi-**label** segmentation task. 14 | #: That mean you have *C = 1..N* classes which pixels are labeled as **1**, 15 | #: classes are not mutually exclusive and each class have its own *channel*, 16 | #: pixels in each channel which are not belong to class labeled as **0**. 17 | #: Target mask shape - (N, C, H, W), model output mask shape (N, C, H, W). 18 | MULTILABEL_MODE: str = "multilabel" 19 | -------------------------------------------------------------------------------- /change_detection_pytorch/base/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Decoder(torch.nn.Module): 5 | # TODO: support learnable fusion modules 6 | def __init__(self): 7 | super().__init__() 8 | self.FUSION_DIC = {"2to1_fusion": ["sum", "diff", "abs_diff"], 9 | "2to2_fusion": ["concat"]} 10 | 11 | def fusion(self, x1, x2, fusion_form="concat"): 12 | """Specify the form of feature fusion""" 13 | if fusion_form == "concat": 14 | x = torch.cat([x1, x2], dim=1) 15 | elif fusion_form == "sum": 16 | x = x1 + x2 17 | elif fusion_form == "diff": 18 | x = x2 - x1 19 | elif fusion_form == "abs_diff": 20 | x = torch.abs(x1 - x2) 21 | else: 22 | raise ValueError('the fusion form "{}" is not defined'.format(fusion_form)) 23 | 24 | return x 25 | 26 | def aggregation_layer(self, fea1, fea2, fusion_form="concat", ignore_original_img=True): 27 | """aggregate features from siamese or non-siamese branches""" 28 | 29 | start_idx = 1 if ignore_original_img else 0 30 | aggregate_fea = [self.fusion(fea1[idx], fea2[idx], fusion_form) 31 | for idx in range(start_idx, len(fea1))] 32 | 33 | return aggregate_fea 34 | -------------------------------------------------------------------------------- /change_detection_pytorch/base/heads.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .modules import Flatten, Activation 3 | 4 | 5 | class SegmentationHead(nn.Sequential): 6 | 7 | def __init__(self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1, align_corners=True): 8 | conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2) 9 | upsampling = nn.Upsample(scale_factor=upsampling, mode='bilinear', align_corners=align_corners) if upsampling > 1 else nn.Identity() 10 | activation = Activation(activation) 11 | super().__init__(conv2d, upsampling, activation) 12 | 13 | 14 | class ClassificationHead(nn.Sequential): 15 | 16 | def __init__(self, in_channels, classes, pooling="avg", dropout=0.2, activation=None): 17 | if pooling not in ("max", "avg"): 18 | raise ValueError("Pooling should be one of ('max', 'avg'), got {}.".format(pooling)) 19 | pool = nn.AdaptiveAvgPool2d(1) if pooling == 'avg' else nn.AdaptiveMaxPool2d(1) 20 | flatten = Flatten() 21 | dropout = nn.Dropout(p=dropout, inplace=True) if dropout else nn.Identity() 22 | linear = nn.Linear(in_channels, classes, bias=True) 23 | activation = Activation(activation) 24 | super().__init__(pool, flatten, dropout, linear, activation) 25 | -------------------------------------------------------------------------------- /tests/test_preprocessing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import mock 4 | import pytest 5 | import numpy as np 6 | 7 | # mock detection module 8 | sys.modules['torchvision._C'] = mock.Mock() 9 | 10 | import change_detection_pytorch as smp 11 | 12 | 13 | def _test_preprocessing(inp, out, **params): 14 | preprocessed_output = smp.encoders.preprocess_input(inp, **params) 15 | assert np.allclose(preprocessed_output, out) 16 | 17 | 18 | def test_mean(): 19 | inp = np.ones((32, 32, 3)) 20 | out = np.zeros((32, 32, 3)) 21 | mean = (1, 1, 1) 22 | _test_preprocessing(inp, out, mean=mean) 23 | 24 | 25 | def test_std(): 26 | inp = np.ones((32, 32, 3)) * 255 27 | out = np.ones((32, 32, 3)) 28 | std = (255, 255, 255) 29 | _test_preprocessing(inp, out, std=std) 30 | 31 | 32 | def test_input_range(): 33 | inp = np.ones((32, 32, 3)) 34 | out = np.ones((32, 32, 3)) 35 | _test_preprocessing(inp, out, input_range=(0, 1)) 36 | _test_preprocessing(inp * 255, out, input_range=(0, 1)) 37 | _test_preprocessing(inp * 255, out * 255, input_range=(0, 255)) 38 | 39 | 40 | def test_input_space(): 41 | inp = np.stack( 42 | [np.ones((32, 32)), 43 | np.zeros((32, 32))], 44 | axis=-1 45 | ) 46 | out = np.stack( 47 | [np.zeros((32, 32)), 48 | np.ones((32, 32))], 49 | axis=-1 50 | ) 51 | _test_preprocessing(inp, out, input_space='BGR') 52 | -------------------------------------------------------------------------------- /change_detection_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from .unet import Unet 2 | from .unetplusplus import UnetPlusPlus 3 | from .manet import MAnet 4 | from .linknet import Linknet 5 | from .fpn import FPN 6 | from .pspnet import PSPNet 7 | from .deeplabv3 import DeepLabV3, DeepLabV3Plus 8 | from .pan import PAN 9 | from .stanet import STANet 10 | from .upernet import UPerNet 11 | 12 | from . import encoders 13 | from . import utils 14 | from . import losses 15 | from . import datasets 16 | 17 | from .__version__ import __version__ 18 | 19 | from typing import Optional 20 | import torch 21 | 22 | 23 | def create_model( 24 | arch: str, 25 | encoder_name: str = "resnet34", 26 | encoder_weights: Optional[str] = "imagenet", 27 | in_channels: int = 3, 28 | classes: int = 1, 29 | **kwargs, 30 | ) -> torch.nn.Module: 31 | """Models wrapper. Allows to create any model just with parametes 32 | 33 | """ 34 | 35 | archs = [Unet, UnetPlusPlus, MAnet, Linknet, FPN, PSPNet, DeepLabV3, DeepLabV3Plus, PAN, STANet, UPerNet] 36 | archs_dict = {a.__name__.lower(): a for a in archs} 37 | try: 38 | model_class = archs_dict[arch.lower()] 39 | except KeyError: 40 | raise KeyError("Wrong architecture type `{}`. Available options are: {}".format( 41 | arch, list(archs_dict.keys()), 42 | )) 43 | return model_class( 44 | encoder_name=encoder_name, 45 | encoder_weights=encoder_weights, 46 | in_channels=in_channels, 47 | classes=classes, 48 | **kwargs, 49 | ) 50 | -------------------------------------------------------------------------------- /change_detection_pytorch/losses/soft_ce.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from torch import nn, Tensor 3 | import torch 4 | import torch.nn.functional as F 5 | from ._functional import label_smoothed_nll_loss 6 | 7 | __all__ = ["SoftCrossEntropyLoss"] 8 | 9 | 10 | class SoftCrossEntropyLoss(nn.Module): 11 | __name__ = "SoftCrossEntropyLoss" 12 | 13 | __constants__ = ["reduction", "ignore_index", "smooth_factor"] 14 | 15 | def __init__( 16 | self, 17 | reduction: str = "mean", 18 | smooth_factor: Optional[float] = None, 19 | ignore_index: Optional[int] = -100, 20 | dim: int = 1, 21 | ): 22 | """Drop-in replacement for torch.nn.CrossEntropyLoss with label_smoothing 23 | 24 | Args: 25 | smooth_factor: Factor to smooth target (e.g. if smooth_factor=0.1 then [1, 0, 0] -> [0.9, 0.05, 0.05]) 26 | 27 | Shape 28 | - **y_pred** - torch.Tensor of shape (N, C, H, W) 29 | - **y_true** - torch.Tensor of shape (N, H, W) 30 | 31 | Reference 32 | https://github.com/BloodAxe/pytorch-toolbelt 33 | """ 34 | super().__init__() 35 | self.smooth_factor = smooth_factor 36 | self.ignore_index = ignore_index 37 | self.reduction = reduction 38 | self.dim = dim 39 | 40 | def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: 41 | log_prob = F.log_softmax(y_pred, dim=self.dim) 42 | return label_smoothed_nll_loss( 43 | log_prob, 44 | y_true, 45 | epsilon=self.smooth_factor, 46 | ignore_index=self.ignore_index, 47 | reduction=self.reduction, 48 | dim=self.dim, 49 | ) 50 | -------------------------------------------------------------------------------- /change_detection_pytorch/base/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from . import initialization as init 3 | 4 | 5 | class SegmentationModel(torch.nn.Module): 6 | 7 | def initialize(self): 8 | init.initialize_decoder(self.decoder) 9 | init.initialize_head(self.segmentation_head) 10 | if self.classification_head is not None: 11 | init.initialize_head(self.classification_head) 12 | 13 | def base_forward(self, x1, x2): 14 | """Sequentially pass `x1` `x2` trough model`s encoder, decoder and heads""" 15 | if self.siam_encoder: 16 | features = self.encoder(x1), self.encoder(x2) 17 | else: 18 | features = self.encoder(x1), self.encoder_non_siam(x2) 19 | 20 | decoder_output = self.decoder(*features) 21 | 22 | # TODO: features = self.fusion_policy(features) 23 | 24 | masks = self.segmentation_head(decoder_output) 25 | 26 | if self.classification_head is not None: 27 | raise AttributeError("`classification_head` is not supported now.") 28 | # labels = self.classification_head(features[-1]) 29 | # return masks, labels 30 | 31 | return masks 32 | 33 | def forward(self, x1, x2): 34 | """Sequentially pass `x1` `x2` trough model`s encoder, decoder and heads""" 35 | return self.base_forward(x1, x2) 36 | 37 | def predict(self, x1, x2): 38 | """Inference method. Switch model to `eval` mode, call `.forward(x1, x2)` with `torch.no_grad()` 39 | 40 | Args: 41 | x1, x2: 4D torch tensor with shape (batch_size, channels, height, width) 42 | 43 | Return: 44 | prediction: 4D torch tensor with shape (batch_size, classes, height, width) 45 | 46 | """ 47 | if self.training: 48 | self.eval() 49 | 50 | with torch.no_grad(): 51 | x = self.forward(x1, x2) 52 | 53 | return x 54 | -------------------------------------------------------------------------------- /misc/generate_table_timm.py: -------------------------------------------------------------------------------- 1 | import timm 2 | from tqdm import tqdm 3 | 4 | 5 | def check_features_and_reduction(name): 6 | encoder = timm.create_model(name, features_only=True, pretrained=False) 7 | if not encoder.feature_info.reduction() == [2, 4, 8, 16, 32]: 8 | raise ValueError 9 | 10 | def has_dilation_support(name): 11 | try: 12 | timm.create_model(name, features_only=True, output_stride=8, pretrained=False) 13 | timm.create_model(name, features_only=True, output_stride=16, pretrained=False) 14 | return True 15 | except Exception as e: 16 | return False 17 | 18 | def make_table(data): 19 | names = supported.keys() 20 | max_len1 = max([len(x) for x in names]) + 2 21 | max_len2 = len("support dilation") + 2 22 | 23 | l1 = "+" + "-" * max_len1 + "+" + "-" * max_len2 + "+\n" 24 | l2 = "+" + "=" * max_len1 + "+" + "=" * max_len2 + "+\n" 25 | top = "| " + "Encoder name".ljust(max_len1 - 2) + " | " + "Support dilation".center(max_len2 - 2) + " |\n" 26 | 27 | table = l1 + top + l2 28 | 29 | for k in sorted(data.keys()): 30 | support = "✅".center(max_len2 - 3) if data[k]["has_dilation"] else " ".center(max_len2 - 2) 31 | table += "| " + k.ljust(max_len1 - 2) + " | " + support + " |\n" 32 | table += l1 33 | 34 | return table 35 | 36 | 37 | if __name__ == "__main__": 38 | 39 | supported_models = {} 40 | 41 | with tqdm(timm.list_models()) as names: 42 | for name in names: 43 | try: 44 | check_features_and_reduction(name) 45 | has_dilation = has_dilation_support(name) 46 | supported_models[name] = dict(has_dilation=has_dilation) 47 | except Exception: 48 | continue 49 | 50 | table = make_table(supported_models) 51 | print(table) 52 | print(f"Total encoders: {len(supported_models.keys())}") 53 | -------------------------------------------------------------------------------- /change_detection_pytorch/encoders/_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import List 4 | from collections import OrderedDict 5 | 6 | from . import _utils as utils 7 | 8 | 9 | class EncoderMixin: 10 | """Add encoder functionality such as: 11 | - output channels specification of feature tensors (produced by encoder) 12 | - patching first convolution for arbitrary input channels 13 | """ 14 | 15 | @property 16 | def out_channels(self): 17 | """Return channels dimensions for each tensor of forward output of encoder""" 18 | return self._out_channels[: self._depth + 1] 19 | 20 | def set_in_channels(self, in_channels, pretrained=True): 21 | """Change first convolution channels""" 22 | if in_channels == 3: 23 | return 24 | 25 | self._in_channels = in_channels 26 | if self._out_channels[0] == 3: 27 | self._out_channels = tuple([in_channels] + list(self._out_channels)[1:]) 28 | 29 | utils.patch_first_conv(model=self, new_in_channels=in_channels, pretrained=pretrained) 30 | 31 | def get_stages(self): 32 | """Method should be overridden in encoder""" 33 | raise NotImplementedError 34 | 35 | def make_dilated(self, output_stride): 36 | 37 | if output_stride == 16: 38 | stage_list=[5,] 39 | dilation_list=[2,] 40 | 41 | elif output_stride == 8: 42 | stage_list=[4, 5] 43 | dilation_list=[2, 4] 44 | 45 | else: 46 | raise ValueError("Output stride should be 16 or 8, got {}.".format(output_stride)) 47 | 48 | stages = self.get_stages() 49 | for stage_indx, dilation_rate in zip(stage_list, dilation_list): 50 | utils.replace_strides_with_dilation( 51 | module=stages[stage_indx], 52 | dilation_rate=dilation_rate, 53 | ) 54 | -------------------------------------------------------------------------------- /change_detection_pytorch/utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from . import base 4 | from . import functional as F 5 | from ..base.modules import Activation 6 | 7 | # See change_detection_pytorch/losses 8 | # class JaccardLoss(base.Loss): 9 | # 10 | # def __init__(self, eps=1., activation=None, ignore_channels=None, **kwargs): 11 | # super().__init__(**kwargs) 12 | # self.eps = eps 13 | # self.activation = Activation(activation) 14 | # self.ignore_channels = ignore_channels 15 | # 16 | # def forward(self, y_pr, y_gt): 17 | # y_pr = self.activation(y_pr) 18 | # return 1 - F.jaccard( 19 | # y_pr, y_gt, 20 | # eps=self.eps, 21 | # threshold=None, 22 | # ignore_channels=self.ignore_channels, 23 | # ) 24 | # 25 | # 26 | # class DiceLoss(base.Loss): 27 | # 28 | # def __init__(self, eps=1., beta=1., activation=None, ignore_channels=None, **kwargs): 29 | # super().__init__(**kwargs) 30 | # self.eps = eps 31 | # self.beta = beta 32 | # self.activation = Activation(activation) 33 | # self.ignore_channels = ignore_channels 34 | # 35 | # def forward(self, y_pr, y_gt): 36 | # y_pr = self.activation(y_pr) 37 | # return 1 - F.f_score( 38 | # y_pr, y_gt, 39 | # beta=self.beta, 40 | # eps=self.eps, 41 | # threshold=None, 42 | # ignore_channels=self.ignore_channels, 43 | # ) 44 | 45 | 46 | class L1Loss(nn.L1Loss, base.Loss): 47 | pass 48 | 49 | 50 | class MSELoss(nn.MSELoss, base.Loss): 51 | pass 52 | 53 | 54 | class CrossEntropyLoss(nn.CrossEntropyLoss, base.Loss): 55 | pass 56 | 57 | 58 | class NLLLoss(nn.NLLLoss, base.Loss): 59 | pass 60 | 61 | 62 | class BCELoss(nn.BCELoss, base.Loss): 63 | pass 64 | 65 | 66 | class BCEWithLogitsLoss(nn.BCEWithLogitsLoss, base.Loss): 67 | pass 68 | -------------------------------------------------------------------------------- /change_detection_pytorch/utils/meter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Meter(object): 5 | '''Meters provide a way to keep track of important statistics in an online manner. 6 | This class is abstract, but provides a standard interface for all meters to follow. 7 | ''' 8 | 9 | def reset(self): 10 | '''Resets the meter to default settings.''' 11 | pass 12 | 13 | def add(self, value): 14 | '''Log a new value to the meter 15 | Args: 16 | value: Next result to include. 17 | ''' 18 | pass 19 | 20 | def value(self): 21 | '''Get the value of the meter in the current state.''' 22 | pass 23 | 24 | 25 | class AverageValueMeter(Meter): 26 | def __init__(self): 27 | super(AverageValueMeter, self).__init__() 28 | self.reset() 29 | self.val = 0 30 | 31 | def add(self, value, n=1): 32 | self.val = value 33 | self.sum += value 34 | self.var += value * value 35 | self.n += n 36 | 37 | if self.n == 0: 38 | self.mean, self.std = np.nan, np.nan 39 | elif self.n == 1: 40 | self.mean = 0.0 + self.sum # This is to force a copy in torch/numpy 41 | self.std = np.inf 42 | self.mean_old = self.mean 43 | self.m_s = 0.0 44 | else: 45 | self.mean = self.mean_old + (value - n * self.mean_old) / float(self.n) 46 | self.m_s += (value - self.mean_old) * (value - self.mean) 47 | self.mean_old = self.mean 48 | self.std = np.sqrt(self.m_s / (self.n - 1.0)) 49 | 50 | def value(self): 51 | return self.mean, self.std 52 | 53 | def reset(self): 54 | self.n = 0 55 | self.sum = 0.0 56 | self.var = 0.0 57 | self.val = 0.0 58 | self.mean = np.nan 59 | self.mean_old = 0.0 60 | self.m_s = 0.0 61 | self.std = np.nan 62 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | __pycache__/ 3 | .idea/ 4 | .pytest_cache 5 | changelog.md 6 | liky_train.py 7 | change_detection_pytorch/datasets/PRCV_CD.py 8 | tools 9 | change_detection_pytorch/V2 10 | change_detection_pytorch/ufpn 11 | change_detection_pytorch/encoders/hrnet.py 12 | weights 13 | # linoadd 20211202 14 | # Byte-compiled / optimized / DLL files 15 | *.py[cod] 16 | *$py.class 17 | 18 | # C extensions 19 | *.so 20 | 21 | # Distribution / packaging 22 | .Python 23 | build/ 24 | develop-eggs/ 25 | dist/ 26 | downloads/ 27 | eggs/ 28 | .eggs/ 29 | lib/ 30 | lib64/ 31 | parts/ 32 | sdist/ 33 | var/ 34 | wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # celery beat schedule file 91 | celerybeat-schedule 92 | 93 | # SageMath parsed files 94 | *.sage.py 95 | 96 | # Environments 97 | .env 98 | .venv 99 | env/ 100 | venv/ 101 | ENV/ 102 | env.bak/ 103 | venv.bak/ 104 | 105 | # Spyder project settings 106 | .spyderproject 107 | .spyproject 108 | 109 | # Rope project settings 110 | .ropeproject 111 | 112 | # mkdocs documentation 113 | /site 114 | 115 | # mypy 116 | .mypy_cache/ -------------------------------------------------------------------------------- /change_detection_pytorch/stanet/BAM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | 6 | class BAM(nn.Module): 7 | """ Basic self-attention module 8 | """ 9 | 10 | def __init__(self, in_dim, ds=8, activation=nn.ReLU): 11 | super(BAM, self).__init__() 12 | self.chanel_in = in_dim 13 | self.key_channel = self.chanel_in // 8 14 | self.activation = activation 15 | self.ds = ds # 16 | self.pool = nn.AvgPool2d(self.ds) 17 | print('ds: ', ds) 18 | self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) 19 | self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) 20 | self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) 21 | self.gamma = nn.Parameter(torch.zeros(1)) 22 | 23 | self.softmax = nn.Softmax(dim=-1) # 24 | 25 | def forward(self, input): 26 | """ 27 | inputs : 28 | x : input feature maps( B X C X W X H) 29 | returns : 30 | out : self attention value + input feature 31 | attention: B X N X N (N is Width*Height) 32 | """ 33 | x = self.pool(input) 34 | m_batchsize, C, width, height = x.size() 35 | proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1) # B X C X (N)/(ds*ds) 36 | proj_key = self.key_conv(x).view(m_batchsize, -1, width * height) # B X C x (*W*H)/(ds*ds) 37 | energy = torch.bmm(proj_query, proj_key) # transpose check 38 | energy = (self.key_channel ** -.5) * energy 39 | 40 | attention = self.softmax(energy) # BX (N) X (N)/(ds*ds)/(ds*ds) 41 | 42 | proj_value = self.value_conv(x).view(m_batchsize, -1, width * height) # B X C X N 43 | 44 | out = torch.bmm(proj_value, attention.permute(0, 2, 1)) 45 | out = out.view(m_batchsize, C, width, height) 46 | 47 | out = F.interpolate(out, [width * self.ds, height * self.ds]) 48 | out = out + input 49 | 50 | return out 51 | -------------------------------------------------------------------------------- /change_detection_pytorch/utils/base.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch.nn as nn 3 | 4 | class BaseObject(nn.Module): 5 | 6 | def __init__(self, name=None): 7 | super().__init__() 8 | self._name = name 9 | 10 | @property 11 | def __name__(self): 12 | if self._name is None: 13 | name = self.__class__.__name__ 14 | s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) 15 | return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() 16 | else: 17 | return self._name 18 | 19 | 20 | class Metric(BaseObject): 21 | pass 22 | 23 | 24 | class Loss(BaseObject): 25 | 26 | def __add__(self, other): 27 | if isinstance(other, Loss): 28 | return SumOfLosses(self, other) 29 | else: 30 | raise ValueError('Loss should be inherited from `Loss` class') 31 | 32 | def __radd__(self, other): 33 | return self.__add__(other) 34 | 35 | def __mul__(self, value): 36 | if isinstance(value, (int, float)): 37 | return MultipliedLoss(self, value) 38 | else: 39 | raise ValueError('Loss should be inherited from `BaseLoss` class') 40 | 41 | def __rmul__(self, other): 42 | return self.__mul__(other) 43 | 44 | 45 | class SumOfLosses(Loss): 46 | 47 | def __init__(self, l1, l2): 48 | name = '{} + {}'.format(l1.__name__, l2.__name__) 49 | super().__init__(name=name) 50 | self.l1 = l1 51 | self.l2 = l2 52 | 53 | def __call__(self, *inputs): 54 | return self.l1.forward(*inputs) + self.l2.forward(*inputs) 55 | 56 | 57 | class MultipliedLoss(Loss): 58 | 59 | def __init__(self, loss, multiplier): 60 | 61 | # resolve name 62 | if len(loss.__name__.split('+')) > 1: 63 | name = '{} * ({})'.format(multiplier, loss.__name__) 64 | else: 65 | name = '{} * {}'.format(multiplier, loss.__name__) 66 | super().__init__(name=name) 67 | self.loss = loss 68 | self.multiplier = multiplier 69 | 70 | def __call__(self, *inputs): 71 | return self.multiplier * self.loss.forward(*inputs) 72 | -------------------------------------------------------------------------------- /change_detection_pytorch/encoders/xception.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch.nn as nn 3 | 4 | from pretrainedmodels.models.xception import pretrained_settings 5 | from pretrainedmodels.models.xception import Xception 6 | 7 | from ._base import EncoderMixin 8 | 9 | 10 | class XceptionEncoder(Xception, EncoderMixin): 11 | 12 | def __init__(self, out_channels, *args, depth=5, **kwargs): 13 | super().__init__(*args, **kwargs) 14 | 15 | self._out_channels = out_channels 16 | self._depth = depth 17 | self._in_channels = 3 18 | 19 | # modify padding to maintain output shape 20 | self.conv1.padding = (1, 1) 21 | self.conv2.padding = (1, 1) 22 | 23 | del self.fc 24 | 25 | def make_dilated(self, output_stride): 26 | raise ValueError("Xception encoder does not support dilated mode " 27 | "due to pooling operation for downsampling!") 28 | 29 | def get_stages(self): 30 | return [ 31 | nn.Identity(), 32 | nn.Sequential(self.conv1, self.bn1, self.relu, self.conv2, self.bn2, self.relu), 33 | self.block1, 34 | self.block2, 35 | nn.Sequential(self.block3, self.block4, self.block5, self.block6, self.block7, 36 | self.block8, self.block9, self.block10, self.block11), 37 | nn.Sequential(self.block12, self.conv3, self.bn3, self.relu, self.conv4, self.bn4), 38 | ] 39 | 40 | def forward(self, x): 41 | stages = self.get_stages() 42 | 43 | features = [] 44 | for i in range(self._depth + 1): 45 | x = stages[i](x) 46 | features.append(x) 47 | 48 | return features 49 | 50 | def load_state_dict(self, state_dict): 51 | # remove linear 52 | state_dict.pop('fc.bias', None) 53 | state_dict.pop('fc.weight', None) 54 | 55 | super().load_state_dict(state_dict) 56 | 57 | 58 | xception_encoders = { 59 | 'xception': { 60 | 'encoder': XceptionEncoder, 61 | 'pretrained_settings': pretrained_settings['xception'], 62 | 'params': { 63 | 'out_channels': (3, 64, 128, 256, 728, 2048), 64 | } 65 | }, 66 | } 67 | -------------------------------------------------------------------------------- /change_detection_pytorch/encoders/_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def patch_first_conv(model, new_in_channels, default_in_channels=3, pretrained=True): 6 | """Change first convolution layer input channels. 7 | In case: 8 | in_channels == 1 or in_channels == 2 -> reuse original weights 9 | in_channels > 3 -> make random kaiming normal initialization 10 | """ 11 | 12 | # get first conv 13 | for module in model.modules(): 14 | if isinstance(module, nn.Conv2d) and module.in_channels == default_in_channels: 15 | break 16 | 17 | weight = module.weight.detach() 18 | module.in_channels = new_in_channels 19 | 20 | if not pretrained: 21 | module.weight = nn.parameter.Parameter( 22 | torch.Tensor( 23 | module.out_channels, 24 | new_in_channels // module.groups, 25 | *module.kernel_size 26 | ) 27 | ) 28 | module.reset_parameters() 29 | 30 | elif new_in_channels == 1: 31 | new_weight = weight.sum(1, keepdim=True) 32 | module.weight = nn.parameter.Parameter(new_weight) 33 | 34 | else: 35 | new_weight = torch.Tensor( 36 | module.out_channels, 37 | new_in_channels // module.groups, 38 | *module.kernel_size 39 | ) 40 | 41 | for i in range(new_in_channels): 42 | new_weight[:, i] = weight[:, i % default_in_channels] 43 | 44 | new_weight = new_weight * (default_in_channels / new_in_channels) 45 | module.weight = nn.parameter.Parameter(new_weight) 46 | 47 | 48 | def replace_strides_with_dilation(module, dilation_rate): 49 | """Patch Conv2d modules replacing strides with dilation""" 50 | for mod in module.modules(): 51 | if isinstance(mod, nn.Conv2d): 52 | mod.stride = (1, 1) 53 | mod.dilation = (dilation_rate, dilation_rate) 54 | kh, kw = mod.kernel_size 55 | mod.padding = ((kh // 2) * dilation_rate, (kh // 2) * dilation_rate) 56 | 57 | # Kostyl for EfficientNet 58 | if hasattr(mod, "static_padding"): 59 | mod.static_padding = nn.Identity() 60 | -------------------------------------------------------------------------------- /change_detection_pytorch/losses/hybrid_loss.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn.modules.loss import _Loss 7 | 8 | __all__ = ["HybridLoss"] 9 | 10 | 11 | # Custom combination of multiple loss functions 12 | # For example: 13 | # loss1 = cdp.utils.losses.CrossEntropyLoss() 14 | # loss2 = cdp.losses.DiceLoss(mode='multiclass') 15 | # loss = cdp.losses.HybridLoss(loss1, loss2, reduction='sum') 16 | 17 | class HybridLoss(_Loss): 18 | __name__ = "HybridLoss" 19 | 20 | def __init__( 21 | self, 22 | loss1: _Loss, 23 | loss2: _Loss, 24 | reduction: Optional[str] = "mean", 25 | ): 26 | """Implementation of Hybrid loss for image segmentation task. 27 | It supports binary, multiclass and multilabel cases 28 | 29 | Args: 30 | loss1: The first loss function. 31 | loss2: The second loss function. 32 | reduction: Specifies the reduction to apply to the output: 33 | ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will 34 | be applied, ``'mean'``: the weighted mean of the output is taken, 35 | ``'sum'``: the output will be summed. Default: ``'mean'`` 36 | 37 | Shape 38 | - **y_pred** - torch.Tensor of shape (N, C, H, W) 39 | - **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W) 40 | 41 | """ 42 | super(HybridLoss, self).__init__() 43 | 44 | self.loss1 = loss1 45 | self.loss2 = loss2 46 | self.reduction = reduction 47 | 48 | def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: 49 | assert y_true.size(0) == y_pred.size(0) 50 | 51 | loss1 = self.loss1(y_pred, y_true) 52 | loss2 = self.loss2(y_pred, y_true) 53 | loss = torch.stack([loss1, loss2], dim=0) 54 | 55 | if self.reduction == "mean": 56 | loss = loss.mean() 57 | elif self.reduction == "sum": 58 | loss = loss.sum() 59 | elif self.reduction == "none": 60 | pass 61 | else: 62 | raise ValueError('reduction="{}" is not defined'.format(self.reduction)) 63 | 64 | return loss 65 | -------------------------------------------------------------------------------- /change_detection_pytorch/datasets/SVCD.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import albumentations as A 4 | from albumentations.pytorch import ToTensorV2 5 | from change_detection_pytorch.datasets.custom import CustomDataset 6 | 7 | 8 | class SVCD_Dataset(CustomDataset): 9 | """ season-varying change detection dataset""" 10 | 11 | def __init__(self, img_dir, sub_dir_1='A', sub_dir_2='B', ann_dir=None, img_suffix='.jpg', seg_map_suffix='.jpg', 12 | transform=None, split=None, data_root=None, test_mode=False, size=256, debug=False): 13 | super().__init__(img_dir, sub_dir_1, sub_dir_2, ann_dir, img_suffix, seg_map_suffix, transform, split, 14 | data_root, test_mode, size, debug) 15 | 16 | def get_default_transform(self): 17 | """Set the default transformation.""" 18 | 19 | default_transform = A.Compose([ 20 | A.Resize(self.size, self.size), 21 | # A.HorizontalFlip(p=0.5), 22 | # A.RandomRotate90(p=0.5), 23 | A.Normalize(), 24 | ToTensorV2() 25 | ], additional_targets={'image_2': 'image'}) 26 | return default_transform 27 | 28 | def get_test_transform(self): 29 | """Set the test transformation.""" 30 | 31 | test_transform = A.Compose([ 32 | A.Normalize(), 33 | ToTensorV2() 34 | ], additional_targets={'image_2': 'image'}) 35 | return test_transform 36 | 37 | def __getitem__(self, idx): 38 | """Get training/test data after pipeline. 39 | Args: 40 | idx (int): Index of data. 41 | Returns: 42 | dict: Training/test data (with annotation if `test_mode` is set 43 | False). 44 | """ 45 | 46 | if not self.ann_dir: 47 | ann = None 48 | img1, img2, filename = self.prepare_img(idx) 49 | transformed_data = self.transform(image=img1, image_2=img2) 50 | img1, img2 = transformed_data['image'], transformed_data['image_2'] 51 | return img1, img2, filename 52 | else: 53 | img1, img2, ann, filename = self.prepare_img_ann(idx) 54 | transformed_data = self.transform(image=img1, image_2=img2, mask=ann) 55 | img1, img2, ann = transformed_data['image'], transformed_data['image_2'], transformed_data['mask'] 56 | return img1, img2, ann, filename 57 | -------------------------------------------------------------------------------- /change_detection_pytorch/losses/tversky.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | from ._functional import soft_tversky_score 5 | from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE 6 | from .dice import DiceLoss 7 | 8 | __all__ = ["TverskyLoss", "TverskyLossFocal"] 9 | 10 | 11 | class TverskyLoss(DiceLoss): 12 | """Implementation of Tversky loss for image segmentation task. 13 | Where TP and FP is weighted by alpha and beta params. 14 | With alpha == beta == 0.5, this loss becomes equal DiceLoss. 15 | It supports binary, multiclass and multilabel cases 16 | 17 | Args: 18 | mode: Metric mode {'binary', 'multiclass', 'multilabel'} 19 | classes: Optional list of classes that contribute in loss computation; 20 | By default, all channels are included. 21 | log_loss: If True, loss computed as ``-log(tversky)`` otherwise ``1 - tversky`` 22 | from_logits: If True assumes input is raw logits 23 | smooth: 24 | ignore_index: Label that indicates ignored pixels (does not contribute to loss) 25 | eps: Small epsilon for numerical stability 26 | alpha: Weight constant that penalize model for FPs (False Positives) 27 | beta: Weight constant that penalize model for FNs (False Positives) 28 | gamma: Constant that squares the error function. Defaults to ``1.0`` 29 | 30 | Return: 31 | loss: torch.Tensor 32 | 33 | """ 34 | __name__ = "TverskyLoss" 35 | 36 | def __init__( 37 | self, 38 | mode: str, 39 | classes: List[int] = None, 40 | log_loss: bool = False, 41 | from_logits: bool = True, 42 | smooth: float = 0.0, 43 | ignore_index: Optional[int] = None, 44 | eps: float = 1e-7, 45 | alpha: float = 0.5, 46 | beta: float = 0.5, 47 | gamma: float = 1.0, 48 | ): 49 | 50 | assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE} 51 | super().__init__(mode, classes, log_loss, from_logits, smooth, ignore_index, eps) 52 | self.alpha = alpha 53 | self.beta = beta 54 | self.gamma = gamma 55 | 56 | def aggregate_loss(self, loss): 57 | return loss.mean() ** self.gamma 58 | 59 | def compute_score(self, output, target, smooth=0.0, eps=1e-7, dims=None) -> torch.Tensor: 60 | return soft_tversky_score(output, target, self.alpha, self.beta, smooth, eps, dims) 61 | -------------------------------------------------------------------------------- /change_detection_pytorch/datasets/LEVIR_CD.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import albumentations as A 4 | from albumentations.pytorch import ToTensorV2 5 | 6 | from .custom import CustomDataset 7 | from .transforms.albu import ChunkImage, ToTensorTest 8 | 9 | 10 | class LEVIR_CD_Dataset(CustomDataset): 11 | """LEVIR-CD dataset""" 12 | 13 | def __init__(self, img_dir, sub_dir_1='A', sub_dir_2='B', ann_dir=None, img_suffix='.png', seg_map_suffix='.png', 14 | transform=None, split=None, data_root=None, test_mode=False, size=256, debug=False): 15 | super().__init__(img_dir, sub_dir_1, sub_dir_2, ann_dir, img_suffix, seg_map_suffix, transform, split, 16 | data_root, test_mode, size, debug) 17 | 18 | def get_default_transform(self): 19 | """Set the default transformation.""" 20 | 21 | default_transform = A.Compose([ 22 | A.RandomCrop(self.size, self.size), 23 | # A.ShiftScaleRotate(), 24 | A.Normalize(), 25 | ToTensorV2() 26 | ], additional_targets={'image_2': 'image'}) 27 | return default_transform 28 | 29 | def get_test_transform(self): 30 | """Set the test transformation.""" 31 | 32 | test_transform = A.Compose([ 33 | A.Normalize(), 34 | ToTensorV2() 35 | ], additional_targets={'image_2': 'image'}) 36 | return test_transform 37 | 38 | def __getitem__(self, idx): 39 | """Get training/test data after pipeline. 40 | Args: 41 | idx (int): Index of data. 42 | Returns: 43 | dict: Training/test data (with annotation if `test_mode` is set 44 | False). 45 | """ 46 | 47 | if not self.ann_dir: 48 | ann = None 49 | img1, img2, filename = self.prepare_img(idx) 50 | transformed_data = self.transform(image=img1, image_2=img2) 51 | img1, img2 = transformed_data['image'], transformed_data['image_2'] 52 | return img1, img2, filename 53 | else: 54 | img1, img2, ann, filename = self.prepare_img_ann(idx) 55 | transformed_data = self.transform(image=img1, image_2=img2, mask=ann) 56 | img1, img2, ann = transformed_data['image'], transformed_data['image_2'], transformed_data['mask'] 57 | return img1, img2, ann, filename 58 | 59 | 60 | if __name__ == "__main__": 61 | LEVIR_CD_Dataset('dir') 62 | -------------------------------------------------------------------------------- /change_detection_pytorch/losses/soft_bce.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, Tensor 6 | 7 | __all__ = ["SoftBCEWithLogitsLoss"] 8 | 9 | 10 | class SoftBCEWithLogitsLoss(nn.Module): 11 | __name__ = "SoftBCEWithLogitsLoss" 12 | 13 | __constants__ = ["weight", "pos_weight", "reduction", "ignore_index", "smooth_factor"] 14 | 15 | def __init__( 16 | self, 17 | weight: Optional[torch.Tensor] = None, 18 | ignore_index: Optional[int] = -100, 19 | reduction: str = "mean", 20 | smooth_factor: Optional[float] = None, 21 | pos_weight: Optional[torch.Tensor] = None, 22 | ): 23 | """Drop-in replacement for torch.nn.BCEWithLogitsLoss with few additions: ignore_index and label_smoothing 24 | 25 | Args: 26 | ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient. 27 | smooth_factor: Factor to smooth target (e.g. if smooth_factor=0.1 then [1, 0, 1] -> [0.9, 0.1, 0.9]) 28 | 29 | Shape 30 | - **y_pred** - torch.Tensor of shape NxCxHxW 31 | - **y_true** - torch.Tensor of shape NxHxW or Nx1xHxW 32 | 33 | Reference 34 | https://github.com/BloodAxe/pytorch-toolbelt 35 | 36 | """ 37 | super().__init__() 38 | self.ignore_index = ignore_index 39 | self.reduction = reduction 40 | self.smooth_factor = smooth_factor 41 | self.register_buffer("weight", weight) 42 | self.register_buffer("pos_weight", pos_weight) 43 | 44 | def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: 45 | """ 46 | Args: 47 | y_pred: torch.Tensor of shape (N, C, H, W) 48 | y_true: torch.Tensor of shape (N, H, W) or (N, 1, H, W) 49 | 50 | Returns: 51 | loss: torch.Tensor 52 | """ 53 | 54 | if self.smooth_factor is not None: 55 | soft_targets = (1 - y_true) * self.smooth_factor + y_true * (1 - self.smooth_factor) 56 | else: 57 | soft_targets = y_true 58 | 59 | loss = F.binary_cross_entropy_with_logits( 60 | y_pred, soft_targets, self.weight, pos_weight=self.pos_weight, reduction="none" 61 | ) 62 | 63 | if self.ignore_index is not None: 64 | not_ignored_mask = y_true != self.ignore_index 65 | loss *= not_ignored_mask.type_as(loss) 66 | 67 | if self.reduction == "mean": 68 | loss = loss.mean() 69 | 70 | if self.reduction == "sum": 71 | loss = loss.sum() 72 | 73 | return loss 74 | -------------------------------------------------------------------------------- /change_detection_pytorch/pspnet/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ..base import Decoder, modules 6 | 7 | 8 | class PSPBlock(nn.Module): 9 | 10 | def __init__(self, in_channels, out_channels, pool_size, use_bathcnorm=True): 11 | super().__init__() 12 | if pool_size == 1: 13 | use_bathcnorm = False # PyTorch does not support BatchNorm for 1x1 shape 14 | self.pool = nn.Sequential( 15 | nn.AdaptiveAvgPool2d(output_size=(pool_size, pool_size)), 16 | modules.Conv2dReLU(in_channels, out_channels, (1, 1), use_batchnorm=use_bathcnorm) 17 | ) 18 | 19 | def forward(self, x): 20 | h, w = x.size(2), x.size(3) 21 | x = self.pool(x) 22 | x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True) 23 | return x 24 | 25 | 26 | class PSPModule(nn.Module): 27 | def __init__(self, in_channels, sizes=(1, 2, 3, 6), use_bathcnorm=True): 28 | super().__init__() 29 | 30 | self.blocks = nn.ModuleList([ 31 | PSPBlock(in_channels, in_channels // len(sizes), size, use_bathcnorm=use_bathcnorm) for size in sizes 32 | ]) 33 | 34 | def forward(self, x): 35 | xs = [block(x) for block in self.blocks] + [x] 36 | x = torch.cat(xs, dim=1) 37 | return x 38 | 39 | 40 | class PSPDecoder(Decoder): 41 | 42 | def __init__( 43 | self, 44 | encoder_channels, 45 | use_batchnorm=True, 46 | out_channels=512, 47 | dropout=0.2, 48 | fusion_form="concat", 49 | ): 50 | super().__init__() 51 | 52 | # adjust encoder channels according to fusion form 53 | self.fusion_form = fusion_form 54 | if self.fusion_form in self.FUSION_DIC["2to2_fusion"]: 55 | encoder_channels = [ch*2 for ch in encoder_channels] 56 | 57 | self.psp = PSPModule( 58 | in_channels=encoder_channels[-1], 59 | sizes=(1, 2, 3, 6), 60 | use_bathcnorm=use_batchnorm, 61 | ) 62 | 63 | self.conv = modules.Conv2dReLU( 64 | in_channels=encoder_channels[-1] * 2, 65 | out_channels=out_channels, 66 | kernel_size=1, 67 | use_batchnorm=use_batchnorm, 68 | ) 69 | 70 | self.dropout = nn.Dropout2d(p=dropout) 71 | 72 | def forward(self, *features): 73 | # features = self.aggregation_layer(features[0], features[1], 74 | # self.fusion_form, ignore_original_img=True) 75 | # x = features[-1] 76 | x = self.fusion(features[0][-1], features[1][-1], self.fusion_form) 77 | x = self.psp(x) 78 | x = self.conv(x) 79 | x = self.dropout(x) 80 | 81 | return x 82 | -------------------------------------------------------------------------------- /change_detection_pytorch/linknet/decoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from ..base import Decoder, modules 4 | 5 | 6 | class TransposeX2(nn.Sequential): 7 | 8 | def __init__(self, in_channels, out_channels, use_batchnorm=True): 9 | super().__init__() 10 | layers = [ 11 | nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1), 12 | nn.ReLU(inplace=True) 13 | ] 14 | 15 | if use_batchnorm: 16 | layers.insert(1, nn.BatchNorm2d(out_channels)) 17 | 18 | super().__init__(*layers) 19 | 20 | 21 | class DecoderBlock(nn.Module): 22 | def __init__(self, in_channels, out_channels, use_batchnorm=True): 23 | super().__init__() 24 | 25 | self.block = nn.Sequential( 26 | modules.Conv2dReLU(in_channels, in_channels // 4, kernel_size=1, use_batchnorm=use_batchnorm), 27 | TransposeX2(in_channels // 4, in_channels // 4, use_batchnorm=use_batchnorm), 28 | modules.Conv2dReLU(in_channels // 4, out_channels, kernel_size=1, use_batchnorm=use_batchnorm), 29 | ) 30 | 31 | def forward(self, x, skip=None): 32 | x = self.block(x) 33 | if skip is not None: 34 | x = x + skip 35 | return x 36 | 37 | 38 | class LinknetDecoder(Decoder): 39 | 40 | def __init__( 41 | self, 42 | encoder_channels, 43 | prefinal_channels=32, 44 | n_blocks=5, 45 | use_batchnorm=True, 46 | fusion_form="concat", 47 | ): 48 | super().__init__() 49 | 50 | encoder_channels = encoder_channels[1:] # remove first skip 51 | encoder_channels = encoder_channels[::-1] # reverse channels to start from head of encoder 52 | 53 | # adjust encoder channels according to fusion form 54 | self.fusion_form = fusion_form 55 | if self.fusion_form in self.FUSION_DIC["2to2_fusion"]: 56 | encoder_channels = [ch*2 for ch in encoder_channels] 57 | 58 | channels = list(encoder_channels) + [prefinal_channels] 59 | 60 | self.blocks = nn.ModuleList([ 61 | DecoderBlock(channels[i], channels[i + 1], use_batchnorm=use_batchnorm) 62 | for i in range(n_blocks) 63 | ]) 64 | 65 | def forward(self, *features): 66 | 67 | features = self.aggregation_layer(features[0], features[1], 68 | self.fusion_form, ignore_original_img=True) 69 | # features = features[1:] # remove first skip with same spatial resolution 70 | features = features[::-1] # reverse channels to start from head of encoder 71 | 72 | x = features[0] 73 | skips = features[1:] 74 | 75 | for i, decoder_block in enumerate(self.blocks): 76 | skip = skips[i] if i < len(skips) else None 77 | x = decoder_block(x, skip) 78 | 79 | return x 80 | -------------------------------------------------------------------------------- /change_detection_pytorch/encoders/mobilenet.py: -------------------------------------------------------------------------------- 1 | """ Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` 2 | 3 | Attributes: 4 | 5 | _out_channels (list of int): specify number of channels for each encoder feature tensor 6 | _depth (int): specify number of stages in decoder (in other words number of downsampling operations) 7 | _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) 8 | 9 | Methods: 10 | 11 | forward(self, x: torch.Tensor) 12 | produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of 13 | shape NCHW (features should be sorted in descending order according to spatial resolution, starting 14 | with resolution same as input `x` tensor). 15 | 16 | Input: `x` with shape (1, 3, 64, 64) 17 | Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes 18 | [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), 19 | (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) 20 | 21 | also should support number of features according to specified depth, e.g. if depth = 5, 22 | number of feature tensors = 6 (one with same resolution as input and 5 downsampled), 23 | depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). 24 | """ 25 | 26 | import torchvision 27 | import torch.nn as nn 28 | 29 | from ._base import EncoderMixin 30 | 31 | 32 | class MobileNetV2Encoder(torchvision.models.MobileNetV2, EncoderMixin): 33 | 34 | def __init__(self, out_channels, depth=5, **kwargs): 35 | super().__init__(**kwargs) 36 | self._depth = depth 37 | self._out_channels = out_channels 38 | self._in_channels = 3 39 | del self.classifier 40 | 41 | def get_stages(self): 42 | return [ 43 | nn.Identity(), 44 | self.features[:2], 45 | self.features[2:4], 46 | self.features[4:7], 47 | self.features[7:14], 48 | self.features[14:], 49 | ] 50 | 51 | def forward(self, x): 52 | stages = self.get_stages() 53 | 54 | features = [] 55 | for i in range(self._depth + 1): 56 | x = stages[i](x) 57 | features.append(x) 58 | 59 | return features 60 | 61 | def load_state_dict(self, state_dict, **kwargs): 62 | state_dict.pop("classifier.1.bias", None) 63 | state_dict.pop("classifier.1.weight", None) 64 | super().load_state_dict(state_dict, **kwargs) 65 | 66 | 67 | mobilenet_encoders = { 68 | "mobilenet_v2": { 69 | "encoder": MobileNetV2Encoder, 70 | "pretrained_settings": { 71 | "imagenet": { 72 | "mean": [0.485, 0.456, 0.406], 73 | "std": [0.229, 0.224, 0.225], 74 | "url": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", 75 | "input_space": "RGB", 76 | "input_range": [0, 1], 77 | }, 78 | }, 79 | "params": { 80 | "out_channels": (3, 16, 24, 32, 96, 1280), 81 | }, 82 | }, 83 | } 84 | -------------------------------------------------------------------------------- /change_detection_pytorch/stanet/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional 3 | from torch.nn import functional as F 4 | from ..encoders import get_encoder 5 | from .decoder import STANetDecoder 6 | from ..base import SegmentationHead 7 | 8 | 9 | class STANet(torch.nn.Module): 10 | """ 11 | Args: 12 | encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) 13 | to extract features of different spatial resolution 14 | encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and 15 | other pretrained weights (see table with available weights for each encoder_name) 16 | in_channels: A number of input channels for the model, default is 3 (RGB images) 17 | classes: A number of classes for output mask (or you can think as a number of channels of output mask) 18 | activation: An activation function to apply after the final convolution layer. 19 | Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, **callable** and **None**. 20 | Default is **None** 21 | return_distance_map: If True, return distance map, which shape is (BatchSize, Height, Width), of feature maps from images of two periods. Default False. 22 | 23 | Returns: 24 | ``torch.nn.Module``: STANet 25 | 26 | .. STANet: 27 | https://www.mdpi.com/2072-4292/12/10/1662 28 | 29 | """ 30 | 31 | def __init__( 32 | self, 33 | encoder_name: str = "resnet", 34 | encoder_weights: Optional[str] = "imagenet", 35 | sa_mode: str = "PAM", 36 | in_channels: int = 3, 37 | classes=2, 38 | activation=None, 39 | return_distance_map=False, 40 | **kwargs 41 | ): 42 | super(STANet, self).__init__() 43 | self.return_distance_map = return_distance_map 44 | self.encoder = get_encoder( 45 | encoder_name, 46 | in_channels=in_channels, 47 | weights=encoder_weights 48 | ) 49 | 50 | self.decoder = STANetDecoder( 51 | encoder_out_channels=self.encoder.out_channels, 52 | sa_mode=sa_mode 53 | ) 54 | self.segmentation_head = SegmentationHead( 55 | in_channels=self.decoder.out_channel * 2, 56 | out_channels=classes, 57 | activation=activation, 58 | kernel_size=3, 59 | ) 60 | 61 | def forward(self, x1, x2): 62 | # only support siam encoder 63 | features = self.encoder(x1), self.encoder(x2) 64 | features = self.decoder(*features) 65 | if self.return_distance_map: 66 | dist = F.pairwise_distance(features[0], features[1], keepdim=True) 67 | dist = F.interpolate(dist, x1.shape[2:], mode='bilinear', align_corners=True) 68 | return dist 69 | else: 70 | decoder_output = torch.cat([features[0], features[1]], dim=1) 71 | decoder_output = F.interpolate(decoder_output, x1.shape[2:], mode='bilinear', align_corners=True) 72 | masks = self.segmentation_head(decoder_output) 73 | return masks 74 | -------------------------------------------------------------------------------- /change_detection_pytorch/losses/focal.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from functools import partial 3 | 4 | import torch 5 | from torch.nn.modules.loss import _Loss 6 | from ._functional import focal_loss_with_logits 7 | from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE 8 | 9 | __all__ = ["FocalLoss"] 10 | 11 | 12 | class FocalLoss(_Loss): 13 | __name__ = "FocalLoss" 14 | 15 | def __init__( 16 | self, 17 | mode: str, 18 | alpha: Optional[float] = None, 19 | gamma: Optional[float] = 2., 20 | ignore_index: Optional[int] = None, 21 | reduction: Optional[str] = "mean", 22 | normalized: bool = False, 23 | reduced_threshold: Optional[float] = None, 24 | ): 25 | """Compute Focal loss 26 | 27 | Args: 28 | mode: Loss mode 'binary', 'multiclass' or 'multilabel' 29 | alpha: Prior probability of having positive value in target. 30 | gamma: Power factor for dampening weight (focal strength). 31 | ignore_index: If not None, targets may contain values to be ignored. 32 | Target values equal to ignore_index will be ignored from loss computation. 33 | normalized: Compute normalized focal loss (https://arxiv.org/pdf/1909.07829.pdf). 34 | reduced_threshold: Switch to reduced focal loss. Note, when using this mode you should use `reduction="sum"`. 35 | 36 | Shape 37 | - **y_pred** - torch.Tensor of shape (N, C, H, W) 38 | - **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W) 39 | 40 | Reference 41 | https://github.com/BloodAxe/pytorch-toolbelt 42 | 43 | """ 44 | assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE} 45 | super().__init__() 46 | 47 | self.mode = mode 48 | self.ignore_index = ignore_index 49 | self.focal_loss_fn = partial( 50 | focal_loss_with_logits, 51 | alpha=alpha, 52 | gamma=gamma, 53 | reduced_threshold=reduced_threshold, 54 | reduction=reduction, 55 | normalized=normalized, 56 | ) 57 | 58 | def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: 59 | 60 | if self.mode in {BINARY_MODE, MULTILABEL_MODE}: 61 | y_true = y_true.view(-1) 62 | y_pred = y_pred.view(-1) 63 | 64 | if self.ignore_index is not None: 65 | # Filter predictions with ignore label from loss computation 66 | not_ignored = y_true != self.ignore_index 67 | y_pred = y_pred[not_ignored] 68 | y_true = y_true[not_ignored] 69 | 70 | loss = self.focal_loss_fn(y_pred, y_true) 71 | 72 | elif self.mode == MULTICLASS_MODE: 73 | 74 | num_classes = y_pred.size(1) 75 | loss = 0 76 | 77 | # Filter anchors with -1 label from loss computation 78 | if self.ignore_index is not None: 79 | not_ignored = y_true != self.ignore_index 80 | 81 | for cls in range(num_classes): 82 | cls_y_true = (y_true == cls).long() 83 | cls_y_pred = y_pred[:, cls, ...] 84 | 85 | if self.ignore_index is not None: 86 | cls_y_true = cls_y_true[not_ignored] 87 | cls_y_pred = cls_y_pred[not_ignored] 88 | 89 | loss += self.focal_loss_fn(cls_y_pred, cls_y_true) 90 | 91 | return loss 92 | -------------------------------------------------------------------------------- /change_detection_pytorch/encoders/inceptionresnetv2.py: -------------------------------------------------------------------------------- 1 | """ Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` 2 | 3 | Attributes: 4 | 5 | _out_channels (list of int): specify number of channels for each encoder feature tensor 6 | _depth (int): specify number of stages in decoder (in other words number of downsampling operations) 7 | _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) 8 | 9 | Methods: 10 | 11 | forward(self, x: torch.Tensor) 12 | produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of 13 | shape NCHW (features should be sorted in descending order according to spatial resolution, starting 14 | with resolution same as input `x` tensor). 15 | 16 | Input: `x` with shape (1, 3, 64, 64) 17 | Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes 18 | [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), 19 | (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) 20 | 21 | also should support number of features according to specified depth, e.g. if depth = 5, 22 | number of feature tensors = 6 (one with same resolution as input and 5 downsampled), 23 | depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). 24 | """ 25 | 26 | import torch.nn as nn 27 | from pretrainedmodels.models.inceptionresnetv2 import InceptionResNetV2 28 | from pretrainedmodels.models.inceptionresnetv2 import pretrained_settings 29 | 30 | from ._base import EncoderMixin 31 | 32 | 33 | class InceptionResNetV2Encoder(InceptionResNetV2, EncoderMixin): 34 | def __init__(self, out_channels, depth=5, **kwargs): 35 | super().__init__(**kwargs) 36 | 37 | self._out_channels = out_channels 38 | self._depth = depth 39 | self._in_channels = 3 40 | 41 | # correct paddings 42 | for m in self.modules(): 43 | if isinstance(m, nn.Conv2d): 44 | if m.kernel_size == (3, 3): 45 | m.padding = (1, 1) 46 | if isinstance(m, nn.MaxPool2d): 47 | m.padding = (1, 1) 48 | 49 | # remove linear layers 50 | del self.avgpool_1a 51 | del self.last_linear 52 | 53 | def make_dilated(self, output_stride): 54 | raise ValueError("InceptionResnetV2 encoder does not support dilated mode " 55 | "due to pooling operation for downsampling!") 56 | 57 | def get_stages(self): 58 | return [ 59 | nn.Identity(), 60 | nn.Sequential(self.conv2d_1a, self.conv2d_2a, self.conv2d_2b), 61 | nn.Sequential(self.maxpool_3a, self.conv2d_3b, self.conv2d_4a), 62 | nn.Sequential(self.maxpool_5a, self.mixed_5b, self.repeat), 63 | nn.Sequential(self.mixed_6a, self.repeat_1), 64 | nn.Sequential(self.mixed_7a, self.repeat_2, self.block8, self.conv2d_7b), 65 | ] 66 | 67 | def forward(self, x): 68 | 69 | stages = self.get_stages() 70 | 71 | features = [] 72 | for i in range(self._depth + 1): 73 | x = stages[i](x) 74 | features.append(x) 75 | 76 | return features 77 | 78 | def load_state_dict(self, state_dict, **kwargs): 79 | state_dict.pop("last_linear.bias", None) 80 | state_dict.pop("last_linear.weight", None) 81 | super().load_state_dict(state_dict, **kwargs) 82 | 83 | 84 | inceptionresnetv2_encoders = { 85 | "inceptionresnetv2": { 86 | "encoder": InceptionResNetV2Encoder, 87 | "pretrained_settings": pretrained_settings["inceptionresnetv2"], 88 | "params": {"out_channels": (3, 64, 192, 320, 1088, 1536), "num_classes": 1000}, 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /change_detection_pytorch/encoders/inceptionv4.py: -------------------------------------------------------------------------------- 1 | """ Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` 2 | 3 | Attributes: 4 | 5 | _out_channels (list of int): specify number of channels for each encoder feature tensor 6 | _depth (int): specify number of stages in decoder (in other words number of downsampling operations) 7 | _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) 8 | 9 | Methods: 10 | 11 | forward(self, x: torch.Tensor) 12 | produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of 13 | shape NCHW (features should be sorted in descending order according to spatial resolution, starting 14 | with resolution same as input `x` tensor). 15 | 16 | Input: `x` with shape (1, 3, 64, 64) 17 | Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes 18 | [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), 19 | (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) 20 | 21 | also should support number of features according to specified depth, e.g. if depth = 5, 22 | number of feature tensors = 6 (one with same resolution as input and 5 downsampled), 23 | depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). 24 | """ 25 | 26 | import torch.nn as nn 27 | from pretrainedmodels.models.inceptionv4 import InceptionV4, BasicConv2d 28 | from pretrainedmodels.models.inceptionv4 import pretrained_settings 29 | 30 | from ._base import EncoderMixin 31 | 32 | 33 | class InceptionV4Encoder(InceptionV4, EncoderMixin): 34 | def __init__(self, stage_idxs, out_channels, depth=5, **kwargs): 35 | super().__init__(**kwargs) 36 | self._stage_idxs = stage_idxs 37 | self._out_channels = out_channels 38 | self._depth = depth 39 | self._in_channels = 3 40 | 41 | # correct paddings 42 | for m in self.modules(): 43 | if isinstance(m, nn.Conv2d): 44 | if m.kernel_size == (3, 3): 45 | m.padding = (1, 1) 46 | if isinstance(m, nn.MaxPool2d): 47 | m.padding = (1, 1) 48 | 49 | # remove linear layers 50 | del self.last_linear 51 | 52 | def make_dilated(self, output_stride): 53 | raise ValueError("InceptionV4 encoder does not support dilated mode " 54 | "due to pooling operation for downsampling!") 55 | 56 | def get_stages(self): 57 | return [ 58 | nn.Identity(), 59 | self.features[: self._stage_idxs[0]], 60 | self.features[self._stage_idxs[0]: self._stage_idxs[1]], 61 | self.features[self._stage_idxs[1]: self._stage_idxs[2]], 62 | self.features[self._stage_idxs[2]: self._stage_idxs[3]], 63 | self.features[self._stage_idxs[3]:], 64 | ] 65 | 66 | def forward(self, x): 67 | 68 | stages = self.get_stages() 69 | 70 | features = [] 71 | for i in range(self._depth + 1): 72 | x = stages[i](x) 73 | features.append(x) 74 | 75 | return features 76 | 77 | def load_state_dict(self, state_dict, **kwargs): 78 | state_dict.pop("last_linear.bias", None) 79 | state_dict.pop("last_linear.weight", None) 80 | super().load_state_dict(state_dict, **kwargs) 81 | 82 | 83 | inceptionv4_encoders = { 84 | "inceptionv4": { 85 | "encoder": InceptionV4Encoder, 86 | "pretrained_settings": pretrained_settings["inceptionv4"], 87 | "params": { 88 | "stage_idxs": (3, 5, 9, 15), 89 | "out_channels": (3, 64, 192, 384, 1024, 1536), 90 | "num_classes": 1001, 91 | }, 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /local_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader, Dataset 3 | 4 | import change_detection_pytorch as cdp 5 | from change_detection_pytorch.datasets import LEVIR_CD_Dataset, SVCD_Dataset 6 | from change_detection_pytorch.utils.lr_scheduler import GradualWarmupScheduler 7 | 8 | DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 9 | 10 | model = cdp.Unet( 11 | encoder_name="resnet34", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7 12 | encoder_weights="imagenet", # use `imagenet` pre-trained weights for encoder initialization 13 | in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.) 14 | classes=2, # model output channels (number of classes in your datasets) 15 | siam_encoder=True, # whether to use a siamese encoder 16 | fusion_form='concat', # the form of fusing features from two branches. e.g. concat, sum, diff, or abs_diff. 17 | ) 18 | 19 | train_dataset = LEVIR_CD_Dataset('../LEVIR-CD/train', 20 | sub_dir_1='A', 21 | sub_dir_2='B', 22 | img_suffix='.png', 23 | ann_dir='../LEVIR-CD/train/label', 24 | debug=False) 25 | 26 | valid_dataset = LEVIR_CD_Dataset('../LEVIR-CD/test', 27 | sub_dir_1='A', 28 | sub_dir_2='B', 29 | img_suffix='.png', 30 | ann_dir='../LEVIR-CD/test/label', 31 | debug=False, 32 | test_mode=True) 33 | 34 | train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=0) 35 | valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=0) 36 | 37 | loss = cdp.utils.losses.CrossEntropyLoss() 38 | metrics = [ 39 | cdp.utils.metrics.Fscore(activation='argmax2d'), 40 | cdp.utils.metrics.Precision(activation='argmax2d'), 41 | cdp.utils.metrics.Recall(activation='argmax2d'), 42 | ] 43 | 44 | optimizer = torch.optim.Adam([ 45 | dict(params=model.parameters(), lr=0.0001), 46 | ]) 47 | 48 | scheduler_steplr = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50, ], gamma=0.1) 49 | 50 | # create epoch runners 51 | # it is a simple loop of iterating over dataloader`s samples 52 | train_epoch = cdp.utils.train.TrainEpoch( 53 | model, 54 | loss=loss, 55 | metrics=metrics, 56 | optimizer=optimizer, 57 | device=DEVICE, 58 | verbose=True, 59 | ) 60 | 61 | valid_epoch = cdp.utils.train.ValidEpoch( 62 | model, 63 | loss=loss, 64 | metrics=metrics, 65 | device=DEVICE, 66 | verbose=True, 67 | ) 68 | 69 | # train model for 60 epochs 70 | 71 | max_score = 0 72 | MAX_EPOCH = 60 73 | 74 | for i in range(MAX_EPOCH): 75 | 76 | print('\nEpoch: {}'.format(i)) 77 | train_logs = train_epoch.run(train_loader) 78 | valid_logs = valid_epoch.run(valid_loader) 79 | scheduler_steplr.step() 80 | 81 | # do something (save model, change lr, etc.) 82 | if max_score < valid_logs['fscore']: 83 | max_score = valid_logs['fscore'] 84 | print('max_score', max_score) 85 | torch.save(model, './best_model.pth') 86 | print('Model saved!') 87 | 88 | # save results (change maps) 89 | """ 90 | Note: if you use sliding window inference, set: 91 | from change_detection_pytorch.datasets.transforms.albu import ( 92 | ChunkImage, ToTensorTest) 93 | 94 | test_transform = A.Compose([ 95 | A.Normalize(), 96 | ChunkImage({window_size}}), 97 | ToTensorTest(), 98 | ], additional_targets={'image_2': 'image'}) 99 | 100 | """ 101 | valid_epoch.infer_vis(valid_loader, save=True, slide=False, save_dir='./res') 102 | -------------------------------------------------------------------------------- /change_detection_pytorch/encoders/timm_sknet.py: -------------------------------------------------------------------------------- 1 | from ._base import EncoderMixin 2 | from timm.models.resnet import ResNet 3 | from timm.models.sknet import SelectiveKernelBottleneck, SelectiveKernelBasic 4 | import torch.nn as nn 5 | 6 | 7 | class SkNetEncoder(ResNet, EncoderMixin): 8 | def __init__(self, out_channels, depth=5, **kwargs): 9 | super().__init__(**kwargs) 10 | self._depth = depth 11 | self._out_channels = out_channels 12 | self._in_channels = 3 13 | 14 | del self.fc 15 | del self.global_pool 16 | 17 | def get_stages(self): 18 | return [ 19 | nn.Identity(), 20 | nn.Sequential(self.conv1, self.bn1, self.act1), 21 | nn.Sequential(self.maxpool, self.layer1), 22 | self.layer2, 23 | self.layer3, 24 | self.layer4, 25 | ] 26 | 27 | def forward(self, x): 28 | stages = self.get_stages() 29 | 30 | features = [] 31 | for i in range(self._depth + 1): 32 | x = stages[i](x) 33 | features.append(x) 34 | 35 | return features 36 | 37 | def load_state_dict(self, state_dict, **kwargs): 38 | state_dict.pop("fc.bias", None) 39 | state_dict.pop("fc.weight", None) 40 | super().load_state_dict(state_dict, **kwargs) 41 | 42 | 43 | sknet_weights = { 44 | 'timm-skresnet18': { 45 | 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet18_ra-4eec2804.pth' 46 | }, 47 | 'timm-skresnet34': { 48 | 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet34_ra-bdc0ccde.pth' 49 | }, 50 | 'timm-skresnext50_32x4d': { 51 | 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnext50_ra-f40e40bf.pth', 52 | } 53 | } 54 | 55 | pretrained_settings = {} 56 | for model_name, sources in sknet_weights.items(): 57 | pretrained_settings[model_name] = {} 58 | for source_name, source_url in sources.items(): 59 | pretrained_settings[model_name][source_name] = { 60 | "url": source_url, 61 | 'input_size': [3, 224, 224], 62 | 'input_range': [0, 1], 63 | 'mean': [0.485, 0.456, 0.406], 64 | 'std': [0.229, 0.224, 0.225], 65 | 'num_classes': 1000 66 | } 67 | 68 | timm_sknet_encoders = { 69 | 'timm-skresnet18': { 70 | 'encoder': SkNetEncoder, 71 | "pretrained_settings": pretrained_settings["timm-skresnet18"], 72 | 'params': { 73 | 'out_channels': (3, 64, 64, 128, 256, 512), 74 | 'block': SelectiveKernelBasic, 75 | 'layers': [2, 2, 2, 2], 76 | 'zero_init_last_bn': False, 77 | 'block_args': {'sk_kwargs': {'rd_ratio': 1/8, 'split_input': True}} 78 | } 79 | }, 80 | 'timm-skresnet34': { 81 | 'encoder': SkNetEncoder, 82 | "pretrained_settings": pretrained_settings["timm-skresnet34"], 83 | 'params': { 84 | 'out_channels': (3, 64, 64, 128, 256, 512), 85 | 'block': SelectiveKernelBasic, 86 | 'layers': [3, 4, 6, 3], 87 | 'zero_init_last_bn': False, 88 | 'block_args': {'sk_kwargs': {'rd_ratio': 1/8, 'split_input': True}} 89 | } 90 | }, 91 | 'timm-skresnext50_32x4d': { 92 | 'encoder': SkNetEncoder, 93 | "pretrained_settings": pretrained_settings["timm-skresnext50_32x4d"], 94 | 'params': { 95 | 'out_channels': (3, 64, 256, 512, 1024, 2048), 96 | 'block': SelectiveKernelBottleneck, 97 | 'layers': [3, 4, 6, 3], 98 | 'zero_init_last_bn': False, 99 | 'cardinality': 32, 100 | 'base_width': 4 101 | } 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /change_detection_pytorch/losses/jaccard.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.nn.modules.loss import _Loss 6 | from ._functional import soft_jaccard_score, to_tensor 7 | from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE 8 | 9 | __all__ = ["JaccardLoss"] 10 | 11 | 12 | class JaccardLoss(_Loss): 13 | __name__ = "JaccardLoss" 14 | 15 | def __init__( 16 | self, 17 | mode: str, 18 | classes: Optional[List[int]] = None, 19 | log_loss: bool = False, 20 | from_logits: bool = True, 21 | smooth: float = 0., 22 | eps: float = 1e-7, 23 | ): 24 | """Implementation of Jaccard loss for image segmentation task. 25 | It supports binary, multiclass and multilabel cases 26 | 27 | Args: 28 | mode: Loss mode 'binary', 'multiclass' or 'multilabel' 29 | classes: List of classes that contribute in loss computation. By default, all channels are included. 30 | log_loss: If True, loss computed as `- log(jaccard_coeff)`, otherwise `1 - jaccard_coeff` 31 | from_logits: If True, assumes input is raw logits 32 | smooth: Smoothness constant for dice coefficient 33 | eps: A small epsilon for numerical stability to avoid zero division error 34 | (denominator will be always greater or equal to eps) 35 | 36 | Shape 37 | - **y_pred** - torch.Tensor of shape (N, C, H, W) 38 | - **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W) 39 | 40 | Reference 41 | https://github.com/BloodAxe/pytorch-toolbelt 42 | """ 43 | assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE} 44 | super(JaccardLoss, self).__init__() 45 | 46 | self.mode = mode 47 | if classes is not None: 48 | assert mode != BINARY_MODE, "Masking classes is not supported with mode=binary" 49 | classes = to_tensor(classes, dtype=torch.long) 50 | 51 | self.classes = classes 52 | self.from_logits = from_logits 53 | self.smooth = smooth 54 | self.eps = eps 55 | self.log_loss = log_loss 56 | 57 | def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: 58 | 59 | assert y_true.size(0) == y_pred.size(0) 60 | 61 | if self.from_logits: 62 | # Apply activations to get [0..1] class probabilities 63 | # Using Log-Exp as this gives more numerically stable result and does not cause vanishing gradient on 64 | # extreme values 0 and 1 65 | if self.mode == MULTICLASS_MODE: 66 | y_pred = y_pred.log_softmax(dim=1).exp() 67 | else: 68 | y_pred = F.logsigmoid(y_pred).exp() 69 | 70 | bs = y_true.size(0) 71 | num_classes = y_pred.size(1) 72 | dims = (0, 2) 73 | 74 | if self.mode == BINARY_MODE: 75 | y_true = y_true.view(bs, 1, -1) 76 | y_pred = y_pred.view(bs, 1, -1) 77 | 78 | if self.mode == MULTICLASS_MODE: 79 | y_true = y_true.view(bs, -1) 80 | y_pred = y_pred.view(bs, num_classes, -1) 81 | 82 | y_true = F.one_hot(y_true, num_classes) # N,H*W -> N,H*W, C 83 | y_true = y_true.permute(0, 2, 1) # H, C, H*W 84 | 85 | if self.mode == MULTILABEL_MODE: 86 | y_true = y_true.view(bs, num_classes, -1) 87 | y_pred = y_pred.view(bs, num_classes, -1) 88 | 89 | scores = soft_jaccard_score(y_pred, y_true.type(y_pred.dtype), smooth=self.smooth, eps=self.eps, dims=dims) 90 | 91 | if self.log_loss: 92 | loss = -torch.log(scores.clamp_min(self.eps)) 93 | else: 94 | loss = 1.0 - scores 95 | 96 | # IoU loss is defined for non-empty classes 97 | # So we zero contribution of channel that does not have true pixels 98 | # NOTE: A better workaround would be to use loss term `mean(y_pred)` 99 | # for this case, however it will be a modified jaccard loss 100 | 101 | mask = y_true.sum(dims) > 0 102 | loss *= mask.float() 103 | 104 | if self.classes is not None: 105 | loss = loss[self.classes] 106 | 107 | return loss.mean() 108 | -------------------------------------------------------------------------------- /change_detection_pytorch/utils/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import random 4 | import warnings 5 | from functools import wraps 6 | from typing import Optional 7 | 8 | import numpy as np 9 | import torch 10 | 11 | log = logging.getLogger(__name__) 12 | 13 | 14 | def seed_everything(seed: Optional[int] = None, workers: bool = False, deterministic: bool = False) -> int: 15 | """ 16 | `` 17 | This part of the code comes from PyTorch Lightning. 18 | https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/seed.py 19 | `` 20 | 21 | Function that sets seed for pseudo-random number generators in: 22 | pytorch, numpy, python.random 23 | In addition, sets the following environment variables: 24 | 25 | - `PL_GLOBAL_SEED`: will be passed to spawned subprocesses (e.g. ddp_spawn backend). 26 | - `PL_SEED_WORKERS`: (optional) is set to 1 if ``workers=True``. 27 | 28 | Args: 29 | seed: the integer value seed for global random state in Lightning. 30 | If `None`, will read seed from `PL_GLOBAL_SEED` env variable 31 | or select it randomly. 32 | workers: if set to ``True``, will properly configure all dataloaders passed to the 33 | Trainer with a ``worker_init_fn``. If the user already provides such a function 34 | for their dataloaders, setting this argument will have no influence. See also: 35 | :func:`~pytorch_lightning.utilities.seed.pl_worker_init_function`. 36 | deterministic (bool): Whether to set the deterministic option for 37 | CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` 38 | to True and `torch.backends.cudnn.benchmark` to False. 39 | Default: False. 40 | """ 41 | max_seed_value = np.iinfo(np.uint32).max 42 | min_seed_value = np.iinfo(np.uint32).min 43 | 44 | try: 45 | if seed is None: 46 | seed = os.environ.get("PL_GLOBAL_SEED") 47 | seed = int(seed) 48 | except (TypeError, ValueError): 49 | seed = _select_seed_randomly(min_seed_value, max_seed_value) 50 | rank_zero_warn(f"No correct seed found, seed set to {seed}") 51 | 52 | if not (min_seed_value <= seed <= max_seed_value): 53 | rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}") 54 | seed = _select_seed_randomly(min_seed_value, max_seed_value) 55 | 56 | # using `log.info` instead of `rank_zero_info`, 57 | # so users can verify the seed is properly set in distributed training. 58 | log.info(f"Global seed set to {seed}") 59 | os.environ["PL_GLOBAL_SEED"] = str(seed) 60 | random.seed(seed) 61 | np.random.seed(seed) 62 | torch.manual_seed(seed) 63 | torch.cuda.manual_seed_all(seed) 64 | 65 | os.environ["PL_SEED_WORKERS"] = f"{int(workers)}" 66 | 67 | if deterministic: 68 | torch.backends.cudnn.deterministic = True 69 | torch.backends.cudnn.benchmark = False 70 | 71 | return seed 72 | 73 | 74 | def _select_seed_randomly(min_seed_value: int = 0, max_seed_value: int = 255) -> int: 75 | return random.randint(min_seed_value, max_seed_value) 76 | 77 | 78 | def rank_zero_only(fn): 79 | @wraps(fn) 80 | def wrapped_fn(*args, **kwargs): 81 | if rank_zero_only.rank == 0: 82 | return fn(*args, **kwargs) 83 | 84 | return wrapped_fn 85 | 86 | 87 | @rank_zero_only 88 | def rank_zero_warn(*args, stacklevel: int = 4, **kwargs): 89 | warnings.warn(*args, stacklevel=stacklevel, **kwargs) 90 | 91 | 92 | def reset_seed() -> None: 93 | """ 94 | Reset the seed to the value that :func:`seed_everything` previously set. 95 | If :func:`seed_everything` is unused, this function will do nothing. 96 | """ 97 | seed = os.environ.get("PL_GLOBAL_SEED", None) 98 | workers = os.environ.get("PL_SEED_WORKERS", False) 99 | if seed is not None: 100 | seed_everything(int(seed), workers=bool(workers)) 101 | 102 | 103 | def format_logs(logs): 104 | str_logs = ['{} - {:.4}'.format(k, v) for k, v in logs.items()] 105 | s = ', '.join(str_logs) 106 | return s 107 | 108 | 109 | def check_tensor(data, is_label): 110 | if not is_label: 111 | return data if data.ndim <= 4 else data.squeeze() 112 | return data.long() if data.ndim <= 3 else data.squeeze().long() 113 | -------------------------------------------------------------------------------- /change_detection_pytorch/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | from ._preprocessing import preprocess_input 7 | from .densenet import densenet_encoders 8 | from .dpn import dpn_encoders 9 | from .efficientnet import efficient_net_encoders 10 | from .inceptionresnetv2 import inceptionresnetv2_encoders 11 | from .inceptionv4 import inceptionv4_encoders 12 | from .mobilenet import mobilenet_encoders 13 | from .resnet import resnet_encoders 14 | from .senet import senet_encoders 15 | from .timm_efficientnet import timm_efficientnet_encoders 16 | from .timm_gernet import timm_gernet_encoders 17 | from .timm_mobilenetv3 import timm_mobilenetv3_encoders 18 | from .timm_regnet import timm_regnet_encoders 19 | from .timm_res2net import timm_res2net_encoders 20 | from .timm_resnest import timm_resnest_encoders 21 | from .timm_sknet import timm_sknet_encoders 22 | from .timm_universal import TimmUniversalEncoder 23 | from .vgg import vgg_encoders 24 | from .xception import xception_encoders 25 | from .swin_transformer import swin_transformer_encoders 26 | from .mit_encoder import mit_encoders 27 | # from .hrnet import hrnet_encoders 28 | 29 | DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 30 | 31 | encoders = {} 32 | encoders.update(resnet_encoders) 33 | encoders.update(dpn_encoders) 34 | encoders.update(vgg_encoders) 35 | encoders.update(senet_encoders) 36 | encoders.update(densenet_encoders) 37 | encoders.update(inceptionresnetv2_encoders) 38 | encoders.update(inceptionv4_encoders) 39 | encoders.update(efficient_net_encoders) 40 | encoders.update(mobilenet_encoders) 41 | encoders.update(xception_encoders) 42 | encoders.update(timm_efficientnet_encoders) 43 | encoders.update(timm_resnest_encoders) 44 | encoders.update(timm_res2net_encoders) 45 | encoders.update(timm_regnet_encoders) 46 | encoders.update(timm_sknet_encoders) 47 | encoders.update(timm_mobilenetv3_encoders) 48 | encoders.update(timm_gernet_encoders) 49 | encoders.update(swin_transformer_encoders) 50 | encoders.update(mit_encoders) 51 | # encoders.update(hrnet_encoders) 52 | 53 | 54 | def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **kwargs): 55 | 56 | if name.startswith("tu-"): 57 | name = name[3:] 58 | encoder = TimmUniversalEncoder( 59 | name=name, 60 | in_channels=in_channels, 61 | depth=depth, 62 | output_stride=output_stride, 63 | pretrained=weights is not None, 64 | **kwargs 65 | ) 66 | return encoder 67 | 68 | try: 69 | Encoder = encoders[name]["encoder"] 70 | except KeyError: 71 | raise KeyError("Wrong encoder name `{}`, supported encoders: {}".format(name, list(encoders.keys()))) 72 | 73 | params = encoders[name]["params"] 74 | params.update(depth=depth) 75 | encoder = Encoder(**params) 76 | 77 | if weights is not None: 78 | try: 79 | settings = encoders[name]["pretrained_settings"][weights] 80 | except KeyError: 81 | raise KeyError("Wrong pretrained weights `{}` for encoder `{}`. Available options are: {}".format( 82 | weights, name, list(encoders[name]["pretrained_settings"].keys()), 83 | )) 84 | encoder.load_state_dict(model_zoo.load_url(settings["url"], map_location=torch.device(DEVICE))) 85 | 86 | encoder.set_in_channels(in_channels, pretrained=weights is not None) 87 | if output_stride != 32: 88 | encoder.make_dilated(output_stride) 89 | 90 | return encoder 91 | 92 | 93 | def get_encoder_names(): 94 | return list(encoders.keys()) 95 | 96 | 97 | def get_preprocessing_params(encoder_name, pretrained="imagenet"): 98 | settings = encoders[encoder_name]["pretrained_settings"] 99 | 100 | if pretrained not in settings.keys(): 101 | raise ValueError("Available pretrained options {}".format(settings.keys())) 102 | 103 | formatted_settings = {} 104 | formatted_settings["input_space"] = settings[pretrained].get("input_space") 105 | formatted_settings["input_range"] = settings[pretrained].get("input_range") 106 | formatted_settings["mean"] = settings[pretrained].get("mean") 107 | formatted_settings["std"] = settings[pretrained].get("std") 108 | return formatted_settings 109 | 110 | 111 | def get_preprocessing_fn(encoder_name, pretrained="imagenet"): 112 | params = get_preprocessing_params(encoder_name, pretrained=pretrained) 113 | return functools.partial(preprocess_input, **params) 114 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Note: To use the 'upload' functionality of this file, you must: 5 | # $ pip install twine 6 | 7 | import io 8 | import os 9 | import sys 10 | from shutil import rmtree 11 | 12 | from setuptools import find_packages, setup, Command 13 | 14 | # Package meta-data. 15 | NAME = 'change_detection_pytorch' 16 | DESCRIPTION = 'Change detection models with pre-trained backbones. Inspired by segmentation_models.pytorch.' 17 | URL = 'https://github.com/likyoo/change_detection.pytorch' 18 | EMAIL = 'likyoo@sdust.edu.cn' 19 | AUTHOR = 'Kaiyu Li, Fulin Sun, Xudong Liu' 20 | REQUIRES_PYTHON = '>=3.0.0' 21 | VERSION = None 22 | 23 | # The rest you shouldn't have to touch too much :) 24 | # ------------------------------------------------ 25 | # Except, perhaps the License and Trove Classifiers! 26 | # If you do change the License, remember to change the Trove Classifier for that! 27 | 28 | here = os.path.abspath(os.path.dirname(__file__)) 29 | 30 | # What packages are required for this module to be executed? 31 | try: 32 | with open(os.path.join(here, 'requirements.txt'), encoding='utf-8') as f: 33 | REQUIRED = f.read().split('\n') 34 | except: 35 | REQUIRED = [ 36 | "torchvision>=0.5.0", 37 | "pretrainedmodels==0.7.4", 38 | "efficientnet-pytorch==0.6.3", 39 | "timm==0.4.12", 40 | "albumentations", 41 | ] 42 | 43 | # What packages are optional? 44 | EXTRAS = { 45 | 'test': ['pytest'] 46 | } 47 | 48 | # Import the README and use it as the long-description. 49 | # Note: this will only work if 'README.md' is present in your MANIFEST.in file! 50 | try: 51 | with io.open(os.path.join(here, 'README.md'), encoding='utf-8') as f: 52 | long_description = '\n' + f.read() 53 | except FileNotFoundError: 54 | long_description = DESCRIPTION 55 | 56 | # Load the package's __version__.py module as a dictionary. 57 | about = {} 58 | if not VERSION: 59 | with open(os.path.join(here, NAME, '__version__.py')) as f: 60 | exec(f.read(), about) 61 | else: 62 | about['__version__'] = VERSION 63 | 64 | 65 | class UploadCommand(Command): 66 | """Support setup.py upload.""" 67 | 68 | description = 'Build and publish the package.' 69 | user_options = [] 70 | 71 | @staticmethod 72 | def status(s): 73 | """Prints things in bold.""" 74 | print(s) 75 | 76 | def initialize_options(self): 77 | pass 78 | 79 | def finalize_options(self): 80 | pass 81 | 82 | def run(self): 83 | try: 84 | self.status('Removing previous builds...') 85 | rmtree(os.path.join(here, 'dist')) 86 | except OSError: 87 | pass 88 | 89 | self.status('Building Source and Wheel (universal) distribution...') 90 | os.system('{0} setup.py sdist bdist_wheel --universal'.format(sys.executable)) 91 | 92 | self.status('Uploading the package to PyPI via Twine...') 93 | os.system('twine upload dist/*') 94 | 95 | self.status('Pushing git tags...') 96 | os.system('git tag v{0}'.format(about['__version__'])) 97 | os.system('git push --tags') 98 | 99 | sys.exit() 100 | 101 | 102 | # Where the magic happens: 103 | setup( 104 | name=NAME, 105 | version=about['__version__'], 106 | description=DESCRIPTION, 107 | long_description=long_description, 108 | long_description_content_type='text/markdown', 109 | author=AUTHOR, 110 | author_email=EMAIL, 111 | python_requires=REQUIRES_PYTHON, 112 | url=URL, 113 | packages=find_packages(exclude=('tests', 'docs', 'images')), 114 | # If your package is a single module, use this instead of 'packages': 115 | # py_modules=['mypackage'], 116 | 117 | # entry_points={ 118 | # 'console_scripts': ['mycli=mymodule:cli'], 119 | # }, 120 | install_requires=REQUIRED, 121 | extras_require=EXTRAS, 122 | include_package_data=True, 123 | license='MIT', 124 | classifiers=[ 125 | # Trove classifiers 126 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers 127 | 'License :: OSI Approved :: MIT License', 128 | 'Programming Language :: Python', 129 | 'Programming Language :: Python :: 3', 130 | 'Programming Language :: Python :: Implementation :: CPython', 131 | 'Programming Language :: Python :: Implementation :: PyPy' 132 | ], 133 | # $ setup.py publish support. 134 | cmdclass={ 135 | 'upload': UploadCommand, 136 | }, 137 | ) 138 | -------------------------------------------------------------------------------- /change_detection_pytorch/utils/metrics.py: -------------------------------------------------------------------------------- 1 | from . import base 2 | from . import functional as F 3 | from ..base.modules import Activation 4 | 5 | 6 | class IoU(base.Metric): 7 | __name__ = 'iou_score' 8 | 9 | def __init__(self, eps=1e-7, threshold=0.5, activation=None, ignore_channels=None, **kwargs): 10 | super().__init__(**kwargs) 11 | self.eps = eps 12 | self.threshold = threshold 13 | self.activation = Activation(activation) 14 | self.ignore_channels = ignore_channels 15 | 16 | def forward(self, y_pr, y_gt): 17 | y_pr = self.activation(y_pr) 18 | return F.iou( 19 | y_pr, y_gt, 20 | eps=self.eps, 21 | threshold=self.threshold, 22 | ignore_channels=self.ignore_channels, 23 | ) 24 | 25 | 26 | class Fscore(base.Metric): 27 | 28 | def __init__(self, beta=1, eps=1e-7, threshold=0.5, activation=None, ignore_channels=None, **kwargs): 29 | super().__init__(**kwargs) 30 | self.eps = eps 31 | self.beta = beta 32 | self.threshold = threshold 33 | self.activation = Activation(activation) 34 | self.ignore_channels = ignore_channels 35 | 36 | def forward(self, y_pr, y_gt): 37 | y_pr = self.activation(y_pr) 38 | return F.f_score( 39 | y_pr, y_gt, 40 | eps=self.eps, 41 | beta=self.beta, 42 | threshold=self.threshold, 43 | ignore_channels=self.ignore_channels, 44 | ) 45 | 46 | 47 | class Accuracy(base.Metric): 48 | 49 | def __init__(self, threshold=0.5, activation=None, ignore_channels=None, **kwargs): 50 | super().__init__(**kwargs) 51 | self.threshold = threshold 52 | self.activation = Activation(activation) 53 | self.ignore_channels = ignore_channels 54 | 55 | def forward(self, y_pr, y_gt): 56 | y_pr = self.activation(y_pr) 57 | return F.accuracy( 58 | y_pr, y_gt, 59 | threshold=self.threshold, 60 | ignore_channels=self.ignore_channels, 61 | ) 62 | 63 | 64 | class Recall(base.Metric): 65 | 66 | def __init__(self, eps=1e-7, threshold=0.5, activation=None, ignore_channels=None, **kwargs): 67 | super().__init__(**kwargs) 68 | self.eps = eps 69 | self.threshold = threshold 70 | self.activation = Activation(activation) 71 | self.ignore_channels = ignore_channels 72 | 73 | def forward(self, y_pr, y_gt): 74 | y_pr = self.activation(y_pr) 75 | return F.recall( 76 | y_pr, y_gt, 77 | eps=self.eps, 78 | threshold=self.threshold, 79 | ignore_channels=self.ignore_channels, 80 | ) 81 | 82 | 83 | class Precision(base.Metric): 84 | 85 | def __init__(self, eps=1e-7, threshold=0.5, activation=None, ignore_channels=None, **kwargs): 86 | super().__init__(**kwargs) 87 | self.eps = eps 88 | self.threshold = threshold 89 | self.activation = Activation(activation) 90 | self.ignore_channels = ignore_channels 91 | 92 | def forward(self, y_pr, y_gt): 93 | y_pr = self.activation(y_pr) 94 | return F.precision( 95 | y_pr, y_gt, 96 | eps=self.eps, 97 | threshold=self.threshold, 98 | ignore_channels=self.ignore_channels, 99 | ) 100 | 101 | 102 | class Dice(base.Metric): 103 | 104 | def __init__(self, eps=1e-7, threshold=0.5, activation=None, ignore_channels=None, **kwargs): 105 | super().__init__(**kwargs) 106 | self.eps = eps 107 | self.threshold = threshold 108 | self.activation = Activation(activation) 109 | self.ignore_channels = ignore_channels 110 | 111 | def forward(self, y_pr, y_gt): 112 | y_pr = self.activation(y_pr) 113 | return F.dice( 114 | y_pr, y_gt, 115 | eps=self.eps, 116 | threshold=self.threshold, 117 | ignore_channels=self.ignore_channels, 118 | ) 119 | 120 | 121 | class Kappa(base.Metric): 122 | 123 | def __init__(self, eps=1e-7, threshold=0.5, activation=None, ignore_channels=None, **kwargs): 124 | super().__init__(**kwargs) 125 | self.eps = eps 126 | self.threshold = threshold 127 | self.activation = Activation(activation) 128 | self.ignore_channels = ignore_channels 129 | 130 | def forward(self, y_pr, y_gt): 131 | y_pr = self.activation(y_pr) 132 | return F.kappa( 133 | y_pr, y_gt, 134 | eps=self.eps, 135 | threshold=self.threshold, 136 | ignore_channels=self.ignore_channels, 137 | ) 138 | 139 | -------------------------------------------------------------------------------- /change_detection_pytorch/fpn/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ..base import Decoder 6 | 7 | 8 | class Conv3x3GNReLU(nn.Module): 9 | def __init__(self, in_channels, out_channels, upsample=False): 10 | super().__init__() 11 | self.upsample = upsample 12 | self.block = nn.Sequential( 13 | nn.Conv2d( 14 | in_channels, out_channels, (3, 3), stride=1, padding=1, bias=False 15 | ), 16 | nn.GroupNorm(32, out_channels), 17 | nn.ReLU(inplace=True), 18 | ) 19 | 20 | def forward(self, x): 21 | x = self.block(x) 22 | if self.upsample: 23 | x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) 24 | return x 25 | 26 | 27 | class FPNBlock(nn.Module): 28 | def __init__(self, pyramid_channels, skip_channels): 29 | super().__init__() 30 | self.skip_conv = nn.Conv2d(skip_channels, pyramid_channels, kernel_size=1) 31 | 32 | def forward(self, x, skip=None): 33 | x = F.interpolate(x, scale_factor=2, mode="nearest") 34 | skip = self.skip_conv(skip) 35 | x = x + skip 36 | return x 37 | 38 | 39 | class SegmentationBlock(nn.Module): 40 | def __init__(self, in_channels, out_channels, n_upsamples=0): 41 | super().__init__() 42 | 43 | blocks = [Conv3x3GNReLU(in_channels, out_channels, upsample=bool(n_upsamples))] 44 | 45 | if n_upsamples > 1: 46 | for _ in range(1, n_upsamples): 47 | blocks.append(Conv3x3GNReLU(out_channels, out_channels, upsample=True)) 48 | 49 | self.block = nn.Sequential(*blocks) 50 | 51 | def forward(self, x): 52 | return self.block(x) 53 | 54 | 55 | class MergeBlock(nn.Module): 56 | def __init__(self, policy): 57 | super().__init__() 58 | if policy not in ["add", "cat"]: 59 | raise ValueError( 60 | "`merge_policy` must be one of: ['add', 'cat'], got {}".format( 61 | policy 62 | ) 63 | ) 64 | self.policy = policy 65 | 66 | def forward(self, x): 67 | if self.policy == 'add': 68 | return sum(x) 69 | elif self.policy == 'cat': 70 | return torch.cat(x, dim=1) 71 | else: 72 | raise ValueError( 73 | "`merge_policy` must be one of: ['add', 'cat'], got {}".format(self.policy) 74 | ) 75 | 76 | 77 | class FPNDecoder(Decoder): 78 | def __init__( 79 | self, 80 | encoder_channels, 81 | encoder_depth=5, 82 | pyramid_channels=256, 83 | segmentation_channels=128, 84 | dropout=0.2, 85 | merge_policy="add", 86 | fusion_form="concat", 87 | ): 88 | super().__init__() 89 | 90 | self.out_channels = segmentation_channels if merge_policy == "add" else segmentation_channels * 4 91 | if encoder_depth < 3: 92 | raise ValueError("Encoder depth for FPN decoder cannot be less than 3, got {}.".format(encoder_depth)) 93 | 94 | encoder_channels = encoder_channels[::-1] 95 | encoder_channels = encoder_channels[:encoder_depth + 1] 96 | # (512, 256, 128, 64, 64, 3) 97 | 98 | # adjust encoder channels according to fusion form 99 | self.fusion_form = fusion_form 100 | if self.fusion_form in self.FUSION_DIC["2to2_fusion"]: 101 | encoder_channels = [ch*2 for ch in encoder_channels] 102 | 103 | self.p5 = nn.Conv2d(encoder_channels[0], pyramid_channels, kernel_size=1) 104 | self.p4 = FPNBlock(pyramid_channels, encoder_channels[1]) 105 | self.p3 = FPNBlock(pyramid_channels, encoder_channels[2]) 106 | self.p2 = FPNBlock(pyramid_channels, encoder_channels[3]) 107 | 108 | self.seg_blocks = nn.ModuleList([ 109 | SegmentationBlock(pyramid_channels, segmentation_channels, n_upsamples=n_upsamples) 110 | for n_upsamples in [3, 2, 1, 0] 111 | ]) 112 | 113 | self.merge = MergeBlock(merge_policy) 114 | self.dropout = nn.Dropout2d(p=dropout, inplace=True) 115 | 116 | def forward(self, *features): 117 | 118 | features = self.aggregation_layer(features[0], features[1], 119 | self.fusion_form, ignore_original_img=True) 120 | c2, c3, c4, c5 = features[-4:] 121 | 122 | p5 = self.p5(c5) 123 | p4 = self.p4(p5, c4) 124 | p3 = self.p3(p4, c3) 125 | p2 = self.p2(p3, c2) 126 | 127 | feature_pyramid = [seg_block(p) for seg_block, p in zip(self.seg_blocks, [p5, p4, p3, p2])] 128 | x = self.merge(feature_pyramid) 129 | x = self.dropout(x) 130 | 131 | return x 132 | -------------------------------------------------------------------------------- /change_detection_pytorch/unet/decoder.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from ..base import modules as md 7 | from ..base import Decoder 8 | 9 | 10 | class DecoderBlock(nn.Module): 11 | def __init__( 12 | self, 13 | in_channels, 14 | skip_channels, 15 | out_channels, 16 | use_batchnorm=True, 17 | attention_type=None, 18 | ): 19 | super().__init__() 20 | self.conv1 = md.Conv2dReLU( 21 | in_channels + skip_channels, 22 | out_channels, 23 | kernel_size=3, 24 | padding=1, 25 | use_batchnorm=use_batchnorm, 26 | ) 27 | self.attention1 = md.Attention(attention_type, in_channels=in_channels + skip_channels) 28 | self.conv2 = md.Conv2dReLU( 29 | out_channels, 30 | out_channels, 31 | kernel_size=3, 32 | padding=1, 33 | use_batchnorm=use_batchnorm, 34 | ) 35 | self.attention2 = md.Attention(attention_type, in_channels=out_channels) 36 | 37 | def forward(self, x, skip=None): 38 | x = F.interpolate(x, scale_factor=2, mode="nearest") 39 | 40 | if skip is not None: 41 | x = torch.cat([x, skip], dim=1) 42 | x = self.attention1(x) 43 | x = self.conv1(x) 44 | x = self.conv2(x) 45 | x = self.attention2(x) 46 | return x 47 | 48 | 49 | class CenterBlock(nn.Sequential): 50 | def __init__(self, in_channels, out_channels, use_batchnorm=True): 51 | conv1 = md.Conv2dReLU( 52 | in_channels, 53 | out_channels, 54 | kernel_size=3, 55 | padding=1, 56 | use_batchnorm=use_batchnorm, 57 | ) 58 | conv2 = md.Conv2dReLU( 59 | out_channels, 60 | out_channels, 61 | kernel_size=3, 62 | padding=1, 63 | use_batchnorm=use_batchnorm, 64 | ) 65 | super().__init__(conv1, conv2) 66 | 67 | 68 | class UnetDecoder(Decoder): 69 | def __init__( 70 | self, 71 | encoder_channels, 72 | decoder_channels, 73 | n_blocks=5, 74 | use_batchnorm=True, 75 | attention_type=None, 76 | center=False, 77 | fusion_form="concat", 78 | ): 79 | super().__init__() 80 | 81 | if n_blocks != len(decoder_channels): 82 | raise ValueError( 83 | "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format( 84 | n_blocks, len(decoder_channels) 85 | ) 86 | ) 87 | 88 | encoder_channels = encoder_channels[1:] # remove first skip with same spatial resolution 89 | encoder_channels = encoder_channels[::-1] # reverse channels to start from head of encoder 90 | 91 | # computing blocks input and output channels 92 | head_channels = encoder_channels[0] 93 | in_channels = [head_channels] + list(decoder_channels[:-1]) 94 | skip_channels = list(encoder_channels[1:]) + [0] 95 | out_channels = decoder_channels 96 | 97 | # adjust encoder channels according to fusion form 98 | self.fusion_form = fusion_form 99 | if self.fusion_form in self.FUSION_DIC["2to2_fusion"]: 100 | skip_channels = [ch*2 for ch in skip_channels] 101 | in_channels[0] = in_channels[0] * 2 102 | head_channels = head_channels * 2 103 | 104 | if center: 105 | self.center = CenterBlock( 106 | head_channels, head_channels, use_batchnorm=use_batchnorm 107 | ) 108 | else: 109 | self.center = nn.Identity() 110 | 111 | # combine decoder keyword arguments 112 | kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type) 113 | blocks = [ 114 | DecoderBlock(in_ch, skip_ch, out_ch, **kwargs) 115 | for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels) 116 | ] 117 | self.blocks = nn.ModuleList(blocks) 118 | 119 | def forward(self, *features): 120 | 121 | features = self.aggregation_layer(features[0], features[1], 122 | self.fusion_form, ignore_original_img=True) 123 | # features = features[1:] # remove first skip with same spatial resolution 124 | features = features[::-1] # reverse channels to start from head of encoder 125 | 126 | head = features[0] 127 | skips = features[1:] 128 | 129 | x = self.center(head) 130 | for i, decoder_block in enumerate(self.blocks): 131 | skip = skips[i] if i < len(skips) else None 132 | x = decoder_block(x, skip) 133 | 134 | return x 135 | -------------------------------------------------------------------------------- /change_detection_pytorch/pan/model.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | from ..base import ClassificationHead, SegmentationHead, SegmentationModel 4 | from ..encoders import get_encoder 5 | from .decoder import PANDecoder 6 | 7 | 8 | class PAN(SegmentationModel): 9 | """ Implementation of PAN_ (Pyramid Attention Network). 10 | 11 | Note: 12 | Currently works with shape of input tensor >= [B x C x 128 x 128] for pytorch <= 1.1.0 13 | and with shape of input tensor >= [B x C x 256 x 256] for pytorch == 1.3.1 14 | 15 | Args: 16 | encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) 17 | to extract features of different spatial resolution 18 | encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and 19 | other pretrained weights (see table with available weights for each encoder_name) 20 | encoder_output_stride: 16 or 32, if 16 use dilation in encoder last layer. 21 | Doesn't work with ***ception***, **vgg***, **densenet*`** backbones.Default is 16. 22 | decoder_channels: A number of convolution layer filters in decoder blocks 23 | in_channels: A number of input channels for the model, default is 3 (RGB images) 24 | classes: A number of classes for output mask (or you can think as a number of channels of output mask) 25 | activation: An activation function to apply after the final convolution layer. 26 | Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, **callable** and **None**. 27 | Default is **None** 28 | upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity 29 | aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build 30 | on top of encoder if **aux_params** is not **None** (default). Supported params: 31 | - classes (int): A number of classes 32 | - pooling (str): One of "max", "avg". Default is "avg" 33 | - dropout (float): Dropout factor in [0, 1) 34 | - activation (str): An activation function to apply "sigmoid"/"softmax" (could be **None** to return logits) 35 | siam_encoder: Whether using siamese branch. Default is True 36 | fusion_form: The form of fusing features from two branches. Available options are **"concat"**, **"sum"**, and **"diff"**. 37 | Default is **concat** 38 | 39 | Returns: 40 | ``torch.nn.Module``: **PAN** 41 | 42 | .. _PAN: 43 | https://arxiv.org/abs/1805.10180 44 | 45 | """ 46 | 47 | def __init__( 48 | self, 49 | encoder_name: str = "resnet34", 50 | encoder_weights: Optional[str] = "imagenet", 51 | encoder_output_stride: int = 16, 52 | decoder_channels: int = 32, 53 | in_channels: int = 3, 54 | classes: int = 1, 55 | activation: Optional[Union[str, callable]] = None, 56 | upsampling: int = 4, 57 | aux_params: Optional[dict] = None, 58 | siam_encoder: bool = True, 59 | fusion_form: str = "concat", 60 | **kwargs 61 | ): 62 | super().__init__() 63 | 64 | if encoder_output_stride not in [16, 32]: 65 | raise ValueError("PAN support output stride 16 or 32, got {}".format(encoder_output_stride)) 66 | 67 | self.siam_encoder = siam_encoder 68 | 69 | self.encoder = get_encoder( 70 | encoder_name, 71 | in_channels=in_channels, 72 | depth=5, 73 | weights=encoder_weights, 74 | output_stride=encoder_output_stride, 75 | ) 76 | 77 | if not self.siam_encoder: 78 | self.encoder_non_siam = get_encoder( 79 | encoder_name, 80 | in_channels=in_channels, 81 | depth=5, 82 | weights=encoder_weights, 83 | output_stride=encoder_output_stride, 84 | ) 85 | 86 | self.decoder = PANDecoder( 87 | encoder_channels=self.encoder.out_channels, 88 | decoder_channels=decoder_channels, 89 | fusion_form=fusion_form, 90 | ) 91 | 92 | self.segmentation_head = SegmentationHead( 93 | in_channels=decoder_channels, 94 | out_channels=classes, 95 | activation=activation, 96 | kernel_size=3, 97 | upsampling=upsampling 98 | ) 99 | 100 | if aux_params is not None: 101 | self.classification_head = ClassificationHead( 102 | in_channels=self.encoder.out_channels[-1], **aux_params 103 | ) 104 | else: 105 | self.classification_head = None 106 | 107 | self.name = "pan-{}".format(encoder_name) 108 | self.initialize() 109 | -------------------------------------------------------------------------------- /change_detection_pytorch/encoders/timm_gernet.py: -------------------------------------------------------------------------------- 1 | from timm.models import ByoModelCfg, ByoBlockCfg, ByobNet 2 | 3 | from ._base import EncoderMixin 4 | import torch.nn as nn 5 | 6 | 7 | class GERNetEncoder(ByobNet, EncoderMixin): 8 | def __init__(self, out_channels, depth=5, **kwargs): 9 | super().__init__(**kwargs) 10 | self._depth = depth 11 | self._out_channels = out_channels 12 | self._in_channels = 3 13 | 14 | del self.head 15 | 16 | def get_stages(self): 17 | return [ 18 | nn.Identity(), 19 | self.stem, 20 | self.stages[0], 21 | self.stages[1], 22 | self.stages[2], 23 | nn.Sequential(self.stages[3], self.stages[4], self.final_conv) 24 | ] 25 | 26 | def forward(self, x): 27 | stages = self.get_stages() 28 | 29 | features = [] 30 | for i in range(self._depth + 1): 31 | x = stages[i](x) 32 | features.append(x) 33 | 34 | return features 35 | 36 | def load_state_dict(self, state_dict, **kwargs): 37 | state_dict.pop("head.fc.weight", None) 38 | state_dict.pop("head.fc.bias", None) 39 | super().load_state_dict(state_dict, **kwargs) 40 | 41 | 42 | regnet_weights = { 43 | 'timm-gernet_s': { 44 | 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_s-756b4751.pth', 45 | }, 46 | 'timm-gernet_m': { 47 | 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_m-0873c53a.pth', 48 | }, 49 | 'timm-gernet_l': { 50 | 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_l-f31e2e8d.pth', 51 | }, 52 | } 53 | 54 | pretrained_settings = {} 55 | for model_name, sources in regnet_weights.items(): 56 | pretrained_settings[model_name] = {} 57 | for source_name, source_url in sources.items(): 58 | pretrained_settings[model_name][source_name] = { 59 | "url": source_url, 60 | 'input_range': [0, 1], 61 | 'mean': [0.485, 0.456, 0.406], 62 | 'std': [0.229, 0.224, 0.225], 63 | 'num_classes': 1000 64 | } 65 | 66 | timm_gernet_encoders = { 67 | 'timm-gernet_s': { 68 | 'encoder': GERNetEncoder, 69 | "pretrained_settings": pretrained_settings["timm-gernet_s"], 70 | 'params': { 71 | 'out_channels': (3, 13, 48, 48, 384, 1920), 72 | 'cfg': ByoModelCfg( 73 | blocks=( 74 | ByoBlockCfg(type='basic', d=1, c=48, s=2, gs=0, br=1.), 75 | ByoBlockCfg(type='basic', d=3, c=48, s=2, gs=0, br=1.), 76 | ByoBlockCfg(type='bottle', d=7, c=384, s=2, gs=0, br=1 / 4), 77 | ByoBlockCfg(type='bottle', d=2, c=560, s=2, gs=1, br=3.), 78 | ByoBlockCfg(type='bottle', d=1, c=256, s=1, gs=1, br=3.), 79 | ), 80 | stem_chs=13, 81 | stem_pool=None, 82 | num_features=1920, 83 | ) 84 | }, 85 | }, 86 | 'timm-gernet_m': { 87 | 'encoder': GERNetEncoder, 88 | "pretrained_settings": pretrained_settings["timm-gernet_m"], 89 | 'params': { 90 | 'out_channels': (3, 32, 128, 192, 640, 2560), 91 | 'cfg': ByoModelCfg( 92 | blocks=( 93 | ByoBlockCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.), 94 | ByoBlockCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.), 95 | ByoBlockCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4), 96 | ByoBlockCfg(type='bottle', d=4, c=640, s=2, gs=1, br=3.), 97 | ByoBlockCfg(type='bottle', d=1, c=640, s=1, gs=1, br=3.), 98 | ), 99 | stem_chs=32, 100 | stem_pool=None, 101 | num_features=2560, 102 | ) 103 | }, 104 | }, 105 | 'timm-gernet_l': { 106 | 'encoder': GERNetEncoder, 107 | "pretrained_settings": pretrained_settings["timm-gernet_l"], 108 | 'params': { 109 | 'out_channels': (3, 32, 128, 192, 640, 2560), 110 | 'cfg': ByoModelCfg( 111 | blocks=( 112 | ByoBlockCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.), 113 | ByoBlockCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.), 114 | ByoBlockCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4), 115 | ByoBlockCfg(type='bottle', d=5, c=640, s=2, gs=1, br=3.), 116 | ByoBlockCfg(type='bottle', d=4, c=640, s=1, gs=1, br=3.), 117 | ), 118 | stem_chs=32, 119 | stem_pool=None, 120 | num_features=2560, 121 | ) 122 | }, 123 | }, 124 | } 125 | -------------------------------------------------------------------------------- /change_detection_pytorch/linknet/model.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | from ..base import ClassificationHead, SegmentationHead, SegmentationModel 4 | from ..encoders import get_encoder 5 | from .decoder import LinknetDecoder 6 | 7 | 8 | class Linknet(SegmentationModel): 9 | """Linknet_ is a fully convolution neural network for image semantic segmentation. Consist of *encoder* 10 | and *decoder* parts connected with *skip connections*. Encoder extract features of different spatial 11 | resolution (skip connections) which are used by decoder to define accurate segmentation mask. Use *sum* 12 | for fusing decoder blocks with skip connections. 13 | 14 | Note: 15 | This implementation by default has 4 skip connections (original - 3). 16 | 17 | Args: 18 | encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) 19 | to extract features of different spatial resolution 20 | encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features 21 | two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features 22 | with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). 23 | Default is 5 24 | encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and 25 | other pretrained weights (see table with available weights for each encoder_name) 26 | decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers 27 | is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. 28 | Available options are **True, False, "inplace"** 29 | in_channels: A number of input channels for the model, default is 3 (RGB images) 30 | classes: A number of classes for output mask (or you can think as a number of channels of output mask) 31 | activation: An activation function to apply after the final convolution layer. 32 | Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, **callable** and **None**. 33 | Default is **None** 34 | aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build 35 | on top of encoder if **aux_params** is not **None** (default). Supported params: 36 | - classes (int): A number of classes 37 | - pooling (str): One of "max", "avg". Default is "avg" 38 | - dropout (float): Dropout factor in [0, 1) 39 | - activation (str): An activation function to apply "sigmoid"/"softmax" (could be **None** to return logits) 40 | siam_encoder: Whether using siamese branch. Default is True 41 | fusion_form: The form of fusing features from two branches. Available options are **"concat"**, **"sum"**, and **"diff"**. 42 | Default is **concat** 43 | 44 | Returns: 45 | ``torch.nn.Module``: **Linknet** 46 | 47 | .. _Linknet: 48 | https://arxiv.org/abs/1707.03718 49 | """ 50 | 51 | def __init__( 52 | self, 53 | encoder_name: str = "resnet34", 54 | encoder_depth: int = 5, 55 | encoder_weights: Optional[str] = "imagenet", 56 | decoder_use_batchnorm: bool = True, 57 | in_channels: int = 3, 58 | classes: int = 1, 59 | activation: Optional[Union[str, callable]] = None, 60 | aux_params: Optional[dict] = None, 61 | siam_encoder: bool = True, 62 | fusion_form: str = "concat", 63 | **kwargs 64 | ): 65 | super().__init__() 66 | 67 | self.siam_encoder = siam_encoder 68 | 69 | self.encoder = get_encoder( 70 | encoder_name, 71 | in_channels=in_channels, 72 | depth=encoder_depth, 73 | weights=encoder_weights, 74 | ) 75 | 76 | if not self.siam_encoder: 77 | self.encoder_non_siam = get_encoder( 78 | encoder_name, 79 | in_channels=in_channels, 80 | depth=encoder_depth, 81 | weights=encoder_weights, 82 | ) 83 | 84 | self.decoder = LinknetDecoder( 85 | encoder_channels=self.encoder.out_channels, 86 | n_blocks=encoder_depth, 87 | prefinal_channels=32, 88 | use_batchnorm=decoder_use_batchnorm, 89 | fusion_form=fusion_form, 90 | ) 91 | 92 | self.segmentation_head = SegmentationHead( 93 | in_channels=32, out_channels=classes, activation=activation, kernel_size=1 94 | ) 95 | 96 | if aux_params is not None: 97 | self.classification_head = ClassificationHead( 98 | in_channels=self.encoder.out_channels[-1], **aux_params 99 | ) 100 | else: 101 | self.classification_head = None 102 | 103 | self.name = "link-{}".format(encoder_name) 104 | self.initialize() 105 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import mock 4 | import pytest 5 | import torch 6 | 7 | # mock detection module 8 | sys.modules["torchvision._C"] = mock.Mock() 9 | import change_detection_pytorch as smp 10 | 11 | 12 | def get_encoders(): 13 | exclude_encoders = [ 14 | "senet154", 15 | "resnext101_32x16d", 16 | "resnext101_32x32d", 17 | "resnext101_32x48d", 18 | ] 19 | encoders = smp.encoders.get_encoder_names() 20 | encoders = [e for e in encoders if e not in exclude_encoders] 21 | encoders.append("tu-resnet34") # for timm universal encoder 22 | return encoders 23 | 24 | 25 | ENCODERS = get_encoders() 26 | DEFAULT_ENCODER = "resnet18" 27 | 28 | 29 | def get_sample(model_class): 30 | if model_class in [smp.Unet, smp.Linknet, smp.FPN, smp.PSPNet, smp.UnetPlusPlus, smp.MAnet]: 31 | sample = torch.ones([1, 3, 64, 64]) 32 | elif model_class == smp.PAN: 33 | sample = torch.ones([2, 3, 256, 256]) 34 | elif model_class == smp.DeepLabV3: 35 | sample = torch.ones([2, 3, 128, 128]) 36 | else: 37 | raise ValueError("Not supported model class {}".format(model_class)) 38 | return sample 39 | 40 | 41 | def _test_forward(model, sample, test_shape=False): 42 | with torch.no_grad(): 43 | out = model(sample) 44 | if test_shape: 45 | assert out.shape[2:] == sample.shape[2:] 46 | 47 | 48 | def _test_forward_backward(model, sample, test_shape=False): 49 | out = model(sample) 50 | out.mean().backward() 51 | if test_shape: 52 | assert out.shape[2:] == sample.shape[2:] 53 | 54 | 55 | @pytest.mark.parametrize("encoder_name", ENCODERS) 56 | @pytest.mark.parametrize("encoder_depth", [3, 5]) 57 | @pytest.mark.parametrize("model_class", [smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus]) 58 | def test_forward(model_class, encoder_name, encoder_depth, **kwargs): 59 | if model_class is smp.Unet or model_class is smp.UnetPlusPlus or model_class is smp.MAnet: 60 | kwargs["decoder_channels"] = (16, 16, 16, 16, 16)[-encoder_depth:] 61 | model = model_class( 62 | encoder_name, encoder_depth=encoder_depth, encoder_weights=None, **kwargs 63 | ) 64 | sample = get_sample(model_class) 65 | model.eval() 66 | if encoder_depth == 5 and model_class != smp.PSPNet: 67 | test_shape = True 68 | else: 69 | test_shape = False 70 | 71 | _test_forward(model, sample, test_shape) 72 | 73 | 74 | @pytest.mark.parametrize( 75 | "model_class", 76 | [smp.PAN, smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus, smp.MAnet, smp.DeepLabV3] 77 | ) 78 | def test_forward_backward(model_class): 79 | sample = get_sample(model_class) 80 | model = model_class(DEFAULT_ENCODER, encoder_weights=None) 81 | _test_forward_backward(model, sample) 82 | 83 | 84 | @pytest.mark.parametrize("model_class", [smp.PAN, smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus, smp.MAnet]) 85 | def test_aux_output(model_class): 86 | model = model_class( 87 | DEFAULT_ENCODER, encoder_weights=None, aux_params=dict(classes=2) 88 | ) 89 | sample = get_sample(model_class) 90 | label_size = (sample.shape[0], 2) 91 | mask, label = model(sample) 92 | assert label.size() == label_size 93 | 94 | 95 | @pytest.mark.parametrize("upsampling", [2, 4, 8]) 96 | @pytest.mark.parametrize("model_class", [smp.FPN, smp.PSPNet]) 97 | def test_upsample(model_class, upsampling): 98 | default_upsampling = 4 if model_class is smp.FPN else 8 99 | model = model_class(DEFAULT_ENCODER, encoder_weights=None, upsampling=upsampling) 100 | sample = get_sample(model_class) 101 | mask = model(sample) 102 | assert mask.size()[-1] / 64 == upsampling / default_upsampling 103 | 104 | 105 | @pytest.mark.parametrize("model_class", [smp.FPN]) 106 | @pytest.mark.parametrize("encoder_name", ENCODERS) 107 | @pytest.mark.parametrize("in_channels", [1, 2, 4]) 108 | def test_in_channels(model_class, encoder_name, in_channels): 109 | sample = torch.ones([1, in_channels, 64, 64]) 110 | model = model_class(DEFAULT_ENCODER, encoder_weights=None, in_channels=in_channels) 111 | model.eval() 112 | with torch.no_grad(): 113 | model(sample) 114 | 115 | assert model.encoder._in_channels == in_channels 116 | 117 | 118 | @pytest.mark.parametrize("encoder_name", ENCODERS) 119 | def test_dilation(encoder_name): 120 | if ( 121 | encoder_name in ['inceptionresnetv2', 'xception', 'inceptionv4'] or 122 | encoder_name.startswith('vgg') or 123 | encoder_name.startswith('densenet') or 124 | encoder_name.startswith('timm-res') 125 | ): 126 | return 127 | 128 | encoder = smp.encoders.get_encoder(encoder_name, output_stride=16) 129 | 130 | encoder.eval() 131 | with torch.no_grad(): 132 | sample = torch.ones([1, 3, 64, 64]) 133 | output = encoder(sample) 134 | 135 | shapes = [out.shape[-1] for out in output] 136 | assert shapes == [64, 32, 16, 8, 4, 4] # last downsampling replaced with dilation 137 | 138 | 139 | if __name__ == "__main__": 140 | pytest.main([__file__]) 141 | -------------------------------------------------------------------------------- /change_detection_pytorch/fpn/model.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | from ..base import ClassificationHead, SegmentationHead, SegmentationModel 4 | from ..encoders import get_encoder 5 | from .decoder import FPNDecoder 6 | 7 | 8 | class FPN(SegmentationModel): 9 | """FPN_ is a fully convolution neural network for image semantic segmentation. 10 | 11 | Args: 12 | encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) 13 | to extract features of different spatial resolution 14 | encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features 15 | two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features 16 | with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). 17 | Default is 5 18 | encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and 19 | other pretrained weights (see table with available weights for each encoder_name) 20 | decoder_pyramid_channels: A number of convolution filters in Feature Pyramid of FPN_ 21 | decoder_segmentation_channels: A number of convolution filters in segmentation blocks of FPN_ 22 | decoder_merge_policy: Determines how to merge pyramid features inside FPN. Available options are **add** and **cat** 23 | decoder_dropout: Spatial dropout rate in range (0, 1) for feature pyramid in FPN_ 24 | in_channels: A number of input channels for the model, default is 3 (RGB images) 25 | classes: A number of classes for output mask (or you can think as a number of channels of output mask) 26 | activation: An activation function to apply after the final convolution layer. 27 | Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, **callable** and **None**. 28 | Default is **None** 29 | upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity 30 | aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build 31 | on top of encoder if **aux_params** is not **None** (default). Supported params: 32 | - classes (int): A number of classes 33 | - pooling (str): One of "max", "avg". Default is "avg" 34 | - dropout (float): Dropout factor in [0, 1) 35 | - activation (str): An activation function to apply "sigmoid"/"softmax" (could be **None** to return logits) 36 | siam_encoder: Whether using siamese branch. Default is True 37 | fusion_form: The form of fusing features from two branches. Available options are **"concat"**, **"sum"**, and **"diff"**. 38 | Default is **concat** 39 | 40 | Returns: 41 | ``torch.nn.Module``: **FPN** 42 | 43 | .. _FPN: 44 | http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf 45 | 46 | """ 47 | 48 | def __init__( 49 | self, 50 | encoder_name: str = "resnet34", 51 | encoder_depth: int = 5, 52 | encoder_weights: Optional[str] = "imagenet", 53 | decoder_pyramid_channels: int = 256, 54 | decoder_segmentation_channels: int = 128, 55 | decoder_merge_policy: str = "add", 56 | decoder_dropout: float = 0.2, 57 | in_channels: int = 3, 58 | classes: int = 1, 59 | activation: Optional[str] = None, 60 | upsampling: int = 4, 61 | aux_params: Optional[dict] = None, 62 | siam_encoder: bool = True, 63 | fusion_form: str = "concat", 64 | **kwargs 65 | ): 66 | super().__init__() 67 | 68 | self.siam_encoder = siam_encoder 69 | 70 | self.encoder = get_encoder( 71 | encoder_name, 72 | in_channels=in_channels, 73 | depth=encoder_depth, 74 | weights=encoder_weights, 75 | ) 76 | 77 | if not self.siam_encoder: 78 | self.encoder_non_siam = get_encoder( 79 | encoder_name, 80 | in_channels=in_channels, 81 | depth=encoder_depth, 82 | weights=encoder_weights, 83 | ) 84 | 85 | self.decoder = FPNDecoder( 86 | encoder_channels=self.encoder.out_channels, 87 | encoder_depth=encoder_depth, 88 | pyramid_channels=decoder_pyramid_channels, 89 | segmentation_channels=decoder_segmentation_channels, 90 | dropout=decoder_dropout, 91 | merge_policy=decoder_merge_policy, 92 | fusion_form=fusion_form, 93 | ) 94 | 95 | self.segmentation_head = SegmentationHead( 96 | in_channels=self.decoder.out_channels, 97 | out_channels=classes, 98 | activation=activation, 99 | kernel_size=1, 100 | upsampling=upsampling, 101 | ) 102 | 103 | if aux_params is not None: 104 | self.classification_head = ClassificationHead( 105 | in_channels=self.encoder.out_channels[-1], **aux_params 106 | ) 107 | else: 108 | self.classification_head = None 109 | 110 | self.name = "fpn-{}".format(encoder_name) 111 | self.initialize() 112 | -------------------------------------------------------------------------------- /change_detection_pytorch/pspnet/model.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | from ..base import ClassificationHead, SegmentationHead, SegmentationModel 4 | from ..encoders import get_encoder 5 | from .decoder import PSPDecoder 6 | 7 | 8 | class PSPNet(SegmentationModel): 9 | """PSPNet_ is a fully convolution neural network for image semantic segmentation. Consist of 10 | *encoder* and *Spatial Pyramid* (decoder). Spatial Pyramid build on top of encoder and does not 11 | use "fine-features" (features of high spatial resolution). PSPNet can be used for multiclass segmentation 12 | of high resolution images, however it is not good for detecting small objects and producing accurate, pixel-level mask. 13 | 14 | Args: 15 | encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) 16 | to extract features of different spatial resolution 17 | encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features 18 | two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features 19 | with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). 20 | Default is 5 21 | encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and 22 | other pretrained weights (see table with available weights for each encoder_name) 23 | psp_out_channels: A number of filters in Spatial Pyramid 24 | psp_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers 25 | is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. 26 | Available options are **True, False, "inplace"** 27 | psp_dropout: Spatial dropout rate in [0, 1) used in Spatial Pyramid 28 | in_channels: A number of input channels for the model, default is 3 (RGB images) 29 | classes: A number of classes for output mask (or you can think as a number of channels of output mask) 30 | activation: An activation function to apply after the final convolution layer. 31 | Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, **callable** and **None**. 32 | Default is **None** 33 | upsampling: Final upsampling factor. Default is 8 to preserve input-output spatial shape identity 34 | aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build 35 | on top of encoder if **aux_params** is not **None** (default). Supported params: 36 | - classes (int): A number of classes 37 | - pooling (str): One of "max", "avg". Default is "avg" 38 | - dropout (float): Dropout factor in [0, 1) 39 | - activation (str): An activation function to apply "sigmoid"/"softmax" (could be **None** to return logits) 40 | siam_encoder: Whether using siamese branch. Default is True 41 | fusion_form: The form of fusing features from two branches. Available options are **"concat"**, **"sum"**, and **"diff"**. 42 | Default is **concat** 43 | 44 | Returns: 45 | ``torch.nn.Module``: **PSPNet** 46 | 47 | .. _PSPNet: 48 | https://arxiv.org/abs/1612.01105 49 | """ 50 | 51 | def __init__( 52 | self, 53 | encoder_name: str = "resnet34", 54 | encoder_weights: Optional[str] = "imagenet", 55 | encoder_depth: int = 3, 56 | psp_out_channels: int = 512, 57 | psp_use_batchnorm: bool = True, 58 | psp_dropout: float = 0.2, 59 | in_channels: int = 3, 60 | classes: int = 1, 61 | activation: Optional[Union[str, callable]] = None, 62 | upsampling: int = 8, 63 | aux_params: Optional[dict] = None, 64 | siam_encoder: bool = True, 65 | fusion_form: str = "concat", 66 | **kwargs 67 | ): 68 | super().__init__() 69 | 70 | self.siam_encoder = siam_encoder 71 | 72 | self.encoder = get_encoder( 73 | encoder_name, 74 | in_channels=in_channels, 75 | depth=encoder_depth, 76 | weights=encoder_weights, 77 | ) 78 | 79 | if not self.siam_encoder: 80 | self.encoder_non_siam = get_encoder( 81 | encoder_name, 82 | in_channels=in_channels, 83 | depth=encoder_depth, 84 | weights=encoder_weights, 85 | ) 86 | 87 | self.decoder = PSPDecoder( 88 | encoder_channels=self.encoder.out_channels, 89 | use_batchnorm=psp_use_batchnorm, 90 | out_channels=psp_out_channels, 91 | dropout=psp_dropout, 92 | fusion_form=fusion_form, 93 | ) 94 | 95 | self.segmentation_head = SegmentationHead( 96 | in_channels=psp_out_channels, 97 | out_channels=classes, 98 | kernel_size=3, 99 | activation=activation, 100 | upsampling=upsampling, 101 | ) 102 | 103 | if aux_params: 104 | self.classification_head = ClassificationHead( 105 | in_channels=self.encoder.out_channels[-1], **aux_params 106 | ) 107 | else: 108 | self.classification_head = None 109 | 110 | self.name = "psp-{}".format(encoder_name) 111 | self.initialize() 112 | -------------------------------------------------------------------------------- /change_detection_pytorch/losses/dice.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.nn.modules.loss import _Loss 6 | from ._functional import soft_dice_score, to_tensor 7 | from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE 8 | 9 | __all__ = ["DiceLoss"] 10 | 11 | 12 | class DiceLoss(_Loss): 13 | __name__ = "DiceLoss" 14 | 15 | def __init__( 16 | self, 17 | mode: str, 18 | classes: Optional[List[int]] = None, 19 | log_loss: bool = False, 20 | from_logits: bool = True, 21 | smooth: float = 0.0, 22 | ignore_index: Optional[int] = None, 23 | eps: float = 1e-7, 24 | ): 25 | """Implementation of Dice loss for image segmentation task. 26 | It supports binary, multiclass and multilabel cases 27 | 28 | Args: 29 | mode: Loss mode 'binary', 'multiclass' or 'multilabel' 30 | classes: List of classes that contribute in loss computation. By default, all channels are included. 31 | log_loss: If True, loss computed as `- log(dice_coeff)`, otherwise `1 - dice_coeff` 32 | from_logits: If True, assumes input is raw logits 33 | smooth: Smoothness constant for dice coefficient (a) 34 | ignore_index: Label that indicates ignored pixels (does not contribute to loss) 35 | eps: A small epsilon for numerical stability to avoid zero division error 36 | (denominator will be always greater or equal to eps) 37 | 38 | Shape 39 | - **y_pred** - torch.Tensor of shape (N, C, H, W) 40 | - **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W) 41 | 42 | Reference 43 | https://github.com/BloodAxe/pytorch-toolbelt 44 | """ 45 | assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE} 46 | super(DiceLoss, self).__init__() 47 | self.mode = mode 48 | if classes is not None: 49 | assert mode != BINARY_MODE, "Masking classes is not supported with mode=binary" 50 | classes = to_tensor(classes, dtype=torch.long) 51 | 52 | self.classes = classes 53 | self.from_logits = from_logits 54 | self.smooth = smooth 55 | self.eps = eps 56 | self.log_loss = log_loss 57 | self.ignore_index = ignore_index 58 | 59 | def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: 60 | 61 | assert y_true.size(0) == y_pred.size(0) 62 | 63 | if self.from_logits: 64 | # Apply activations to get [0..1] class probabilities 65 | # Using Log-Exp as this gives more numerically stable result and does not cause vanishing gradient on 66 | # extreme values 0 and 1 67 | if self.mode == MULTICLASS_MODE: 68 | y_pred = y_pred.log_softmax(dim=1).exp() 69 | else: 70 | y_pred = F.logsigmoid(y_pred).exp() 71 | 72 | bs = y_true.size(0) 73 | num_classes = y_pred.size(1) 74 | dims = (0, 2) 75 | 76 | if self.mode == BINARY_MODE: 77 | y_true = y_true.view(bs, 1, -1) 78 | y_pred = y_pred.view(bs, 1, -1) 79 | 80 | if self.ignore_index is not None: 81 | mask = y_true != self.ignore_index 82 | y_pred = y_pred * mask 83 | y_true = y_true * mask 84 | 85 | if self.mode == MULTICLASS_MODE: 86 | y_true = y_true.view(bs, -1) 87 | y_pred = y_pred.view(bs, num_classes, -1) 88 | 89 | if self.ignore_index is not None: 90 | mask = y_true != self.ignore_index 91 | y_pred = y_pred * mask.unsqueeze(1) 92 | 93 | y_true = F.one_hot((y_true * mask).to(torch.long), num_classes) # N,H*W -> N,H*W, C 94 | y_true = y_true.permute(0, 2, 1) * mask.unsqueeze(1) # H, C, H*W 95 | else: 96 | y_true = F.one_hot(y_true, num_classes) # N,H*W -> N,H*W, C 97 | y_true = y_true.permute(0, 2, 1) # H, C, H*W 98 | 99 | if self.mode == MULTILABEL_MODE: 100 | y_true = y_true.view(bs, num_classes, -1) 101 | y_pred = y_pred.view(bs, num_classes, -1) 102 | 103 | if self.ignore_index is not None: 104 | mask = y_true != self.ignore_index 105 | y_pred = y_pred * mask 106 | y_true = y_true * mask 107 | 108 | scores = self.compute_score(y_pred, y_true.type_as(y_pred), smooth=self.smooth, eps=self.eps, dims=dims) 109 | 110 | if self.log_loss: 111 | loss = -torch.log(scores.clamp_min(self.eps)) 112 | else: 113 | loss = 1.0 - scores 114 | 115 | # Dice loss is undefined for non-empty classes 116 | # So we zero contribution of channel that does not have true pixels 117 | # NOTE: A better workaround would be to use loss term `mean(y_pred)` 118 | # for this case, however it will be a modified jaccard loss 119 | 120 | mask = y_true.sum(dims) > 0 121 | loss *= mask.to(loss.dtype) 122 | 123 | if self.classes is not None: 124 | loss = loss[self.classes] 125 | 126 | return self.aggregate_loss(loss) 127 | 128 | def aggregate_loss(self, loss): 129 | return loss.mean() 130 | 131 | def compute_score(self, output, target, smooth=0.0, eps=1e-7, dims=None) -> torch.Tensor: 132 | return soft_dice_score(output, target, smooth, eps, dims) 133 | -------------------------------------------------------------------------------- /change_detection_pytorch/upernet/model.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | from ..base import ClassificationHead, SegmentationHead, SegmentationModel 4 | from ..encoders import get_encoder 5 | from .decoder import UPerNetDecoder 6 | 7 | 8 | class UPerNet(SegmentationModel): 9 | """UPerNet_ is a fully convolution neural network for image semantic segmentation. 10 | 11 | Args: 12 | encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) 13 | to extract features of different spatial resolution 14 | encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features 15 | two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features 16 | with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). 17 | Default is 5 18 | encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and 19 | other pretrained weights (see table with available weights for each encoder_name) 20 | decoder_psp_channels: A number of filters in Spatial Pyramid 21 | decoder_pyramid_channels: A number of convolution filters in Feature Pyramid of FPN_ 22 | decoder_segmentation_channels: A number of convolution filters in segmentation blocks of FPN_ 23 | decoder_merge_policy: Determines how to merge pyramid features inside FPN. Available options are **add** and **cat** 24 | decoder_dropout: Spatial dropout rate in range (0, 1) for feature pyramid in FPN_ 25 | in_channels: A number of input channels for the model, default is 3 (RGB images) 26 | classes: A number of classes for output mask (or you can think as a number of channels of output mask) 27 | activation: An activation function to apply after the final convolution layer. 28 | Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, **callable** and **None**. 29 | Default is **None** 30 | upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity 31 | aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build 32 | on top of encoder if **aux_params** is not **None** (default). Supported params: 33 | - classes (int): A number of classes 34 | - pooling (str): One of "max", "avg". Default is "avg" 35 | - dropout (float): Dropout factor in [0, 1) 36 | - activation (str): An activation function to apply "sigmoid"/"softmax" (could be **None** to return logits) 37 | siam_encoder: Whether using siamese branch. Default is True 38 | fusion_form: The form of fusing features from two branches. Available options are **"concat"**, **"sum"**, and **"diff"**. 39 | Default is **concat** 40 | 41 | Returns: 42 | ``torch.nn.Module``: **UPerNet** 43 | 44 | .. _UPerNet: 45 | https://arxiv.org/abs/1807.10221 46 | 47 | """ 48 | 49 | def __init__( 50 | self, 51 | encoder_name: str = "resnet34", 52 | encoder_depth: int = 5, 53 | encoder_weights: Optional[str] = "imagenet", 54 | decoder_psp_channels: int = 512, 55 | decoder_pyramid_channels: int = 256, 56 | decoder_segmentation_channels: int = 256, 57 | decoder_merge_policy: str = "add", 58 | decoder_dropout: float = 0.2, 59 | in_channels: int = 3, 60 | classes: int = 1, 61 | activation: Optional[str] = None, 62 | upsampling: int = 4, 63 | aux_params: Optional[dict] = None, 64 | siam_encoder: bool = True, 65 | fusion_form: str = "concat", 66 | **kwargs 67 | ): 68 | super().__init__() 69 | 70 | self.siam_encoder = siam_encoder 71 | 72 | self.encoder = get_encoder( 73 | encoder_name, 74 | in_channels=in_channels, 75 | depth=encoder_depth, 76 | weights=encoder_weights, 77 | ) 78 | 79 | if not self.siam_encoder: 80 | self.encoder_non_siam = get_encoder( 81 | encoder_name, 82 | in_channels=in_channels, 83 | depth=encoder_depth, 84 | weights=encoder_weights, 85 | ) 86 | 87 | self.decoder = UPerNetDecoder( 88 | encoder_channels=self.encoder.out_channels, 89 | encoder_depth=encoder_depth, 90 | psp_channels=decoder_psp_channels, 91 | pyramid_channels=decoder_pyramid_channels, 92 | segmentation_channels=decoder_segmentation_channels, 93 | dropout=decoder_dropout, 94 | merge_policy=decoder_merge_policy, 95 | fusion_form=fusion_form, 96 | ) 97 | 98 | self.segmentation_head = SegmentationHead( 99 | in_channels=self.decoder.out_channels, 100 | out_channels=classes, 101 | activation=activation, 102 | kernel_size=1, 103 | upsampling=upsampling, 104 | align_corners=False, 105 | ) 106 | 107 | if aux_params is not None: 108 | self.classification_head = ClassificationHead( 109 | in_channels=self.encoder.out_channels[-1], **aux_params 110 | ) 111 | else: 112 | self.classification_head = None 113 | 114 | self.name = "upernet-{}".format(encoder_name) 115 | self.initialize() 116 | -------------------------------------------------------------------------------- /change_detection_pytorch/upernet/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ..base import Decoder, modules 6 | 7 | 8 | class PSPBlock(nn.Module): 9 | 10 | def __init__(self, in_channels, out_channels, pool_size, use_bathcnorm=True): 11 | super().__init__() 12 | if pool_size == 1: 13 | use_bathcnorm = False # PyTorch does not support BatchNorm for 1x1 shape 14 | self.pool = nn.Sequential( 15 | nn.AdaptiveAvgPool2d(output_size=(pool_size, pool_size)), 16 | modules.Conv2dReLU(in_channels, out_channels, (1, 1), use_batchnorm=use_bathcnorm) 17 | ) 18 | 19 | def forward(self, x): 20 | h, w = x.size(2), x.size(3) 21 | x = self.pool(x) 22 | x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=False) 23 | return x 24 | 25 | 26 | class PSPModule(nn.Module): 27 | def __init__(self, in_channels, out_channels, sizes=(1, 2, 3, 6), use_bathcnorm=True): 28 | super().__init__() 29 | 30 | self.blocks = nn.ModuleList([ 31 | PSPBlock(in_channels, out_channels, size, use_bathcnorm=use_bathcnorm) for size in sizes 32 | ]) 33 | 34 | def forward(self, x): 35 | xs = [block(x) for block in self.blocks] + [x] 36 | x = torch.cat(xs, dim=1) 37 | return x 38 | 39 | 40 | class FPNBlock(nn.Module): 41 | def __init__(self, pyramid_channels, skip_channels): 42 | super().__init__() 43 | self.skip_conv = nn.Sequential( 44 | nn.Conv2d(skip_channels, pyramid_channels, kernel_size=1, bias=False), 45 | nn.BatchNorm2d(pyramid_channels), # adjust to "SynchronizedBatchNorm2d" if you need. 46 | nn.ReLU(inplace=True) 47 | ) 48 | 49 | def forward(self, x, skip=None): 50 | x = F.interpolate(x, scale_factor=2, mode="nearest") 51 | skip = self.skip_conv(skip) 52 | x = x + skip 53 | return x 54 | 55 | 56 | class MergeBlock(nn.Module): 57 | def __init__(self, policy): 58 | super().__init__() 59 | if policy not in ["add", "cat"]: 60 | raise ValueError( 61 | "`merge_policy` must be one of: ['add', 'cat'], got {}".format( 62 | policy 63 | ) 64 | ) 65 | self.policy = policy 66 | 67 | def forward(self, x): 68 | if self.policy == 'add': 69 | return sum(x) 70 | elif self.policy == 'cat': 71 | return torch.cat(x, dim=1) 72 | else: 73 | raise ValueError( 74 | "`merge_policy` must be one of: ['add', 'cat'], got {}".format(self.policy) 75 | ) 76 | 77 | 78 | class UPerNetDecoder(Decoder): 79 | def __init__( 80 | self, 81 | encoder_channels, 82 | encoder_depth=5, 83 | psp_channels=512, 84 | pyramid_channels=256, 85 | segmentation_channels=128, 86 | dropout=0.2, 87 | merge_policy="cat", 88 | fusion_form="concat", 89 | ): 90 | super().__init__() 91 | 92 | self.out_channels = segmentation_channels if merge_policy == "add" else segmentation_channels * 4 93 | if encoder_depth < 3: 94 | raise ValueError("Encoder depth for UPerNet decoder cannot be less than 3, got {}.".format(encoder_depth)) 95 | 96 | encoder_channels = encoder_channels[::-1] 97 | encoder_channels = encoder_channels[:encoder_depth + 1] 98 | 99 | # adjust encoder channels according to fusion form 100 | self.fusion_form = fusion_form 101 | if self.fusion_form in self.FUSION_DIC["2to2_fusion"]: 102 | encoder_channels = [ch*2 for ch in encoder_channels] 103 | 104 | self.psp = PSPModule( 105 | in_channels=encoder_channels[0], 106 | out_channels=psp_channels, 107 | sizes=(1, 2, 3, 6), 108 | use_bathcnorm=True, 109 | ) 110 | 111 | self.psp_last_conv = modules.Conv2dReLU( 112 | in_channels=psp_channels * len((1, 2, 3, 6)) + encoder_channels[0], 113 | out_channels=pyramid_channels, 114 | kernel_size=1, 115 | use_batchnorm=True, 116 | ) 117 | 118 | self.p5 = nn.Conv2d(encoder_channels[0], pyramid_channels, kernel_size=1) 119 | self.p4 = FPNBlock(pyramid_channels, encoder_channels[1]) 120 | self.p3 = FPNBlock(pyramid_channels, encoder_channels[2]) 121 | self.p2 = FPNBlock(pyramid_channels, encoder_channels[3]) 122 | 123 | self.merge = MergeBlock(merge_policy) 124 | 125 | self.conv_last = modules.Conv2dReLU(self.out_channels, pyramid_channels, 1) 126 | self.dropout = nn.Dropout2d(p=dropout, inplace=True) 127 | 128 | def forward(self, *features): 129 | 130 | features = self.aggregation_layer(features[0], features[1], 131 | self.fusion_form, ignore_original_img=True) 132 | c2, c3, c4, c5 = features[-4:] 133 | 134 | c5 = self.psp(c5) 135 | p5 = self.psp_last_conv(c5) 136 | 137 | p4 = self.p4(p5, c4) 138 | p3 = self.p3(p4, c3) 139 | p2 = self.p2(p3, c2) 140 | 141 | output_size = p2.size()[2:] 142 | feature_pyramid = [nn.functional.interpolate(p, output_size, 143 | mode='bilinear', align_corners=False) for p in [p5, p4, p3, p2]] 144 | x = self.merge(feature_pyramid) 145 | x = self.conv_last(x) 146 | # x = self.dropout(x) 147 | 148 | return x 149 | -------------------------------------------------------------------------------- /change_detection_pytorch/manet/model.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union 2 | 3 | from ..base import ClassificationHead, SegmentationHead, SegmentationModel 4 | from ..encoders import get_encoder 5 | from .decoder import MAnetDecoder 6 | 7 | 8 | class MAnet(SegmentationModel): 9 | """MAnet_ : Multi-scale Attention Net. The MA-Net can capture rich contextual dependencies based on the attention mechanism, 10 | using two blocks: 11 | - Position-wise Attention Block (PAB), which captures the spatial dependencies between pixels in a global view 12 | - Multi-scale Fusion Attention Block (MFAB), which captures the channel dependencies between any feature map by 13 | multi-scale semantic feature fusion 14 | 15 | Args: 16 | encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) 17 | to extract features of different spatial resolution 18 | encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features 19 | two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features 20 | with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). 21 | Default is 5 22 | encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and 23 | other pretrained weights (see table with available weights for each encoder_name) 24 | decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder. 25 | Length of the list should be the same as **encoder_depth** 26 | decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers 27 | is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. 28 | Available options are **True, False, "inplace"** 29 | decoder_pab_channels: A number of channels for PAB module in decoder. 30 | Default is 64. 31 | in_channels: A number of input channels for the model, default is 3 (RGB images) 32 | classes: A number of classes for output mask (or you can think as a number of channels of output mask) 33 | activation: An activation function to apply after the final convolution layer. 34 | Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, **callable** and **None**. 35 | Default is **None** 36 | aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build 37 | on top of encoder if **aux_params** is not **None** (default). Supported params: 38 | - classes (int): A number of classes 39 | - pooling (str): One of "max", "avg". Default is "avg" 40 | - dropout (float): Dropout factor in [0, 1) 41 | - activation (str): An activation function to apply "sigmoid"/"softmax" (could be **None** to return logits) 42 | siam_encoder: Whether using siamese branch. Default is True 43 | fusion_form: The form of fusing features from two branches. Available options are **"concat"**, **"sum"**, and **"diff"**. 44 | Default is **concat** 45 | 46 | Returns: 47 | ``torch.nn.Module``: **MAnet** 48 | 49 | .. _MAnet: 50 | https://ieeexplore.ieee.org/abstract/document/9201310 51 | 52 | """ 53 | 54 | def __init__( 55 | self, 56 | encoder_name: str = "resnet34", 57 | encoder_depth: int = 5, 58 | encoder_weights: Optional[str] = "imagenet", 59 | decoder_use_batchnorm: bool = True, 60 | decoder_channels: List[int] = (256, 128, 64, 32, 16), 61 | decoder_pab_channels: int = 64, 62 | in_channels: int = 3, 63 | classes: int = 1, 64 | activation: Optional[Union[str, callable]] = None, 65 | aux_params: Optional[dict] = None, 66 | siam_encoder: bool = True, 67 | fusion_form: str = "concat", 68 | **kwargs 69 | ): 70 | super().__init__() 71 | 72 | self.siam_encoder = siam_encoder 73 | 74 | self.encoder = get_encoder( 75 | encoder_name, 76 | in_channels=in_channels, 77 | depth=encoder_depth, 78 | weights=encoder_weights, 79 | ) 80 | 81 | if not self.siam_encoder: 82 | self.encoder_non_siam = get_encoder( 83 | encoder_name, 84 | in_channels=in_channels, 85 | depth=encoder_depth, 86 | weights=encoder_weights, 87 | ) 88 | 89 | self.decoder = MAnetDecoder( 90 | encoder_channels=self.encoder.out_channels, 91 | decoder_channels=decoder_channels, 92 | n_blocks=encoder_depth, 93 | use_batchnorm=decoder_use_batchnorm, 94 | pab_channels=decoder_pab_channels, 95 | fusion_form=fusion_form, 96 | ) 97 | 98 | self.segmentation_head = SegmentationHead( 99 | in_channels=decoder_channels[-1], 100 | out_channels=classes, 101 | activation=activation, 102 | kernel_size=3, 103 | ) 104 | 105 | if aux_params is not None: 106 | self.classification_head = ClassificationHead( 107 | in_channels=self.encoder.out_channels[-1], **aux_params 108 | ) 109 | else: 110 | self.classification_head = None 111 | 112 | self.name = "manet-{}".format(encoder_name) 113 | self.initialize() 114 | -------------------------------------------------------------------------------- /change_detection_pytorch/encoders/densenet.py: -------------------------------------------------------------------------------- 1 | """ Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` 2 | 3 | Attributes: 4 | 5 | _out_channels (list of int): specify number of channels for each encoder feature tensor 6 | _depth (int): specify number of stages in decoder (in other words number of downsampling operations) 7 | _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) 8 | 9 | Methods: 10 | 11 | forward(self, x: torch.Tensor) 12 | produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of 13 | shape NCHW (features should be sorted in descending order according to spatial resolution, starting 14 | with resolution same as input `x` tensor). 15 | 16 | Input: `x` with shape (1, 3, 64, 64) 17 | Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes 18 | [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), 19 | (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) 20 | 21 | also should support number of features according to specified depth, e.g. if depth = 5, 22 | number of feature tensors = 6 (one with same resolution as input and 5 downsampled), 23 | depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). 24 | """ 25 | 26 | import re 27 | import torch.nn as nn 28 | 29 | from pretrainedmodels.models.torchvision_models import pretrained_settings 30 | from torchvision.models.densenet import DenseNet 31 | 32 | from ._base import EncoderMixin 33 | 34 | 35 | class TransitionWithSkip(nn.Module): 36 | 37 | def __init__(self, module): 38 | super().__init__() 39 | self.module = module 40 | 41 | def forward(self, x): 42 | for module in self.module: 43 | x = module(x) 44 | if isinstance(module, nn.ReLU): 45 | skip = x 46 | return x, skip 47 | 48 | 49 | class DenseNetEncoder(DenseNet, EncoderMixin): 50 | def __init__(self, out_channels, depth=5, **kwargs): 51 | super().__init__(**kwargs) 52 | self._out_channels = out_channels 53 | self._depth = depth 54 | self._in_channels = 3 55 | del self.classifier 56 | 57 | def make_dilated(self, output_stride): 58 | raise ValueError("DenseNet encoders do not support dilated mode " 59 | "due to pooling operation for downsampling!") 60 | 61 | def get_stages(self): 62 | return [ 63 | nn.Identity(), 64 | nn.Sequential(self.features.conv0, self.features.norm0, self.features.relu0), 65 | nn.Sequential(self.features.pool0, self.features.denseblock1, 66 | TransitionWithSkip(self.features.transition1)), 67 | nn.Sequential(self.features.denseblock2, TransitionWithSkip(self.features.transition2)), 68 | nn.Sequential(self.features.denseblock3, TransitionWithSkip(self.features.transition3)), 69 | nn.Sequential(self.features.denseblock4, self.features.norm5) 70 | ] 71 | 72 | def forward(self, x): 73 | 74 | stages = self.get_stages() 75 | 76 | features = [] 77 | for i in range(self._depth + 1): 78 | x = stages[i](x) 79 | if isinstance(x, (list, tuple)): 80 | x, skip = x 81 | features.append(skip) 82 | else: 83 | features.append(x) 84 | 85 | return features 86 | 87 | def load_state_dict(self, state_dict): 88 | pattern = re.compile( 89 | r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" 90 | ) 91 | for key in list(state_dict.keys()): 92 | res = pattern.match(key) 93 | if res: 94 | new_key = res.group(1) + res.group(2) 95 | state_dict[new_key] = state_dict[key] 96 | del state_dict[key] 97 | 98 | # remove linear 99 | state_dict.pop("classifier.bias", None) 100 | state_dict.pop("classifier.weight", None) 101 | 102 | super().load_state_dict(state_dict) 103 | 104 | 105 | densenet_encoders = { 106 | "densenet121": { 107 | "encoder": DenseNetEncoder, 108 | "pretrained_settings": pretrained_settings["densenet121"], 109 | "params": { 110 | "out_channels": (3, 64, 256, 512, 1024, 1024), 111 | "num_init_features": 64, 112 | "growth_rate": 32, 113 | "block_config": (6, 12, 24, 16), 114 | }, 115 | }, 116 | "densenet169": { 117 | "encoder": DenseNetEncoder, 118 | "pretrained_settings": pretrained_settings["densenet169"], 119 | "params": { 120 | "out_channels": (3, 64, 256, 512, 1280, 1664), 121 | "num_init_features": 64, 122 | "growth_rate": 32, 123 | "block_config": (6, 12, 32, 32), 124 | }, 125 | }, 126 | "densenet201": { 127 | "encoder": DenseNetEncoder, 128 | "pretrained_settings": pretrained_settings["densenet201"], 129 | "params": { 130 | "out_channels": (3, 64, 256, 512, 1792, 1920), 131 | "num_init_features": 64, 132 | "growth_rate": 32, 133 | "block_config": (6, 12, 48, 32), 134 | }, 135 | }, 136 | "densenet161": { 137 | "encoder": DenseNetEncoder, 138 | "pretrained_settings": pretrained_settings["densenet161"], 139 | "params": { 140 | "out_channels": (3, 96, 384, 768, 2112, 2208), 141 | "num_init_features": 96, 142 | "growth_rate": 48, 143 | "block_config": (6, 12, 36, 24), 144 | }, 145 | }, 146 | } 147 | -------------------------------------------------------------------------------- /change_detection_pytorch/unetplusplus/model.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union, List 2 | from .decoder import UnetPlusPlusDecoder 3 | from ..encoders import get_encoder 4 | from ..base import SegmentationModel 5 | from ..base import SegmentationHead, ClassificationHead 6 | 7 | 8 | class UnetPlusPlus(SegmentationModel): 9 | """Unet++ is a fully convolution neural network for image semantic segmentation. Consist of *encoder* 10 | and *decoder* parts connected with *skip connections*. Encoder extract features of different spatial 11 | resolution (skip connections) which are used by decoder to define accurate segmentation mask. Decoder of 12 | Unet++ is more complex than in usual Unet. 13 | 14 | Args: 15 | encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) 16 | to extract features of different spatial resolution 17 | encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features 18 | two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features 19 | with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). 20 | Default is 5 21 | encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and 22 | other pretrained weights (see table with available weights for each encoder_name) 23 | decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder. 24 | Length of the list should be the same as **encoder_depth** 25 | decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers 26 | is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. 27 | Available options are **True, False, "inplace"** 28 | decoder_attention_type: Attention module used in decoder of the model. Available options are **None** and **scse**. 29 | SCSE paper - https://arxiv.org/abs/1808.08127 30 | in_channels: A number of input channels for the model, default is 3 (RGB images) 31 | classes: A number of classes for output mask (or you can think as a number of channels of output mask) 32 | activation: An activation function to apply after the final convolution layer. 33 | Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, **callable** and **None**. 34 | Default is **None** 35 | aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build 36 | on top of encoder if **aux_params** is not **None** (default). Supported params: 37 | - classes (int): A number of classes 38 | - pooling (str): One of "max", "avg". Default is "avg" 39 | - dropout (float): Dropout factor in [0, 1) 40 | - activation (str): An activation function to apply "sigmoid"/"softmax" (could be **None** to return logits) 41 | siam_encoder: Whether using siamese branch. Default is True 42 | fusion_form: The form of fusing features from two branches. Available options are **"concat"**, **"sum"**, and **"diff"**. 43 | Default is **concat** 44 | seg_ensemble: The module used to get ensemble result of UNet++. Available options are **None** and **ecam**. 45 | ECAM paper - https://ieeexplore.ieee.org/document/9355573 46 | 47 | Returns: 48 | ``torch.nn.Module``: **Unet++** 49 | 50 | Reference: 51 | https://arxiv.org/abs/1807.10165 52 | 53 | """ 54 | 55 | def __init__( 56 | self, 57 | encoder_name: str = "resnet34", 58 | encoder_depth: int = 5, 59 | encoder_weights: Optional[str] = "imagenet", 60 | decoder_use_batchnorm: bool = True, 61 | decoder_channels: List[int] = (256, 128, 64, 32, 16), 62 | decoder_attention_type: Optional[str] = None, 63 | in_channels: int = 3, 64 | classes: int = 1, 65 | activation: Optional[Union[str, callable]] = None, 66 | aux_params: Optional[dict] = None, 67 | siam_encoder: bool = True, 68 | fusion_form: str = "diff", 69 | seg_ensemble: Optional[str] = None, 70 | **kwargs 71 | ): 72 | super().__init__() 73 | 74 | self.siam_encoder = siam_encoder 75 | 76 | self.encoder = get_encoder( 77 | encoder_name, 78 | in_channels=in_channels, 79 | depth=encoder_depth, 80 | weights=encoder_weights, 81 | ) 82 | 83 | if not self.siam_encoder: 84 | self.encoder_non_siam = get_encoder( 85 | encoder_name, 86 | in_channels=in_channels, 87 | depth=encoder_depth, 88 | weights=encoder_weights, 89 | ) 90 | 91 | self.decoder = UnetPlusPlusDecoder( 92 | encoder_channels=self.encoder.out_channels, 93 | decoder_channels=decoder_channels, 94 | n_blocks=encoder_depth, 95 | use_batchnorm=decoder_use_batchnorm, 96 | center=True if encoder_name.startswith("vgg") else False, 97 | attention_type=decoder_attention_type, 98 | fusion_form=fusion_form, 99 | seg_ensemble=seg_ensemble, 100 | ) 101 | 102 | self.segmentation_head = SegmentationHead( 103 | in_channels=decoder_channels[-1], 104 | out_channels=classes, 105 | activation=activation, 106 | kernel_size=3, 107 | ) 108 | 109 | if aux_params is not None: 110 | self.classification_head = ClassificationHead( 111 | in_channels=self.encoder.out_channels[-1], **aux_params 112 | ) 113 | else: 114 | self.classification_head = None 115 | 116 | self.name = "unetplusplus-{}".format(encoder_name) 117 | self.initialize() 118 | -------------------------------------------------------------------------------- /change_detection_pytorch/unet/model.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union, List 2 | from .decoder import UnetDecoder 3 | from ..encoders import get_encoder 4 | from ..base import SegmentationModel 5 | from ..base import SegmentationHead, ClassificationHead 6 | 7 | 8 | class Unet(SegmentationModel): 9 | """Unet_ is a fully convolution neural network for image semantic segmentation. Consist of *encoder* 10 | and *decoder* parts connected with *skip connections*. Encoder extract features of different spatial 11 | resolution (skip connections) which are used by decoder to define accurate segmentation mask. Use *concatenation* 12 | for fusing decoder blocks with skip connections. 13 | 14 | Args: 15 | encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) 16 | to extract features of different spatial resolution 17 | encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features 18 | two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features 19 | with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). 20 | Default is 5 21 | encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and 22 | other pretrained weights (see table with available weights for each encoder_name) 23 | decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder. 24 | Length of the list should be the same as **encoder_depth** 25 | decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers 26 | is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. 27 | Available options are **True, False, "inplace"** 28 | decoder_attention_type: Attention module used in decoder of the model. Available options are **None** and **scse**. 29 | SCSE paper - https://arxiv.org/abs/1808.08127 30 | in_channels: A number of input channels for the model, default is 3 (RGB images) 31 | classes: A number of classes for output mask (or you can think as a number of channels of output mask) 32 | activation: An activation function to apply after the final convolution layer. 33 | Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, **callable** and **None**. 34 | Default is **None** 35 | aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build 36 | on top of encoder if **aux_params** is not **None** (default). Supported params: 37 | - classes (int): A number of classes 38 | - pooling (str): One of "max", "avg". Default is "avg" 39 | - dropout (float): Dropout factor in [0, 1) 40 | - activation (str): An activation function to apply "sigmoid"/"softmax" (could be **None** to return logits) 41 | siam_encoder: Whether using siamese branch. Default is True 42 | fusion_form: The form of fusing features from two branches. Available options are **"concat"**, **"sum"**, and **"diff"**. 43 | Default is **concat** 44 | 45 | Returns: 46 | ``torch.nn.Module``: Unet 47 | 48 | .. _Unet: 49 | https://arxiv.org/abs/1505.04597 50 | 51 | """ 52 | 53 | def __init__( 54 | self, 55 | encoder_name: str = "resnet34", 56 | encoder_depth: int = 5, 57 | encoder_weights: Optional[str] = "imagenet", 58 | decoder_use_batchnorm: bool = True, 59 | decoder_channels: List[int] = (256, 128, 64, 32, 16), 60 | decoder_attention_type: Optional[str] = None, 61 | in_channels: int = 3, 62 | classes: int = 1, 63 | activation: Optional[Union[str, callable]] = None, 64 | aux_params: Optional[dict] = None, 65 | siam_encoder: bool = True, 66 | fusion_form: str = "concat", 67 | **kwargs 68 | ): 69 | super().__init__() 70 | 71 | self.siam_encoder = siam_encoder 72 | 73 | self.encoder = get_encoder( 74 | encoder_name, 75 | in_channels=in_channels, 76 | depth=encoder_depth, 77 | weights=encoder_weights, 78 | ) 79 | 80 | if not self.siam_encoder: 81 | self.encoder_non_siam = get_encoder( 82 | encoder_name, 83 | in_channels=in_channels, 84 | depth=encoder_depth, 85 | weights=encoder_weights, 86 | ) 87 | 88 | self.decoder = UnetDecoder( 89 | encoder_channels=self.encoder.out_channels, 90 | decoder_channels=decoder_channels, 91 | n_blocks=encoder_depth, 92 | use_batchnorm=decoder_use_batchnorm, 93 | center=True if encoder_name.startswith("vgg") else False, 94 | attention_type=decoder_attention_type, 95 | fusion_form=fusion_form, 96 | ) 97 | 98 | self.segmentation_head = SegmentationHead( 99 | in_channels=decoder_channels[-1], 100 | out_channels=classes, 101 | activation=activation, 102 | kernel_size=3, 103 | ) 104 | 105 | if aux_params is not None: 106 | self.classification_head = ClassificationHead( 107 | in_channels=self.encoder.out_channels[-1], **aux_params 108 | ) 109 | else: 110 | self.classification_head = None 111 | 112 | self.name = "u-{}".format(encoder_name) 113 | self.initialize() 114 | 115 | if __name__ == "__main__": 116 | import torch 117 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 118 | input1 = torch.randn(1, 3, 256, 256).to(device) 119 | input2 = torch.randn(1, 3, 256, 256).to(device) 120 | net = Unet() 121 | res = net.forward(input1, input2) 122 | print(res.shape) 123 | -------------------------------------------------------------------------------- /change_detection_pytorch/encoders/vgg.py: -------------------------------------------------------------------------------- 1 | """ Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` 2 | 3 | Attributes: 4 | 5 | _out_channels (list of int): specify number of channels for each encoder feature tensor 6 | _depth (int): specify number of stages in decoder (in other words number of downsampling operations) 7 | _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) 8 | 9 | Methods: 10 | 11 | forward(self, x: torch.Tensor) 12 | produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of 13 | shape NCHW (features should be sorted in descending order according to spatial resolution, starting 14 | with resolution same as input `x` tensor). 15 | 16 | Input: `x` with shape (1, 3, 64, 64) 17 | Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes 18 | [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), 19 | (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) 20 | 21 | also should support number of features according to specified depth, e.g. if depth = 5, 22 | number of feature tensors = 6 (one with same resolution as input and 5 downsampled), 23 | depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). 24 | """ 25 | 26 | import torch.nn as nn 27 | from torchvision.models.vgg import VGG 28 | from torchvision.models.vgg import make_layers 29 | from pretrainedmodels.models.torchvision_models import pretrained_settings 30 | 31 | from ._base import EncoderMixin 32 | 33 | # fmt: off 34 | cfg = { 35 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 36 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 37 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 38 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 39 | } 40 | # fmt: on 41 | 42 | 43 | class VGGEncoder(VGG, EncoderMixin): 44 | def __init__(self, out_channels, config, batch_norm=False, depth=5, **kwargs): 45 | super().__init__(make_layers(config, batch_norm=batch_norm), **kwargs) 46 | self._out_channels = out_channels 47 | self._depth = depth 48 | self._in_channels = 3 49 | del self.classifier 50 | 51 | def make_dilated(self, output_stride): 52 | raise ValueError("'VGG' models do not support dilated mode due to Max Pooling" 53 | " operations for downsampling!") 54 | 55 | def get_stages(self): 56 | stages = [] 57 | stage_modules = [] 58 | for module in self.features: 59 | if isinstance(module, nn.MaxPool2d): 60 | stages.append(nn.Sequential(*stage_modules)) 61 | stage_modules = [] 62 | stage_modules.append(module) 63 | stages.append(nn.Sequential(*stage_modules)) 64 | return stages 65 | 66 | def forward(self, x): 67 | stages = self.get_stages() 68 | 69 | features = [] 70 | for i in range(self._depth + 1): 71 | x = stages[i](x) 72 | features.append(x) 73 | 74 | return features 75 | 76 | def load_state_dict(self, state_dict, **kwargs): 77 | keys = list(state_dict.keys()) 78 | for k in keys: 79 | if k.startswith("classifier"): 80 | state_dict.pop(k, None) 81 | super().load_state_dict(state_dict, **kwargs) 82 | 83 | 84 | vgg_encoders = { 85 | "vgg11": { 86 | "encoder": VGGEncoder, 87 | "pretrained_settings": pretrained_settings["vgg11"], 88 | "params": { 89 | "out_channels": (64, 128, 256, 512, 512, 512), 90 | "config": cfg["A"], 91 | "batch_norm": False, 92 | }, 93 | }, 94 | "vgg11_bn": { 95 | "encoder": VGGEncoder, 96 | "pretrained_settings": pretrained_settings["vgg11_bn"], 97 | "params": { 98 | "out_channels": (64, 128, 256, 512, 512, 512), 99 | "config": cfg["A"], 100 | "batch_norm": True, 101 | }, 102 | }, 103 | "vgg13": { 104 | "encoder": VGGEncoder, 105 | "pretrained_settings": pretrained_settings["vgg13"], 106 | "params": { 107 | "out_channels": (64, 128, 256, 512, 512, 512), 108 | "config": cfg["B"], 109 | "batch_norm": False, 110 | }, 111 | }, 112 | "vgg13_bn": { 113 | "encoder": VGGEncoder, 114 | "pretrained_settings": pretrained_settings["vgg13_bn"], 115 | "params": { 116 | "out_channels": (64, 128, 256, 512, 512, 512), 117 | "config": cfg["B"], 118 | "batch_norm": True, 119 | }, 120 | }, 121 | "vgg16": { 122 | "encoder": VGGEncoder, 123 | "pretrained_settings": pretrained_settings["vgg16"], 124 | "params": { 125 | "out_channels": (64, 128, 256, 512, 512, 512), 126 | "config": cfg["D"], 127 | "batch_norm": False, 128 | }, 129 | }, 130 | "vgg16_bn": { 131 | "encoder": VGGEncoder, 132 | "pretrained_settings": pretrained_settings["vgg16_bn"], 133 | "params": { 134 | "out_channels": (64, 128, 256, 512, 512, 512), 135 | "config": cfg["D"], 136 | "batch_norm": True, 137 | }, 138 | }, 139 | "vgg19": { 140 | "encoder": VGGEncoder, 141 | "pretrained_settings": pretrained_settings["vgg19"], 142 | "params": { 143 | "out_channels": (64, 128, 256, 512, 512, 512), 144 | "config": cfg["E"], 145 | "batch_norm": False, 146 | }, 147 | }, 148 | "vgg19_bn": { 149 | "encoder": VGGEncoder, 150 | "pretrained_settings": pretrained_settings["vgg19_bn"], 151 | "params": { 152 | "out_channels": (64, 128, 256, 512, 512, 512), 153 | "config": cfg["E"], 154 | "batch_norm": True, 155 | }, 156 | }, 157 | } 158 | -------------------------------------------------------------------------------- /change_detection_pytorch/encoders/timm_res2net.py: -------------------------------------------------------------------------------- 1 | from ._base import EncoderMixin 2 | from timm.models.resnet import ResNet 3 | from timm.models.res2net import Bottle2neck 4 | import torch.nn as nn 5 | 6 | 7 | class Res2NetEncoder(ResNet, EncoderMixin): 8 | def __init__(self, out_channels, depth=5, **kwargs): 9 | super().__init__(**kwargs) 10 | self._depth = depth 11 | self._out_channels = out_channels 12 | self._in_channels = 3 13 | 14 | del self.fc 15 | del self.global_pool 16 | 17 | def get_stages(self): 18 | return [ 19 | nn.Identity(), 20 | nn.Sequential(self.conv1, self.bn1, self.act1), 21 | nn.Sequential(self.maxpool, self.layer1), 22 | self.layer2, 23 | self.layer3, 24 | self.layer4, 25 | ] 26 | 27 | def make_dilated(self, output_stride): 28 | raise ValueError("Res2Net encoders do not support dilated mode") 29 | 30 | def forward(self, x): 31 | stages = self.get_stages() 32 | 33 | features = [] 34 | for i in range(self._depth + 1): 35 | x = stages[i](x) 36 | features.append(x) 37 | 38 | return features 39 | 40 | def load_state_dict(self, state_dict, **kwargs): 41 | state_dict.pop("fc.bias", None) 42 | state_dict.pop("fc.weight", None) 43 | super().load_state_dict(state_dict, **kwargs) 44 | 45 | 46 | res2net_weights = { 47 | 'timm-res2net50_26w_4s': { 48 | 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_4s-06e79181.pth' 49 | }, 50 | 'timm-res2net50_48w_2s': { 51 | 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_48w_2s-afed724a.pth' 52 | }, 53 | 'timm-res2net50_14w_8s': { 54 | 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_14w_8s-6527dddc.pth', 55 | }, 56 | 'timm-res2net50_26w_6s': { 57 | 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_6s-19041792.pth', 58 | }, 59 | 'timm-res2net50_26w_8s': { 60 | 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_8s-2c7c9f12.pth', 61 | }, 62 | 'timm-res2net101_26w_4s': { 63 | 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net101_26w_4s-02a759a1.pth', 64 | }, 65 | 'timm-res2next50': { 66 | 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2next50_4s-6ef7e7bf.pth', 67 | } 68 | } 69 | 70 | pretrained_settings = {} 71 | for model_name, sources in res2net_weights.items(): 72 | pretrained_settings[model_name] = {} 73 | for source_name, source_url in sources.items(): 74 | pretrained_settings[model_name][source_name] = { 75 | "url": source_url, 76 | 'input_size': [3, 224, 224], 77 | 'input_range': [0, 1], 78 | 'mean': [0.485, 0.456, 0.406], 79 | 'std': [0.229, 0.224, 0.225], 80 | 'num_classes': 1000 81 | } 82 | 83 | 84 | timm_res2net_encoders = { 85 | 'timm-res2net50_26w_4s': { 86 | 'encoder': Res2NetEncoder, 87 | "pretrained_settings": pretrained_settings["timm-res2net50_26w_4s"], 88 | 'params': { 89 | 'out_channels': (3, 64, 256, 512, 1024, 2048), 90 | 'block': Bottle2neck, 91 | 'layers': [3, 4, 6, 3], 92 | 'base_width': 26, 93 | 'block_args': {'scale': 4} 94 | }, 95 | }, 96 | 'timm-res2net101_26w_4s': { 97 | 'encoder': Res2NetEncoder, 98 | "pretrained_settings": pretrained_settings["timm-res2net101_26w_4s"], 99 | 'params': { 100 | 'out_channels': (3, 64, 256, 512, 1024, 2048), 101 | 'block': Bottle2neck, 102 | 'layers': [3, 4, 23, 3], 103 | 'base_width': 26, 104 | 'block_args': {'scale': 4} 105 | }, 106 | }, 107 | 'timm-res2net50_26w_6s': { 108 | 'encoder': Res2NetEncoder, 109 | "pretrained_settings": pretrained_settings["timm-res2net50_26w_6s"], 110 | 'params': { 111 | 'out_channels': (3, 64, 256, 512, 1024, 2048), 112 | 'block': Bottle2neck, 113 | 'layers': [3, 4, 6, 3], 114 | 'base_width': 26, 115 | 'block_args': {'scale': 6} 116 | }, 117 | }, 118 | 'timm-res2net50_26w_8s': { 119 | 'encoder': Res2NetEncoder, 120 | "pretrained_settings": pretrained_settings["timm-res2net50_26w_8s"], 121 | 'params': { 122 | 'out_channels': (3, 64, 256, 512, 1024, 2048), 123 | 'block': Bottle2neck, 124 | 'layers': [3, 4, 6, 3], 125 | 'base_width': 26, 126 | 'block_args': {'scale': 8} 127 | }, 128 | }, 129 | 'timm-res2net50_48w_2s': { 130 | 'encoder': Res2NetEncoder, 131 | "pretrained_settings": pretrained_settings["timm-res2net50_48w_2s"], 132 | 'params': { 133 | 'out_channels': (3, 64, 256, 512, 1024, 2048), 134 | 'block': Bottle2neck, 135 | 'layers': [3, 4, 6, 3], 136 | 'base_width': 48, 137 | 'block_args': {'scale': 2} 138 | }, 139 | }, 140 | 'timm-res2net50_14w_8s': { 141 | 'encoder': Res2NetEncoder, 142 | "pretrained_settings": pretrained_settings["timm-res2net50_14w_8s"], 143 | 'params': { 144 | 'out_channels': (3, 64, 256, 512, 1024, 2048), 145 | 'block': Bottle2neck, 146 | 'layers': [3, 4, 6, 3], 147 | 'base_width': 14, 148 | 'block_args': {'scale': 8} 149 | }, 150 | }, 151 | 'timm-res2next50': { 152 | 'encoder': Res2NetEncoder, 153 | "pretrained_settings": pretrained_settings["timm-res2next50"], 154 | 'params': { 155 | 'out_channels': (3, 64, 256, 512, 1024, 2048), 156 | 'block': Bottle2neck, 157 | 'layers': [3, 4, 6, 3], 158 | 'base_width': 4, 159 | 'cardinality': 8, 160 | 'block_args': {'scale': 4} 161 | }, 162 | } 163 | } 164 | -------------------------------------------------------------------------------- /change_detection_pytorch/utils/functional.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def _take_channels(*xs, ignore_channels=None): 5 | if ignore_channels is None: 6 | return xs 7 | else: 8 | channels = [channel for channel in range(xs[0].shape[1]) if channel not in ignore_channels] 9 | xs = [torch.index_select(x, dim=1, index=torch.tensor(channels).to(x.device)) for x in xs] 10 | return xs 11 | 12 | 13 | def _threshold(x, threshold=None): 14 | if threshold is not None: 15 | return (x > threshold).type(x.dtype) 16 | else: 17 | return x 18 | 19 | 20 | def iou(pr, gt, eps=1e-7, threshold=None, ignore_channels=None): 21 | """Calculate Intersection over Union between ground truth and prediction 22 | Args: 23 | pr (torch.Tensor): predicted tensor 24 | gt (torch.Tensor): ground truth tensor 25 | eps (float): epsilon to avoid zero division 26 | threshold: threshold for outputs binarization 27 | Returns: 28 | float: IoU (Jaccard) score 29 | """ 30 | 31 | pr = _threshold(pr, threshold=threshold) 32 | pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) 33 | 34 | intersection = torch.sum(gt * pr) 35 | union = torch.sum(gt) + torch.sum(pr) - intersection + eps 36 | return (intersection + eps) / union 37 | 38 | 39 | jaccard = iou 40 | 41 | 42 | def f_score(pr, gt, beta=1, eps=1e-7, threshold=None, ignore_channels=None): 43 | """Calculate F-score between ground truth and prediction 44 | Args: 45 | pr (torch.Tensor): predicted tensor 46 | gt (torch.Tensor): ground truth tensor 47 | beta (float): positive constant 48 | eps (float): epsilon to avoid zero division 49 | threshold: threshold for outputs binarization 50 | Returns: 51 | float: F score 52 | """ 53 | 54 | pr = _threshold(pr, threshold=threshold) 55 | pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) 56 | 57 | tp = torch.sum(gt * pr) 58 | fp = torch.sum(pr) - tp 59 | fn = torch.sum(gt) - tp 60 | 61 | score = ((1 + beta ** 2) * tp + eps) \ 62 | / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + eps) 63 | 64 | return score 65 | 66 | 67 | def accuracy(pr, gt, threshold=0.5, ignore_channels=None): 68 | """Calculate accuracy score between ground truth and prediction 69 | Args: 70 | pr (torch.Tensor): predicted tensor 71 | gt (torch.Tensor): ground truth tensor 72 | eps (float): epsilon to avoid zero division 73 | threshold: threshold for outputs binarization 74 | Returns: 75 | float: precision score 76 | """ 77 | pr = _threshold(pr, threshold=threshold) 78 | pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) 79 | 80 | tp_tn = torch.sum(gt == pr, dtype=pr.dtype) 81 | 82 | score = tp_tn / gt.view(-1).shape[0] 83 | 84 | return score 85 | 86 | 87 | def precision(pr, gt, eps=1e-7, threshold=None, ignore_channels=None): 88 | """Calculate precision score between ground truth and prediction 89 | Args: 90 | pr (torch.Tensor): predicted tensor 91 | gt (torch.Tensor): ground truth tensor 92 | eps (float): epsilon to avoid zero division 93 | threshold: threshold for outputs binarization 94 | Returns: 95 | float: precision score 96 | """ 97 | 98 | pr = _threshold(pr, threshold=threshold) 99 | pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) 100 | 101 | tp = torch.sum(gt * pr) 102 | fp = torch.sum(pr) - tp 103 | 104 | score = (tp + eps) / (tp + fp + eps) 105 | 106 | return score 107 | 108 | 109 | def recall(pr, gt, eps=1e-7, threshold=None, ignore_channels=None): 110 | """Calculate Recall between ground truth and prediction 111 | Args: 112 | pr (torch.Tensor): A list of predicted elements 113 | gt (torch.Tensor): A list of elements that are to be predicted 114 | eps (float): epsilon to avoid zero division 115 | threshold: threshold for outputs binarization 116 | Returns: 117 | float: recall score 118 | """ 119 | 120 | pr = _threshold(pr, threshold=threshold) 121 | pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) 122 | 123 | tp = torch.sum(gt * pr) 124 | fn = torch.sum(gt) - tp 125 | 126 | score = (tp + eps) / (tp + fn + eps) 127 | 128 | return score 129 | 130 | 131 | def kappa(pr, gt, eps=1e-7, threshold=None, ignore_channels=None): 132 | """Calculate kappa score between ground truth and prediction 133 | Args: 134 | pr (torch.Tensor): A list of predicted elements 135 | gt (torch.Tensor): A list of elements that are to be predicted 136 | eps (float): epsilon to avoid zero division 137 | threshold: threshold for outputs binarization 138 | Returns: 139 | float: kappa score 140 | """ 141 | 142 | pr = _threshold(pr, threshold=threshold) 143 | pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) 144 | 145 | tp = torch.sum(gt * pr) 146 | fp = torch.sum(pr) - tp 147 | fn = torch.sum(gt) - tp 148 | tn = torch.sum((1 - gt)*(1 - pr)) 149 | 150 | N = tp + tn + fp + fn 151 | p0 = (tp + tn) / N 152 | pe = ((tp + fp) * (tp + fn) + (tn + fp) * (tn + fn)) / (N * N) 153 | 154 | score = (p0 - pe) / (1 - pe) 155 | 156 | return score 157 | 158 | 159 | def dice(pr, gt, eps=1e-7, threshold=None, ignore_channels=None): 160 | """Calculate dice score between ground truth and prediction 161 | Args: 162 | pr (torch.Tensor): A list of predicted elements 163 | gt (torch.Tensor): A list of elements that are to be predicted 164 | eps (float): epsilon to avoid zero division 165 | threshold: threshold for outputs binarization 166 | Returns: 167 | float: dice score 168 | """ 169 | pr = _threshold(pr, threshold=threshold) 170 | pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) 171 | 172 | tp = torch.sum(gt * pr) 173 | fp = torch.sum(pr) - tp 174 | fn = torch.sum(gt) - tp 175 | 176 | _precision = precision(pr, gt, eps=eps, threshold=threshold, ignore_channels=ignore_channels) 177 | _recall = recall(pr, gt, eps=eps, threshold=threshold, ignore_channels=ignore_channels) 178 | 179 | score = 2 * _precision * _recall / (_precision + _recall) 180 | 181 | return score -------------------------------------------------------------------------------- /change_detection_pytorch/encoders/senet.py: -------------------------------------------------------------------------------- 1 | """ Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` 2 | 3 | Attributes: 4 | 5 | _out_channels (list of int): specify number of channels for each encoder feature tensor 6 | _depth (int): specify number of stages in decoder (in other words number of downsampling operations) 7 | _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) 8 | 9 | Methods: 10 | 11 | forward(self, x: torch.Tensor) 12 | produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of 13 | shape NCHW (features should be sorted in descending order according to spatial resolution, starting 14 | with resolution same as input `x` tensor). 15 | 16 | Input: `x` with shape (1, 3, 64, 64) 17 | Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes 18 | [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), 19 | (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) 20 | 21 | also should support number of features according to specified depth, e.g. if depth = 5, 22 | number of feature tensors = 6 (one with same resolution as input and 5 downsampled), 23 | depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). 24 | """ 25 | 26 | import torch.nn as nn 27 | 28 | from pretrainedmodels.models.senet import ( 29 | SENet, 30 | SEBottleneck, 31 | SEResNetBottleneck, 32 | SEResNeXtBottleneck, 33 | pretrained_settings, 34 | ) 35 | from ._base import EncoderMixin 36 | 37 | 38 | class SENetEncoder(SENet, EncoderMixin): 39 | def __init__(self, out_channels, depth=5, **kwargs): 40 | super().__init__(**kwargs) 41 | 42 | self._out_channels = out_channels 43 | self._depth = depth 44 | self._in_channels = 3 45 | 46 | del self.last_linear 47 | del self.avg_pool 48 | 49 | def get_stages(self): 50 | return [ 51 | nn.Identity(), 52 | self.layer0[:-1], 53 | nn.Sequential(self.layer0[-1], self.layer1), 54 | self.layer2, 55 | self.layer3, 56 | self.layer4, 57 | ] 58 | 59 | def forward(self, x): 60 | stages = self.get_stages() 61 | 62 | features = [] 63 | for i in range(self._depth + 1): 64 | x = stages[i](x) 65 | features.append(x) 66 | 67 | return features 68 | 69 | def load_state_dict(self, state_dict, **kwargs): 70 | state_dict.pop("last_linear.bias", None) 71 | state_dict.pop("last_linear.weight", None) 72 | super().load_state_dict(state_dict, **kwargs) 73 | 74 | 75 | senet_encoders = { 76 | "senet154": { 77 | "encoder": SENetEncoder, 78 | "pretrained_settings": pretrained_settings["senet154"], 79 | "params": { 80 | "out_channels": (3, 128, 256, 512, 1024, 2048), 81 | "block": SEBottleneck, 82 | "dropout_p": 0.2, 83 | "groups": 64, 84 | "layers": [3, 8, 36, 3], 85 | "num_classes": 1000, 86 | "reduction": 16, 87 | }, 88 | }, 89 | "se_resnet50": { 90 | "encoder": SENetEncoder, 91 | "pretrained_settings": pretrained_settings["se_resnet50"], 92 | "params": { 93 | "out_channels": (3, 64, 256, 512, 1024, 2048), 94 | "block": SEResNetBottleneck, 95 | "layers": [3, 4, 6, 3], 96 | "downsample_kernel_size": 1, 97 | "downsample_padding": 0, 98 | "dropout_p": None, 99 | "groups": 1, 100 | "inplanes": 64, 101 | "input_3x3": False, 102 | "num_classes": 1000, 103 | "reduction": 16, 104 | }, 105 | }, 106 | "se_resnet101": { 107 | "encoder": SENetEncoder, 108 | "pretrained_settings": pretrained_settings["se_resnet101"], 109 | "params": { 110 | "out_channels": (3, 64, 256, 512, 1024, 2048), 111 | "block": SEResNetBottleneck, 112 | "layers": [3, 4, 23, 3], 113 | "downsample_kernel_size": 1, 114 | "downsample_padding": 0, 115 | "dropout_p": None, 116 | "groups": 1, 117 | "inplanes": 64, 118 | "input_3x3": False, 119 | "num_classes": 1000, 120 | "reduction": 16, 121 | }, 122 | }, 123 | "se_resnet152": { 124 | "encoder": SENetEncoder, 125 | "pretrained_settings": pretrained_settings["se_resnet152"], 126 | "params": { 127 | "out_channels": (3, 64, 256, 512, 1024, 2048), 128 | "block": SEResNetBottleneck, 129 | "layers": [3, 8, 36, 3], 130 | "downsample_kernel_size": 1, 131 | "downsample_padding": 0, 132 | "dropout_p": None, 133 | "groups": 1, 134 | "inplanes": 64, 135 | "input_3x3": False, 136 | "num_classes": 1000, 137 | "reduction": 16, 138 | }, 139 | }, 140 | "se_resnext50_32x4d": { 141 | "encoder": SENetEncoder, 142 | "pretrained_settings": pretrained_settings["se_resnext50_32x4d"], 143 | "params": { 144 | "out_channels": (3, 64, 256, 512, 1024, 2048), 145 | "block": SEResNeXtBottleneck, 146 | "layers": [3, 4, 6, 3], 147 | "downsample_kernel_size": 1, 148 | "downsample_padding": 0, 149 | "dropout_p": None, 150 | "groups": 32, 151 | "inplanes": 64, 152 | "input_3x3": False, 153 | "num_classes": 1000, 154 | "reduction": 16, 155 | }, 156 | }, 157 | "se_resnext101_32x4d": { 158 | "encoder": SENetEncoder, 159 | "pretrained_settings": pretrained_settings["se_resnext101_32x4d"], 160 | "params": { 161 | "out_channels": (3, 64, 256, 512, 1024, 2048), 162 | "block": SEResNeXtBottleneck, 163 | "layers": [3, 4, 23, 3], 164 | "downsample_kernel_size": 1, 165 | "downsample_padding": 0, 166 | "dropout_p": None, 167 | "groups": 32, 168 | "inplanes": 64, 169 | "input_3x3": False, 170 | "num_classes": 1000, 171 | "reduction": 16, 172 | }, 173 | }, 174 | } 175 | -------------------------------------------------------------------------------- /change_detection_pytorch/unetplusplus/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ..base import modules as md 6 | from ..base import Decoder 7 | 8 | 9 | class DecoderBlock(nn.Module): 10 | def __init__( 11 | self, 12 | in_channels, 13 | skip_channels, 14 | out_channels, 15 | use_batchnorm=True, 16 | attention_type=None, 17 | ): 18 | super().__init__() 19 | self.conv1 = md.Conv2dReLU( 20 | in_channels + skip_channels, 21 | out_channels, 22 | kernel_size=3, 23 | padding=1, 24 | use_batchnorm=use_batchnorm, 25 | ) 26 | self.attention1 = md.Attention(attention_type, in_channels=in_channels + skip_channels) 27 | self.conv2 = md.Conv2dReLU( 28 | out_channels, 29 | out_channels, 30 | kernel_size=3, 31 | padding=1, 32 | use_batchnorm=use_batchnorm, 33 | ) 34 | self.attention2 = md.Attention(attention_type, in_channels=out_channels) 35 | 36 | def forward(self, x, skip=None): 37 | x = F.interpolate(x, scale_factor=2, mode="nearest") 38 | if skip is not None: 39 | x = torch.cat([x, skip], dim=1) 40 | x = self.attention1(x) 41 | x = self.conv1(x) 42 | x = self.conv2(x) 43 | x = self.attention2(x) 44 | return x 45 | 46 | 47 | class CenterBlock(nn.Sequential): 48 | def __init__(self, in_channels, out_channels, use_batchnorm=True): 49 | conv1 = md.Conv2dReLU( 50 | in_channels, 51 | out_channels, 52 | kernel_size=3, 53 | padding=1, 54 | use_batchnorm=use_batchnorm, 55 | ) 56 | conv2 = md.Conv2dReLU( 57 | out_channels, 58 | out_channels, 59 | kernel_size=3, 60 | padding=1, 61 | use_batchnorm=use_batchnorm, 62 | ) 63 | super().__init__(conv1, conv2) 64 | 65 | 66 | class UnetPlusPlusDecoder(Decoder): 67 | def __init__( 68 | self, 69 | encoder_channels, 70 | decoder_channels, 71 | n_blocks=5, 72 | use_batchnorm=True, 73 | attention_type=None, 74 | center=False, 75 | fusion_form="concat", 76 | seg_ensemble=None, 77 | ): 78 | super().__init__() 79 | 80 | if n_blocks != len(decoder_channels): 81 | raise ValueError( 82 | "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format( 83 | n_blocks, len(decoder_channels) 84 | ) 85 | ) 86 | 87 | encoder_channels = encoder_channels[1:] # remove first skip with same spatial resolution 88 | encoder_channels = encoder_channels[::-1] # reverse channels to start from head of encoder 89 | # computing blocks input and output channels 90 | head_channels = encoder_channels[0] 91 | self.in_channels = [head_channels] + list(decoder_channels[:-1]) 92 | self.skip_channels = list(encoder_channels[1:]) + [0] 93 | self.out_channels = decoder_channels 94 | 95 | # adjust encoder channels according to fusion form 96 | self.fusion_form = fusion_form 97 | if self.fusion_form in self.FUSION_DIC["2to2_fusion"]: 98 | self.skip_channels = [ch*2 for ch in self.skip_channels] 99 | self.in_channels[0] = self.in_channels[0] * 2 100 | 101 | if center: 102 | self.center = CenterBlock( 103 | head_channels, head_channels, use_batchnorm=use_batchnorm 104 | ) 105 | else: 106 | self.center = nn.Identity() 107 | 108 | # combine decoder keyword arguments 109 | kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type) 110 | 111 | blocks = {} 112 | for layer_idx in range(len(self.in_channels) - 1): 113 | for depth_idx in range(layer_idx+1): 114 | out_ch = self.out_channels[layer_idx] 115 | skip_ch = self.out_channels[layer_idx] * (layer_idx-depth_idx) + self.skip_channels[layer_idx] 116 | in_ch = self.skip_channels[layer_idx - 1] if depth_idx == layer_idx and depth_idx != 0 else self.in_channels[layer_idx] 117 | 118 | blocks[f'x_{depth_idx}_{layer_idx}'] = DecoderBlock(in_ch, skip_ch, out_ch, **kwargs) 119 | blocks[f'x_{0}_{len(self.in_channels)-1}'] =\ 120 | DecoderBlock(self.in_channels[-1], 0, self.out_channels[-1], **kwargs) 121 | self.blocks = nn.ModuleDict(blocks) 122 | self.depth = len(self.in_channels) - 1 123 | self.seg_ensemble = seg_ensemble 124 | self.ECAM = md.ECAM(in_channels=self.out_channels[-2], out_channels=self.out_channels[-1]) if seg_ensemble == "ecam" else None 125 | 126 | def forward(self, *features): 127 | 128 | features = self.aggregation_layer(features[0], features[1], 129 | self.fusion_form, ignore_original_img=True) 130 | # features = features[1:] # remove first skip with same spatial resolution 131 | features = features[::-1] # reverse channels to start from head of encoder 132 | # start building dense connections 133 | dense_x = {} 134 | for layer_idx in range(len(self.in_channels) - 1): 135 | for depth_idx in range(self.depth-layer_idx): 136 | if layer_idx == 0: 137 | output = self.blocks[f'x_{depth_idx}_{depth_idx}'](features[depth_idx], features[depth_idx+1]) 138 | dense_x[f'x_{depth_idx}_{depth_idx}'] = output 139 | else: 140 | dense_l_i = depth_idx + layer_idx 141 | cat_features = [dense_x[f'x_{idx}_{dense_l_i}'] for idx in range(depth_idx+1, dense_l_i+1)] 142 | cat_features = torch.cat(cat_features + [features[dense_l_i+1]], dim=1) 143 | dense_x[f'x_{depth_idx}_{dense_l_i}'] =\ 144 | self.blocks[f'x_{depth_idx}_{dense_l_i}'](dense_x[f'x_{depth_idx}_{dense_l_i-1}'], cat_features) 145 | 146 | if self.seg_ensemble == "ecam": 147 | return self.ECAM([dense_x[f'x_{i}_{self.depth - 1}'] for i in range(self.depth)]) 148 | else: 149 | dense_x[f'x_{0}_{self.depth}'] = self.blocks[f'x_{0}_{self.depth}'](dense_x[f'x_{0}_{self.depth - 1}']) 150 | return dense_x[f'x_{0}_{self.depth}'] 151 | -------------------------------------------------------------------------------- /change_detection_pytorch/encoders/dpn.py: -------------------------------------------------------------------------------- 1 | """ Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` 2 | 3 | Attributes: 4 | 5 | _out_channels (list of int): specify number of channels for each encoder feature tensor 6 | _depth (int): specify number of stages in decoder (in other words number of downsampling operations) 7 | _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) 8 | 9 | Methods: 10 | 11 | forward(self, x: torch.Tensor) 12 | produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of 13 | shape NCHW (features should be sorted in descending order according to spatial resolution, starting 14 | with resolution same as input `x` tensor). 15 | 16 | Input: `x` with shape (1, 3, 64, 64) 17 | Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes 18 | [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), 19 | (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) 20 | 21 | also should support number of features according to specified depth, e.g. if depth = 5, 22 | number of feature tensors = 6 (one with same resolution as input and 5 downsampled), 23 | depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). 24 | """ 25 | 26 | import torch 27 | import torch.nn as nn 28 | import torch.nn.functional as F 29 | 30 | from pretrainedmodels.models.dpn import DPN 31 | from pretrainedmodels.models.dpn import pretrained_settings 32 | 33 | from ._base import EncoderMixin 34 | 35 | 36 | class DPNEncoder(DPN, EncoderMixin): 37 | def __init__(self, stage_idxs, out_channels, depth=5, **kwargs): 38 | super().__init__(**kwargs) 39 | self._stage_idxs = stage_idxs 40 | self._depth = depth 41 | self._out_channels = out_channels 42 | self._in_channels = 3 43 | 44 | del self.last_linear 45 | 46 | def get_stages(self): 47 | return [ 48 | nn.Identity(), 49 | nn.Sequential(self.features[0].conv, self.features[0].bn, self.features[0].act), 50 | nn.Sequential(self.features[0].pool, self.features[1 : self._stage_idxs[0]]), 51 | self.features[self._stage_idxs[0] : self._stage_idxs[1]], 52 | self.features[self._stage_idxs[1] : self._stage_idxs[2]], 53 | self.features[self._stage_idxs[2] : self._stage_idxs[3]], 54 | ] 55 | 56 | def forward(self, x): 57 | 58 | stages = self.get_stages() 59 | 60 | features = [] 61 | for i in range(self._depth + 1): 62 | x = stages[i](x) 63 | if isinstance(x, (list, tuple)): 64 | features.append(F.relu(torch.cat(x, dim=1), inplace=True)) 65 | else: 66 | features.append(x) 67 | 68 | return features 69 | 70 | def load_state_dict(self, state_dict, **kwargs): 71 | state_dict.pop("last_linear.bias", None) 72 | state_dict.pop("last_linear.weight", None) 73 | super().load_state_dict(state_dict, **kwargs) 74 | 75 | 76 | dpn_encoders = { 77 | "dpn68": { 78 | "encoder": DPNEncoder, 79 | "pretrained_settings": pretrained_settings["dpn68"], 80 | "params": { 81 | "stage_idxs": (4, 8, 20, 24), 82 | "out_channels": (3, 10, 144, 320, 704, 832), 83 | "groups": 32, 84 | "inc_sec": (16, 32, 32, 64), 85 | "k_r": 128, 86 | "k_sec": (3, 4, 12, 3), 87 | "num_classes": 1000, 88 | "num_init_features": 10, 89 | "small": True, 90 | "test_time_pool": True, 91 | }, 92 | }, 93 | "dpn68b": { 94 | "encoder": DPNEncoder, 95 | "pretrained_settings": pretrained_settings["dpn68b"], 96 | "params": { 97 | "stage_idxs": (4, 8, 20, 24), 98 | "out_channels": (3, 10, 144, 320, 704, 832), 99 | "b": True, 100 | "groups": 32, 101 | "inc_sec": (16, 32, 32, 64), 102 | "k_r": 128, 103 | "k_sec": (3, 4, 12, 3), 104 | "num_classes": 1000, 105 | "num_init_features": 10, 106 | "small": True, 107 | "test_time_pool": True, 108 | }, 109 | }, 110 | "dpn92": { 111 | "encoder": DPNEncoder, 112 | "pretrained_settings": pretrained_settings["dpn92"], 113 | "params": { 114 | "stage_idxs": (4, 8, 28, 32), 115 | "out_channels": (3, 64, 336, 704, 1552, 2688), 116 | "groups": 32, 117 | "inc_sec": (16, 32, 24, 128), 118 | "k_r": 96, 119 | "k_sec": (3, 4, 20, 3), 120 | "num_classes": 1000, 121 | "num_init_features": 64, 122 | "test_time_pool": True, 123 | }, 124 | }, 125 | "dpn98": { 126 | "encoder": DPNEncoder, 127 | "pretrained_settings": pretrained_settings["dpn98"], 128 | "params": { 129 | "stage_idxs": (4, 10, 30, 34), 130 | "out_channels": (3, 96, 336, 768, 1728, 2688), 131 | "groups": 40, 132 | "inc_sec": (16, 32, 32, 128), 133 | "k_r": 160, 134 | "k_sec": (3, 6, 20, 3), 135 | "num_classes": 1000, 136 | "num_init_features": 96, 137 | "test_time_pool": True, 138 | }, 139 | }, 140 | "dpn107": { 141 | "encoder": DPNEncoder, 142 | "pretrained_settings": pretrained_settings["dpn107"], 143 | "params": { 144 | "stage_idxs": (5, 13, 33, 37), 145 | "out_channels": (3, 128, 376, 1152, 2432, 2688), 146 | "groups": 50, 147 | "inc_sec": (20, 64, 64, 128), 148 | "k_r": 200, 149 | "k_sec": (4, 8, 20, 3), 150 | "num_classes": 1000, 151 | "num_init_features": 128, 152 | "test_time_pool": True, 153 | }, 154 | }, 155 | "dpn131": { 156 | "encoder": DPNEncoder, 157 | "pretrained_settings": pretrained_settings["dpn131"], 158 | "params": { 159 | "stage_idxs": (5, 13, 41, 45), 160 | "out_channels": (3, 128, 352, 832, 1984, 2688), 161 | "groups": 40, 162 | "inc_sec": (16, 32, 32, 128), 163 | "k_r": 160, 164 | "k_sec": (4, 8, 28, 3), 165 | "num_classes": 1000, 166 | "num_init_features": 128, 167 | "test_time_pool": True, 168 | }, 169 | }, 170 | } 171 | --------------------------------------------------------------------------------