├── LICENSE ├── README.md ├── requirements.txt ├── segmentation_models_pytorch_3d ├── __init__.py ├── __version__.py ├── base │ ├── __init__.py │ ├── heads.py │ ├── initialization.py │ ├── model.py │ └── modules.py ├── datasets │ ├── __init__.py │ └── oxford_pet.py ├── decoders │ ├── __init__.py │ ├── deeplabv3 │ │ ├── __init__.py │ │ ├── decoder.py │ │ └── model.py │ ├── fpn │ │ ├── __init__.py │ │ ├── decoder.py │ │ └── model.py │ ├── linknet │ │ ├── __init__.py │ │ ├── decoder.py │ │ └── model.py │ ├── manet │ │ ├── __init__.py │ │ ├── decoder.py │ │ └── model.py │ ├── pan │ │ ├── __init__.py │ │ ├── decoder.py │ │ └── model.py │ ├── pspnet │ │ ├── __init__.py │ │ ├── decoder.py │ │ └── model.py │ ├── unet │ │ ├── __init__.py │ │ ├── decoder.py │ │ └── model.py │ └── unetplusplus │ │ ├── __init__.py │ │ ├── decoder.py │ │ └── model.py ├── encoders │ ├── __init__.py │ ├── _base.py │ ├── _preprocessing.py │ ├── _utils.py │ ├── densenet.py │ ├── dpn.py │ ├── efficientnet.py │ ├── inceptionresnetv2.py │ ├── inceptionv4.py │ ├── mix_transformer.py │ ├── mobilenet.py │ ├── mobileone.py │ ├── resnet.py │ ├── senet.py │ ├── timm_efficientnet.py │ ├── timm_gernet.py │ ├── timm_mobilenetv3.py │ ├── timm_regnet.py │ ├── timm_res2net.py │ ├── timm_resnest.py │ ├── timm_sknet.py │ ├── timm_universal.py │ ├── vgg.py │ └── xception.py ├── losses │ ├── __init__.py │ ├── _functional.py │ ├── constants.py │ ├── dice.py │ ├── focal.py │ ├── jaccard.py │ ├── lovasz.py │ ├── mcc.py │ ├── soft_bce.py │ ├── soft_ce.py │ └── tversky.py ├── metrics │ ├── __init__.py │ └── functional.py └── utils │ ├── __init__.py │ ├── base.py │ ├── convert_weights.py │ ├── functional.py │ ├── losses.py │ ├── meter.py │ ├── metrics.py │ └── train.py ├── setup.py └── test.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Roman Solovyev (ZFTurbo) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision>=0.5.0 3 | pretrainedmodels==0.7.4 4 | efficientnet-pytorch==0.7.1 5 | timm==0.9.7 6 | timm-3d==1.0.1 7 | tqdm 8 | pillow 9 | six 10 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/__init__.py: -------------------------------------------------------------------------------- 1 | from . import datasets 2 | from . import encoders 3 | from . import decoders 4 | from . import losses 5 | from . import metrics 6 | 7 | from .decoders.unet import Unet 8 | from .decoders.unetplusplus import UnetPlusPlus 9 | from .decoders.manet import MAnet 10 | from .decoders.linknet import Linknet 11 | from .decoders.fpn import FPN 12 | from .decoders.pspnet import PSPNet 13 | from .decoders.deeplabv3 import DeepLabV3, DeepLabV3Plus 14 | from .decoders.pan import PAN 15 | 16 | from .__version__ import __version__ 17 | 18 | # some private imports for create_model function 19 | from typing import Optional as _Optional 20 | import torch as _torch 21 | 22 | 23 | def create_model( 24 | arch: str, 25 | encoder_name: str = "resnet34", 26 | encoder_weights: _Optional[str] = "imagenet", 27 | in_channels: int = 3, 28 | classes: int = 1, 29 | **kwargs, 30 | ) -> _torch.nn.Module: 31 | """Models entrypoint, allows to create any model architecture just with 32 | parameters, without using its class 33 | """ 34 | 35 | archs = [ 36 | Unet, 37 | UnetPlusPlus, 38 | MAnet, 39 | Linknet, 40 | FPN, 41 | PSPNet, 42 | DeepLabV3, 43 | DeepLabV3Plus, 44 | PAN, 45 | ] 46 | archs_dict = {a.__name__.lower(): a for a in archs} 47 | try: 48 | model_class = archs_dict[arch.lower()] 49 | except KeyError: 50 | raise KeyError( 51 | "Wrong architecture type `{}`. Available options are: {}".format( 52 | arch, 53 | list(archs_dict.keys()), 54 | ) 55 | ) 56 | return model_class( 57 | encoder_name=encoder_name, 58 | encoder_weights=encoder_weights, 59 | in_channels=in_channels, 60 | classes=classes, 61 | **kwargs, 62 | ) 63 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/__version__.py: -------------------------------------------------------------------------------- 1 | VERSION = (0, 3, 3) 2 | 3 | __version__ = ".".join(map(str, VERSION)) 4 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import SegmentationModel 2 | 3 | from .modules import ( 4 | Conv3dReLU, 5 | Attention, 6 | ) 7 | 8 | from .heads import ( 9 | SegmentationHead, 10 | ClassificationHead, 11 | ) 12 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/base/heads.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .modules import Activation 3 | 4 | 5 | class SegmentationHead(nn.Sequential): 6 | def __init__(self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1): 7 | conv3d = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2) 8 | upsampling = nn.Upsample(scale_factor=upsampling, mode='trilinear') if upsampling > 1 else nn.Identity() 9 | activation = Activation(activation) 10 | super().__init__(conv3d, upsampling, activation) 11 | 12 | 13 | class ClassificationHead(nn.Sequential): 14 | def __init__(self, in_channels, classes, pooling="avg", dropout=0.2, activation=None): 15 | if pooling not in ("max", "avg"): 16 | raise ValueError("Pooling should be one of ('max', 'avg'), got {}.".format(pooling)) 17 | pool = nn.AdaptiveAvgPool3d(1) if pooling == "avg" else nn.AdaptiveMaxPool3d(1) 18 | flatten = nn.Flatten() 19 | dropout = nn.Dropout(p=dropout, inplace=True) if dropout else nn.Identity() 20 | linear = nn.Linear(in_channels, classes, bias=True) 21 | activation = Activation(activation) 22 | super().__init__(pool, flatten, dropout, linear, activation) 23 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/base/initialization.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def initialize_decoder(module): 5 | for m in module.modules(): 6 | 7 | if isinstance(m, nn.Conv2d): 8 | nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu") 9 | if m.bias is not None: 10 | nn.init.constant_(m.bias, 0) 11 | 12 | elif isinstance(m, nn.BatchNorm2d): 13 | nn.init.constant_(m.weight, 1) 14 | nn.init.constant_(m.bias, 0) 15 | 16 | elif isinstance(m, nn.Linear): 17 | nn.init.xavier_uniform_(m.weight) 18 | if m.bias is not None: 19 | nn.init.constant_(m.bias, 0) 20 | 21 | 22 | def initialize_head(module): 23 | for m in module.modules(): 24 | if isinstance(m, (nn.Linear, nn.Conv2d)): 25 | nn.init.xavier_uniform_(m.weight) 26 | if m.bias is not None: 27 | nn.init.constant_(m.bias, 0) 28 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/base/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from . import initialization as init 3 | 4 | 5 | class SegmentationModel(torch.nn.Module): 6 | def initialize(self): 7 | init.initialize_decoder(self.decoder) 8 | init.initialize_head(self.segmentation_head) 9 | if self.classification_head is not None: 10 | init.initialize_head(self.classification_head) 11 | 12 | def check_input_shape(self, x): 13 | 14 | h, w, d = x.shape[-3:] 15 | try: 16 | if self.encoder.strides is not None: 17 | hs, ws, ds = 1, 1, 1 18 | for stride in self.encoder.strides: 19 | hs *= stride[0] 20 | ws *= stride[1] 21 | ds *= stride[2] 22 | if h % hs != 0 or w % ws != 0 or d % ds != 0: 23 | new_h = (h // hs + 1) * hs if h % hs != 0 else h 24 | new_w = (w // ws + 1) * ws if w % ws != 0 else w 25 | new_d = (d // ds + 1) * ds if d % ds != 0 else d 26 | raise RuntimeError( 27 | f"Wrong input shape height={h}, width={w}, depth={d}. Expected image height and width and depth " 28 | f"divisible by {hs}, {ws}, {ds}. Consider pad your images to shape ({new_h}, {new_w}, {new_d})." 29 | ) 30 | else: 31 | output_stride = self.encoder.output_stride 32 | if h % output_stride != 0 or w % output_stride != 0 or d % output_stride != 0: 33 | new_h = (h // output_stride + 1) * output_stride if h % output_stride != 0 else h 34 | new_w = (w // output_stride + 1) * output_stride if w % output_stride != 0 else w 35 | new_d = (d // output_stride + 1) * output_stride if d % output_stride != 0 else d 36 | raise RuntimeError( 37 | f"Wrong input shape height={h}, width={w}, depth={d}. Expected image height and width and depth " 38 | f"divisible by {output_stride}. Consider pad your images to shape ({new_h}, {new_w}, {new_d})." 39 | ) 40 | except: 41 | pass 42 | 43 | def forward(self, x): 44 | """Sequentially pass `x` trough model`s encoder, decoder and heads""" 45 | 46 | self.check_input_shape(x) 47 | 48 | features = self.encoder(x) 49 | decoder_output = self.decoder(*features) 50 | 51 | masks = self.segmentation_head(decoder_output) 52 | 53 | if self.classification_head is not None: 54 | labels = self.classification_head(features[-1]) 55 | return masks, labels 56 | 57 | return masks 58 | 59 | @torch.no_grad() 60 | def predict(self, x): 61 | """Inference method. Switch model to `eval` mode, call `.forward(x)` with `torch.no_grad()` 62 | 63 | Args: 64 | x: 4D torch tensor with shape (batch_size, channels, height, width) 65 | 66 | Return: 67 | prediction: 4D torch tensor with shape (batch_size, classes, height, width) 68 | 69 | """ 70 | if self.training: 71 | self.eval() 72 | 73 | x = self.forward(x) 74 | 75 | return x 76 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/base/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | try: 5 | from inplace_abn import InPlaceABN 6 | except ImportError: 7 | InPlaceABN = None 8 | 9 | 10 | class Conv3dReLU(nn.Sequential): 11 | def __init__( 12 | self, 13 | in_channels, 14 | out_channels, 15 | kernel_size, 16 | padding=0, 17 | stride=1, 18 | use_batchnorm=True, 19 | ): 20 | 21 | if use_batchnorm == "inplace" and InPlaceABN is None: 22 | raise RuntimeError( 23 | "In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. " 24 | + "To install see: https://github.com/mapillary/inplace_abn" 25 | ) 26 | 27 | conv = nn.Conv3d( 28 | in_channels, 29 | out_channels, 30 | kernel_size, 31 | stride=stride, 32 | padding=padding, 33 | bias=not (use_batchnorm), 34 | ) 35 | relu = nn.ReLU(inplace=True) 36 | 37 | if use_batchnorm == "inplace": 38 | bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0) 39 | relu = nn.Identity() 40 | 41 | elif use_batchnorm and use_batchnorm != "inplace": 42 | bn = nn.BatchNorm3d(out_channels) 43 | 44 | else: 45 | bn = nn.Identity() 46 | 47 | super(Conv3dReLU, self).__init__(conv, bn, relu) 48 | 49 | 50 | class SCSEModule(nn.Module): 51 | def __init__(self, in_channels, reduction=16): 52 | super().__init__() 53 | self.cSE = nn.Sequential( 54 | nn.AdaptiveAvgPool3d(1), 55 | nn.Conv3d(in_channels, in_channels // reduction, 1), 56 | nn.ReLU(inplace=True), 57 | nn.Conv3d(in_channels // reduction, in_channels, 1), 58 | nn.Sigmoid(), 59 | ) 60 | self.sSE = nn.Sequential(nn.Conv3d(in_channels, 1, 1), nn.Sigmoid()) 61 | 62 | def forward(self, x): 63 | return x * self.cSE(x) + x * self.sSE(x) 64 | 65 | 66 | class ArgMax(nn.Module): 67 | def __init__(self, dim=None): 68 | super().__init__() 69 | self.dim = dim 70 | 71 | def forward(self, x): 72 | return torch.argmax(x, dim=self.dim) 73 | 74 | 75 | class Clamp(nn.Module): 76 | def __init__(self, min=0, max=1): 77 | super().__init__() 78 | self.min, self.max = min, max 79 | 80 | def forward(self, x): 81 | return torch.clamp(x, self.min, self.max) 82 | 83 | 84 | class Activation(nn.Module): 85 | def __init__(self, name, **params): 86 | 87 | super().__init__() 88 | 89 | if name is None or name == "identity": 90 | self.activation = nn.Identity(**params) 91 | elif name == "sigmoid": 92 | self.activation = nn.Sigmoid() 93 | elif name == "softmax2d": 94 | self.activation = nn.Softmax(dim=1, **params) 95 | elif name == "softmax": 96 | self.activation = nn.Softmax(**params) 97 | elif name == "logsoftmax": 98 | self.activation = nn.LogSoftmax(**params) 99 | elif name == "tanh": 100 | self.activation = nn.Tanh() 101 | elif name == "argmax": 102 | self.activation = ArgMax(**params) 103 | elif name == "argmax2d": 104 | self.activation = ArgMax(dim=1, **params) 105 | elif name == "clamp": 106 | self.activation = Clamp(**params) 107 | elif callable(name): 108 | self.activation = name(**params) 109 | else: 110 | raise ValueError( 111 | f"Activation should be callable/sigmoid/softmax/logsoftmax/tanh/" 112 | f"argmax/argmax2d/clamp/None; got {name}" 113 | ) 114 | 115 | def forward(self, x): 116 | return self.activation(x) 117 | 118 | 119 | class Attention(nn.Module): 120 | def __init__(self, name, **params): 121 | super().__init__() 122 | 123 | if name is None: 124 | self.attention = nn.Identity(**params) 125 | elif name == "scse": 126 | self.attention = SCSEModule(**params) 127 | else: 128 | raise ValueError("Attention {} is not implemented".format(name)) 129 | 130 | def forward(self, x): 131 | return self.attention(x) 132 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .oxford_pet import OxfordPetDataset, SimpleOxfordPetDataset 2 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/datasets/oxford_pet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import shutil 4 | import numpy as np 5 | 6 | from PIL import Image 7 | from tqdm import tqdm 8 | from urllib.request import urlretrieve 9 | 10 | 11 | class OxfordPetDataset(torch.utils.data.Dataset): 12 | def __init__(self, root, mode="train", transform=None): 13 | 14 | assert mode in {"train", "valid", "test"} 15 | 16 | self.root = root 17 | self.mode = mode 18 | self.transform = transform 19 | 20 | self.images_directory = os.path.join(self.root, "images") 21 | self.masks_directory = os.path.join(self.root, "annotations", "trimaps") 22 | 23 | self.filenames = self._read_split() # read train/valid/test splits 24 | 25 | def __len__(self): 26 | return len(self.filenames) 27 | 28 | def __getitem__(self, idx): 29 | 30 | filename = self.filenames[idx] 31 | image_path = os.path.join(self.images_directory, filename + ".jpg") 32 | mask_path = os.path.join(self.masks_directory, filename + ".png") 33 | 34 | image = np.array(Image.open(image_path).convert("RGB")) 35 | 36 | trimap = np.array(Image.open(mask_path)) 37 | mask = self._preprocess_mask(trimap) 38 | 39 | sample = dict(image=image, mask=mask, trimap=trimap) 40 | if self.transform is not None: 41 | sample = self.transform(**sample) 42 | 43 | return sample 44 | 45 | @staticmethod 46 | def _preprocess_mask(mask): 47 | mask = mask.astype(np.float32) 48 | mask[mask == 2.0] = 0.0 49 | mask[(mask == 1.0) | (mask == 3.0)] = 1.0 50 | return mask 51 | 52 | def _read_split(self): 53 | split_filename = "test.txt" if self.mode == "test" else "trainval.txt" 54 | split_filepath = os.path.join(self.root, "annotations", split_filename) 55 | with open(split_filepath) as f: 56 | split_data = f.read().strip("\n").split("\n") 57 | filenames = [x.split(" ")[0] for x in split_data] 58 | if self.mode == "train": # 90% for train 59 | filenames = [x for i, x in enumerate(filenames) if i % 10 != 0] 60 | elif self.mode == "valid": # 10% for validation 61 | filenames = [x for i, x in enumerate(filenames) if i % 10 == 0] 62 | return filenames 63 | 64 | @staticmethod 65 | def download(root): 66 | 67 | # load images 68 | filepath = os.path.join(root, "images.tar.gz") 69 | download_url( 70 | url="https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz", 71 | filepath=filepath, 72 | ) 73 | extract_archive(filepath) 74 | 75 | # load annotations 76 | filepath = os.path.join(root, "annotations.tar.gz") 77 | download_url( 78 | url="https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz", 79 | filepath=filepath, 80 | ) 81 | extract_archive(filepath) 82 | 83 | 84 | class SimpleOxfordPetDataset(OxfordPetDataset): 85 | def __getitem__(self, *args, **kwargs): 86 | 87 | sample = super().__getitem__(*args, **kwargs) 88 | 89 | # resize images 90 | image = np.array(Image.fromarray(sample["image"]).resize((256, 256), Image.LINEAR)) 91 | mask = np.array(Image.fromarray(sample["mask"]).resize((256, 256), Image.NEAREST)) 92 | trimap = np.array(Image.fromarray(sample["trimap"]).resize((256, 256), Image.NEAREST)) 93 | 94 | # convert to other format HWC -> CHW 95 | sample["image"] = np.moveaxis(image, -1, 0) 96 | sample["mask"] = np.expand_dims(mask, 0) 97 | sample["trimap"] = np.expand_dims(trimap, 0) 98 | 99 | return sample 100 | 101 | 102 | class TqdmUpTo(tqdm): 103 | def update_to(self, b=1, bsize=1, tsize=None): 104 | if tsize is not None: 105 | self.total = tsize 106 | self.update(b * bsize - self.n) 107 | 108 | 109 | def download_url(url, filepath): 110 | directory = os.path.dirname(os.path.abspath(filepath)) 111 | os.makedirs(directory, exist_ok=True) 112 | if os.path.exists(filepath): 113 | return 114 | 115 | with TqdmUpTo( 116 | unit="B", 117 | unit_scale=True, 118 | unit_divisor=1024, 119 | miniters=1, 120 | desc=os.path.basename(filepath), 121 | ) as t: 122 | urlretrieve(url, filename=filepath, reporthook=t.update_to, data=None) 123 | t.total = t.n 124 | 125 | 126 | def extract_archive(filepath): 127 | extract_dir = os.path.dirname(os.path.abspath(filepath)) 128 | dst_dir = os.path.splitext(filepath)[0] 129 | if not os.path.exists(dst_dir): 130 | shutil.unpack_archive(filepath, extract_dir) 131 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/decoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZFTurbo/segmentation_models_pytorch_3d/fdb392238636056144b776c52ced3459e08ab1cf/segmentation_models_pytorch_3d/decoders/__init__.py -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/decoders/deeplabv3/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import DeepLabV3, DeepLabV3Plus 2 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/decoders/deeplabv3/decoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | BSD 3-Clause License 3 | 4 | Copyright (c) Soumith Chintala 2016, 5 | All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived from 19 | this software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 25 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 26 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 27 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 28 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 29 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | """ 32 | 33 | import torch 34 | from torch import nn 35 | from torch.nn import functional as F 36 | 37 | __all__ = ["DeepLabV3Decoder"] 38 | 39 | 40 | class DeepLabV3Decoder(nn.Sequential): 41 | def __init__(self, in_channels, out_channels=256, atrous_rates=(12, 24, 36)): 42 | super().__init__( 43 | ASPP(in_channels, out_channels, atrous_rates), 44 | nn.Conv3d(out_channels, out_channels, 3, padding=1, bias=False), 45 | nn.BatchNorm3d(out_channels), 46 | nn.ReLU(), 47 | ) 48 | self.out_channels = out_channels 49 | 50 | def forward(self, *features): 51 | return super().forward(features[-1]) 52 | 53 | 54 | class DeepLabV3PlusDecoder(nn.Module): 55 | def __init__( 56 | self, 57 | encoder_channels, 58 | out_channels=256, 59 | atrous_rates=(12, 24, 36), 60 | output_stride=16, 61 | ): 62 | super().__init__() 63 | if output_stride not in {8, 16}: 64 | raise ValueError("Output stride should be 8 or 16, got {}.".format(output_stride)) 65 | 66 | self.out_channels = out_channels 67 | self.output_stride = output_stride 68 | 69 | self.aspp = nn.Sequential( 70 | ASPP(encoder_channels[-1], out_channels, atrous_rates, separable=True), 71 | SeparableConv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), 72 | nn.BatchNorm3d(out_channels), 73 | nn.ReLU(), 74 | ) 75 | 76 | scale_factor = 2 if output_stride == 8 else 4 77 | self.up = nn.Upsample(scale_factor=scale_factor, mode="trilinear") 78 | 79 | highres_in_channels = encoder_channels[-4] 80 | highres_out_channels = 48 # proposed by authors of paper 81 | self.block1 = nn.Sequential( 82 | nn.Conv3d(highres_in_channels, highres_out_channels, kernel_size=1, bias=False), 83 | nn.BatchNorm3d(highres_out_channels), 84 | nn.ReLU(), 85 | ) 86 | self.block2 = nn.Sequential( 87 | SeparableConv3d( 88 | highres_out_channels + out_channels, 89 | out_channels, 90 | kernel_size=3, 91 | padding=1, 92 | bias=False, 93 | ), 94 | nn.BatchNorm3d(out_channels), 95 | nn.ReLU(), 96 | ) 97 | 98 | def forward(self, *features): 99 | aspp_features = self.aspp(features[-1]) 100 | aspp_features = self.up(aspp_features) 101 | high_res_features = self.block1(features[-4]) 102 | concat_features = torch.cat([aspp_features, high_res_features], dim=1) 103 | fused_features = self.block2(concat_features) 104 | return fused_features 105 | 106 | 107 | class ASPPConv(nn.Sequential): 108 | def __init__(self, in_channels, out_channels, dilation): 109 | super().__init__( 110 | nn.Conv3d( 111 | in_channels, 112 | out_channels, 113 | kernel_size=3, 114 | padding=dilation, 115 | dilation=dilation, 116 | bias=False, 117 | ), 118 | nn.BatchNorm3d(out_channels), 119 | nn.ReLU(), 120 | ) 121 | 122 | 123 | class ASPPSeparableConv(nn.Sequential): 124 | def __init__(self, in_channels, out_channels, dilation): 125 | super().__init__( 126 | SeparableConv3d( 127 | in_channels, 128 | out_channels, 129 | kernel_size=3, 130 | padding=dilation, 131 | dilation=dilation, 132 | bias=False, 133 | ), 134 | nn.BatchNorm3d(out_channels), 135 | nn.ReLU(), 136 | ) 137 | 138 | 139 | class ASPPPooling(nn.Sequential): 140 | def __init__(self, in_channels, out_channels): 141 | super().__init__( 142 | nn.AdaptiveAvgPool3d(1), 143 | nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=False), 144 | nn.BatchNorm3d(out_channels), 145 | nn.ReLU(), 146 | ) 147 | 148 | def forward(self, x): 149 | size = x.shape[-3:] 150 | for mod in self: 151 | x = mod(x) 152 | return F.interpolate(x, size=size, mode="trilinear", align_corners=False) 153 | 154 | 155 | class ASPP(nn.Module): 156 | def __init__(self, in_channels, out_channels, atrous_rates, separable=False): 157 | super(ASPP, self).__init__() 158 | modules = [] 159 | modules.append( 160 | nn.Sequential( 161 | nn.Conv3d(in_channels, out_channels, 1, bias=False), 162 | nn.BatchNorm3d(out_channels), 163 | nn.ReLU(), 164 | ) 165 | ) 166 | 167 | rate1, rate2, rate3 = tuple(atrous_rates) 168 | ASPPConvModule = ASPPConv if not separable else ASPPSeparableConv 169 | 170 | modules.append(ASPPConvModule(in_channels, out_channels, rate1)) 171 | modules.append(ASPPConvModule(in_channels, out_channels, rate2)) 172 | modules.append(ASPPConvModule(in_channels, out_channels, rate3)) 173 | modules.append(ASPPPooling(in_channels, out_channels)) 174 | 175 | self.convs = nn.ModuleList(modules) 176 | 177 | self.project = nn.Sequential( 178 | nn.Conv3d(5 * out_channels, out_channels, kernel_size=1, bias=False), 179 | nn.BatchNorm3d(out_channels), 180 | nn.ReLU(), 181 | nn.Dropout(0.5), 182 | ) 183 | 184 | def forward(self, x): 185 | res = [] 186 | for conv in self.convs: 187 | res.append(conv(x)) 188 | res = torch.cat(res, dim=1) 189 | return self.project(res) 190 | 191 | 192 | class SeparableConv3d(nn.Sequential): 193 | def __init__( 194 | self, 195 | in_channels, 196 | out_channels, 197 | kernel_size, 198 | stride=1, 199 | padding=0, 200 | dilation=1, 201 | bias=True, 202 | ): 203 | dephtwise_conv = nn.Conv3d( 204 | in_channels, 205 | in_channels, 206 | kernel_size, 207 | stride=stride, 208 | padding=padding, 209 | dilation=dilation, 210 | groups=in_channels, 211 | bias=False, 212 | ) 213 | pointwise_conv = nn.Conv3d( 214 | in_channels, 215 | out_channels, 216 | kernel_size=1, 217 | bias=bias, 218 | ) 219 | super().__init__(dephtwise_conv, pointwise_conv) 220 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/decoders/deeplabv3/model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from typing import Optional 3 | 4 | from segmentation_models_pytorch_3d.base import ( 5 | SegmentationModel, 6 | SegmentationHead, 7 | ClassificationHead, 8 | ) 9 | from segmentation_models_pytorch_3d.encoders import get_encoder 10 | from .decoder import DeepLabV3Decoder, DeepLabV3PlusDecoder 11 | 12 | 13 | class DeepLabV3(SegmentationModel): 14 | """DeepLabV3_ implementation from "Rethinking Atrous Convolution for Semantic Image Segmentation" 15 | 16 | Args: 17 | encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) 18 | to extract features of different spatial resolution 19 | encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features 20 | two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features 21 | with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). 22 | Default is 5 23 | encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and 24 | other pretrained weights (see table with available weights for each encoder_name) 25 | decoder_channels: A number of convolution filters in ASPP module. Default is 256 26 | in_channels: A number of input channels for the model, default is 3 (RGB images) 27 | classes: A number of classes for output mask (or you can think as a number of channels of output mask) 28 | activation: An activation function to apply after the final convolution layer. 29 | Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, 30 | **callable** and **None**. 31 | Default is **None** 32 | upsampling: Final upsampling factor. Default is 8 to preserve input-output spatial shape identity 33 | aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build 34 | on top of encoder if **aux_params** is not **None** (default). Supported params: 35 | - classes (int): A number of classes 36 | - pooling (str): One of "max", "avg". Default is "avg" 37 | - dropout (float): Dropout factor in [0, 1) 38 | - activation (str): An activation function to apply "sigmoid"/"softmax" 39 | (could be **None** to return logits) 40 | Returns: 41 | ``torch.nn.Module``: **DeepLabV3** 42 | 43 | .. _DeeplabV3: 44 | https://arxiv.org/abs/1706.05587 45 | 46 | """ 47 | 48 | def __init__( 49 | self, 50 | encoder_name: str = "resnet34", 51 | encoder_depth: int = 5, 52 | encoder_weights: Optional[str] = "imagenet", 53 | decoder_channels: int = 256, 54 | in_channels: int = 3, 55 | classes: int = 1, 56 | activation: Optional[str] = None, 57 | upsampling: int = 8, 58 | aux_params: Optional[dict] = None, 59 | ): 60 | super().__init__() 61 | 62 | self.encoder = get_encoder( 63 | encoder_name, 64 | in_channels=in_channels, 65 | depth=encoder_depth, 66 | weights=encoder_weights, 67 | output_stride=8, 68 | ) 69 | 70 | self.decoder = DeepLabV3Decoder( 71 | in_channels=self.encoder.out_channels[-1], 72 | out_channels=decoder_channels, 73 | ) 74 | 75 | self.segmentation_head = SegmentationHead( 76 | in_channels=self.decoder.out_channels, 77 | out_channels=classes, 78 | activation=activation, 79 | kernel_size=1, 80 | upsampling=upsampling, 81 | ) 82 | 83 | if aux_params is not None: 84 | self.classification_head = ClassificationHead(in_channels=self.encoder.out_channels[-1], **aux_params) 85 | else: 86 | self.classification_head = None 87 | 88 | 89 | class DeepLabV3Plus(SegmentationModel): 90 | """DeepLabV3+ implementation from "Encoder-Decoder with Atrous Separable 91 | Convolution for Semantic Image Segmentation" 92 | 93 | Args: 94 | encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) 95 | to extract features of different spatial resolution 96 | encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features 97 | two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features 98 | with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). 99 | Default is 5 100 | encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and 101 | other pretrained weights (see table with available weights for each encoder_name) 102 | encoder_output_stride: Downsampling factor for last encoder features (see original paper for explanation) 103 | decoder_atrous_rates: Dilation rates for ASPP module (should be a tuple of 3 integer values) 104 | decoder_channels: A number of convolution filters in ASPP module. Default is 256 105 | in_channels: A number of input channels for the model, default is 3 (RGB images) 106 | classes: A number of classes for output mask (or you can think as a number of channels of output mask) 107 | activation: An activation function to apply after the final convolution layer. 108 | Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, 109 | **callable** and **None**. 110 | Default is **None** 111 | upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity 112 | aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build 113 | on top of encoder if **aux_params** is not **None** (default). Supported params: 114 | - classes (int): A number of classes 115 | - pooling (str): One of "max", "avg". Default is "avg" 116 | - dropout (float): Dropout factor in [0, 1) 117 | - activation (str): An activation function to apply "sigmoid"/"softmax" 118 | (could be **None** to return logits) 119 | Returns: 120 | ``torch.nn.Module``: **DeepLabV3Plus** 121 | 122 | Reference: 123 | https://arxiv.org/abs/1802.02611v3 124 | 125 | """ 126 | 127 | def __init__( 128 | self, 129 | encoder_name: str = "resnet34", 130 | encoder_depth: int = 5, 131 | encoder_weights: Optional[str] = "imagenet", 132 | encoder_output_stride: int = 16, 133 | decoder_channels: int = 256, 134 | decoder_atrous_rates: tuple = (12, 24, 36), 135 | in_channels: int = 3, 136 | classes: int = 1, 137 | activation: Optional[str] = None, 138 | upsampling: int = 4, 139 | aux_params: Optional[dict] = None, 140 | ): 141 | super().__init__() 142 | 143 | if encoder_output_stride not in [8, 16]: 144 | raise ValueError("Encoder output stride should be 8 or 16, got {}".format(encoder_output_stride)) 145 | 146 | self.encoder = get_encoder( 147 | encoder_name, 148 | in_channels=in_channels, 149 | depth=encoder_depth, 150 | weights=encoder_weights, 151 | output_stride=encoder_output_stride, 152 | ) 153 | 154 | self.decoder = DeepLabV3PlusDecoder( 155 | encoder_channels=self.encoder.out_channels, 156 | out_channels=decoder_channels, 157 | atrous_rates=decoder_atrous_rates, 158 | output_stride=encoder_output_stride, 159 | ) 160 | 161 | self.segmentation_head = SegmentationHead( 162 | in_channels=self.decoder.out_channels, 163 | out_channels=classes, 164 | activation=activation, 165 | kernel_size=1, 166 | upsampling=upsampling, 167 | ) 168 | 169 | if aux_params is not None: 170 | self.classification_head = ClassificationHead(in_channels=self.encoder.out_channels[-1], **aux_params) 171 | else: 172 | self.classification_head = None 173 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/decoders/fpn/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import FPN 2 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/decoders/fpn/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Conv3x3GNReLU(nn.Module): 7 | def __init__(self, in_channels, out_channels, upsample=False): 8 | super().__init__() 9 | self.upsample = upsample 10 | self.block = nn.Sequential( 11 | nn.Conv3d(in_channels, out_channels, (3, 3, 3), stride=1, padding=1, bias=False), 12 | nn.GroupNorm(32, out_channels), 13 | nn.ReLU(inplace=True), 14 | ) 15 | 16 | def forward(self, x): 17 | x = self.block(x) 18 | if self.upsample: 19 | x = F.interpolate(x, scale_factor=2, mode="trilinear", align_corners=True) 20 | return x 21 | 22 | 23 | class FPNBlock(nn.Module): 24 | def __init__(self, pyramid_channels, skip_channels): 25 | super().__init__() 26 | self.skip_conv = nn.Conv3d(skip_channels, pyramid_channels, kernel_size=1) 27 | 28 | def forward(self, x, skip=None): 29 | x = F.interpolate(x, scale_factor=2, mode="nearest") 30 | skip = self.skip_conv(skip) 31 | x = x + skip 32 | return x 33 | 34 | 35 | class SegmentationBlock(nn.Module): 36 | def __init__(self, in_channels, out_channels, n_upsamples=0): 37 | super().__init__() 38 | 39 | blocks = [Conv3x3GNReLU(in_channels, out_channels, upsample=bool(n_upsamples))] 40 | 41 | if n_upsamples > 1: 42 | for _ in range(1, n_upsamples): 43 | blocks.append(Conv3x3GNReLU(out_channels, out_channels, upsample=True)) 44 | 45 | self.block = nn.Sequential(*blocks) 46 | 47 | def forward(self, x): 48 | return self.block(x) 49 | 50 | 51 | class MergeBlock(nn.Module): 52 | def __init__(self, policy): 53 | super().__init__() 54 | if policy not in ["add", "cat"]: 55 | raise ValueError("`merge_policy` must be one of: ['add', 'cat'], got {}".format(policy)) 56 | self.policy = policy 57 | 58 | def forward(self, x): 59 | if self.policy == "add": 60 | return sum(x) 61 | elif self.policy == "cat": 62 | return torch.cat(x, dim=1) 63 | else: 64 | raise ValueError("`merge_policy` must be one of: ['add', 'cat'], got {}".format(self.policy)) 65 | 66 | 67 | class FPNDecoder(nn.Module): 68 | def __init__( 69 | self, 70 | encoder_channels, 71 | encoder_depth=5, 72 | pyramid_channels=256, 73 | segmentation_channels=128, 74 | dropout=0.2, 75 | merge_policy="add", 76 | ): 77 | super().__init__() 78 | 79 | self.out_channels = segmentation_channels if merge_policy == "add" else segmentation_channels * 4 80 | if encoder_depth < 3: 81 | raise ValueError("Encoder depth for FPN decoder cannot be less than 3, got {}.".format(encoder_depth)) 82 | 83 | encoder_channels = encoder_channels[::-1] 84 | encoder_channels = encoder_channels[: encoder_depth + 1] 85 | 86 | self.p5 = nn.Conv3d(encoder_channels[0], pyramid_channels, kernel_size=1) 87 | self.p4 = FPNBlock(pyramid_channels, encoder_channels[1]) 88 | self.p3 = FPNBlock(pyramid_channels, encoder_channels[2]) 89 | self.p2 = FPNBlock(pyramid_channels, encoder_channels[3]) 90 | 91 | self.seg_blocks = nn.ModuleList( 92 | [ 93 | SegmentationBlock(pyramid_channels, segmentation_channels, n_upsamples=n_upsamples) 94 | for n_upsamples in [3, 2, 1, 0] 95 | ] 96 | ) 97 | 98 | self.merge = MergeBlock(merge_policy) 99 | self.dropout = nn.Dropout3d(p=dropout, inplace=True) 100 | 101 | def forward(self, *features): 102 | c2, c3, c4, c5 = features[-4:] 103 | 104 | p5 = self.p5(c5) 105 | p4 = self.p4(p5, c4) 106 | p3 = self.p3(p4, c3) 107 | p2 = self.p2(p3, c2) 108 | 109 | feature_pyramid = [seg_block(p) for seg_block, p in zip(self.seg_blocks, [p5, p4, p3, p2])] 110 | x = self.merge(feature_pyramid) 111 | x = self.dropout(x) 112 | 113 | return x 114 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/decoders/fpn/model.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | from segmentation_models_pytorch_3d.base import ( 4 | SegmentationModel, 5 | SegmentationHead, 6 | ClassificationHead, 7 | ) 8 | from segmentation_models_pytorch_3d.encoders import get_encoder 9 | from .decoder import FPNDecoder 10 | 11 | 12 | class FPN(SegmentationModel): 13 | """FPN_ is a fully convolution neural network for image semantic segmentation. 14 | 15 | Args: 16 | encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) 17 | to extract features of different spatial resolution 18 | encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features 19 | two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features 20 | with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). 21 | Default is 5 22 | encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and 23 | other pretrained weights (see table with available weights for each encoder_name) 24 | decoder_pyramid_channels: A number of convolution filters in Feature Pyramid of FPN_ 25 | decoder_segmentation_channels: A number of convolution filters in segmentation blocks of FPN_ 26 | decoder_merge_policy: Determines how to merge pyramid features inside FPN. Available options are **add** 27 | and **cat** 28 | decoder_dropout: Spatial dropout rate in range (0, 1) for feature pyramid in FPN_ 29 | in_channels: A number of input channels for the model, default is 3 (RGB images) 30 | classes: A number of classes for output mask (or you can think as a number of channels of output mask) 31 | activation: An activation function to apply after the final convolution layer. 32 | Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, 33 | **callable** and **None**. 34 | Default is **None** 35 | upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity 36 | aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build 37 | on top of encoder if **aux_params** is not **None** (default). Supported params: 38 | - classes (int): A number of classes 39 | - pooling (str): One of "max", "avg". Default is "avg" 40 | - dropout (float): Dropout factor in [0, 1) 41 | - activation (str): An activation function to apply "sigmoid"/"softmax" 42 | (could be **None** to return logits) 43 | 44 | Returns: 45 | ``torch.nn.Module``: **FPN** 46 | 47 | .. _FPN: 48 | http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf 49 | 50 | """ 51 | 52 | def __init__( 53 | self, 54 | encoder_name: str = "resnet34", 55 | encoder_depth: int = 5, 56 | encoder_weights: Optional[str] = "imagenet", 57 | decoder_pyramid_channels: int = 256, 58 | decoder_segmentation_channels: int = 128, 59 | decoder_merge_policy: str = "add", 60 | decoder_dropout: float = 0.2, 61 | in_channels: int = 3, 62 | classes: int = 1, 63 | activation: Optional[str] = None, 64 | upsampling: int = 4, 65 | aux_params: Optional[dict] = None, 66 | ): 67 | super().__init__() 68 | 69 | # validate input params 70 | if encoder_name.startswith("mit_b") and encoder_depth != 5: 71 | raise ValueError("Encoder {} support only encoder_depth=5".format(encoder_name)) 72 | 73 | self.encoder = get_encoder( 74 | encoder_name, 75 | in_channels=in_channels, 76 | depth=encoder_depth, 77 | weights=encoder_weights, 78 | ) 79 | 80 | self.decoder = FPNDecoder( 81 | encoder_channels=self.encoder.out_channels, 82 | encoder_depth=encoder_depth, 83 | pyramid_channels=decoder_pyramid_channels, 84 | segmentation_channels=decoder_segmentation_channels, 85 | dropout=decoder_dropout, 86 | merge_policy=decoder_merge_policy, 87 | ) 88 | 89 | self.segmentation_head = SegmentationHead( 90 | in_channels=self.decoder.out_channels, 91 | out_channels=classes, 92 | activation=activation, 93 | kernel_size=1, 94 | upsampling=upsampling, 95 | ) 96 | 97 | if aux_params is not None: 98 | self.classification_head = ClassificationHead(in_channels=self.encoder.out_channels[-1], **aux_params) 99 | else: 100 | self.classification_head = None 101 | 102 | self.name = "fpn-{}".format(encoder_name) 103 | self.initialize() 104 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/decoders/linknet/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import Linknet 2 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/decoders/linknet/decoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from segmentation_models_pytorch_3d.base import modules 4 | 5 | 6 | class TransposeX2(nn.Sequential): 7 | def __init__(self, in_channels, out_channels, use_batchnorm=True): 8 | super().__init__() 9 | layers = [ 10 | nn.ConvTranspose3d(in_channels, out_channels, kernel_size=4, stride=2, padding=1), 11 | nn.ReLU(inplace=True), 12 | ] 13 | 14 | if use_batchnorm: 15 | layers.insert(1, nn.BatchNorm3d(out_channels)) 16 | 17 | super().__init__(*layers) 18 | 19 | 20 | class DecoderBlock(nn.Module): 21 | def __init__(self, in_channels, out_channels, use_batchnorm=True): 22 | super().__init__() 23 | 24 | self.block = nn.Sequential( 25 | modules.Conv3dReLU( 26 | in_channels, 27 | in_channels // 4, 28 | kernel_size=1, 29 | use_batchnorm=use_batchnorm, 30 | ), 31 | TransposeX2(in_channels // 4, in_channels // 4, use_batchnorm=use_batchnorm), 32 | modules.Conv3dReLU( 33 | in_channels // 4, 34 | out_channels, 35 | kernel_size=1, 36 | use_batchnorm=use_batchnorm, 37 | ), 38 | ) 39 | 40 | def forward(self, x, skip=None): 41 | x = self.block(x) 42 | if skip is not None: 43 | x = x + skip 44 | return x 45 | 46 | 47 | class LinknetDecoder(nn.Module): 48 | def __init__( 49 | self, 50 | encoder_channels, 51 | prefinal_channels=32, 52 | n_blocks=5, 53 | use_batchnorm=True, 54 | ): 55 | super().__init__() 56 | 57 | # remove first skip 58 | encoder_channels = encoder_channels[1:] 59 | # reverse channels to start from head of encoder 60 | encoder_channels = encoder_channels[::-1] 61 | 62 | channels = list(encoder_channels) + [prefinal_channels] 63 | 64 | self.blocks = nn.ModuleList( 65 | [DecoderBlock(channels[i], channels[i + 1], use_batchnorm=use_batchnorm) for i in range(n_blocks)] 66 | ) 67 | 68 | def forward(self, *features): 69 | features = features[1:] # remove first skip 70 | features = features[::-1] # reverse channels to start from head of encoder 71 | 72 | x = features[0] 73 | skips = features[1:] 74 | 75 | for i, decoder_block in enumerate(self.blocks): 76 | skip = skips[i] if i < len(skips) else None 77 | x = decoder_block(x, skip) 78 | 79 | return x 80 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/decoders/linknet/model.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | from segmentation_models_pytorch_3d.base import ( 4 | SegmentationHead, 5 | SegmentationModel, 6 | ClassificationHead, 7 | ) 8 | from segmentation_models_pytorch_3d.encoders import get_encoder 9 | from .decoder import LinknetDecoder 10 | 11 | 12 | class Linknet(SegmentationModel): 13 | """Linknet_ is a fully convolution neural network for image semantic segmentation. Consist of *encoder* 14 | and *decoder* parts connected with *skip connections*. Encoder extract features of different spatial 15 | resolution (skip connections) which are used by decoder to define accurate segmentation mask. Use *sum* 16 | for fusing decoder blocks with skip connections. 17 | 18 | Note: 19 | This implementation by default has 4 skip connections (original - 3). 20 | 21 | Args: 22 | encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) 23 | to extract features of different spatial resolution 24 | encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features 25 | two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features 26 | with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). 27 | Default is 5 28 | encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and 29 | other pretrained weights (see table with available weights for each encoder_name) 30 | decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers 31 | is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. 32 | Available options are **True, False, "inplace"** 33 | in_channels: A number of input channels for the model, default is 3 (RGB images) 34 | classes: A number of classes for output mask (or you can think as a number of channels of output mask) 35 | activation: An activation function to apply after the final convolution layer. 36 | Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, 37 | **callable** and **None**. 38 | Default is **None** 39 | aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build 40 | on top of encoder if **aux_params** is not **None** (default). Supported params: 41 | - classes (int): A number of classes 42 | - pooling (str): One of "max", "avg". Default is "avg" 43 | - dropout (float): Dropout factor in [0, 1) 44 | - activation (str): An activation function to apply "sigmoid"/"softmax" 45 | (could be **None** to return logits) 46 | 47 | Returns: 48 | ``torch.nn.Module``: **Linknet** 49 | 50 | .. _Linknet: 51 | https://arxiv.org/abs/1707.03718 52 | """ 53 | 54 | def __init__( 55 | self, 56 | encoder_name: str = "resnet34", 57 | encoder_depth: int = 5, 58 | encoder_weights: Optional[str] = "imagenet", 59 | decoder_use_batchnorm: bool = True, 60 | in_channels: int = 3, 61 | classes: int = 1, 62 | activation: Optional[Union[str, callable]] = None, 63 | aux_params: Optional[dict] = None, 64 | ): 65 | super().__init__() 66 | 67 | if encoder_name.startswith("mit_b"): 68 | raise ValueError("Encoder `{}` is not supported for Linknet".format(encoder_name)) 69 | 70 | self.encoder = get_encoder( 71 | encoder_name, 72 | in_channels=in_channels, 73 | depth=encoder_depth, 74 | weights=encoder_weights, 75 | ) 76 | 77 | self.decoder = LinknetDecoder( 78 | encoder_channels=self.encoder.out_channels, 79 | n_blocks=encoder_depth, 80 | prefinal_channels=32, 81 | use_batchnorm=decoder_use_batchnorm, 82 | ) 83 | 84 | self.segmentation_head = SegmentationHead( 85 | in_channels=32, out_channels=classes, activation=activation, kernel_size=1 86 | ) 87 | 88 | if aux_params is not None: 89 | self.classification_head = ClassificationHead(in_channels=self.encoder.out_channels[-1], **aux_params) 90 | else: 91 | self.classification_head = None 92 | 93 | self.name = "link-{}".format(encoder_name) 94 | self.initialize() 95 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/decoders/manet/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import MAnet 2 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/decoders/manet/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from segmentation_models_pytorch_3d.base import modules as md 6 | 7 | 8 | class PAB(nn.Module): 9 | def __init__(self, in_channels, out_channels, pab_channels=64): 10 | super(PAB, self).__init__() 11 | # Series of 1x1 conv to generate attention feature maps 12 | self.pab_channels = pab_channels 13 | self.in_channels = in_channels 14 | self.top_conv = nn.Conv3d(in_channels, pab_channels, kernel_size=1) 15 | self.center_conv = nn.Conv3d(in_channels, pab_channels, kernel_size=1) 16 | self.bottom_conv = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1) 17 | self.map_softmax = nn.Softmax(dim=1) 18 | self.out_conv = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1) 19 | 20 | def forward(self, x): 21 | bsize = x.size()[0] 22 | h = x.size()[2] 23 | w = x.size()[3] 24 | d = x.size()[4] 25 | x_top = self.top_conv(x) 26 | x_center = self.center_conv(x) 27 | x_bottom = self.bottom_conv(x) 28 | 29 | x_top = x_top.flatten(2) 30 | x_center = x_center.flatten(2).transpose(1, 2) 31 | x_bottom = x_bottom.flatten(2).transpose(1, 2) 32 | 33 | sp_map = torch.matmul(x_center, x_top) 34 | sp_map = self.map_softmax(sp_map.view(bsize, -1)).view(bsize, h * w * d, h * w * d) 35 | sp_map = torch.matmul(sp_map, x_bottom) 36 | sp_map = sp_map.reshape(bsize, self.in_channels, h, w, d) 37 | x = x + sp_map 38 | x = self.out_conv(x) 39 | return x 40 | 41 | 42 | class MFAB(nn.Module): 43 | def __init__(self, in_channels, skip_channels, out_channels, use_batchnorm=True, reduction=16): 44 | # MFAB is just a modified version of SE-blocks, one for skip, one for input 45 | super(MFAB, self).__init__() 46 | self.hl_conv = nn.Sequential( 47 | md.Conv3dReLU( 48 | in_channels, 49 | in_channels, 50 | kernel_size=3, 51 | padding=1, 52 | use_batchnorm=use_batchnorm, 53 | ), 54 | md.Conv3dReLU( 55 | in_channels, 56 | skip_channels, 57 | kernel_size=1, 58 | use_batchnorm=use_batchnorm, 59 | ), 60 | ) 61 | reduced_channels = max(1, skip_channels // reduction) 62 | self.SE_ll = nn.Sequential( 63 | nn.AdaptiveAvgPool3d(1), 64 | nn.Conv3d(skip_channels, reduced_channels, 1), 65 | nn.ReLU(inplace=True), 66 | nn.Conv3d(reduced_channels, skip_channels, 1), 67 | nn.Sigmoid(), 68 | ) 69 | self.SE_hl = nn.Sequential( 70 | nn.AdaptiveAvgPool3d(1), 71 | nn.Conv3d(skip_channels, reduced_channels, 1), 72 | nn.ReLU(inplace=True), 73 | nn.Conv3d(reduced_channels, skip_channels, 1), 74 | nn.Sigmoid(), 75 | ) 76 | self.conv1 = md.Conv3dReLU( 77 | skip_channels + skip_channels, # we transform C-prime form high level to C from skip connection 78 | out_channels, 79 | kernel_size=3, 80 | padding=1, 81 | use_batchnorm=use_batchnorm, 82 | ) 83 | self.conv2 = md.Conv3dReLU( 84 | out_channels, 85 | out_channels, 86 | kernel_size=3, 87 | padding=1, 88 | use_batchnorm=use_batchnorm, 89 | ) 90 | 91 | def forward(self, x, skip=None): 92 | x = self.hl_conv(x) 93 | x = F.interpolate(x, scale_factor=2, mode="nearest") 94 | attention_hl = self.SE_hl(x) 95 | if skip is not None: 96 | attention_ll = self.SE_ll(skip) 97 | attention_hl = attention_hl + attention_ll 98 | x = x * attention_hl 99 | x = torch.cat([x, skip], dim=1) 100 | x = self.conv1(x) 101 | x = self.conv2(x) 102 | return x 103 | 104 | 105 | class DecoderBlock(nn.Module): 106 | def __init__(self, in_channels, skip_channels, out_channels, use_batchnorm=True): 107 | super().__init__() 108 | self.conv1 = md.Conv3dReLU( 109 | in_channels + skip_channels, 110 | out_channels, 111 | kernel_size=3, 112 | padding=1, 113 | use_batchnorm=use_batchnorm, 114 | ) 115 | self.conv2 = md.Conv3dReLU( 116 | out_channels, 117 | out_channels, 118 | kernel_size=3, 119 | padding=1, 120 | use_batchnorm=use_batchnorm, 121 | ) 122 | 123 | def forward(self, x, skip=None): 124 | x = F.interpolate(x, scale_factor=2, mode="nearest") 125 | if skip is not None: 126 | x = torch.cat([x, skip], dim=1) 127 | x = self.conv1(x) 128 | x = self.conv2(x) 129 | return x 130 | 131 | 132 | class MAnetDecoder(nn.Module): 133 | def __init__( 134 | self, 135 | encoder_channels, 136 | decoder_channels, 137 | n_blocks=5, 138 | reduction=16, 139 | use_batchnorm=True, 140 | pab_channels=64, 141 | ): 142 | super().__init__() 143 | 144 | if n_blocks != len(decoder_channels): 145 | raise ValueError( 146 | "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format( 147 | n_blocks, len(decoder_channels) 148 | ) 149 | ) 150 | 151 | # remove first skip with same spatial resolution 152 | encoder_channels = encoder_channels[1:] 153 | 154 | # reverse channels to start from head of encoder 155 | encoder_channels = encoder_channels[::-1] 156 | 157 | # computing blocks input and output channels 158 | head_channels = encoder_channels[0] 159 | in_channels = [head_channels] + list(decoder_channels[:-1]) 160 | skip_channels = list(encoder_channels[1:]) + [0] 161 | out_channels = decoder_channels 162 | 163 | self.center = PAB(head_channels, head_channels, pab_channels=pab_channels) 164 | 165 | # combine decoder keyword arguments 166 | kwargs = dict(use_batchnorm=use_batchnorm) # no attention type here 167 | blocks = [ 168 | MFAB(in_ch, skip_ch, out_ch, reduction=reduction, **kwargs) 169 | if skip_ch > 0 170 | else DecoderBlock(in_ch, skip_ch, out_ch, **kwargs) 171 | for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels) 172 | ] 173 | # for the last we dont have skip connection -> use simple decoder block 174 | self.blocks = nn.ModuleList(blocks) 175 | 176 | def forward(self, *features): 177 | 178 | features = features[1:] # remove first skip with same spatial resolution 179 | features = features[::-1] # reverse channels to start from head of encoder 180 | 181 | head = features[0] 182 | skips = features[1:] 183 | 184 | x = self.center(head) 185 | for i, decoder_block in enumerate(self.blocks): 186 | skip = skips[i] if i < len(skips) else None 187 | x = decoder_block(x, skip) 188 | 189 | return x 190 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/decoders/manet/model.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union, List 2 | 3 | from segmentation_models_pytorch_3d.encoders import get_encoder 4 | from segmentation_models_pytorch_3d.base import ( 5 | SegmentationModel, 6 | SegmentationHead, 7 | ClassificationHead, 8 | ) 9 | from .decoder import MAnetDecoder 10 | 11 | 12 | class MAnet(SegmentationModel): 13 | """MAnet_ : Multi-scale Attention Net. The MA-Net can capture rich contextual dependencies based on 14 | the attention mechanism, using two blocks: 15 | - Position-wise Attention Block (PAB), which captures the spatial dependencies between pixels in a global view 16 | - Multi-scale Fusion Attention Block (MFAB), which captures the channel dependencies between any feature map by 17 | multi-scale semantic feature fusion 18 | 19 | Args: 20 | encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) 21 | to extract features of different spatial resolution 22 | encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features 23 | two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features 24 | with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). 25 | Default is 5 26 | encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and 27 | other pretrained weights (see table with available weights for each encoder_name) 28 | decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder. 29 | Length of the list should be the same as **encoder_depth** 30 | decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers 31 | is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. 32 | Available options are **True, False, "inplace"** 33 | decoder_pab_channels: A number of channels for PAB module in decoder. 34 | Default is 64. 35 | in_channels: A number of input channels for the model, default is 3 (RGB images) 36 | classes: A number of classes for output mask (or you can think as a number of channels of output mask) 37 | activation: An activation function to apply after the final convolution layer. 38 | Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, 39 | **callable** and **None**. 40 | Default is **None** 41 | aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build 42 | on top of encoder if **aux_params** is not **None** (default). Supported params: 43 | - classes (int): A number of classes 44 | - pooling (str): One of "max", "avg". Default is "avg" 45 | - dropout (float): Dropout factor in [0, 1) 46 | - activation (str): An activation function to apply "sigmoid"/"softmax" 47 | (could be **None** to return logits) 48 | 49 | Returns: 50 | ``torch.nn.Module``: **MAnet** 51 | 52 | .. _MAnet: 53 | https://ieeexplore.ieee.org/abstract/document/9201310 54 | 55 | """ 56 | 57 | def __init__( 58 | self, 59 | encoder_name: str = "resnet34", 60 | encoder_depth: int = 5, 61 | encoder_weights: Optional[str] = "imagenet", 62 | decoder_use_batchnorm: bool = True, 63 | decoder_channels: List[int] = (256, 128, 64, 32, 16), 64 | decoder_pab_channels: int = 64, 65 | in_channels: int = 3, 66 | classes: int = 1, 67 | activation: Optional[Union[str, callable]] = None, 68 | aux_params: Optional[dict] = None, 69 | ): 70 | super().__init__() 71 | 72 | self.encoder = get_encoder( 73 | encoder_name, 74 | in_channels=in_channels, 75 | depth=encoder_depth, 76 | weights=encoder_weights, 77 | ) 78 | 79 | self.decoder = MAnetDecoder( 80 | encoder_channels=self.encoder.out_channels, 81 | decoder_channels=decoder_channels, 82 | n_blocks=encoder_depth, 83 | use_batchnorm=decoder_use_batchnorm, 84 | pab_channels=decoder_pab_channels, 85 | ) 86 | 87 | self.segmentation_head = SegmentationHead( 88 | in_channels=decoder_channels[-1], 89 | out_channels=classes, 90 | activation=activation, 91 | kernel_size=3, 92 | ) 93 | 94 | if aux_params is not None: 95 | self.classification_head = ClassificationHead(in_channels=self.encoder.out_channels[-1], **aux_params) 96 | else: 97 | self.classification_head = None 98 | 99 | self.name = "manet-{}".format(encoder_name) 100 | self.initialize() 101 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/decoders/pan/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import PAN 2 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/decoders/pan/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ConvBnRelu(nn.Module): 7 | def __init__( 8 | self, 9 | in_channels: int, 10 | out_channels: int, 11 | kernel_size: int, 12 | stride: int = 1, 13 | padding: int = 0, 14 | dilation: int = 1, 15 | groups: int = 1, 16 | bias: bool = True, 17 | add_relu: bool = True, 18 | interpolate: bool = False, 19 | ): 20 | super(ConvBnRelu, self).__init__() 21 | self.conv = nn.Conv3d( 22 | in_channels=in_channels, 23 | out_channels=out_channels, 24 | kernel_size=kernel_size, 25 | stride=stride, 26 | padding=padding, 27 | dilation=dilation, 28 | bias=bias, 29 | groups=groups, 30 | ) 31 | self.add_relu = add_relu 32 | self.interpolate = interpolate 33 | self.bn = nn.BatchNorm3d(out_channels) 34 | self.activation = nn.ReLU(inplace=True) 35 | 36 | def forward(self, x): 37 | x = self.conv(x) 38 | x = self.bn(x) 39 | if self.add_relu: 40 | x = self.activation(x) 41 | if self.interpolate: 42 | x = F.interpolate(x, scale_factor=2, mode="trilinear", align_corners=True) 43 | return x 44 | 45 | 46 | class FPABlock(nn.Module): 47 | def __init__(self, in_channels, out_channels, upscale_mode="trilinear"): 48 | super(FPABlock, self).__init__() 49 | 50 | self.upscale_mode = upscale_mode 51 | if self.upscale_mode == "trilinear": 52 | self.align_corners = True 53 | else: 54 | self.align_corners = False 55 | 56 | # global pooling branch 57 | self.branch1 = nn.Sequential( 58 | nn.AdaptiveAvgPool3d(1), 59 | ConvBnRelu( 60 | in_channels=in_channels, 61 | out_channels=out_channels, 62 | kernel_size=1, 63 | stride=1, 64 | padding=0, 65 | ), 66 | ) 67 | 68 | # midddle branch 69 | self.mid = nn.Sequential( 70 | ConvBnRelu( 71 | in_channels=in_channels, 72 | out_channels=out_channels, 73 | kernel_size=1, 74 | stride=1, 75 | padding=0, 76 | ) 77 | ) 78 | self.down1 = nn.Sequential( 79 | nn.MaxPool3d(kernel_size=2, stride=2), 80 | ConvBnRelu( 81 | in_channels=in_channels, 82 | out_channels=1, 83 | kernel_size=7, 84 | stride=1, 85 | padding=3, 86 | ), 87 | ) 88 | self.down2 = nn.Sequential( 89 | nn.MaxPool3d(kernel_size=2, stride=2), 90 | ConvBnRelu(in_channels=1, out_channels=1, kernel_size=5, stride=1, padding=2), 91 | ) 92 | self.down3 = nn.Sequential( 93 | nn.MaxPool3d(kernel_size=2, stride=2), 94 | ConvBnRelu(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1), 95 | ConvBnRelu(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1), 96 | ) 97 | self.conv2 = ConvBnRelu(in_channels=1, out_channels=1, kernel_size=5, stride=1, padding=2) 98 | self.conv1 = ConvBnRelu(in_channels=1, out_channels=1, kernel_size=7, stride=1, padding=3) 99 | 100 | def forward(self, x): 101 | h, w, d = x.size(2), x.size(3), x.size(4) 102 | b1 = self.branch1(x) 103 | upscale_parameters = dict(mode=self.upscale_mode, align_corners=self.align_corners) 104 | b1 = F.interpolate(b1, size=(h, w, d), **upscale_parameters) 105 | 106 | mid = self.mid(x) 107 | x1 = self.down1(x) 108 | x2 = self.down2(x1) 109 | x3 = self.down3(x2) 110 | x3 = F.interpolate(x3, size=(h // 4, w // 4, d // 4), **upscale_parameters) 111 | 112 | x2 = self.conv2(x2) 113 | x = x2 + x3 114 | x = F.interpolate(x, size=(h // 2, w // 2, d // 2), **upscale_parameters) 115 | 116 | x1 = self.conv1(x1) 117 | x = x + x1 118 | x = F.interpolate(x, size=(h, w, d), **upscale_parameters) 119 | 120 | x = torch.mul(x, mid) 121 | x = x + b1 122 | return x 123 | 124 | 125 | class GAUBlock(nn.Module): 126 | def __init__(self, in_channels: int, out_channels: int, upscale_mode: str = "trilinear"): 127 | super(GAUBlock, self).__init__() 128 | 129 | self.upscale_mode = upscale_mode 130 | self.align_corners = True if upscale_mode == "trilinear" else None 131 | 132 | self.conv1 = nn.Sequential( 133 | nn.AdaptiveAvgPool3d(1), 134 | ConvBnRelu( 135 | in_channels=out_channels, 136 | out_channels=out_channels, 137 | kernel_size=1, 138 | add_relu=False, 139 | ), 140 | nn.Sigmoid(), 141 | ) 142 | self.conv2 = ConvBnRelu(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1) 143 | 144 | def forward(self, x, y): 145 | """ 146 | Args: 147 | x: low level feature 148 | y: high level feature 149 | """ 150 | h, w, d = x.size(2), x.size(3), x.size(4) 151 | y_up = F.interpolate(y, size=(h, w, d), mode=self.upscale_mode, align_corners=self.align_corners) 152 | x = self.conv2(x) 153 | y = self.conv1(y) 154 | z = torch.mul(x, y) 155 | return y_up + z 156 | 157 | 158 | class PANDecoder(nn.Module): 159 | def __init__(self, encoder_channels, decoder_channels, upscale_mode: str = "trilinear"): 160 | super().__init__() 161 | 162 | self.fpa = FPABlock(in_channels=encoder_channels[-1], out_channels=decoder_channels) 163 | self.gau3 = GAUBlock( 164 | in_channels=encoder_channels[-2], 165 | out_channels=decoder_channels, 166 | upscale_mode=upscale_mode, 167 | ) 168 | self.gau2 = GAUBlock( 169 | in_channels=encoder_channels[-3], 170 | out_channels=decoder_channels, 171 | upscale_mode=upscale_mode, 172 | ) 173 | self.gau1 = GAUBlock( 174 | in_channels=encoder_channels[-4], 175 | out_channels=decoder_channels, 176 | upscale_mode=upscale_mode, 177 | ) 178 | 179 | def forward(self, *features): 180 | bottleneck = features[-1] 181 | x5 = self.fpa(bottleneck) # 1/32 182 | x4 = self.gau3(features[-2], x5) # 1/16 183 | x3 = self.gau2(features[-3], x4) # 1/8 184 | x2 = self.gau1(features[-4], x3) # 1/4 185 | 186 | return x2 187 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/decoders/pan/model.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | from segmentation_models_pytorch_3d.encoders import get_encoder 4 | from segmentation_models_pytorch_3d.base import ( 5 | SegmentationModel, 6 | SegmentationHead, 7 | ClassificationHead, 8 | ) 9 | from .decoder import PANDecoder 10 | 11 | 12 | class PAN(SegmentationModel): 13 | """Implementation of PAN_ (Pyramid Attention Network). 14 | 15 | Note: 16 | Currently works with shape of input tensor >= [B x C x 128 x 128] for pytorch <= 1.1.0 17 | and with shape of input tensor >= [B x C x 256 x 256] for pytorch == 1.3.1 18 | 19 | Args: 20 | encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) 21 | to extract features of different spatial resolution 22 | encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and 23 | other pretrained weights (see table with available weights for each encoder_name) 24 | encoder_output_stride: 16 or 32, if 16 use dilation in encoder last layer. 25 | Doesn't work with ***ception***, **vgg***, **densenet*`** backbones.Default is 16. 26 | decoder_channels: A number of convolution layer filters in decoder blocks 27 | in_channels: A number of input channels for the model, default is 3 (RGB images) 28 | classes: A number of classes for output mask (or you can think as a number of channels of output mask) 29 | activation: An activation function to apply after the final convolution layer. 30 | Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, 31 | **callable** and **None**. 32 | Default is **None** 33 | upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity 34 | aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build 35 | on top of encoder if **aux_params** is not **None** (default). Supported params: 36 | - classes (int): A number of classes 37 | - pooling (str): One of "max", "avg". Default is "avg" 38 | - dropout (float): Dropout factor in [0, 1) 39 | - activation (str): An activation function to apply "sigmoid"/"softmax" 40 | (could be **None** to return logits) 41 | 42 | Returns: 43 | ``torch.nn.Module``: **PAN** 44 | 45 | .. _PAN: 46 | https://arxiv.org/abs/1805.10180 47 | 48 | """ 49 | 50 | def __init__( 51 | self, 52 | encoder_name: str = "resnet34", 53 | encoder_weights: Optional[str] = "imagenet", 54 | encoder_output_stride: int = 16, 55 | decoder_channels: int = 32, 56 | in_channels: int = 3, 57 | classes: int = 1, 58 | activation: Optional[Union[str, callable]] = None, 59 | upsampling: int = 4, 60 | aux_params: Optional[dict] = None, 61 | ): 62 | super().__init__() 63 | 64 | if encoder_output_stride not in [16, 32]: 65 | raise ValueError("PAN support output stride 16 or 32, got {}".format(encoder_output_stride)) 66 | 67 | self.encoder = get_encoder( 68 | encoder_name, 69 | in_channels=in_channels, 70 | depth=5, 71 | weights=encoder_weights, 72 | output_stride=encoder_output_stride, 73 | ) 74 | 75 | self.decoder = PANDecoder( 76 | encoder_channels=self.encoder.out_channels, 77 | decoder_channels=decoder_channels, 78 | ) 79 | 80 | self.segmentation_head = SegmentationHead( 81 | in_channels=decoder_channels, 82 | out_channels=classes, 83 | activation=activation, 84 | kernel_size=3, 85 | upsampling=upsampling, 86 | ) 87 | 88 | if aux_params is not None: 89 | self.classification_head = ClassificationHead(in_channels=self.encoder.out_channels[-1], **aux_params) 90 | else: 91 | self.classification_head = None 92 | 93 | self.name = "pan-{}".format(encoder_name) 94 | self.initialize() 95 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/decoders/pspnet/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import PSPNet 2 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/decoders/pspnet/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from segmentation_models_pytorch_3d.base import modules 6 | 7 | 8 | class PSPBlock(nn.Module): 9 | def __init__(self, in_channels, out_channels, pool_size, use_bathcnorm=True): 10 | super().__init__() 11 | if pool_size == 1: 12 | use_bathcnorm = False # PyTorch does not support BatchNorm for 1x1 shape 13 | self.pool = nn.Sequential( 14 | nn.AdaptiveAvgPool3d(output_size=(pool_size, pool_size, pool_size)), 15 | modules.Conv3dReLU(in_channels, out_channels, (1, 1, 1), use_batchnorm=use_bathcnorm), 16 | ) 17 | 18 | def forward(self, x): 19 | h, w, d = x.size(2), x.size(3), x.size(4) 20 | x = self.pool(x) 21 | x = F.interpolate(x, size=(h, w, d), mode="trilinear", align_corners=True) 22 | return x 23 | 24 | 25 | class PSPModule(nn.Module): 26 | def __init__(self, in_channels, sizes=(1, 2, 3, 6), use_bathcnorm=True): 27 | super().__init__() 28 | 29 | self.blocks = nn.ModuleList( 30 | [ 31 | PSPBlock( 32 | in_channels, 33 | in_channels // len(sizes), 34 | size, 35 | use_bathcnorm=use_bathcnorm, 36 | ) 37 | for size in sizes 38 | ] 39 | ) 40 | 41 | def forward(self, x): 42 | xs = [block(x) for block in self.blocks] + [x] 43 | x = torch.cat(xs, dim=1) 44 | return x 45 | 46 | 47 | class PSPDecoder(nn.Module): 48 | def __init__( 49 | self, 50 | encoder_channels, 51 | use_batchnorm=True, 52 | out_channels=512, 53 | dropout=0.2, 54 | ): 55 | super().__init__() 56 | 57 | self.psp = PSPModule( 58 | in_channels=encoder_channels[-1], 59 | sizes=(1, 2, 3, 6), 60 | use_bathcnorm=use_batchnorm, 61 | ) 62 | 63 | self.conv = modules.Conv3dReLU( 64 | in_channels=encoder_channels[-1] * 2, 65 | out_channels=out_channels, 66 | kernel_size=1, 67 | use_batchnorm=use_batchnorm, 68 | ) 69 | 70 | self.dropout = nn.Dropout3d(p=dropout) 71 | 72 | def forward(self, *features): 73 | x = features[-1] 74 | x = self.psp(x) 75 | x = self.conv(x) 76 | x = self.dropout(x) 77 | 78 | return x 79 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/decoders/pspnet/model.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | from segmentation_models_pytorch_3d.encoders import get_encoder 4 | from segmentation_models_pytorch_3d.base import ( 5 | SegmentationModel, 6 | SegmentationHead, 7 | ClassificationHead, 8 | ) 9 | from .decoder import PSPDecoder 10 | 11 | 12 | class PSPNet(SegmentationModel): 13 | """PSPNet_ is a fully convolution neural network for image semantic segmentation. Consist of 14 | *encoder* and *Spatial Pyramid* (decoder). Spatial Pyramid build on top of encoder and does not 15 | use "fine-features" (features of high spatial resolution). PSPNet can be used for multiclass segmentation 16 | of high resolution images, however it is not good for detecting small objects and producing accurate, 17 | pixel-level mask. 18 | 19 | Args: 20 | encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) 21 | to extract features of different spatial resolution 22 | encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features 23 | two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features 24 | with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). 25 | Default is 5 26 | encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and 27 | other pretrained weights (see table with available weights for each encoder_name) 28 | psp_out_channels: A number of filters in Spatial Pyramid 29 | psp_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers 30 | is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. 31 | Available options are **True, False, "inplace"** 32 | psp_dropout: Spatial dropout rate in [0, 1) used in Spatial Pyramid 33 | in_channels: A number of input channels for the model, default is 3 (RGB images) 34 | classes: A number of classes for output mask (or you can think as a number of channels of output mask) 35 | activation: An activation function to apply after the final convolution layer. 36 | Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, 37 | **callable** and **None**. 38 | Default is **None** 39 | upsampling: Final upsampling factor. Default is 8 to preserve input-output spatial shape identity 40 | aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build 41 | on top of encoder if **aux_params** is not **None** (default). Supported params: 42 | - classes (int): A number of classes 43 | - pooling (str): One of "max", "avg". Default is "avg" 44 | - dropout (float): Dropout factor in [0, 1) 45 | - activation (str): An activation function to apply "sigmoid"/"softmax" 46 | (could be **None** to return logits) 47 | 48 | Returns: 49 | ``torch.nn.Module``: **PSPNet** 50 | 51 | .. _PSPNet: 52 | https://arxiv.org/abs/1612.01105 53 | """ 54 | 55 | def __init__( 56 | self, 57 | encoder_name: str = "resnet34", 58 | encoder_weights: Optional[str] = "imagenet", 59 | encoder_depth: int = 3, 60 | psp_out_channels: int = 512, 61 | psp_use_batchnorm: bool = True, 62 | psp_dropout: float = 0.2, 63 | in_channels: int = 3, 64 | classes: int = 1, 65 | activation: Optional[Union[str, callable]] = None, 66 | upsampling: int = 8, 67 | aux_params: Optional[dict] = None, 68 | ): 69 | super().__init__() 70 | 71 | self.encoder = get_encoder( 72 | encoder_name, 73 | in_channels=in_channels, 74 | depth=encoder_depth, 75 | weights=encoder_weights, 76 | ) 77 | 78 | self.decoder = PSPDecoder( 79 | encoder_channels=self.encoder.out_channels, 80 | use_batchnorm=psp_use_batchnorm, 81 | out_channels=psp_out_channels, 82 | dropout=psp_dropout, 83 | ) 84 | 85 | self.segmentation_head = SegmentationHead( 86 | in_channels=psp_out_channels, 87 | out_channels=classes, 88 | kernel_size=3, 89 | activation=activation, 90 | upsampling=upsampling, 91 | ) 92 | 93 | if aux_params: 94 | self.classification_head = ClassificationHead(in_channels=self.encoder.out_channels[-1], **aux_params) 95 | else: 96 | self.classification_head = None 97 | 98 | self.name = "psp-{}".format(encoder_name) 99 | self.initialize() 100 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/decoders/unet/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import Unet 2 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/decoders/unet/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from segmentation_models_pytorch_3d.base import modules as md 6 | 7 | 8 | class DecoderBlock(nn.Module): 9 | def __init__( 10 | self, 11 | in_channels, 12 | skip_channels, 13 | out_channels, 14 | stride, 15 | use_batchnorm=True, 16 | attention_type=None, 17 | ): 18 | super().__init__() 19 | self.conv1 = md.Conv3dReLU( 20 | in_channels + skip_channels, 21 | out_channels, 22 | kernel_size=3, 23 | padding=1, 24 | use_batchnorm=use_batchnorm, 25 | ) 26 | self.attention1 = md.Attention(attention_type, in_channels=in_channels + skip_channels) 27 | self.conv2 = md.Conv3dReLU( 28 | out_channels, 29 | out_channels, 30 | kernel_size=3, 31 | padding=1, 32 | use_batchnorm=use_batchnorm, 33 | ) 34 | self.attention2 = md.Attention(attention_type, in_channels=out_channels) 35 | self.stride = stride 36 | 37 | def forward(self, x, skip=None): 38 | x = F.interpolate(x, scale_factor=self.stride, mode="nearest") 39 | # print("!!!", x.shape) 40 | if skip is not None: 41 | # print("!!!!", skip.shape) 42 | x = torch.cat([x, skip], dim=1) 43 | x = self.attention1(x) 44 | x = self.conv1(x) 45 | x = self.conv2(x) 46 | x = self.attention2(x) 47 | return x 48 | 49 | 50 | class CenterBlock(nn.Sequential): 51 | def __init__(self, in_channels, out_channels, use_batchnorm=True): 52 | conv1 = md.Conv3dReLU( 53 | in_channels, 54 | out_channels, 55 | kernel_size=3, 56 | padding=1, 57 | use_batchnorm=use_batchnorm, 58 | ) 59 | conv2 = md.Conv3dReLU( 60 | out_channels, 61 | out_channels, 62 | kernel_size=3, 63 | padding=1, 64 | use_batchnorm=use_batchnorm, 65 | ) 66 | super().__init__(conv1, conv2) 67 | 68 | 69 | class UnetDecoder(nn.Module): 70 | def __init__( 71 | self, 72 | encoder_channels, 73 | decoder_channels, 74 | n_blocks=5, 75 | use_batchnorm=True, 76 | attention_type=None, 77 | center=False, 78 | strides=((2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2)), 79 | ): 80 | super().__init__() 81 | 82 | if n_blocks != len(decoder_channels): 83 | raise ValueError( 84 | "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format( 85 | n_blocks, len(decoder_channels) 86 | ) 87 | ) 88 | 89 | if n_blocks != len(strides): 90 | raise ValueError( 91 | "Model depth is {}, but you provide `strides` as {}.".format( 92 | n_blocks, len(strides) 93 | ) 94 | ) 95 | 96 | # remove first skip with same spatial resolution 97 | encoder_channels = encoder_channels[1:] 98 | # reverse channels to start from head of encoder 99 | encoder_channels = encoder_channels[::-1] 100 | 101 | # computing blocks input and output channels 102 | head_channels = encoder_channels[0] 103 | in_channels = [head_channels] + list(decoder_channels[:-1]) 104 | skip_channels = list(encoder_channels[1:]) + [0] 105 | out_channels = decoder_channels 106 | 107 | if center: 108 | self.center = CenterBlock(head_channels, head_channels, use_batchnorm=use_batchnorm) 109 | else: 110 | self.center = nn.Identity() 111 | 112 | # combine decoder keyword arguments 113 | kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type) 114 | blocks = [ 115 | DecoderBlock(in_ch, skip_ch, out_ch, stride, **kwargs) 116 | for in_ch, skip_ch, out_ch, stride in zip(in_channels, skip_channels, out_channels, strides[::-1]) 117 | ] 118 | self.blocks = nn.ModuleList(blocks) 119 | 120 | def forward(self, *features): 121 | 122 | features = features[1:] # remove first skip with same spatial resolution 123 | features = features[::-1] # reverse channels to start from head of encoder 124 | 125 | head = features[0] 126 | skips = features[1:] 127 | 128 | x = self.center(head) 129 | for i, decoder_block in enumerate(self.blocks): 130 | skip = skips[i] if i < len(skips) else None 131 | x = decoder_block(x, skip) 132 | 133 | return x 134 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/decoders/unet/model.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union, List 2 | 3 | from segmentation_models_pytorch_3d.encoders import get_encoder 4 | from segmentation_models_pytorch_3d.base import ( 5 | SegmentationModel, 6 | SegmentationHead, 7 | ClassificationHead, 8 | ) 9 | from .decoder import UnetDecoder 10 | 11 | 12 | class Unet(SegmentationModel): 13 | """Unet_ is a fully convolution neural network for image semantic segmentation. Consist of *encoder* 14 | and *decoder* parts connected with *skip connections*. Encoder extract features of different spatial 15 | resolution (skip connections) which are used by decoder to define accurate segmentation mask. Use *concatenation* 16 | for fusing decoder blocks with skip connections. 17 | 18 | Args: 19 | encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) 20 | to extract features of different spatial resolution 21 | encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features 22 | two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features 23 | with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). 24 | Default is 5 25 | encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and 26 | other pretrained weights (see table with available weights for each encoder_name) 27 | decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder. 28 | Length of the list should be the same as **encoder_depth** 29 | decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers 30 | is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. 31 | Available options are **True, False, "inplace"** 32 | decoder_attention_type: Attention module used in decoder of the model. Available options are 33 | **None** and **scse** (https://arxiv.org/abs/1808.08127). 34 | in_channels: A number of input channels for the model, default is 3 (RGB images) 35 | classes: A number of classes for output mask (or you can think as a number of channels of output mask) 36 | activation: An activation function to apply after the final convolution layer. 37 | Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, 38 | **callable** and **None**. 39 | Default is **None** 40 | aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build 41 | on top of encoder if **aux_params** is not **None** (default). Supported params: 42 | - classes (int): A number of classes 43 | - pooling (str): One of "max", "avg". Default is "avg" 44 | - dropout (float): Dropout factor in [0, 1) 45 | - activation (str): An activation function to apply "sigmoid"/"softmax" 46 | (could be **None** to return logits) 47 | 48 | Returns: 49 | ``torch.nn.Module``: Unet 50 | 51 | .. _Unet: 52 | https://arxiv.org/abs/1505.04597 53 | 54 | """ 55 | 56 | def __init__( 57 | self, 58 | encoder_name: str = "resnet34", 59 | encoder_depth: int = 5, 60 | encoder_weights: Optional[str] = "imagenet", 61 | decoder_use_batchnorm: bool = True, 62 | decoder_channels: List[int] = (256, 128, 64, 32, 16), 63 | decoder_attention_type: Optional[str] = None, 64 | in_channels: int = 3, 65 | classes: int = 1, 66 | activation: Optional[Union[str, callable]] = None, 67 | aux_params: Optional[dict] = None, 68 | strides=((2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2)) 69 | ): 70 | super().__init__() 71 | 72 | self.encoder = get_encoder( 73 | encoder_name, 74 | in_channels=in_channels, 75 | depth=encoder_depth, 76 | weights=encoder_weights, 77 | strides=strides, 78 | ) 79 | 80 | self.decoder = UnetDecoder( 81 | encoder_channels=self.encoder.out_channels, 82 | decoder_channels=decoder_channels, 83 | n_blocks=encoder_depth, 84 | use_batchnorm=decoder_use_batchnorm, 85 | center=True if encoder_name.startswith("vgg") else False, 86 | attention_type=decoder_attention_type, 87 | strides=strides, 88 | ) 89 | 90 | self.segmentation_head = SegmentationHead( 91 | in_channels=decoder_channels[-1], 92 | out_channels=classes, 93 | activation=activation, 94 | kernel_size=3, 95 | ) 96 | 97 | if aux_params is not None: 98 | self.classification_head = ClassificationHead(in_channels=self.encoder.out_channels[-1], **aux_params) 99 | else: 100 | self.classification_head = None 101 | 102 | self.name = "u-{}".format(encoder_name) 103 | self.initialize() 104 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/decoders/unetplusplus/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import UnetPlusPlus 2 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/decoders/unetplusplus/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from segmentation_models_pytorch_3d.base import modules as md 6 | 7 | 8 | class DecoderBlock(nn.Module): 9 | def __init__( 10 | self, 11 | in_channels, 12 | skip_channels, 13 | out_channels, 14 | use_batchnorm=True, 15 | attention_type=None, 16 | ): 17 | super().__init__() 18 | self.conv1 = md.Conv3dReLU( 19 | in_channels + skip_channels, 20 | out_channels, 21 | kernel_size=3, 22 | padding=1, 23 | use_batchnorm=use_batchnorm, 24 | ) 25 | self.attention1 = md.Attention(attention_type, in_channels=in_channels + skip_channels) 26 | self.conv2 = md.Conv3dReLU( 27 | out_channels, 28 | out_channels, 29 | kernel_size=3, 30 | padding=1, 31 | use_batchnorm=use_batchnorm, 32 | ) 33 | self.attention2 = md.Attention(attention_type, in_channels=out_channels) 34 | 35 | def forward(self, x, skip=None): 36 | x = F.interpolate(x, scale_factor=2, mode="nearest") 37 | if skip is not None: 38 | x = torch.cat([x, skip], dim=1) 39 | x = self.attention1(x) 40 | x = self.conv1(x) 41 | x = self.conv2(x) 42 | x = self.attention2(x) 43 | return x 44 | 45 | 46 | class CenterBlock(nn.Sequential): 47 | def __init__(self, in_channels, out_channels, use_batchnorm=True): 48 | conv1 = md.Conv3dReLU( 49 | in_channels, 50 | out_channels, 51 | kernel_size=3, 52 | padding=1, 53 | use_batchnorm=use_batchnorm, 54 | ) 55 | conv2 = md.Conv3dReLU( 56 | out_channels, 57 | out_channels, 58 | kernel_size=3, 59 | padding=1, 60 | use_batchnorm=use_batchnorm, 61 | ) 62 | super().__init__(conv1, conv2) 63 | 64 | 65 | class UnetPlusPlusDecoder(nn.Module): 66 | def __init__( 67 | self, 68 | encoder_channels, 69 | decoder_channels, 70 | n_blocks=5, 71 | use_batchnorm=True, 72 | attention_type=None, 73 | center=False, 74 | ): 75 | super().__init__() 76 | 77 | if n_blocks != len(decoder_channels): 78 | raise ValueError( 79 | "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format( 80 | n_blocks, len(decoder_channels) 81 | ) 82 | ) 83 | 84 | # remove first skip with same spatial resolution 85 | encoder_channels = encoder_channels[1:] 86 | # reverse channels to start from head of encoder 87 | encoder_channels = encoder_channels[::-1] 88 | 89 | # computing blocks input and output channels 90 | head_channels = encoder_channels[0] 91 | self.in_channels = [head_channels] + list(decoder_channels[:-1]) 92 | self.skip_channels = list(encoder_channels[1:]) + [0] 93 | self.out_channels = decoder_channels 94 | if center: 95 | self.center = CenterBlock(head_channels, head_channels, use_batchnorm=use_batchnorm) 96 | else: 97 | self.center = nn.Identity() 98 | 99 | # combine decoder keyword arguments 100 | kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type) 101 | 102 | blocks = {} 103 | for layer_idx in range(len(self.in_channels) - 1): 104 | for depth_idx in range(layer_idx + 1): 105 | if depth_idx == 0: 106 | in_ch = self.in_channels[layer_idx] 107 | skip_ch = self.skip_channels[layer_idx] * (layer_idx + 1) 108 | out_ch = self.out_channels[layer_idx] 109 | else: 110 | out_ch = self.skip_channels[layer_idx] 111 | skip_ch = self.skip_channels[layer_idx] * (layer_idx + 1 - depth_idx) 112 | in_ch = self.skip_channels[layer_idx - 1] 113 | blocks[f"x_{depth_idx}_{layer_idx}"] = DecoderBlock(in_ch, skip_ch, out_ch, **kwargs) 114 | blocks[f"x_{0}_{len(self.in_channels)-1}"] = DecoderBlock( 115 | self.in_channels[-1], 0, self.out_channels[-1], **kwargs 116 | ) 117 | self.blocks = nn.ModuleDict(blocks) 118 | self.depth = len(self.in_channels) - 1 119 | 120 | def forward(self, *features): 121 | 122 | features = features[1:] # remove first skip with same spatial resolution 123 | features = features[::-1] # reverse channels to start from head of encoder 124 | # start building dense connections 125 | dense_x = {} 126 | for layer_idx in range(len(self.in_channels) - 1): 127 | for depth_idx in range(self.depth - layer_idx): 128 | if layer_idx == 0: 129 | output = self.blocks[f"x_{depth_idx}_{depth_idx}"](features[depth_idx], features[depth_idx + 1]) 130 | dense_x[f"x_{depth_idx}_{depth_idx}"] = output 131 | else: 132 | dense_l_i = depth_idx + layer_idx 133 | cat_features = [dense_x[f"x_{idx}_{dense_l_i}"] for idx in range(depth_idx + 1, dense_l_i + 1)] 134 | cat_features = torch.cat(cat_features + [features[dense_l_i + 1]], dim=1) 135 | dense_x[f"x_{depth_idx}_{dense_l_i}"] = self.blocks[f"x_{depth_idx}_{dense_l_i}"]( 136 | dense_x[f"x_{depth_idx}_{dense_l_i-1}"], cat_features 137 | ) 138 | dense_x[f"x_{0}_{self.depth}"] = self.blocks[f"x_{0}_{self.depth}"](dense_x[f"x_{0}_{self.depth-1}"]) 139 | return dense_x[f"x_{0}_{self.depth}"] 140 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/decoders/unetplusplus/model.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union, List 2 | 3 | from segmentation_models_pytorch_3d.encoders import get_encoder 4 | from segmentation_models_pytorch_3d.base import ( 5 | SegmentationModel, 6 | SegmentationHead, 7 | ClassificationHead, 8 | ) 9 | from .decoder import UnetPlusPlusDecoder 10 | 11 | 12 | class UnetPlusPlus(SegmentationModel): 13 | """Unet++ is a fully convolution neural network for image semantic segmentation. Consist of *encoder* 14 | and *decoder* parts connected with *skip connections*. Encoder extract features of different spatial 15 | resolution (skip connections) which are used by decoder to define accurate segmentation mask. Decoder of 16 | Unet++ is more complex than in usual Unet. 17 | 18 | Args: 19 | encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) 20 | to extract features of different spatial resolution 21 | encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features 22 | two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features 23 | with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). 24 | Default is 5 25 | encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and 26 | other pretrained weights (see table with available weights for each encoder_name) 27 | decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder. 28 | Length of the list should be the same as **encoder_depth** 29 | decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers 30 | is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. 31 | Available options are **True, False, "inplace"** 32 | decoder_attention_type: Attention module used in decoder of the model. 33 | Available options are **None** and **scse** (https://arxiv.org/abs/1808.08127). 34 | in_channels: A number of input channels for the model, default is 3 (RGB images) 35 | classes: A number of classes for output mask (or you can think as a number of channels of output mask) 36 | activation: An activation function to apply after the final convolution layer. 37 | Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, 38 | **callable** and **None**. 39 | Default is **None** 40 | aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build 41 | on top of encoder if **aux_params** is not **None** (default). Supported params: 42 | - classes (int): A number of classes 43 | - pooling (str): One of "max", "avg". Default is "avg" 44 | - dropout (float): Dropout factor in [0, 1) 45 | - activation (str): An activation function to apply "sigmoid"/"softmax" 46 | (could be **None** to return logits) 47 | 48 | Returns: 49 | ``torch.nn.Module``: **Unet++** 50 | 51 | Reference: 52 | https://arxiv.org/abs/1807.10165 53 | 54 | """ 55 | 56 | def __init__( 57 | self, 58 | encoder_name: str = "resnet34", 59 | encoder_depth: int = 5, 60 | encoder_weights: Optional[str] = "imagenet", 61 | decoder_use_batchnorm: bool = True, 62 | decoder_channels: List[int] = (256, 128, 64, 32, 16), 63 | decoder_attention_type: Optional[str] = None, 64 | in_channels: int = 3, 65 | classes: int = 1, 66 | activation: Optional[Union[str, callable]] = None, 67 | aux_params: Optional[dict] = None, 68 | ): 69 | super().__init__() 70 | 71 | if encoder_name.startswith("mit_b"): 72 | raise ValueError("UnetPlusPlus is not support encoder_name={}".format(encoder_name)) 73 | 74 | self.encoder = get_encoder( 75 | encoder_name, 76 | in_channels=in_channels, 77 | depth=encoder_depth, 78 | weights=encoder_weights, 79 | ) 80 | 81 | self.decoder = UnetPlusPlusDecoder( 82 | encoder_channels=self.encoder.out_channels, 83 | decoder_channels=decoder_channels, 84 | n_blocks=encoder_depth, 85 | use_batchnorm=decoder_use_batchnorm, 86 | center=True if encoder_name.startswith("vgg") else False, 87 | attention_type=decoder_attention_type, 88 | ) 89 | 90 | self.segmentation_head = SegmentationHead( 91 | in_channels=decoder_channels[-1], 92 | out_channels=classes, 93 | activation=activation, 94 | kernel_size=3, 95 | ) 96 | 97 | if aux_params is not None: 98 | self.classification_head = ClassificationHead(in_channels=self.encoder.out_channels[-1], **aux_params) 99 | else: 100 | self.classification_head = None 101 | 102 | self.name = "unetplusplus-{}".format(encoder_name) 103 | self.initialize() 104 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | import timm 2 | import functools 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | from .resnet import resnet_encoders 6 | from .dpn import dpn_encoders 7 | from .vgg import vgg_encoders 8 | from .senet import senet_encoders 9 | from .densenet import densenet_encoders 10 | from .inceptionresnetv2 import inceptionresnetv2_encoders 11 | from .inceptionv4 import inceptionv4_encoders 12 | from .efficientnet import efficient_net_encoders 13 | from .mobilenet import mobilenet_encoders 14 | from .xception import xception_encoders 15 | from .timm_efficientnet import timm_efficientnet_encoders 16 | from .timm_resnest import timm_resnest_encoders 17 | from .timm_res2net import timm_res2net_encoders 18 | from .timm_regnet import timm_regnet_encoders 19 | from .timm_sknet import timm_sknet_encoders 20 | from .timm_mobilenetv3 import timm_mobilenetv3_encoders 21 | from .timm_gernet import timm_gernet_encoders 22 | from .mix_transformer import mix_transformer_encoders 23 | from .mobileone import mobileone_encoders 24 | 25 | from .timm_universal import TimmUniversalEncoder 26 | 27 | from ._preprocessing import preprocess_input 28 | 29 | encoders = {} 30 | encoders.update(resnet_encoders) 31 | encoders.update(dpn_encoders) 32 | encoders.update(vgg_encoders) 33 | encoders.update(senet_encoders) 34 | encoders.update(densenet_encoders) 35 | encoders.update(inceptionresnetv2_encoders) 36 | encoders.update(inceptionv4_encoders) 37 | encoders.update(efficient_net_encoders) 38 | encoders.update(mobilenet_encoders) 39 | encoders.update(xception_encoders) 40 | encoders.update(timm_efficientnet_encoders) 41 | encoders.update(timm_resnest_encoders) 42 | encoders.update(timm_res2net_encoders) 43 | encoders.update(timm_regnet_encoders) 44 | encoders.update(timm_sknet_encoders) 45 | encoders.update(timm_mobilenetv3_encoders) 46 | encoders.update(timm_gernet_encoders) 47 | encoders.update(mix_transformer_encoders) 48 | encoders.update(mobileone_encoders) 49 | 50 | 51 | def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, strides=((2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2)), **kwargs): 52 | 53 | if name.startswith("tu-"): 54 | name = name[3:] 55 | encoder = TimmUniversalEncoder( 56 | name=name, 57 | in_channels=in_channels, 58 | depth=depth, 59 | output_stride=output_stride, 60 | pretrained=weights is not None, 61 | **kwargs, 62 | ) 63 | return encoder 64 | 65 | try: 66 | Encoder = encoders[name]["encoder"] 67 | except KeyError: 68 | raise KeyError("Wrong encoder name `{}`, supported encoders: {}".format(name, list(encoders.keys()))) 69 | 70 | params = encoders[name]["params"] 71 | params.update(depth=depth) 72 | params.update(strides=strides) 73 | encoder = Encoder(**params) 74 | 75 | if weights is not None: 76 | try: 77 | settings = encoders[name]["pretrained_settings"][weights] 78 | except KeyError: 79 | raise KeyError( 80 | "Wrong pretrained weights `{}` for encoder `{}`. Available options are: {}".format( 81 | weights, 82 | name, 83 | list(encoders[name]["pretrained_settings"].keys()), 84 | ) 85 | ) 86 | state_dict = model_zoo.load_url(settings["url"], map_location='cpu') 87 | try: 88 | from segmentation_models_pytorch_3d.utils.convert_weights import convert_2d_weights_to_3d 89 | state_dict = convert_2d_weights_to_3d(state_dict) 90 | except Exception as e: 91 | print('Can\'t convert. Exception: {}'.format(e)) 92 | pass 93 | encoder.load_state_dict(state_dict) 94 | 95 | encoder.set_in_channels(in_channels, pretrained=weights is not None) 96 | if output_stride != 32: 97 | encoder.make_dilated(output_stride) 98 | 99 | return encoder 100 | 101 | 102 | def get_encoder_names(): 103 | return list(encoders.keys()) 104 | 105 | 106 | def get_preprocessing_params(encoder_name, pretrained="imagenet"): 107 | 108 | if encoder_name.startswith("tu-"): 109 | encoder_name = encoder_name[3:] 110 | if not timm.models.is_model_pretrained(encoder_name): 111 | raise ValueError(f"{encoder_name} does not have pretrained weights and preprocessing parameters") 112 | settings = timm.models.get_pretrained_cfg(encoder_name).__dict__ 113 | else: 114 | all_settings = encoders[encoder_name]["pretrained_settings"] 115 | if pretrained not in all_settings.keys(): 116 | raise ValueError("Available pretrained options {}".format(all_settings.keys())) 117 | settings = all_settings[pretrained] 118 | 119 | formatted_settings = {} 120 | formatted_settings["input_space"] = settings.get("input_space", "RGB") 121 | formatted_settings["input_range"] = list(settings.get("input_range", [0, 1])) 122 | formatted_settings["mean"] = list(settings["mean"]) 123 | formatted_settings["std"] = list(settings["std"]) 124 | 125 | return formatted_settings 126 | 127 | 128 | def get_preprocessing_fn(encoder_name, pretrained="imagenet"): 129 | params = get_preprocessing_params(encoder_name, pretrained=pretrained) 130 | return functools.partial(preprocess_input, **params) 131 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/encoders/_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import List 4 | from collections import OrderedDict 5 | 6 | from . import _utils as utils 7 | 8 | 9 | class EncoderMixin: 10 | """Add encoder functionality such as: 11 | - output channels specification of feature tensors (produced by encoder) 12 | - patching first convolution for arbitrary input channels 13 | """ 14 | 15 | _output_stride = 32 16 | 17 | @property 18 | def out_channels(self): 19 | """Return channels dimensions for each tensor of forward output of encoder""" 20 | return self._out_channels[: self._depth + 1] 21 | 22 | @property 23 | def output_stride(self): 24 | return min(self._output_stride, 2**self._depth) 25 | 26 | def set_in_channels(self, in_channels, pretrained=True): 27 | """Change first convolution channels""" 28 | if in_channels == 3: 29 | return 30 | 31 | self._in_channels = in_channels 32 | if self._out_channels[0] == 3: 33 | self._out_channels = tuple([in_channels] + list(self._out_channels)[1:]) 34 | 35 | utils.patch_first_conv(model=self, new_in_channels=in_channels, pretrained=pretrained) 36 | 37 | def get_stages(self): 38 | """Override it in your implementation""" 39 | raise NotImplementedError 40 | 41 | def make_dilated(self, output_stride): 42 | 43 | if output_stride == 16: 44 | stage_list = [ 45 | 5, 46 | ] 47 | dilation_list = [ 48 | 2, 49 | ] 50 | 51 | elif output_stride == 8: 52 | stage_list = [4, 5] 53 | dilation_list = [2, 4] 54 | 55 | else: 56 | raise ValueError("Output stride should be 16 or 8, got {}.".format(output_stride)) 57 | 58 | self._output_stride = output_stride 59 | 60 | stages = self.get_stages() 61 | for stage_indx, dilation_rate in zip(stage_list, dilation_list): 62 | utils.replace_strides_with_dilation( 63 | module=stages[stage_indx], 64 | dilation_rate=dilation_rate, 65 | ) 66 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/encoders/_preprocessing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def preprocess_input(x, mean=None, std=None, input_space="RGB", input_range=None, **kwargs): 5 | 6 | if input_space == "BGR": 7 | x = x[..., ::-1].copy() 8 | 9 | if input_range is not None: 10 | if x.max() > 1 and input_range[1] == 1: 11 | x = x / 255.0 12 | 13 | if mean is not None: 14 | mean = np.array(mean) 15 | x = x - mean 16 | 17 | if std is not None: 18 | std = np.array(std) 19 | x = x / std 20 | 21 | return x 22 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/encoders/_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def patch_first_conv(model, new_in_channels, default_in_channels=3, pretrained=True): 6 | """Change first convolution layer input channels. 7 | In case: 8 | in_channels == 1 or in_channels == 2 -> reuse original weights 9 | in_channels > 3 -> make random kaiming normal initialization 10 | """ 11 | 12 | # get first conv 13 | for module in model.modules(): 14 | if isinstance(module, nn.Conv3d) and module.in_channels == default_in_channels: 15 | break 16 | 17 | weight = module.weight.detach() 18 | module.in_channels = new_in_channels 19 | 20 | if not pretrained: 21 | module.weight = nn.parameter.Parameter( 22 | torch.Tensor(module.out_channels, new_in_channels // module.groups, *module.kernel_size) 23 | ) 24 | module.reset_parameters() 25 | 26 | elif new_in_channels == 1: 27 | new_weight = weight.sum(1, keepdim=True) 28 | module.weight = nn.parameter.Parameter(new_weight) 29 | 30 | else: 31 | new_weight = torch.Tensor(module.out_channels, new_in_channels // module.groups, *module.kernel_size) 32 | 33 | for i in range(new_in_channels): 34 | new_weight[:, i] = weight[:, i % default_in_channels] 35 | 36 | new_weight = new_weight * (default_in_channels / new_in_channels) 37 | module.weight = nn.parameter.Parameter(new_weight) 38 | 39 | 40 | def replace_strides_with_dilation(module, dilation_rate): 41 | """Patch Conv3d modules replacing strides with dilation""" 42 | for mod in module.modules(): 43 | if isinstance(mod, nn.Conv3d): 44 | mod.stride = (1, 1, 1) 45 | mod.dilation = (dilation_rate, dilation_rate, dilation_rate) 46 | kh, kw, kd = mod.kernel_size 47 | mod.padding = ((kh // 2) * dilation_rate, (kh // 2) * dilation_rate, (kd // 2) * dilation_rate) 48 | 49 | # Kostyl for EfficientNet 50 | if hasattr(mod, "static_padding"): 51 | mod.static_padding = nn.Identity() 52 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/encoders/inceptionresnetv2.py: -------------------------------------------------------------------------------- 1 | """Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` 2 | 3 | Attributes: 4 | 5 | _out_channels (list of int): specify number of channels for each encoder feature tensor 6 | _depth (int): specify number of stages in decoder (in other words number of downsampling operations) 7 | _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) 8 | 9 | Methods: 10 | 11 | forward(self, x: torch.Tensor) 12 | produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of 13 | shape NCHW (features should be sorted in descending order according to spatial resolution, starting 14 | with resolution same as input `x` tensor). 15 | 16 | Input: `x` with shape (1, 3, 64, 64) 17 | Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes 18 | [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), 19 | (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) 20 | 21 | also should support number of features according to specified depth, e.g. if depth = 5, 22 | number of feature tensors = 6 (one with same resolution as input and 5 downsampled), 23 | depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). 24 | """ 25 | 26 | import torch.nn as nn 27 | from pretrainedmodels.models.inceptionresnetv2 import InceptionResNetV2 28 | from pretrainedmodels.models.inceptionresnetv2 import pretrained_settings 29 | 30 | from ._base import EncoderMixin 31 | 32 | 33 | class InceptionResNetV2Encoder(InceptionResNetV2, EncoderMixin): 34 | def __init__(self, out_channels, depth=5, **kwargs): 35 | super().__init__(**kwargs) 36 | 37 | self._out_channels = out_channels 38 | self._depth = depth 39 | self._in_channels = 3 40 | 41 | # correct paddings 42 | for m in self.modules(): 43 | if isinstance(m, nn.Conv2d): 44 | if m.kernel_size == (3, 3): 45 | m.padding = (1, 1) 46 | if isinstance(m, nn.MaxPool2d): 47 | m.padding = (1, 1) 48 | 49 | # remove linear layers 50 | del self.avgpool_1a 51 | del self.last_linear 52 | 53 | def make_dilated(self, *args, **kwargs): 54 | raise ValueError( 55 | "InceptionResnetV2 encoder does not support dilated mode " "due to pooling operation for downsampling!" 56 | ) 57 | 58 | def get_stages(self): 59 | return [ 60 | nn.Identity(), 61 | nn.Sequential(self.conv2d_1a, self.conv2d_2a, self.conv2d_2b), 62 | nn.Sequential(self.maxpool_3a, self.conv2d_3b, self.conv2d_4a), 63 | nn.Sequential(self.maxpool_5a, self.mixed_5b, self.repeat), 64 | nn.Sequential(self.mixed_6a, self.repeat_1), 65 | nn.Sequential(self.mixed_7a, self.repeat_2, self.block8, self.conv2d_7b), 66 | ] 67 | 68 | def forward(self, x): 69 | 70 | stages = self.get_stages() 71 | 72 | features = [] 73 | for i in range(self._depth + 1): 74 | x = stages[i](x) 75 | features.append(x) 76 | 77 | return features 78 | 79 | def load_state_dict(self, state_dict, **kwargs): 80 | state_dict.pop("last_linear.bias", None) 81 | state_dict.pop("last_linear.weight", None) 82 | super().load_state_dict(state_dict, **kwargs) 83 | 84 | 85 | inceptionresnetv2_encoders = { 86 | "inceptionresnetv2": { 87 | "encoder": InceptionResNetV2Encoder, 88 | "pretrained_settings": pretrained_settings["inceptionresnetv2"], 89 | "params": {"out_channels": (3, 64, 192, 320, 1088, 1536), "num_classes": 1000}, 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/encoders/inceptionv4.py: -------------------------------------------------------------------------------- 1 | """Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` 2 | 3 | Attributes: 4 | 5 | _out_channels (list of int): specify number of channels for each encoder feature tensor 6 | _depth (int): specify number of stages in decoder (in other words number of downsampling operations) 7 | _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) 8 | 9 | Methods: 10 | 11 | forward(self, x: torch.Tensor) 12 | produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of 13 | shape NCHW (features should be sorted in descending order according to spatial resolution, starting 14 | with resolution same as input `x` tensor). 15 | 16 | Input: `x` with shape (1, 3, 64, 64) 17 | Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes 18 | [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), 19 | (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) 20 | 21 | also should support number of features according to specified depth, e.g. if depth = 5, 22 | number of feature tensors = 6 (one with same resolution as input and 5 downsampled), 23 | depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). 24 | """ 25 | 26 | import torch.nn as nn 27 | from pretrainedmodels.models.inceptionv4 import InceptionV4, BasicConv2d 28 | from pretrainedmodels.models.inceptionv4 import pretrained_settings 29 | 30 | from ._base import EncoderMixin 31 | 32 | 33 | class InceptionV4Encoder(InceptionV4, EncoderMixin): 34 | def __init__(self, stage_idxs, out_channels, depth=5, **kwargs): 35 | super().__init__(**kwargs) 36 | self._stage_idxs = stage_idxs 37 | self._out_channels = out_channels 38 | self._depth = depth 39 | self._in_channels = 3 40 | 41 | # correct paddings 42 | for m in self.modules(): 43 | if isinstance(m, nn.Conv2d): 44 | if m.kernel_size == (3, 3): 45 | m.padding = (1, 1) 46 | if isinstance(m, nn.MaxPool2d): 47 | m.padding = (1, 1) 48 | 49 | # remove linear layers 50 | del self.last_linear 51 | 52 | def make_dilated(self, stage_list, dilation_list): 53 | raise ValueError( 54 | "InceptionV4 encoder does not support dilated mode " "due to pooling operation for downsampling!" 55 | ) 56 | 57 | def get_stages(self): 58 | return [ 59 | nn.Identity(), 60 | self.features[: self._stage_idxs[0]], 61 | self.features[self._stage_idxs[0] : self._stage_idxs[1]], 62 | self.features[self._stage_idxs[1] : self._stage_idxs[2]], 63 | self.features[self._stage_idxs[2] : self._stage_idxs[3]], 64 | self.features[self._stage_idxs[3] :], 65 | ] 66 | 67 | def forward(self, x): 68 | 69 | stages = self.get_stages() 70 | 71 | features = [] 72 | for i in range(self._depth + 1): 73 | x = stages[i](x) 74 | features.append(x) 75 | 76 | return features 77 | 78 | def load_state_dict(self, state_dict, **kwargs): 79 | state_dict.pop("last_linear.bias", None) 80 | state_dict.pop("last_linear.weight", None) 81 | super().load_state_dict(state_dict, **kwargs) 82 | 83 | 84 | inceptionv4_encoders = { 85 | "inceptionv4": { 86 | "encoder": InceptionV4Encoder, 87 | "pretrained_settings": pretrained_settings["inceptionv4"], 88 | "params": { 89 | "stage_idxs": (3, 5, 9, 15), 90 | "out_channels": (3, 64, 192, 384, 1024, 1536), 91 | "num_classes": 1001, 92 | }, 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/encoders/mobilenet.py: -------------------------------------------------------------------------------- 1 | """Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` 2 | 3 | Attributes: 4 | 5 | _out_channels (list of int): specify number of channels for each encoder feature tensor 6 | _depth (int): specify number of stages in decoder (in other words number of downsampling operations) 7 | _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) 8 | 9 | Methods: 10 | 11 | forward(self, x: torch.Tensor) 12 | produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of 13 | shape NCHW (features should be sorted in descending order according to spatial resolution, starting 14 | with resolution same as input `x` tensor). 15 | 16 | Input: `x` with shape (1, 3, 64, 64) 17 | Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes 18 | [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), 19 | (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) 20 | 21 | also should support number of features according to specified depth, e.g. if depth = 5, 22 | number of feature tensors = 6 (one with same resolution as input and 5 downsampled), 23 | depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). 24 | """ 25 | 26 | import torchvision 27 | import torch.nn as nn 28 | 29 | from ._base import EncoderMixin 30 | 31 | 32 | class MobileNetV2Encoder(torchvision.models.MobileNetV2, EncoderMixin): 33 | def __init__(self, out_channels, depth=5, **kwargs): 34 | super().__init__(**kwargs) 35 | self._depth = depth 36 | self._out_channels = out_channels 37 | self._in_channels = 3 38 | del self.classifier 39 | 40 | def get_stages(self): 41 | return [ 42 | nn.Identity(), 43 | self.features[:2], 44 | self.features[2:4], 45 | self.features[4:7], 46 | self.features[7:14], 47 | self.features[14:], 48 | ] 49 | 50 | def forward(self, x): 51 | stages = self.get_stages() 52 | 53 | features = [] 54 | for i in range(self._depth + 1): 55 | x = stages[i](x) 56 | features.append(x) 57 | 58 | return features 59 | 60 | def load_state_dict(self, state_dict, **kwargs): 61 | state_dict.pop("classifier.1.bias", None) 62 | state_dict.pop("classifier.1.weight", None) 63 | super().load_state_dict(state_dict, **kwargs) 64 | 65 | 66 | mobilenet_encoders = { 67 | "mobilenet_v2": { 68 | "encoder": MobileNetV2Encoder, 69 | "pretrained_settings": { 70 | "imagenet": { 71 | "mean": [0.485, 0.456, 0.406], 72 | "std": [0.229, 0.224, 0.225], 73 | "url": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", 74 | "input_space": "RGB", 75 | "input_range": [0, 1], 76 | }, 77 | }, 78 | "params": { 79 | "out_channels": (3, 16, 24, 32, 96, 1280), 80 | }, 81 | }, 82 | } 83 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/encoders/senet.py: -------------------------------------------------------------------------------- 1 | """Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` 2 | 3 | Attributes: 4 | 5 | _out_channels (list of int): specify number of channels for each encoder feature tensor 6 | _depth (int): specify number of stages in decoder (in other words number of downsampling operations) 7 | _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) 8 | 9 | Methods: 10 | 11 | forward(self, x: torch.Tensor) 12 | produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of 13 | shape NCHW (features should be sorted in descending order according to spatial resolution, starting 14 | with resolution same as input `x` tensor). 15 | 16 | Input: `x` with shape (1, 3, 64, 64) 17 | Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes 18 | [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), 19 | (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) 20 | 21 | also should support number of features according to specified depth, e.g. if depth = 5, 22 | number of feature tensors = 6 (one with same resolution as input and 5 downsampled), 23 | depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). 24 | """ 25 | 26 | import torch.nn as nn 27 | 28 | from pretrainedmodels.models.senet import ( 29 | SENet, 30 | SEBottleneck, 31 | SEResNetBottleneck, 32 | SEResNeXtBottleneck, 33 | pretrained_settings, 34 | ) 35 | from ._base import EncoderMixin 36 | 37 | 38 | class SENetEncoder(SENet, EncoderMixin): 39 | def __init__(self, out_channels, depth=5, **kwargs): 40 | super().__init__(**kwargs) 41 | 42 | self._out_channels = out_channels 43 | self._depth = depth 44 | self._in_channels = 3 45 | 46 | del self.last_linear 47 | del self.avg_pool 48 | 49 | def get_stages(self): 50 | return [ 51 | nn.Identity(), 52 | self.layer0[:-1], 53 | nn.Sequential(self.layer0[-1], self.layer1), 54 | self.layer2, 55 | self.layer3, 56 | self.layer4, 57 | ] 58 | 59 | def forward(self, x): 60 | stages = self.get_stages() 61 | 62 | features = [] 63 | for i in range(self._depth + 1): 64 | x = stages[i](x) 65 | features.append(x) 66 | 67 | return features 68 | 69 | def load_state_dict(self, state_dict, **kwargs): 70 | state_dict.pop("last_linear.bias", None) 71 | state_dict.pop("last_linear.weight", None) 72 | super().load_state_dict(state_dict, **kwargs) 73 | 74 | 75 | senet_encoders = { 76 | "senet154": { 77 | "encoder": SENetEncoder, 78 | "pretrained_settings": pretrained_settings["senet154"], 79 | "params": { 80 | "out_channels": (3, 128, 256, 512, 1024, 2048), 81 | "block": SEBottleneck, 82 | "dropout_p": 0.2, 83 | "groups": 64, 84 | "layers": [3, 8, 36, 3], 85 | "num_classes": 1000, 86 | "reduction": 16, 87 | }, 88 | }, 89 | "se_resnet50": { 90 | "encoder": SENetEncoder, 91 | "pretrained_settings": pretrained_settings["se_resnet50"], 92 | "params": { 93 | "out_channels": (3, 64, 256, 512, 1024, 2048), 94 | "block": SEResNetBottleneck, 95 | "layers": [3, 4, 6, 3], 96 | "downsample_kernel_size": 1, 97 | "downsample_padding": 0, 98 | "dropout_p": None, 99 | "groups": 1, 100 | "inplanes": 64, 101 | "input_3x3": False, 102 | "num_classes": 1000, 103 | "reduction": 16, 104 | }, 105 | }, 106 | "se_resnet101": { 107 | "encoder": SENetEncoder, 108 | "pretrained_settings": pretrained_settings["se_resnet101"], 109 | "params": { 110 | "out_channels": (3, 64, 256, 512, 1024, 2048), 111 | "block": SEResNetBottleneck, 112 | "layers": [3, 4, 23, 3], 113 | "downsample_kernel_size": 1, 114 | "downsample_padding": 0, 115 | "dropout_p": None, 116 | "groups": 1, 117 | "inplanes": 64, 118 | "input_3x3": False, 119 | "num_classes": 1000, 120 | "reduction": 16, 121 | }, 122 | }, 123 | "se_resnet152": { 124 | "encoder": SENetEncoder, 125 | "pretrained_settings": pretrained_settings["se_resnet152"], 126 | "params": { 127 | "out_channels": (3, 64, 256, 512, 1024, 2048), 128 | "block": SEResNetBottleneck, 129 | "layers": [3, 8, 36, 3], 130 | "downsample_kernel_size": 1, 131 | "downsample_padding": 0, 132 | "dropout_p": None, 133 | "groups": 1, 134 | "inplanes": 64, 135 | "input_3x3": False, 136 | "num_classes": 1000, 137 | "reduction": 16, 138 | }, 139 | }, 140 | "se_resnext50_32x4d": { 141 | "encoder": SENetEncoder, 142 | "pretrained_settings": pretrained_settings["se_resnext50_32x4d"], 143 | "params": { 144 | "out_channels": (3, 64, 256, 512, 1024, 2048), 145 | "block": SEResNeXtBottleneck, 146 | "layers": [3, 4, 6, 3], 147 | "downsample_kernel_size": 1, 148 | "downsample_padding": 0, 149 | "dropout_p": None, 150 | "groups": 32, 151 | "inplanes": 64, 152 | "input_3x3": False, 153 | "num_classes": 1000, 154 | "reduction": 16, 155 | }, 156 | }, 157 | "se_resnext101_32x4d": { 158 | "encoder": SENetEncoder, 159 | "pretrained_settings": pretrained_settings["se_resnext101_32x4d"], 160 | "params": { 161 | "out_channels": (3, 64, 256, 512, 1024, 2048), 162 | "block": SEResNeXtBottleneck, 163 | "layers": [3, 4, 23, 3], 164 | "downsample_kernel_size": 1, 165 | "downsample_padding": 0, 166 | "dropout_p": None, 167 | "groups": 32, 168 | "inplanes": 64, 169 | "input_3x3": False, 170 | "num_classes": 1000, 171 | "reduction": 16, 172 | }, 173 | }, 174 | } 175 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/encoders/timm_gernet.py: -------------------------------------------------------------------------------- 1 | from timm.models import ByoModelCfg, ByoBlockCfg, ByobNet 2 | 3 | from ._base import EncoderMixin 4 | import torch.nn as nn 5 | 6 | 7 | class GERNetEncoder(ByobNet, EncoderMixin): 8 | def __init__(self, out_channels, depth=5, **kwargs): 9 | super().__init__(**kwargs) 10 | self._depth = depth 11 | self._out_channels = out_channels 12 | self._in_channels = 3 13 | 14 | del self.head 15 | 16 | def get_stages(self): 17 | return [ 18 | nn.Identity(), 19 | self.stem, 20 | self.stages[0], 21 | self.stages[1], 22 | self.stages[2], 23 | nn.Sequential(self.stages[3], self.stages[4], self.final_conv), 24 | ] 25 | 26 | def forward(self, x): 27 | stages = self.get_stages() 28 | 29 | features = [] 30 | for i in range(self._depth + 1): 31 | x = stages[i](x) 32 | features.append(x) 33 | 34 | return features 35 | 36 | def load_state_dict(self, state_dict, **kwargs): 37 | state_dict.pop("head.fc.weight", None) 38 | state_dict.pop("head.fc.bias", None) 39 | super().load_state_dict(state_dict, **kwargs) 40 | 41 | 42 | regnet_weights = { 43 | "timm-gernet_s": { 44 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_s-756b4751.pth", # noqa 45 | }, 46 | "timm-gernet_m": { 47 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_m-0873c53a.pth", # noqa 48 | }, 49 | "timm-gernet_l": { 50 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_l-f31e2e8d.pth", # noqa 51 | }, 52 | } 53 | 54 | pretrained_settings = {} 55 | for model_name, sources in regnet_weights.items(): 56 | pretrained_settings[model_name] = {} 57 | for source_name, source_url in sources.items(): 58 | pretrained_settings[model_name][source_name] = { 59 | "url": source_url, 60 | "input_range": [0, 1], 61 | "mean": [0.485, 0.456, 0.406], 62 | "std": [0.229, 0.224, 0.225], 63 | "num_classes": 1000, 64 | } 65 | 66 | timm_gernet_encoders = { 67 | "timm-gernet_s": { 68 | "encoder": GERNetEncoder, 69 | "pretrained_settings": pretrained_settings["timm-gernet_s"], 70 | "params": { 71 | "out_channels": (3, 13, 48, 48, 384, 1920), 72 | "cfg": ByoModelCfg( 73 | blocks=( 74 | ByoBlockCfg(type="basic", d=1, c=48, s=2, gs=0, br=1.0), 75 | ByoBlockCfg(type="basic", d=3, c=48, s=2, gs=0, br=1.0), 76 | ByoBlockCfg(type="bottle", d=7, c=384, s=2, gs=0, br=1 / 4), 77 | ByoBlockCfg(type="bottle", d=2, c=560, s=2, gs=1, br=3.0), 78 | ByoBlockCfg(type="bottle", d=1, c=256, s=1, gs=1, br=3.0), 79 | ), 80 | stem_chs=13, 81 | stem_pool=None, 82 | num_features=1920, 83 | ), 84 | }, 85 | }, 86 | "timm-gernet_m": { 87 | "encoder": GERNetEncoder, 88 | "pretrained_settings": pretrained_settings["timm-gernet_m"], 89 | "params": { 90 | "out_channels": (3, 32, 128, 192, 640, 2560), 91 | "cfg": ByoModelCfg( 92 | blocks=( 93 | ByoBlockCfg(type="basic", d=1, c=128, s=2, gs=0, br=1.0), 94 | ByoBlockCfg(type="basic", d=2, c=192, s=2, gs=0, br=1.0), 95 | ByoBlockCfg(type="bottle", d=6, c=640, s=2, gs=0, br=1 / 4), 96 | ByoBlockCfg(type="bottle", d=4, c=640, s=2, gs=1, br=3.0), 97 | ByoBlockCfg(type="bottle", d=1, c=640, s=1, gs=1, br=3.0), 98 | ), 99 | stem_chs=32, 100 | stem_pool=None, 101 | num_features=2560, 102 | ), 103 | }, 104 | }, 105 | "timm-gernet_l": { 106 | "encoder": GERNetEncoder, 107 | "pretrained_settings": pretrained_settings["timm-gernet_l"], 108 | "params": { 109 | "out_channels": (3, 32, 128, 192, 640, 2560), 110 | "cfg": ByoModelCfg( 111 | blocks=( 112 | ByoBlockCfg(type="basic", d=1, c=128, s=2, gs=0, br=1.0), 113 | ByoBlockCfg(type="basic", d=2, c=192, s=2, gs=0, br=1.0), 114 | ByoBlockCfg(type="bottle", d=6, c=640, s=2, gs=0, br=1 / 4), 115 | ByoBlockCfg(type="bottle", d=5, c=640, s=2, gs=1, br=3.0), 116 | ByoBlockCfg(type="bottle", d=4, c=640, s=1, gs=1, br=3.0), 117 | ), 118 | stem_chs=32, 119 | stem_pool=None, 120 | num_features=2560, 121 | ), 122 | }, 123 | }, 124 | } 125 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/encoders/timm_mobilenetv3.py: -------------------------------------------------------------------------------- 1 | import timm 2 | import numpy as np 3 | import torch.nn as nn 4 | 5 | from ._base import EncoderMixin 6 | 7 | 8 | def _make_divisible(x, divisible_by=8): 9 | return int(np.ceil(x * 1.0 / divisible_by) * divisible_by) 10 | 11 | 12 | class MobileNetV3Encoder(nn.Module, EncoderMixin): 13 | def __init__(self, model_name, width_mult, depth=5, **kwargs): 14 | super().__init__() 15 | if "large" not in model_name and "small" not in model_name: 16 | raise ValueError("MobileNetV3 wrong model name {}".format(model_name)) 17 | 18 | self._mode = "small" if "small" in model_name else "large" 19 | self._depth = depth 20 | self._out_channels = self._get_channels(self._mode, width_mult) 21 | self._in_channels = 3 22 | 23 | # minimal models replace hardswish with relu 24 | self.model = timm.create_model( 25 | model_name=model_name, 26 | scriptable=True, # torch.jit scriptable 27 | exportable=True, # onnx export 28 | features_only=True, 29 | ) 30 | 31 | def _get_channels(self, mode, width_mult): 32 | if mode == "small": 33 | channels = [16, 16, 24, 48, 576] 34 | else: 35 | channels = [16, 24, 40, 112, 960] 36 | channels = [ 37 | 3, 38 | ] + [_make_divisible(x * width_mult) for x in channels] 39 | return tuple(channels) 40 | 41 | def get_stages(self): 42 | if self._mode == "small": 43 | return [ 44 | nn.Identity(), 45 | nn.Sequential( 46 | self.model.conv_stem, 47 | self.model.bn1, 48 | self.model.act1, 49 | ), 50 | self.model.blocks[0], 51 | self.model.blocks[1], 52 | self.model.blocks[2:4], 53 | self.model.blocks[4:], 54 | ] 55 | elif self._mode == "large": 56 | return [ 57 | nn.Identity(), 58 | nn.Sequential( 59 | self.model.conv_stem, 60 | self.model.bn1, 61 | self.model.act1, 62 | self.model.blocks[0], 63 | ), 64 | self.model.blocks[1], 65 | self.model.blocks[2], 66 | self.model.blocks[3:5], 67 | self.model.blocks[5:], 68 | ] 69 | else: 70 | ValueError("MobileNetV3 mode should be small or large, got {}".format(self._mode)) 71 | 72 | def forward(self, x): 73 | stages = self.get_stages() 74 | 75 | features = [] 76 | for i in range(self._depth + 1): 77 | x = stages[i](x) 78 | features.append(x) 79 | 80 | return features 81 | 82 | def load_state_dict(self, state_dict, **kwargs): 83 | state_dict.pop("conv_head.weight", None) 84 | state_dict.pop("conv_head.bias", None) 85 | state_dict.pop("classifier.weight", None) 86 | state_dict.pop("classifier.bias", None) 87 | self.model.load_state_dict(state_dict, **kwargs) 88 | 89 | 90 | mobilenetv3_weights = { 91 | "tf_mobilenetv3_large_075": { 92 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth" # noqa 93 | }, 94 | "tf_mobilenetv3_large_100": { 95 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth" # noqa 96 | }, 97 | "tf_mobilenetv3_large_minimal_100": { 98 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth" # noqa 99 | }, 100 | "tf_mobilenetv3_small_075": { 101 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth" # noqa 102 | }, 103 | "tf_mobilenetv3_small_100": { 104 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth" # noqa 105 | }, 106 | "tf_mobilenetv3_small_minimal_100": { 107 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth" # noqa 108 | }, 109 | } 110 | 111 | pretrained_settings = {} 112 | for model_name, sources in mobilenetv3_weights.items(): 113 | pretrained_settings[model_name] = {} 114 | for source_name, source_url in sources.items(): 115 | pretrained_settings[model_name][source_name] = { 116 | "url": source_url, 117 | "input_range": [0, 1], 118 | "mean": [0.485, 0.456, 0.406], 119 | "std": [0.229, 0.224, 0.225], 120 | "input_space": "RGB", 121 | } 122 | 123 | 124 | timm_mobilenetv3_encoders = { 125 | "timm-mobilenetv3_large_075": { 126 | "encoder": MobileNetV3Encoder, 127 | "pretrained_settings": pretrained_settings["tf_mobilenetv3_large_075"], 128 | "params": {"model_name": "tf_mobilenetv3_large_075", "width_mult": 0.75}, 129 | }, 130 | "timm-mobilenetv3_large_100": { 131 | "encoder": MobileNetV3Encoder, 132 | "pretrained_settings": pretrained_settings["tf_mobilenetv3_large_100"], 133 | "params": {"model_name": "tf_mobilenetv3_large_100", "width_mult": 1.0}, 134 | }, 135 | "timm-mobilenetv3_large_minimal_100": { 136 | "encoder": MobileNetV3Encoder, 137 | "pretrained_settings": pretrained_settings["tf_mobilenetv3_large_minimal_100"], 138 | "params": {"model_name": "tf_mobilenetv3_large_minimal_100", "width_mult": 1.0}, 139 | }, 140 | "timm-mobilenetv3_small_075": { 141 | "encoder": MobileNetV3Encoder, 142 | "pretrained_settings": pretrained_settings["tf_mobilenetv3_small_075"], 143 | "params": {"model_name": "tf_mobilenetv3_small_075", "width_mult": 0.75}, 144 | }, 145 | "timm-mobilenetv3_small_100": { 146 | "encoder": MobileNetV3Encoder, 147 | "pretrained_settings": pretrained_settings["tf_mobilenetv3_small_100"], 148 | "params": {"model_name": "tf_mobilenetv3_small_100", "width_mult": 1.0}, 149 | }, 150 | "timm-mobilenetv3_small_minimal_100": { 151 | "encoder": MobileNetV3Encoder, 152 | "pretrained_settings": pretrained_settings["tf_mobilenetv3_small_minimal_100"], 153 | "params": {"model_name": "tf_mobilenetv3_small_minimal_100", "width_mult": 1.0}, 154 | }, 155 | } 156 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/encoders/timm_res2net.py: -------------------------------------------------------------------------------- 1 | from ._base import EncoderMixin 2 | from timm.models.resnet import ResNet 3 | from timm.models.res2net import Bottle2neck 4 | import torch.nn as nn 5 | 6 | 7 | class Res2NetEncoder(ResNet, EncoderMixin): 8 | def __init__(self, out_channels, depth=5, **kwargs): 9 | super().__init__(**kwargs) 10 | self._depth = depth 11 | self._out_channels = out_channels 12 | self._in_channels = 3 13 | 14 | del self.fc 15 | del self.global_pool 16 | 17 | def get_stages(self): 18 | return [ 19 | nn.Identity(), 20 | nn.Sequential(self.conv1, self.bn1, self.act1), 21 | nn.Sequential(self.maxpool, self.layer1), 22 | self.layer2, 23 | self.layer3, 24 | self.layer4, 25 | ] 26 | 27 | def make_dilated(self, *args, **kwargs): 28 | raise ValueError("Res2Net encoders do not support dilated mode") 29 | 30 | def forward(self, x): 31 | stages = self.get_stages() 32 | 33 | features = [] 34 | for i in range(self._depth + 1): 35 | x = stages[i](x) 36 | features.append(x) 37 | 38 | return features 39 | 40 | def load_state_dict(self, state_dict, **kwargs): 41 | state_dict.pop("fc.bias", None) 42 | state_dict.pop("fc.weight", None) 43 | super().load_state_dict(state_dict, **kwargs) 44 | 45 | 46 | res2net_weights = { 47 | "timm-res2net50_26w_4s": { 48 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_4s-06e79181.pth", # noqa 49 | }, 50 | "timm-res2net50_48w_2s": { 51 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_48w_2s-afed724a.pth", # noqa 52 | }, 53 | "timm-res2net50_14w_8s": { 54 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_14w_8s-6527dddc.pth", # noqa 55 | }, 56 | "timm-res2net50_26w_6s": { 57 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_6s-19041792.pth", # noqa 58 | }, 59 | "timm-res2net50_26w_8s": { 60 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_8s-2c7c9f12.pth", # noqa 61 | }, 62 | "timm-res2net101_26w_4s": { 63 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net101_26w_4s-02a759a1.pth", # noqa 64 | }, 65 | "timm-res2next50": { 66 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2next50_4s-6ef7e7bf.pth", # noqa 67 | }, 68 | } 69 | 70 | pretrained_settings = {} 71 | for model_name, sources in res2net_weights.items(): 72 | pretrained_settings[model_name] = {} 73 | for source_name, source_url in sources.items(): 74 | pretrained_settings[model_name][source_name] = { 75 | "url": source_url, 76 | "input_size": [3, 224, 224], 77 | "input_range": [0, 1], 78 | "mean": [0.485, 0.456, 0.406], 79 | "std": [0.229, 0.224, 0.225], 80 | "num_classes": 1000, 81 | } 82 | 83 | 84 | timm_res2net_encoders = { 85 | "timm-res2net50_26w_4s": { 86 | "encoder": Res2NetEncoder, 87 | "pretrained_settings": pretrained_settings["timm-res2net50_26w_4s"], 88 | "params": { 89 | "out_channels": (3, 64, 256, 512, 1024, 2048), 90 | "block": Bottle2neck, 91 | "layers": [3, 4, 6, 3], 92 | "base_width": 26, 93 | "block_args": {"scale": 4}, 94 | }, 95 | }, 96 | "timm-res2net101_26w_4s": { 97 | "encoder": Res2NetEncoder, 98 | "pretrained_settings": pretrained_settings["timm-res2net101_26w_4s"], 99 | "params": { 100 | "out_channels": (3, 64, 256, 512, 1024, 2048), 101 | "block": Bottle2neck, 102 | "layers": [3, 4, 23, 3], 103 | "base_width": 26, 104 | "block_args": {"scale": 4}, 105 | }, 106 | }, 107 | "timm-res2net50_26w_6s": { 108 | "encoder": Res2NetEncoder, 109 | "pretrained_settings": pretrained_settings["timm-res2net50_26w_6s"], 110 | "params": { 111 | "out_channels": (3, 64, 256, 512, 1024, 2048), 112 | "block": Bottle2neck, 113 | "layers": [3, 4, 6, 3], 114 | "base_width": 26, 115 | "block_args": {"scale": 6}, 116 | }, 117 | }, 118 | "timm-res2net50_26w_8s": { 119 | "encoder": Res2NetEncoder, 120 | "pretrained_settings": pretrained_settings["timm-res2net50_26w_8s"], 121 | "params": { 122 | "out_channels": (3, 64, 256, 512, 1024, 2048), 123 | "block": Bottle2neck, 124 | "layers": [3, 4, 6, 3], 125 | "base_width": 26, 126 | "block_args": {"scale": 8}, 127 | }, 128 | }, 129 | "timm-res2net50_48w_2s": { 130 | "encoder": Res2NetEncoder, 131 | "pretrained_settings": pretrained_settings["timm-res2net50_48w_2s"], 132 | "params": { 133 | "out_channels": (3, 64, 256, 512, 1024, 2048), 134 | "block": Bottle2neck, 135 | "layers": [3, 4, 6, 3], 136 | "base_width": 48, 137 | "block_args": {"scale": 2}, 138 | }, 139 | }, 140 | "timm-res2net50_14w_8s": { 141 | "encoder": Res2NetEncoder, 142 | "pretrained_settings": pretrained_settings["timm-res2net50_14w_8s"], 143 | "params": { 144 | "out_channels": (3, 64, 256, 512, 1024, 2048), 145 | "block": Bottle2neck, 146 | "layers": [3, 4, 6, 3], 147 | "base_width": 14, 148 | "block_args": {"scale": 8}, 149 | }, 150 | }, 151 | "timm-res2next50": { 152 | "encoder": Res2NetEncoder, 153 | "pretrained_settings": pretrained_settings["timm-res2next50"], 154 | "params": { 155 | "out_channels": (3, 64, 256, 512, 1024, 2048), 156 | "block": Bottle2neck, 157 | "layers": [3, 4, 6, 3], 158 | "base_width": 4, 159 | "cardinality": 8, 160 | "block_args": {"scale": 4}, 161 | }, 162 | }, 163 | } 164 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/encoders/timm_resnest.py: -------------------------------------------------------------------------------- 1 | from ._base import EncoderMixin 2 | from timm.models.resnet import ResNet 3 | from timm.models.resnest import ResNestBottleneck 4 | import torch.nn as nn 5 | 6 | 7 | class ResNestEncoder(ResNet, EncoderMixin): 8 | def __init__(self, out_channels, depth=5, **kwargs): 9 | super().__init__(**kwargs) 10 | self._depth = depth 11 | self._out_channels = out_channels 12 | self._in_channels = 3 13 | 14 | del self.fc 15 | del self.global_pool 16 | 17 | def get_stages(self): 18 | return [ 19 | nn.Identity(), 20 | nn.Sequential(self.conv1, self.bn1, self.act1), 21 | nn.Sequential(self.maxpool, self.layer1), 22 | self.layer2, 23 | self.layer3, 24 | self.layer4, 25 | ] 26 | 27 | def make_dilated(self, *args, **kwargs): 28 | raise ValueError("ResNest encoders do not support dilated mode") 29 | 30 | def forward(self, x): 31 | stages = self.get_stages() 32 | 33 | features = [] 34 | for i in range(self._depth + 1): 35 | x = stages[i](x) 36 | features.append(x) 37 | 38 | return features 39 | 40 | def load_state_dict(self, state_dict, **kwargs): 41 | state_dict.pop("fc.bias", None) 42 | state_dict.pop("fc.weight", None) 43 | super().load_state_dict(state_dict, **kwargs) 44 | 45 | 46 | resnest_weights = { 47 | "timm-resnest14d": { 48 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest14-9c8fe254.pth", # noqa 49 | }, 50 | "timm-resnest26d": { 51 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest26-50eb607c.pth", # noqa 52 | }, 53 | "timm-resnest50d": { 54 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50-528c19ca.pth", # noqa 55 | }, 56 | "timm-resnest101e": { 57 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest101-22405ba7.pth", # noqa 58 | }, 59 | "timm-resnest200e": { 60 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest200-75117900.pth", # noqa 61 | }, 62 | "timm-resnest269e": { 63 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest269-0cc87c48.pth", # noqa 64 | }, 65 | "timm-resnest50d_4s2x40d": { 66 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50_fast_4s2x40d-41d14ed0.pth", # noqa 67 | }, 68 | "timm-resnest50d_1s4x24d": { 69 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50_fast_1s4x24d-d4a4f76f.pth", # noqa 70 | }, 71 | } 72 | 73 | pretrained_settings = {} 74 | for model_name, sources in resnest_weights.items(): 75 | pretrained_settings[model_name] = {} 76 | for source_name, source_url in sources.items(): 77 | pretrained_settings[model_name][source_name] = { 78 | "url": source_url, 79 | "input_size": [3, 224, 224], 80 | "input_range": [0, 1], 81 | "mean": [0.485, 0.456, 0.406], 82 | "std": [0.229, 0.224, 0.225], 83 | "num_classes": 1000, 84 | } 85 | 86 | 87 | timm_resnest_encoders = { 88 | "timm-resnest14d": { 89 | "encoder": ResNestEncoder, 90 | "pretrained_settings": pretrained_settings["timm-resnest14d"], 91 | "params": { 92 | "out_channels": (3, 64, 256, 512, 1024, 2048), 93 | "block": ResNestBottleneck, 94 | "layers": [1, 1, 1, 1], 95 | "stem_type": "deep", 96 | "stem_width": 32, 97 | "avg_down": True, 98 | "base_width": 64, 99 | "cardinality": 1, 100 | "block_args": {"radix": 2, "avd": True, "avd_first": False}, 101 | }, 102 | }, 103 | "timm-resnest26d": { 104 | "encoder": ResNestEncoder, 105 | "pretrained_settings": pretrained_settings["timm-resnest26d"], 106 | "params": { 107 | "out_channels": (3, 64, 256, 512, 1024, 2048), 108 | "block": ResNestBottleneck, 109 | "layers": [2, 2, 2, 2], 110 | "stem_type": "deep", 111 | "stem_width": 32, 112 | "avg_down": True, 113 | "base_width": 64, 114 | "cardinality": 1, 115 | "block_args": {"radix": 2, "avd": True, "avd_first": False}, 116 | }, 117 | }, 118 | "timm-resnest50d": { 119 | "encoder": ResNestEncoder, 120 | "pretrained_settings": pretrained_settings["timm-resnest50d"], 121 | "params": { 122 | "out_channels": (3, 64, 256, 512, 1024, 2048), 123 | "block": ResNestBottleneck, 124 | "layers": [3, 4, 6, 3], 125 | "stem_type": "deep", 126 | "stem_width": 32, 127 | "avg_down": True, 128 | "base_width": 64, 129 | "cardinality": 1, 130 | "block_args": {"radix": 2, "avd": True, "avd_first": False}, 131 | }, 132 | }, 133 | "timm-resnest101e": { 134 | "encoder": ResNestEncoder, 135 | "pretrained_settings": pretrained_settings["timm-resnest101e"], 136 | "params": { 137 | "out_channels": (3, 128, 256, 512, 1024, 2048), 138 | "block": ResNestBottleneck, 139 | "layers": [3, 4, 23, 3], 140 | "stem_type": "deep", 141 | "stem_width": 64, 142 | "avg_down": True, 143 | "base_width": 64, 144 | "cardinality": 1, 145 | "block_args": {"radix": 2, "avd": True, "avd_first": False}, 146 | }, 147 | }, 148 | "timm-resnest200e": { 149 | "encoder": ResNestEncoder, 150 | "pretrained_settings": pretrained_settings["timm-resnest200e"], 151 | "params": { 152 | "out_channels": (3, 128, 256, 512, 1024, 2048), 153 | "block": ResNestBottleneck, 154 | "layers": [3, 24, 36, 3], 155 | "stem_type": "deep", 156 | "stem_width": 64, 157 | "avg_down": True, 158 | "base_width": 64, 159 | "cardinality": 1, 160 | "block_args": {"radix": 2, "avd": True, "avd_first": False}, 161 | }, 162 | }, 163 | "timm-resnest269e": { 164 | "encoder": ResNestEncoder, 165 | "pretrained_settings": pretrained_settings["timm-resnest269e"], 166 | "params": { 167 | "out_channels": (3, 128, 256, 512, 1024, 2048), 168 | "block": ResNestBottleneck, 169 | "layers": [3, 30, 48, 8], 170 | "stem_type": "deep", 171 | "stem_width": 64, 172 | "avg_down": True, 173 | "base_width": 64, 174 | "cardinality": 1, 175 | "block_args": {"radix": 2, "avd": True, "avd_first": False}, 176 | }, 177 | }, 178 | "timm-resnest50d_4s2x40d": { 179 | "encoder": ResNestEncoder, 180 | "pretrained_settings": pretrained_settings["timm-resnest50d_4s2x40d"], 181 | "params": { 182 | "out_channels": (3, 64, 256, 512, 1024, 2048), 183 | "block": ResNestBottleneck, 184 | "layers": [3, 4, 6, 3], 185 | "stem_type": "deep", 186 | "stem_width": 32, 187 | "avg_down": True, 188 | "base_width": 40, 189 | "cardinality": 2, 190 | "block_args": {"radix": 4, "avd": True, "avd_first": True}, 191 | }, 192 | }, 193 | "timm-resnest50d_1s4x24d": { 194 | "encoder": ResNestEncoder, 195 | "pretrained_settings": pretrained_settings["timm-resnest50d_1s4x24d"], 196 | "params": { 197 | "out_channels": (3, 64, 256, 512, 1024, 2048), 198 | "block": ResNestBottleneck, 199 | "layers": [3, 4, 6, 3], 200 | "stem_type": "deep", 201 | "stem_width": 32, 202 | "avg_down": True, 203 | "base_width": 24, 204 | "cardinality": 4, 205 | "block_args": {"radix": 1, "avd": True, "avd_first": True}, 206 | }, 207 | }, 208 | } 209 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/encoders/timm_sknet.py: -------------------------------------------------------------------------------- 1 | from ._base import EncoderMixin 2 | from timm.models.resnet import ResNet 3 | from timm.models.sknet import SelectiveKernelBottleneck, SelectiveKernelBasic 4 | import torch.nn as nn 5 | 6 | 7 | class SkNetEncoder(ResNet, EncoderMixin): 8 | def __init__(self, out_channels, depth=5, **kwargs): 9 | super().__init__(**kwargs) 10 | self._depth = depth 11 | self._out_channels = out_channels 12 | self._in_channels = 3 13 | 14 | del self.fc 15 | del self.global_pool 16 | 17 | def get_stages(self): 18 | return [ 19 | nn.Identity(), 20 | nn.Sequential(self.conv1, self.bn1, self.act1), 21 | nn.Sequential(self.maxpool, self.layer1), 22 | self.layer2, 23 | self.layer3, 24 | self.layer4, 25 | ] 26 | 27 | def forward(self, x): 28 | stages = self.get_stages() 29 | 30 | features = [] 31 | for i in range(self._depth + 1): 32 | x = stages[i](x) 33 | features.append(x) 34 | 35 | return features 36 | 37 | def load_state_dict(self, state_dict, **kwargs): 38 | state_dict.pop("fc.bias", None) 39 | state_dict.pop("fc.weight", None) 40 | super().load_state_dict(state_dict, **kwargs) 41 | 42 | 43 | sknet_weights = { 44 | "timm-skresnet18": { 45 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet18_ra-4eec2804.pth", # noqa 46 | }, 47 | "timm-skresnet34": { 48 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet34_ra-bdc0ccde.pth", # noqa 49 | }, 50 | "timm-skresnext50_32x4d": { 51 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnext50_ra-f40e40bf.pth", # noqa 52 | }, 53 | } 54 | 55 | pretrained_settings = {} 56 | for model_name, sources in sknet_weights.items(): 57 | pretrained_settings[model_name] = {} 58 | for source_name, source_url in sources.items(): 59 | pretrained_settings[model_name][source_name] = { 60 | "url": source_url, 61 | "input_size": [3, 224, 224], 62 | "input_range": [0, 1], 63 | "mean": [0.485, 0.456, 0.406], 64 | "std": [0.229, 0.224, 0.225], 65 | "num_classes": 1000, 66 | } 67 | 68 | timm_sknet_encoders = { 69 | "timm-skresnet18": { 70 | "encoder": SkNetEncoder, 71 | "pretrained_settings": pretrained_settings["timm-skresnet18"], 72 | "params": { 73 | "out_channels": (3, 64, 64, 128, 256, 512), 74 | "block": SelectiveKernelBasic, 75 | "layers": [2, 2, 2, 2], 76 | "zero_init_last": False, 77 | "block_args": {"sk_kwargs": {"rd_ratio": 1 / 8, "split_input": True}}, 78 | }, 79 | }, 80 | "timm-skresnet34": { 81 | "encoder": SkNetEncoder, 82 | "pretrained_settings": pretrained_settings["timm-skresnet34"], 83 | "params": { 84 | "out_channels": (3, 64, 64, 128, 256, 512), 85 | "block": SelectiveKernelBasic, 86 | "layers": [3, 4, 6, 3], 87 | "zero_init_last": False, 88 | "block_args": {"sk_kwargs": {"rd_ratio": 1 / 8, "split_input": True}}, 89 | }, 90 | }, 91 | "timm-skresnext50_32x4d": { 92 | "encoder": SkNetEncoder, 93 | "pretrained_settings": pretrained_settings["timm-skresnext50_32x4d"], 94 | "params": { 95 | "out_channels": (3, 64, 256, 512, 1024, 2048), 96 | "block": SelectiveKernelBottleneck, 97 | "layers": [3, 4, 6, 3], 98 | "zero_init_last": False, 99 | "cardinality": 32, 100 | "base_width": 4, 101 | }, 102 | }, 103 | } 104 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/encoders/timm_universal.py: -------------------------------------------------------------------------------- 1 | import timm_3d 2 | import torch.nn as nn 3 | 4 | 5 | class TimmUniversalEncoder(nn.Module): 6 | def __init__(self, name, pretrained=True, in_channels=3, depth=5, output_stride=32): 7 | super().__init__() 8 | kwargs = dict( 9 | in_chans=in_channels, 10 | features_only=True, 11 | output_stride=output_stride, 12 | pretrained=pretrained, 13 | out_indices=tuple(range(depth)), 14 | ) 15 | 16 | # not all models support output stride argument, drop it by default 17 | if output_stride == 32: 18 | kwargs.pop("output_stride") 19 | 20 | self.model = timm_3d.create_model(name, **kwargs) 21 | 22 | self._in_channels = in_channels 23 | self._out_channels = [ 24 | in_channels, 25 | ] + self.model.feature_info.channels() 26 | self._depth = depth 27 | self._output_stride = output_stride 28 | 29 | def forward(self, x): 30 | features = self.model(x) 31 | features = [ 32 | x, 33 | ] + features 34 | return features 35 | 36 | @property 37 | def out_channels(self): 38 | return self._out_channels 39 | 40 | @property 41 | def output_stride(self): 42 | return min(self._output_stride, 2**self._depth) 43 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/encoders/vgg.py: -------------------------------------------------------------------------------- 1 | """Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` 2 | 3 | Attributes: 4 | 5 | _out_channels (list of int): specify number of channels for each encoder feature tensor 6 | _depth (int): specify number of stages in decoder (in other words number of downsampling operations) 7 | _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) 8 | 9 | Methods: 10 | 11 | forward(self, x: torch.Tensor) 12 | produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of 13 | shape NCHW (features should be sorted in descending order according to spatial resolution, starting 14 | with resolution same as input `x` tensor). 15 | 16 | Input: `x` with shape (1, 3, 64, 64) 17 | Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes 18 | [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), 19 | (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) 20 | 21 | also should support number of features according to specified depth, e.g. if depth = 5, 22 | number of feature tensors = 6 (one with same resolution as input and 5 downsampled), 23 | depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). 24 | """ 25 | 26 | from typing import Any, cast, Dict, List, Optional, Union 27 | import torch 28 | import torch.nn as nn 29 | from pretrainedmodels.models.torchvision_models import pretrained_settings 30 | 31 | from ._base import EncoderMixin 32 | 33 | 34 | class VGG(nn.Module): 35 | def __init__( 36 | self, features: nn.Module, num_classes: int = 1000, init_weights: bool = True, dropout: float = 0.5 37 | ) -> None: 38 | super().__init__() 39 | self.features = features 40 | self.avgpool = nn.AdaptiveAvgPool3d((7, 7, 7)) 41 | self.classifier = nn.Sequential( 42 | nn.Linear(512 * 7 * 7, 4096), 43 | nn.ReLU(True), 44 | nn.Dropout(p=dropout), 45 | nn.Linear(4096, 4096), 46 | nn.ReLU(True), 47 | nn.Dropout(p=dropout), 48 | nn.Linear(4096, num_classes), 49 | ) 50 | if init_weights: 51 | for m in self.modules(): 52 | if isinstance(m, nn.Conv3d): 53 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 54 | if m.bias is not None: 55 | nn.init.constant_(m.bias, 0) 56 | elif isinstance(m, nn.BatchNorm3d): 57 | nn.init.constant_(m.weight, 1) 58 | nn.init.constant_(m.bias, 0) 59 | elif isinstance(m, nn.Linear): 60 | nn.init.normal_(m.weight, 0, 0.01) 61 | nn.init.constant_(m.bias, 0) 62 | 63 | def forward(self, x: torch.Tensor) -> torch.Tensor: 64 | x = self.features(x) 65 | x = self.avgpool(x) 66 | x = torch.flatten(x, 1) 67 | x = self.classifier(x) 68 | return x 69 | 70 | 71 | def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential: 72 | layers: List[nn.Module] = [] 73 | in_channels = 3 74 | for v in cfg: 75 | if v == "M": 76 | layers += [nn.MaxPool3d(kernel_size=2, stride=2)] 77 | else: 78 | v = cast(int, v) 79 | conv3d = nn.Conv3d(in_channels, v, kernel_size=3, padding=1) 80 | if batch_norm: 81 | layers += [conv3d, nn.BatchNorm3d(v), nn.ReLU(inplace=True)] 82 | else: 83 | layers += [conv3d, nn.ReLU(inplace=True)] 84 | in_channels = v 85 | return nn.Sequential(*layers) 86 | 87 | 88 | # fmt: off 89 | cfg = { 90 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 91 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 92 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 93 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 94 | } 95 | # fmt: on 96 | 97 | 98 | class VGGEncoder(VGG, EncoderMixin): 99 | def __init__(self, out_channels, config, batch_norm=False, depth=5, strides=((2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2)), **kwargs): 100 | super().__init__(make_layers(config, batch_norm=batch_norm), **kwargs) 101 | self._out_channels = out_channels 102 | self._depth = depth 103 | self._in_channels = 3 104 | self.strides = strides 105 | 106 | del self.classifier 107 | 108 | def make_dilated(self, *args, **kwargs): 109 | raise ValueError("'VGG' models do not support dilated mode due to Max Pooling" " operations for downsampling!") 110 | 111 | def get_stages(self): 112 | stages = [] 113 | stage_modules = [] 114 | for module in self.features: 115 | if isinstance(module, nn.MaxPool3d): 116 | stages.append(nn.Sequential(*stage_modules)) 117 | stage_modules = [] 118 | stage_modules.append(module) 119 | stages.append(nn.Sequential(*stage_modules)) 120 | return stages 121 | 122 | def forward(self, x): 123 | stages = self.get_stages() 124 | 125 | features = [] 126 | for i in range(self._depth + 1): 127 | x = stages[i](x) 128 | features.append(x) 129 | 130 | return features 131 | 132 | def load_state_dict(self, state_dict, **kwargs): 133 | keys = list(state_dict.keys()) 134 | for k in keys: 135 | if k.startswith("classifier"): 136 | state_dict.pop(k, None) 137 | super().load_state_dict(state_dict, **kwargs) 138 | 139 | 140 | vgg_encoders = { 141 | "vgg11": { 142 | "encoder": VGGEncoder, 143 | "pretrained_settings": pretrained_settings["vgg11"], 144 | "params": { 145 | "out_channels": (64, 128, 256, 512, 512, 512), 146 | "config": cfg["A"], 147 | "batch_norm": False, 148 | }, 149 | }, 150 | "vgg11_bn": { 151 | "encoder": VGGEncoder, 152 | "pretrained_settings": pretrained_settings["vgg11_bn"], 153 | "params": { 154 | "out_channels": (64, 128, 256, 512, 512, 512), 155 | "config": cfg["A"], 156 | "batch_norm": True, 157 | }, 158 | }, 159 | "vgg13": { 160 | "encoder": VGGEncoder, 161 | "pretrained_settings": pretrained_settings["vgg13"], 162 | "params": { 163 | "out_channels": (64, 128, 256, 512, 512, 512), 164 | "config": cfg["B"], 165 | "batch_norm": False, 166 | }, 167 | }, 168 | "vgg13_bn": { 169 | "encoder": VGGEncoder, 170 | "pretrained_settings": pretrained_settings["vgg13_bn"], 171 | "params": { 172 | "out_channels": (64, 128, 256, 512, 512, 512), 173 | "config": cfg["B"], 174 | "batch_norm": True, 175 | }, 176 | }, 177 | "vgg16": { 178 | "encoder": VGGEncoder, 179 | "pretrained_settings": pretrained_settings["vgg16"], 180 | "params": { 181 | "out_channels": (64, 128, 256, 512, 512, 512), 182 | "config": cfg["D"], 183 | "batch_norm": False, 184 | }, 185 | }, 186 | "vgg16_bn": { 187 | "encoder": VGGEncoder, 188 | "pretrained_settings": pretrained_settings["vgg16_bn"], 189 | "params": { 190 | "out_channels": (64, 128, 256, 512, 512, 512), 191 | "config": cfg["D"], 192 | "batch_norm": True, 193 | }, 194 | }, 195 | "vgg19": { 196 | "encoder": VGGEncoder, 197 | "pretrained_settings": pretrained_settings["vgg19"], 198 | "params": { 199 | "out_channels": (64, 128, 256, 512, 512, 512), 200 | "config": cfg["E"], 201 | "batch_norm": False, 202 | }, 203 | }, 204 | "vgg19_bn": { 205 | "encoder": VGGEncoder, 206 | "pretrained_settings": pretrained_settings["vgg19_bn"], 207 | "params": { 208 | "out_channels": (64, 128, 256, 512, 512, 512), 209 | "config": cfg["E"], 210 | "batch_norm": True, 211 | }, 212 | }, 213 | } 214 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/encoders/xception.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch.nn as nn 3 | 4 | from pretrainedmodels.models.xception import pretrained_settings 5 | from pretrainedmodels.models.xception import Xception 6 | 7 | from ._base import EncoderMixin 8 | 9 | 10 | class XceptionEncoder(Xception, EncoderMixin): 11 | def __init__(self, out_channels, *args, depth=5, **kwargs): 12 | super().__init__(*args, **kwargs) 13 | 14 | self._out_channels = out_channels 15 | self._depth = depth 16 | self._in_channels = 3 17 | 18 | # modify padding to maintain output shape 19 | self.conv1.padding = (1, 1) 20 | self.conv2.padding = (1, 1) 21 | 22 | del self.fc 23 | 24 | def make_dilated(self, *args, **kwargs): 25 | raise ValueError( 26 | "Xception encoder does not support dilated mode " "due to pooling operation for downsampling!" 27 | ) 28 | 29 | def get_stages(self): 30 | return [ 31 | nn.Identity(), 32 | nn.Sequential(self.conv1, self.bn1, self.relu, self.conv2, self.bn2, self.relu), 33 | self.block1, 34 | self.block2, 35 | nn.Sequential( 36 | self.block3, 37 | self.block4, 38 | self.block5, 39 | self.block6, 40 | self.block7, 41 | self.block8, 42 | self.block9, 43 | self.block10, 44 | self.block11, 45 | ), 46 | nn.Sequential(self.block12, self.conv3, self.bn3, self.relu, self.conv4, self.bn4), 47 | ] 48 | 49 | def forward(self, x): 50 | stages = self.get_stages() 51 | 52 | features = [] 53 | for i in range(self._depth + 1): 54 | x = stages[i](x) 55 | features.append(x) 56 | 57 | return features 58 | 59 | def load_state_dict(self, state_dict): 60 | # remove linear 61 | state_dict.pop("fc.bias", None) 62 | state_dict.pop("fc.weight", None) 63 | 64 | super().load_state_dict(state_dict) 65 | 66 | 67 | xception_encoders = { 68 | "xception": { 69 | "encoder": XceptionEncoder, 70 | "pretrained_settings": pretrained_settings["xception"], 71 | "params": { 72 | "out_channels": (3, 64, 128, 256, 728, 2048), 73 | }, 74 | }, 75 | } 76 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE 2 | 3 | from .jaccard import JaccardLoss 4 | from .dice import DiceLoss 5 | from .focal import FocalLoss 6 | from .lovasz import LovaszLoss 7 | from .soft_bce import SoftBCEWithLogitsLoss 8 | from .soft_ce import SoftCrossEntropyLoss 9 | from .tversky import TverskyLoss 10 | from .mcc import MCCLoss 11 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/losses/constants.py: -------------------------------------------------------------------------------- 1 | #: Loss binary mode suppose you are solving binary segmentation task. 2 | #: That mean yor have only one class which pixels are labled as **1**, 3 | #: the rest pixels are background and labeled as **0**. 4 | #: Target mask shape - (N, H, W), model output mask shape (N, 1, H, W). 5 | BINARY_MODE: str = "binary" 6 | 7 | #: Loss multiclass mode suppose you are solving multi-**class** segmentation task. 8 | #: That mean you have *C = 1..N* classes which have unique label values, 9 | #: classes are mutually exclusive and all pixels are labeled with theese values. 10 | #: Target mask shape - (N, H, W), model output mask shape (N, C, H, W). 11 | MULTICLASS_MODE: str = "multiclass" 12 | 13 | #: Loss multilabel mode suppose you are solving multi-**label** segmentation task. 14 | #: That mean you have *C = 1..N* classes which pixels are labeled as **1**, 15 | #: classes are not mutually exclusive and each class have its own *channel*, 16 | #: pixels in each channel which are not belong to class labeled as **0**. 17 | #: Target mask shape - (N, C, H, W), model output mask shape (N, C, H, W). 18 | MULTILABEL_MODE: str = "multilabel" 19 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/losses/dice.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.nn.modules.loss import _Loss 6 | from ._functional import soft_dice_score, to_tensor 7 | from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE 8 | 9 | __all__ = ["DiceLoss"] 10 | 11 | 12 | class DiceLoss(_Loss): 13 | def __init__( 14 | self, 15 | mode: str, 16 | classes: Optional[List[int]] = None, 17 | log_loss: bool = False, 18 | from_logits: bool = True, 19 | smooth: float = 0.0, 20 | ignore_index: Optional[int] = None, 21 | eps: float = 1e-7, 22 | ): 23 | """Dice loss for image segmentation task. 24 | It supports binary, multiclass and multilabel cases 25 | 26 | Args: 27 | mode: Loss mode 'binary', 'multiclass' or 'multilabel' 28 | classes: List of classes that contribute in loss computation. By default, all channels are included. 29 | log_loss: If True, loss computed as `- log(dice_coeff)`, otherwise `1 - dice_coeff` 30 | from_logits: If True, assumes input is raw logits 31 | smooth: Smoothness constant for dice coefficient (a) 32 | ignore_index: Label that indicates ignored pixels (does not contribute to loss) 33 | eps: A small epsilon for numerical stability to avoid zero division error 34 | (denominator will be always greater or equal to eps) 35 | 36 | Shape 37 | - **y_pred** - torch.Tensor of shape (N, C, H, W) 38 | - **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W) 39 | 40 | Reference 41 | https://github.com/BloodAxe/pytorch-toolbelt 42 | """ 43 | assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE} 44 | super(DiceLoss, self).__init__() 45 | self.mode = mode 46 | if classes is not None: 47 | assert mode != BINARY_MODE, "Masking classes is not supported with mode=binary" 48 | classes = to_tensor(classes, dtype=torch.long) 49 | 50 | self.classes = classes 51 | self.from_logits = from_logits 52 | self.smooth = smooth 53 | self.eps = eps 54 | self.log_loss = log_loss 55 | self.ignore_index = ignore_index 56 | 57 | def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: 58 | 59 | assert y_true.size(0) == y_pred.size(0) 60 | 61 | if self.from_logits: 62 | # Apply activations to get [0..1] class probabilities 63 | # Using Log-Exp as this gives more numerically stable result and does not cause vanishing gradient on 64 | # extreme values 0 and 1 65 | if self.mode == MULTICLASS_MODE: 66 | y_pred = y_pred.log_softmax(dim=1).exp() 67 | else: 68 | y_pred = F.logsigmoid(y_pred).exp() 69 | 70 | bs = y_true.size(0) 71 | num_classes = y_pred.size(1) 72 | dims = (0, 2) 73 | 74 | if self.mode == BINARY_MODE: 75 | y_true = y_true.view(bs, 1, -1) 76 | y_pred = y_pred.view(bs, 1, -1) 77 | 78 | if self.ignore_index is not None: 79 | mask = y_true != self.ignore_index 80 | y_pred = y_pred * mask 81 | y_true = y_true * mask 82 | 83 | if self.mode == MULTICLASS_MODE: 84 | y_true = y_true.view(bs, -1) 85 | y_pred = y_pred.view(bs, num_classes, -1) 86 | 87 | if self.ignore_index is not None: 88 | mask = y_true != self.ignore_index 89 | y_pred = y_pred * mask.unsqueeze(1) 90 | 91 | y_true = F.one_hot((y_true * mask).to(torch.long), num_classes) # N,H*W -> N,H*W, C 92 | y_true = y_true.permute(0, 2, 1) * mask.unsqueeze(1) # N, C, H*W 93 | else: 94 | y_true = F.one_hot(y_true, num_classes) # N,H*W -> N,H*W, C 95 | y_true = y_true.permute(0, 2, 1) # N, C, H*W 96 | 97 | if self.mode == MULTILABEL_MODE: 98 | y_true = y_true.view(bs, num_classes, -1) 99 | y_pred = y_pred.view(bs, num_classes, -1) 100 | 101 | if self.ignore_index is not None: 102 | mask = y_true != self.ignore_index 103 | y_pred = y_pred * mask 104 | y_true = y_true * mask 105 | 106 | scores = self.compute_score(y_pred, y_true.type_as(y_pred), smooth=self.smooth, eps=self.eps, dims=dims) 107 | 108 | if self.log_loss: 109 | loss = -torch.log(scores.clamp_min(self.eps)) 110 | else: 111 | loss = 1.0 - scores 112 | 113 | # Dice loss is undefined for non-empty classes 114 | # So we zero contribution of channel that does not have true pixels 115 | # NOTE: A better workaround would be to use loss term `mean(y_pred)` 116 | # for this case, however it will be a modified jaccard loss 117 | 118 | mask = y_true.sum(dims) > 0 119 | loss *= mask.to(loss.dtype) 120 | 121 | if self.classes is not None: 122 | loss = loss[self.classes] 123 | 124 | return self.aggregate_loss(loss) 125 | 126 | def aggregate_loss(self, loss): 127 | return loss.mean() 128 | 129 | def compute_score(self, output, target, smooth=0.0, eps=1e-7, dims=None) -> torch.Tensor: 130 | return soft_dice_score(output, target, smooth, eps, dims) 131 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/losses/focal.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from functools import partial 3 | 4 | import torch 5 | from torch.nn.modules.loss import _Loss 6 | from ._functional import focal_loss_with_logits 7 | from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE 8 | 9 | __all__ = ["FocalLoss"] 10 | 11 | 12 | class FocalLoss(_Loss): 13 | def __init__( 14 | self, 15 | mode: str, 16 | alpha: Optional[float] = None, 17 | gamma: Optional[float] = 2.0, 18 | ignore_index: Optional[int] = None, 19 | reduction: Optional[str] = "mean", 20 | normalized: bool = False, 21 | reduced_threshold: Optional[float] = None, 22 | ): 23 | """Compute Focal loss 24 | 25 | Args: 26 | mode: Loss mode 'binary', 'multiclass' or 'multilabel' 27 | alpha: Prior probability of having positive value in target. 28 | gamma: Power factor for dampening weight (focal strength). 29 | ignore_index: If not None, targets may contain values to be ignored. 30 | Target values equal to ignore_index will be ignored from loss computation. 31 | normalized: Compute normalized focal loss (https://arxiv.org/pdf/1909.07829.pdf). 32 | reduced_threshold: Switch to reduced focal loss. Note, when using this mode you 33 | should use `reduction="sum"`. 34 | 35 | Shape 36 | - **y_pred** - torch.Tensor of shape (N, C, H, W) 37 | - **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W) 38 | 39 | Reference 40 | https://github.com/BloodAxe/pytorch-toolbelt 41 | 42 | """ 43 | assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE} 44 | super().__init__() 45 | 46 | self.mode = mode 47 | self.ignore_index = ignore_index 48 | self.focal_loss_fn = partial( 49 | focal_loss_with_logits, 50 | alpha=alpha, 51 | gamma=gamma, 52 | reduced_threshold=reduced_threshold, 53 | reduction=reduction, 54 | normalized=normalized, 55 | ) 56 | 57 | def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: 58 | 59 | if self.mode in {BINARY_MODE, MULTILABEL_MODE}: 60 | y_true = y_true.view(-1) 61 | y_pred = y_pred.view(-1) 62 | 63 | if self.ignore_index is not None: 64 | # Filter predictions with ignore label from loss computation 65 | not_ignored = y_true != self.ignore_index 66 | y_pred = y_pred[not_ignored] 67 | y_true = y_true[not_ignored] 68 | 69 | loss = self.focal_loss_fn(y_pred, y_true) 70 | 71 | elif self.mode == MULTICLASS_MODE: 72 | 73 | num_classes = y_pred.size(1) 74 | loss = 0 75 | 76 | # Filter anchors with -1 label from loss computation 77 | if self.ignore_index is not None: 78 | not_ignored = y_true != self.ignore_index 79 | 80 | for cls in range(num_classes): 81 | cls_y_true = (y_true == cls).long() 82 | cls_y_pred = y_pred[:, cls, ...] 83 | 84 | if self.ignore_index is not None: 85 | cls_y_true = cls_y_true[not_ignored] 86 | cls_y_pred = cls_y_pred[not_ignored] 87 | 88 | loss += self.focal_loss_fn(cls_y_pred, cls_y_true) 89 | 90 | return loss 91 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/losses/jaccard.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.nn.modules.loss import _Loss 6 | from ._functional import soft_jaccard_score, to_tensor 7 | from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE 8 | 9 | __all__ = ["JaccardLoss"] 10 | 11 | 12 | class JaccardLoss(_Loss): 13 | def __init__( 14 | self, 15 | mode: str, 16 | classes: Optional[List[int]] = None, 17 | log_loss: bool = False, 18 | from_logits: bool = True, 19 | smooth: float = 0.0, 20 | eps: float = 1e-7, 21 | ): 22 | """Jaccard loss for image segmentation task. 23 | It supports binary, multiclass and multilabel cases 24 | 25 | Args: 26 | mode: Loss mode 'binary', 'multiclass' or 'multilabel' 27 | classes: List of classes that contribute in loss computation. By default, all channels are included. 28 | log_loss: If True, loss computed as `- log(jaccard_coeff)`, otherwise `1 - jaccard_coeff` 29 | from_logits: If True, assumes input is raw logits 30 | smooth: Smoothness constant for dice coefficient 31 | eps: A small epsilon for numerical stability to avoid zero division error 32 | (denominator will be always greater or equal to eps) 33 | 34 | Shape 35 | - **y_pred** - torch.Tensor of shape (N, C, H, W) 36 | - **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W) 37 | 38 | Reference 39 | https://github.com/BloodAxe/pytorch-toolbelt 40 | """ 41 | assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE} 42 | super(JaccardLoss, self).__init__() 43 | 44 | self.mode = mode 45 | if classes is not None: 46 | assert mode != BINARY_MODE, "Masking classes is not supported with mode=binary" 47 | classes = to_tensor(classes, dtype=torch.long) 48 | 49 | self.classes = classes 50 | self.from_logits = from_logits 51 | self.smooth = smooth 52 | self.eps = eps 53 | self.log_loss = log_loss 54 | 55 | def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: 56 | 57 | assert y_true.size(0) == y_pred.size(0) 58 | 59 | if self.from_logits: 60 | # Apply activations to get [0..1] class probabilities 61 | # Using Log-Exp as this gives more numerically stable result and does not cause vanishing gradient on 62 | # extreme values 0 and 1 63 | if self.mode == MULTICLASS_MODE: 64 | y_pred = y_pred.log_softmax(dim=1).exp() 65 | else: 66 | y_pred = F.logsigmoid(y_pred).exp() 67 | 68 | bs = y_true.size(0) 69 | num_classes = y_pred.size(1) 70 | dims = (0, 2) 71 | 72 | if self.mode == BINARY_MODE: 73 | y_true = y_true.view(bs, 1, -1) 74 | y_pred = y_pred.view(bs, 1, -1) 75 | 76 | if self.mode == MULTICLASS_MODE: 77 | y_true = y_true.view(bs, -1) 78 | y_pred = y_pred.view(bs, num_classes, -1) 79 | 80 | y_true = F.one_hot(y_true, num_classes) # N,H*W -> N,H*W, C 81 | y_true = y_true.permute(0, 2, 1) # H, C, H*W 82 | 83 | if self.mode == MULTILABEL_MODE: 84 | y_true = y_true.view(bs, num_classes, -1) 85 | y_pred = y_pred.view(bs, num_classes, -1) 86 | 87 | scores = soft_jaccard_score( 88 | y_pred, 89 | y_true.type(y_pred.dtype), 90 | smooth=self.smooth, 91 | eps=self.eps, 92 | dims=dims, 93 | ) 94 | 95 | if self.log_loss: 96 | loss = -torch.log(scores.clamp_min(self.eps)) 97 | else: 98 | loss = 1.0 - scores 99 | 100 | # IoU loss is defined for non-empty classes 101 | # So we zero contribution of channel that does not have true pixels 102 | # NOTE: A better workaround would be to use loss term `mean(y_pred)` 103 | # for this case, however it will be a modified jaccard loss 104 | 105 | mask = y_true.sum(dims) > 0 106 | loss *= mask.float() 107 | 108 | if self.classes is not None: 109 | loss = loss[self.classes] 110 | 111 | return loss.mean() 112 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/losses/mcc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.modules.loss import _Loss 3 | 4 | 5 | class MCCLoss(_Loss): 6 | def __init__(self, eps: float = 1e-5): 7 | """Compute Matthews Correlation Coefficient Loss for image segmentation task. 8 | It only supports binary mode. 9 | 10 | Args: 11 | eps (float): Small epsilon to handle situations where all the samples in the dataset belong to one class 12 | 13 | Reference: 14 | https://github.com/kakumarabhishek/MCC-Loss 15 | """ 16 | super().__init__() 17 | self.eps = eps 18 | 19 | def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: 20 | """Compute MCC loss 21 | 22 | Args: 23 | y_pred (torch.Tensor): model prediction of shape (N, H, W) or (N, 1, H, W) 24 | y_true (torch.Tensor): ground truth labels of shape (N, H, W) or (N, 1, H, W) 25 | 26 | Returns: 27 | torch.Tensor: loss value (1 - mcc) 28 | """ 29 | 30 | bs = y_true.shape[0] 31 | 32 | y_true = y_true.view(bs, 1, -1) 33 | y_pred = y_pred.view(bs, 1, -1) 34 | 35 | tp = torch.sum(torch.mul(y_pred, y_true)) + self.eps 36 | tn = torch.sum(torch.mul((1 - y_pred), (1 - y_true))) + self.eps 37 | fp = torch.sum(torch.mul(y_pred, (1 - y_true))) + self.eps 38 | fn = torch.sum(torch.mul((1 - y_pred), y_true)) + self.eps 39 | 40 | numerator = torch.mul(tp, tn) - torch.mul(fp, fn) 41 | denominator = torch.sqrt(torch.add(tp, fp) * torch.add(tp, fn) * torch.add(tn, fp) * torch.add(tn, fn)) 42 | 43 | mcc = torch.div(numerator.sum(), denominator.sum()) 44 | loss = 1.0 - mcc 45 | 46 | return loss 47 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/losses/soft_bce.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, Tensor 6 | 7 | __all__ = ["SoftBCEWithLogitsLoss"] 8 | 9 | 10 | class SoftBCEWithLogitsLoss(nn.Module): 11 | 12 | __constants__ = [ 13 | "weight", 14 | "pos_weight", 15 | "reduction", 16 | "ignore_index", 17 | "smooth_factor", 18 | ] 19 | 20 | def __init__( 21 | self, 22 | weight: Optional[torch.Tensor] = None, 23 | ignore_index: Optional[int] = -100, 24 | reduction: str = "mean", 25 | smooth_factor: Optional[float] = None, 26 | pos_weight: Optional[torch.Tensor] = None, 27 | ): 28 | """Drop-in replacement for torch.nn.BCEWithLogitsLoss with few additions: ignore_index and label_smoothing 29 | 30 | Args: 31 | ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient. 32 | smooth_factor: Factor to smooth target (e.g. if smooth_factor=0.1 then [1, 0, 1] -> [0.9, 0.1, 0.9]) 33 | 34 | Shape 35 | - **y_pred** - torch.Tensor of shape NxCxHxW 36 | - **y_true** - torch.Tensor of shape NxHxW or Nx1xHxW 37 | 38 | Reference 39 | https://github.com/BloodAxe/pytorch-toolbelt 40 | 41 | """ 42 | super().__init__() 43 | self.ignore_index = ignore_index 44 | self.reduction = reduction 45 | self.smooth_factor = smooth_factor 46 | self.register_buffer("weight", weight) 47 | self.register_buffer("pos_weight", pos_weight) 48 | 49 | def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: 50 | """ 51 | Args: 52 | y_pred: torch.Tensor of shape (N, C, H, W) 53 | y_true: torch.Tensor of shape (N, H, W) or (N, 1, H, W) 54 | 55 | Returns: 56 | loss: torch.Tensor 57 | """ 58 | 59 | if self.smooth_factor is not None: 60 | soft_targets = (1 - y_true) * self.smooth_factor + y_true * (1 - self.smooth_factor) 61 | else: 62 | soft_targets = y_true 63 | 64 | loss = F.binary_cross_entropy_with_logits( 65 | y_pred, 66 | soft_targets, 67 | self.weight, 68 | pos_weight=self.pos_weight, 69 | reduction="none", 70 | ) 71 | 72 | if self.ignore_index is not None: 73 | not_ignored_mask = y_true != self.ignore_index 74 | loss *= not_ignored_mask.type_as(loss) 75 | 76 | if self.reduction == "mean": 77 | loss = loss.mean() 78 | 79 | if self.reduction == "sum": 80 | loss = loss.sum() 81 | 82 | return loss 83 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/losses/soft_ce.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from torch import nn, Tensor 3 | import torch 4 | import torch.nn.functional as F 5 | from ._functional import label_smoothed_nll_loss 6 | 7 | __all__ = ["SoftCrossEntropyLoss"] 8 | 9 | 10 | class SoftCrossEntropyLoss(nn.Module): 11 | 12 | __constants__ = ["reduction", "ignore_index", "smooth_factor"] 13 | 14 | def __init__( 15 | self, 16 | reduction: str = "mean", 17 | smooth_factor: Optional[float] = None, 18 | ignore_index: Optional[int] = -100, 19 | dim: int = 1, 20 | ): 21 | """Drop-in replacement for torch.nn.CrossEntropyLoss with label_smoothing 22 | 23 | Args: 24 | smooth_factor: Factor to smooth target (e.g. if smooth_factor=0.1 then [1, 0, 0] -> [0.9, 0.05, 0.05]) 25 | 26 | Shape 27 | - **y_pred** - torch.Tensor of shape (N, C, H, W) 28 | - **y_true** - torch.Tensor of shape (N, H, W) 29 | 30 | Reference 31 | https://github.com/BloodAxe/pytorch-toolbelt 32 | """ 33 | super().__init__() 34 | self.smooth_factor = smooth_factor 35 | self.ignore_index = ignore_index 36 | self.reduction = reduction 37 | self.dim = dim 38 | 39 | def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: 40 | log_prob = F.log_softmax(y_pred, dim=self.dim) 41 | return label_smoothed_nll_loss( 42 | log_prob, 43 | y_true, 44 | epsilon=self.smooth_factor, 45 | ignore_index=self.ignore_index, 46 | reduction=self.reduction, 47 | dim=self.dim, 48 | ) 49 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/losses/tversky.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | from ._functional import soft_tversky_score 5 | from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE 6 | from .dice import DiceLoss 7 | 8 | __all__ = ["TverskyLoss"] 9 | 10 | 11 | class TverskyLoss(DiceLoss): 12 | """Tversky loss for image segmentation task. 13 | Where FP and FN is weighted by alpha and beta params. 14 | With alpha == beta == 0.5, this loss becomes equal DiceLoss. 15 | It supports binary, multiclass and multilabel cases 16 | 17 | Args: 18 | mode: Metric mode {'binary', 'multiclass', 'multilabel'} 19 | classes: Optional list of classes that contribute in loss computation; 20 | By default, all channels are included. 21 | log_loss: If True, loss computed as ``-log(tversky)`` otherwise ``1 - tversky`` 22 | from_logits: If True assumes input is raw logits 23 | smooth: 24 | ignore_index: Label that indicates ignored pixels (does not contribute to loss) 25 | eps: Small epsilon for numerical stability 26 | alpha: Weight constant that penalize model for FPs (False Positives) 27 | beta: Weight constant that penalize model for FNs (False Negatives) 28 | gamma: Constant that squares the error function. Defaults to ``1.0`` 29 | 30 | Return: 31 | loss: torch.Tensor 32 | 33 | """ 34 | 35 | def __init__( 36 | self, 37 | mode: str, 38 | classes: List[int] = None, 39 | log_loss: bool = False, 40 | from_logits: bool = True, 41 | smooth: float = 0.0, 42 | ignore_index: Optional[int] = None, 43 | eps: float = 1e-7, 44 | alpha: float = 0.5, 45 | beta: float = 0.5, 46 | gamma: float = 1.0, 47 | ): 48 | 49 | assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE} 50 | super().__init__(mode, classes, log_loss, from_logits, smooth, ignore_index, eps) 51 | self.alpha = alpha 52 | self.beta = beta 53 | self.gamma = gamma 54 | 55 | def aggregate_loss(self, loss): 56 | return loss.mean() ** self.gamma 57 | 58 | def compute_score(self, output, target, smooth=0.0, eps=1e-7, dims=None) -> torch.Tensor: 59 | return soft_tversky_score(output, target, self.alpha, self.beta, smooth, eps, dims) 60 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .functional import ( 2 | get_stats, 3 | fbeta_score, 4 | f1_score, 5 | iou_score, 6 | accuracy, 7 | precision, 8 | recall, 9 | sensitivity, 10 | specificity, 11 | balanced_accuracy, 12 | positive_predictive_value, 13 | negative_predictive_value, 14 | false_negative_rate, 15 | false_positive_rate, 16 | false_discovery_rate, 17 | false_omission_rate, 18 | positive_likelihood_ratio, 19 | negative_likelihood_ratio, 20 | ) 21 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | from . import train 4 | from . import losses 5 | from . import metrics 6 | 7 | warnings.warn( 8 | "`smp.utils` module is deprecated and will be removed in future releases.", 9 | DeprecationWarning, 10 | ) 11 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/utils/base.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch.nn as nn 3 | 4 | 5 | class BaseObject(nn.Module): 6 | def __init__(self, name=None): 7 | super().__init__() 8 | self._name = name 9 | 10 | @property 11 | def __name__(self): 12 | if self._name is None: 13 | name = self.__class__.__name__ 14 | s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) 15 | return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() 16 | else: 17 | return self._name 18 | 19 | 20 | class Metric(BaseObject): 21 | pass 22 | 23 | 24 | class Loss(BaseObject): 25 | def __add__(self, other): 26 | if isinstance(other, Loss): 27 | return SumOfLosses(self, other) 28 | else: 29 | raise ValueError("Loss should be inherited from `Loss` class") 30 | 31 | def __radd__(self, other): 32 | return self.__add__(other) 33 | 34 | def __mul__(self, value): 35 | if isinstance(value, (int, float)): 36 | return MultipliedLoss(self, value) 37 | else: 38 | raise ValueError("Loss should be inherited from `BaseLoss` class") 39 | 40 | def __rmul__(self, other): 41 | return self.__mul__(other) 42 | 43 | 44 | class SumOfLosses(Loss): 45 | def __init__(self, l1, l2): 46 | name = "{} + {}".format(l1.__name__, l2.__name__) 47 | super().__init__(name=name) 48 | self.l1 = l1 49 | self.l2 = l2 50 | 51 | def __call__(self, *inputs): 52 | return self.l1.forward(*inputs) + self.l2.forward(*inputs) 53 | 54 | 55 | class MultipliedLoss(Loss): 56 | def __init__(self, loss, multiplier): 57 | 58 | # resolve name 59 | if len(loss.__name__.split("+")) > 1: 60 | name = "{} * ({})".format(multiplier, loss.__name__) 61 | else: 62 | name = "{} * {}".format(multiplier, loss.__name__) 63 | super().__init__(name=name) 64 | self.loss = loss 65 | self.multiplier = multiplier 66 | 67 | def __call__(self, *inputs): 68 | return self.multiplier * self.loss.forward(*inputs) 69 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/utils/convert_weights.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | __author__ = 'Roman Solovyev: https://github.com/ZFTurbo' 3 | 4 | import torch 5 | 6 | def convert_2d_weights_to_3d(state_dict, verbose=False): 7 | layers = list(state_dict.keys()) 8 | for layer in layers: 9 | if ( 10 | 'conv' in layer 11 | or 'downsample' in layer 12 | or '_se_expand' in layer 13 | or '_se_reduce' in layer 14 | or 'patch_embed' in layer 15 | or 'attn.sr.weight' in layer 16 | ): 17 | if len(state_dict[layer].shape) == 4: 18 | shape_init = state_dict[layer].shape 19 | state_dict[layer] = torch.stack([state_dict[layer]]*state_dict[layer].shape[-1], dim=-1) 20 | state_dict[layer] /= state_dict[layer].shape[-1] 21 | if verbose: 22 | print("Convert layer weights: {}. Shape: {} -> {}".format(layer, shape_init, state_dict[layer].shape)) 23 | return state_dict -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/utils/functional.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def _take_channels(*xs, ignore_channels=None): 5 | if ignore_channels is None: 6 | return xs 7 | else: 8 | channels = [channel for channel in range(xs[0].shape[1]) if channel not in ignore_channels] 9 | xs = [torch.index_select(x, dim=1, index=torch.tensor(channels).to(x.device)) for x in xs] 10 | return xs 11 | 12 | 13 | def _threshold(x, threshold=None): 14 | if threshold is not None: 15 | return (x > threshold).type(x.dtype) 16 | else: 17 | return x 18 | 19 | 20 | def iou(pr, gt, eps=1e-7, threshold=None, ignore_channels=None): 21 | """Calculate Intersection over Union between ground truth and prediction 22 | Args: 23 | pr (torch.Tensor): predicted tensor 24 | gt (torch.Tensor): ground truth tensor 25 | eps (float): epsilon to avoid zero division 26 | threshold: threshold for outputs binarization 27 | Returns: 28 | float: IoU (Jaccard) score 29 | """ 30 | 31 | pr = _threshold(pr, threshold=threshold) 32 | pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) 33 | 34 | intersection = torch.sum(gt * pr) 35 | union = torch.sum(gt) + torch.sum(pr) - intersection + eps 36 | return (intersection + eps) / union 37 | 38 | 39 | jaccard = iou 40 | 41 | 42 | def f_score(pr, gt, beta=1, eps=1e-7, threshold=None, ignore_channels=None): 43 | """Calculate F-score between ground truth and prediction 44 | Args: 45 | pr (torch.Tensor): predicted tensor 46 | gt (torch.Tensor): ground truth tensor 47 | beta (float): positive constant 48 | eps (float): epsilon to avoid zero division 49 | threshold: threshold for outputs binarization 50 | Returns: 51 | float: F score 52 | """ 53 | 54 | pr = _threshold(pr, threshold=threshold) 55 | pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) 56 | 57 | tp = torch.sum(gt * pr) 58 | fp = torch.sum(pr) - tp 59 | fn = torch.sum(gt) - tp 60 | 61 | score = ((1 + beta**2) * tp + eps) / ((1 + beta**2) * tp + beta**2 * fn + fp + eps) 62 | 63 | return score 64 | 65 | 66 | def accuracy(pr, gt, threshold=0.5, ignore_channels=None): 67 | """Calculate accuracy score between ground truth and prediction 68 | Args: 69 | pr (torch.Tensor): predicted tensor 70 | gt (torch.Tensor): ground truth tensor 71 | eps (float): epsilon to avoid zero division 72 | threshold: threshold for outputs binarization 73 | Returns: 74 | float: precision score 75 | """ 76 | pr = _threshold(pr, threshold=threshold) 77 | pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) 78 | 79 | tp = torch.sum(gt == pr, dtype=pr.dtype) 80 | score = tp / gt.view(-1).shape[0] 81 | return score 82 | 83 | 84 | def precision(pr, gt, eps=1e-7, threshold=None, ignore_channels=None): 85 | """Calculate precision score between ground truth and prediction 86 | Args: 87 | pr (torch.Tensor): predicted tensor 88 | gt (torch.Tensor): ground truth tensor 89 | eps (float): epsilon to avoid zero division 90 | threshold: threshold for outputs binarization 91 | Returns: 92 | float: precision score 93 | """ 94 | 95 | pr = _threshold(pr, threshold=threshold) 96 | pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) 97 | 98 | tp = torch.sum(gt * pr) 99 | fp = torch.sum(pr) - tp 100 | 101 | score = (tp + eps) / (tp + fp + eps) 102 | 103 | return score 104 | 105 | 106 | def recall(pr, gt, eps=1e-7, threshold=None, ignore_channels=None): 107 | """Calculate Recall between ground truth and prediction 108 | Args: 109 | pr (torch.Tensor): A list of predicted elements 110 | gt (torch.Tensor): A list of elements that are to be predicted 111 | eps (float): epsilon to avoid zero division 112 | threshold: threshold for outputs binarization 113 | Returns: 114 | float: recall score 115 | """ 116 | 117 | pr = _threshold(pr, threshold=threshold) 118 | pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) 119 | 120 | tp = torch.sum(gt * pr) 121 | fn = torch.sum(gt) - tp 122 | 123 | score = (tp + eps) / (tp + fn + eps) 124 | 125 | return score 126 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from . import base 4 | from . import functional as F 5 | from ..base.modules import Activation 6 | 7 | 8 | class JaccardLoss(base.Loss): 9 | def __init__(self, eps=1.0, activation=None, ignore_channels=None, **kwargs): 10 | super().__init__(**kwargs) 11 | self.eps = eps 12 | self.activation = Activation(activation) 13 | self.ignore_channels = ignore_channels 14 | 15 | def forward(self, y_pr, y_gt): 16 | y_pr = self.activation(y_pr) 17 | return 1 - F.jaccard( 18 | y_pr, 19 | y_gt, 20 | eps=self.eps, 21 | threshold=None, 22 | ignore_channels=self.ignore_channels, 23 | ) 24 | 25 | 26 | class DiceLoss(base.Loss): 27 | def __init__(self, eps=1.0, beta=1.0, activation=None, ignore_channels=None, **kwargs): 28 | super().__init__(**kwargs) 29 | self.eps = eps 30 | self.beta = beta 31 | self.activation = Activation(activation) 32 | self.ignore_channels = ignore_channels 33 | 34 | def forward(self, y_pr, y_gt): 35 | y_pr = self.activation(y_pr) 36 | return 1 - F.f_score( 37 | y_pr, 38 | y_gt, 39 | beta=self.beta, 40 | eps=self.eps, 41 | threshold=None, 42 | ignore_channels=self.ignore_channels, 43 | ) 44 | 45 | 46 | class L1Loss(nn.L1Loss, base.Loss): 47 | pass 48 | 49 | 50 | class MSELoss(nn.MSELoss, base.Loss): 51 | pass 52 | 53 | 54 | class CrossEntropyLoss(nn.CrossEntropyLoss, base.Loss): 55 | pass 56 | 57 | 58 | class NLLLoss(nn.NLLLoss, base.Loss): 59 | pass 60 | 61 | 62 | class BCELoss(nn.BCELoss, base.Loss): 63 | pass 64 | 65 | 66 | class BCEWithLogitsLoss(nn.BCEWithLogitsLoss, base.Loss): 67 | pass 68 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/utils/meter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Meter(object): 5 | """Meters provide a way to keep track of important statistics in an online manner. 6 | This class is abstract, but provides a standard interface for all meters to follow. 7 | """ 8 | 9 | def reset(self): 10 | """Reset the meter to default settings.""" 11 | pass 12 | 13 | def add(self, value): 14 | """Log a new value to the meter 15 | Args: 16 | value: Next result to include. 17 | """ 18 | pass 19 | 20 | def value(self): 21 | """Get the value of the meter in the current state.""" 22 | pass 23 | 24 | 25 | class AverageValueMeter(Meter): 26 | def __init__(self): 27 | super(AverageValueMeter, self).__init__() 28 | self.reset() 29 | self.val = 0 30 | 31 | def add(self, value, n=1): 32 | self.val = value 33 | self.sum += value 34 | self.var += value * value 35 | self.n += n 36 | 37 | if self.n == 0: 38 | self.mean, self.std = np.nan, np.nan 39 | elif self.n == 1: 40 | self.mean = 0.0 + self.sum # This is to force a copy in torch/numpy 41 | self.std = np.inf 42 | self.mean_old = self.mean 43 | self.m_s = 0.0 44 | else: 45 | self.mean = self.mean_old + (value - n * self.mean_old) / float(self.n) 46 | self.m_s += (value - self.mean_old) * (value - self.mean) 47 | self.mean_old = self.mean 48 | self.std = np.sqrt(self.m_s / (self.n - 1.0)) 49 | 50 | def value(self): 51 | return self.mean, self.std 52 | 53 | def reset(self): 54 | self.n = 0 55 | self.sum = 0.0 56 | self.var = 0.0 57 | self.val = 0.0 58 | self.mean = np.nan 59 | self.mean_old = 0.0 60 | self.m_s = 0.0 61 | self.std = np.nan 62 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/utils/metrics.py: -------------------------------------------------------------------------------- 1 | from . import base 2 | from . import functional as F 3 | from ..base.modules import Activation 4 | 5 | 6 | class IoU(base.Metric): 7 | __name__ = "iou_score" 8 | 9 | def __init__(self, eps=1e-7, threshold=0.5, activation=None, ignore_channels=None, **kwargs): 10 | super().__init__(**kwargs) 11 | self.eps = eps 12 | self.threshold = threshold 13 | self.activation = Activation(activation) 14 | self.ignore_channels = ignore_channels 15 | 16 | def forward(self, y_pr, y_gt): 17 | y_pr = self.activation(y_pr) 18 | return F.iou( 19 | y_pr, 20 | y_gt, 21 | eps=self.eps, 22 | threshold=self.threshold, 23 | ignore_channels=self.ignore_channels, 24 | ) 25 | 26 | 27 | class Fscore(base.Metric): 28 | def __init__(self, beta=1, eps=1e-7, threshold=0.5, activation=None, ignore_channels=None, **kwargs): 29 | super().__init__(**kwargs) 30 | self.eps = eps 31 | self.beta = beta 32 | self.threshold = threshold 33 | self.activation = Activation(activation) 34 | self.ignore_channels = ignore_channels 35 | 36 | def forward(self, y_pr, y_gt): 37 | y_pr = self.activation(y_pr) 38 | return F.f_score( 39 | y_pr, 40 | y_gt, 41 | eps=self.eps, 42 | beta=self.beta, 43 | threshold=self.threshold, 44 | ignore_channels=self.ignore_channels, 45 | ) 46 | 47 | 48 | class Accuracy(base.Metric): 49 | def __init__(self, threshold=0.5, activation=None, ignore_channels=None, **kwargs): 50 | super().__init__(**kwargs) 51 | self.threshold = threshold 52 | self.activation = Activation(activation) 53 | self.ignore_channels = ignore_channels 54 | 55 | def forward(self, y_pr, y_gt): 56 | y_pr = self.activation(y_pr) 57 | return F.accuracy( 58 | y_pr, 59 | y_gt, 60 | threshold=self.threshold, 61 | ignore_channels=self.ignore_channels, 62 | ) 63 | 64 | 65 | class Recall(base.Metric): 66 | def __init__(self, eps=1e-7, threshold=0.5, activation=None, ignore_channels=None, **kwargs): 67 | super().__init__(**kwargs) 68 | self.eps = eps 69 | self.threshold = threshold 70 | self.activation = Activation(activation) 71 | self.ignore_channels = ignore_channels 72 | 73 | def forward(self, y_pr, y_gt): 74 | y_pr = self.activation(y_pr) 75 | return F.recall( 76 | y_pr, 77 | y_gt, 78 | eps=self.eps, 79 | threshold=self.threshold, 80 | ignore_channels=self.ignore_channels, 81 | ) 82 | 83 | 84 | class Precision(base.Metric): 85 | def __init__(self, eps=1e-7, threshold=0.5, activation=None, ignore_channels=None, **kwargs): 86 | super().__init__(**kwargs) 87 | self.eps = eps 88 | self.threshold = threshold 89 | self.activation = Activation(activation) 90 | self.ignore_channels = ignore_channels 91 | 92 | def forward(self, y_pr, y_gt): 93 | y_pr = self.activation(y_pr) 94 | return F.precision( 95 | y_pr, 96 | y_gt, 97 | eps=self.eps, 98 | threshold=self.threshold, 99 | ignore_channels=self.ignore_channels, 100 | ) 101 | -------------------------------------------------------------------------------- /segmentation_models_pytorch_3d/utils/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | from tqdm import tqdm as tqdm 4 | from .meter import AverageValueMeter 5 | 6 | 7 | class Epoch: 8 | def __init__(self, model, loss, metrics, stage_name, device="cpu", verbose=True): 9 | self.model = model 10 | self.loss = loss 11 | self.metrics = metrics 12 | self.stage_name = stage_name 13 | self.verbose = verbose 14 | self.device = device 15 | 16 | self._to_device() 17 | 18 | def _to_device(self): 19 | self.model.to(self.device) 20 | self.loss.to(self.device) 21 | for metric in self.metrics: 22 | metric.to(self.device) 23 | 24 | def _format_logs(self, logs): 25 | str_logs = ["{} - {:.4}".format(k, v) for k, v in logs.items()] 26 | s = ", ".join(str_logs) 27 | return s 28 | 29 | def batch_update(self, x, y): 30 | raise NotImplementedError 31 | 32 | def on_epoch_start(self): 33 | pass 34 | 35 | def run(self, dataloader): 36 | 37 | self.on_epoch_start() 38 | 39 | logs = {} 40 | loss_meter = AverageValueMeter() 41 | metrics_meters = {metric.__name__: AverageValueMeter() for metric in self.metrics} 42 | 43 | with tqdm( 44 | dataloader, 45 | desc=self.stage_name, 46 | file=sys.stdout, 47 | disable=not (self.verbose), 48 | ) as iterator: 49 | for x, y in iterator: 50 | x, y = x.to(self.device), y.to(self.device) 51 | loss, y_pred = self.batch_update(x, y) 52 | 53 | # update loss logs 54 | loss_value = loss.cpu().detach().numpy() 55 | loss_meter.add(loss_value) 56 | loss_logs = {self.loss.__name__: loss_meter.mean} 57 | logs.update(loss_logs) 58 | 59 | # update metrics logs 60 | for metric_fn in self.metrics: 61 | metric_value = metric_fn(y_pred, y).cpu().detach().numpy() 62 | metrics_meters[metric_fn.__name__].add(metric_value) 63 | metrics_logs = {k: v.mean for k, v in metrics_meters.items()} 64 | logs.update(metrics_logs) 65 | 66 | if self.verbose: 67 | s = self._format_logs(logs) 68 | iterator.set_postfix_str(s) 69 | 70 | return logs 71 | 72 | 73 | class TrainEpoch(Epoch): 74 | def __init__(self, model, loss, metrics, optimizer, device="cpu", verbose=True): 75 | super().__init__( 76 | model=model, 77 | loss=loss, 78 | metrics=metrics, 79 | stage_name="train", 80 | device=device, 81 | verbose=verbose, 82 | ) 83 | self.optimizer = optimizer 84 | 85 | def on_epoch_start(self): 86 | self.model.train() 87 | 88 | def batch_update(self, x, y): 89 | self.optimizer.zero_grad() 90 | prediction = self.model.forward(x) 91 | loss = self.loss(prediction, y) 92 | loss.backward() 93 | self.optimizer.step() 94 | return loss, prediction 95 | 96 | 97 | class ValidEpoch(Epoch): 98 | def __init__(self, model, loss, metrics, device="cpu", verbose=True): 99 | super().__init__( 100 | model=model, 101 | loss=loss, 102 | metrics=metrics, 103 | stage_name="valid", 104 | device=device, 105 | verbose=verbose, 106 | ) 107 | 108 | def on_epoch_start(self): 109 | self.model.eval() 110 | 111 | def batch_update(self, x, y): 112 | with torch.no_grad(): 113 | prediction = self.model.forward(x) 114 | loss = self.loss(prediction, y) 115 | return loss, prediction 116 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | try: 2 | from setuptools import setup 3 | except ImportError: 4 | from distutils.core import setup 5 | 6 | setup( 7 | name='segmentation_models_pytorch_3d', 8 | version='1.0.2', 9 | author='Roman Sol (ZFTurbo)', 10 | packages=[ 11 | 'segmentation_models_pytorch_3d', 12 | 'segmentation_models_pytorch_3d/losses', 13 | 'segmentation_models_pytorch_3d/datasets', 14 | 'segmentation_models_pytorch_3d/base', 15 | 'segmentation_models_pytorch_3d/encoders', 16 | 'segmentation_models_pytorch_3d/decoders', 17 | 'segmentation_models_pytorch_3d/utils', 18 | 'segmentation_models_pytorch_3d/metrics', 19 | 'segmentation_models_pytorch_3d/decoders/linknet', 20 | 'segmentation_models_pytorch_3d/decoders/unet', 21 | 'segmentation_models_pytorch_3d/decoders/deeplabv3', 22 | 'segmentation_models_pytorch_3d/decoders/pan', 23 | 'segmentation_models_pytorch_3d/decoders/pspnet', 24 | 'segmentation_models_pytorch_3d/decoders/fpn', 25 | 'segmentation_models_pytorch_3d/decoders/unetplusplus', 26 | 'segmentation_models_pytorch_3d/decoders/manet', 27 | ], 28 | url='https://github.com/ZFTurbo/segmentation_models_pytorch_3d', 29 | description='Set of models for segmentation of 3D volumes using PyTorch.', 30 | long_description='3D variants of popular models for segmentation like FPN, Unet, Linknet etc using Pytorch module.' 31 | 'Automatic conversion of 2D imagenet weights to 3D variant', 32 | install_requires=[ 33 | 'torch', 34 | 'torchvision>=0.5.0', 35 | "pretrainedmodels==0.7.4", 36 | "efficientnet-pytorch==0.7.1", 37 | "timm==0.9.7", 38 | "timm-3d==1.0.1", 39 | "tqdm", 40 | "pillow", 41 | "six", 42 | ], 43 | ) 44 | --------------------------------------------------------------------------------