├── figures ├── architecture_gcascade.jpg ├── architecture_gcascade.png ├── qualitative_results_synapse.png ├── lists └── lists_Synapse │ ├── test_vol.txt │ └── all.lst ├── lib ├── gcn_lib │ ├── __init__.py │ ├── pos_embed.py │ ├── torch_nn.py │ └── torch_edge.py └── models_timm │ ├── layers │ ├── trace_utils.py │ ├── linear.py │ ├── helpers.py │ ├── conv2d_same.py │ ├── blur_pool.py │ ├── create_conv2d.py │ ├── patch_embed.py │ ├── median_pool.py │ ├── space_to_depth.py │ ├── create_norm.py │ ├── mixed_conv2d.py │ ├── test_time_pool.py │ ├── padding.py │ ├── classifier.py │ ├── global_context.py │ ├── fast_norm.py │ ├── __init__.py │ ├── filter_response_norm.py │ ├── activations_jit.py │ ├── separable_conv.py │ ├── pool2d_same.py │ ├── split_attn.py │ ├── conv_bn_act.py │ ├── config.py │ ├── inplace_abn.py │ ├── split_batchnorm.py │ ├── create_attn.py │ ├── create_norm_act.py │ ├── gather_excite.py │ ├── adaptive_avgmax_pool.py │ ├── activations.py │ ├── squeeze_excite.py │ ├── cbam.py │ ├── mlp.py │ ├── norm.py │ ├── weight_init.py │ ├── attention_pool2d.py │ ├── cond_conv2d.py │ ├── selective_kernel.py │ ├── create_act.py │ ├── std_conv.py │ └── lambda_layer.py │ ├── __init__.py │ ├── factory.py │ ├── pruned │ └── ecaresnet50d_pruned.txt │ ├── fx_features.py │ └── convmixer.py ├── requirements.txt ├── utils ├── README.md ├── format_conversion.py ├── transforms.py ├── lesion │ ├── make_dataset.py │ ├── lesion_dataset.py │ └── helpers.py ├── preprocess_synapse_data.py ├── preprocess_synapse_data_3d.py ├── dataset_ACDC.py └── dataset_synapse.py ├── README.md ├── train_synapse.py └── test_ACDC.py /figures: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /architecture_gcascade.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SLDGroup/G-CASCADE/HEAD/architecture_gcascade.jpg -------------------------------------------------------------------------------- /architecture_gcascade.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SLDGroup/G-CASCADE/HEAD/architecture_gcascade.png -------------------------------------------------------------------------------- /qualitative_results_synapse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SLDGroup/G-CASCADE/HEAD/qualitative_results_synapse.png -------------------------------------------------------------------------------- /lists/lists_Synapse/test_vol.txt: -------------------------------------------------------------------------------- 1 | case0008 2 | case0022 3 | case0038 4 | case0036 5 | case0032 6 | case0002 7 | case0029 8 | case0003 9 | case0001 10 | case0004 11 | case0025 12 | case0035 13 | -------------------------------------------------------------------------------- /lib/gcn_lib/__init__.py: -------------------------------------------------------------------------------- 1 | # 2022.06.17-Changed for building ViG model 2 | # Huawei Technologies Co., Ltd. 3 | from .torch_nn import * 4 | from .torch_edge import * 5 | from .torch_vertex import * 6 | 7 | -------------------------------------------------------------------------------- /lib/models_timm/layers/trace_utils.py: -------------------------------------------------------------------------------- 1 | try: 2 | from torch import _assert 3 | except ImportError: 4 | def _assert(condition: bool, message: str): 5 | assert condition, message 6 | 7 | 8 | def _float_to_int(x: float) -> int: 9 | """ 10 | Symbolic tracing helper to substitute for inbuilt `int`. 11 | Hint: Inbuilt `int` can't accept an argument of type `Proxy` 12 | """ 13 | return int(x) 14 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | loguru 2 | tqdm 3 | pyyaml 4 | pandas 5 | matplotlib 6 | scikit-learn 7 | scikit-image 8 | scipy 9 | opencv-python 10 | seaborn 11 | albumentations 12 | tabulate 13 | warmup-scheduler 14 | torch==1.11.0+cu113 15 | torchvision==0.12.0+cu113 16 | mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.11.0/index.html 17 | einops 18 | pthflops 19 | torchsummary 20 | thop 21 | segmentation-mask-overlay==0.3.4 22 | -------------------------------------------------------------------------------- /utils/README.md: -------------------------------------------------------------------------------- 1 | # Data Preparing 2 | 3 | 1. Access to the synapse multi-organ dataset: 4 | 1. Sign up in the [official Synapse website](https://www.synapse.org/#!Synapse:syn3193805/wiki/) and download the dataset. Convert them to numpy format, clip the images within [-125, 275], normalize each 3D image to [0, 1], and extract 2D slices from 3D volume for training cases while keeping the 3D volume in h5 format for testing cases. 5 | 2. You can also send an Email directly to mostafijur.rahman AT utexas.edu to request the preprocessed data for reproduction. 6 | 7 | -------------------------------------------------------------------------------- /lists/lists_Synapse/all.lst: -------------------------------------------------------------------------------- 1 | case0031.npy.h5 2 | case0007.npy.h5 3 | case0009.npy.h5 4 | case0005.npy.h5 5 | case0026.npy.h5 6 | case0039.npy.h5 7 | case0024.npy.h5 8 | case0034.npy.h5 9 | case0033.npy.h5 10 | case0030.npy.h5 11 | case0023.npy.h5 12 | case0040.npy.h5 13 | case0010.npy.h5 14 | case0021.npy.h5 15 | case0006.npy.h5 16 | case0027.npy.h5 17 | case0028.npy.h5 18 | case0037.npy.h5 19 | case0008.npy.h5 20 | case0022.npy.h5 21 | case0038.npy.h5 22 | case0036.npy.h5 23 | case0032.npy.h5 24 | case0002.npy.h5 25 | case0029.npy.h5 26 | case0003.npy.h5 27 | case0001.npy.h5 28 | case0004.npy.h5 29 | case0025.npy.h5 30 | case0035.npy.h5 31 | -------------------------------------------------------------------------------- /lib/models_timm/layers/linear.py: -------------------------------------------------------------------------------- 1 | """ Linear layer (alternate definition) 2 | """ 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn as nn 6 | 7 | 8 | class Linear(nn.Linear): 9 | r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b` 10 | 11 | Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting 12 | weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case. 13 | """ 14 | def forward(self, input: torch.Tensor) -> torch.Tensor: 15 | if torch.jit.is_scripting(): 16 | bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None 17 | return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias) 18 | else: 19 | return F.linear(input, self.weight, self.bias) 20 | -------------------------------------------------------------------------------- /utils/format_conversion.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from libtiff import TIFF # pip install libtiff 4 | from scipy import misc 5 | import random 6 | 7 | 8 | def tif2png(_src_path, _dst_path): 9 | """ 10 | Usage: 11 | formatting `tif/tiff` files to `jpg/png` files 12 | :param _src_path: 13 | :param _dst_path: 14 | :return: 15 | """ 16 | tif = TIFF.open(_src_path, mode='r') 17 | image = tif.read_image() 18 | misc.imsave(_dst_path, image) 19 | 20 | 21 | def data_split(src_list): 22 | """ 23 | Usage: 24 | randomly spliting dataset 25 | :param src_list: 26 | :return: 27 | """ 28 | counter_list = random.sample(range(0, len(src_list)), 550) 29 | 30 | return counter_list 31 | 32 | 33 | if __name__ == '__main__': 34 | src_dir = '../Dataset/train_dataset/CVC-EndoSceneStill/CVC-612/test_split/masks_tif' 35 | dst_dir = '../Dataset/train_dataset/CVC-EndoSceneStill/CVC-612/test_split/masks' 36 | 37 | os.makedirs(dst_dir, exist_ok=True) 38 | for img_name in os.listdir(src_dir): 39 | tif2png(os.path.join(src_dir, img_name), 40 | os.path.join(dst_dir, img_name.replace('.tif', '.png'))) 41 | -------------------------------------------------------------------------------- /lib/models_timm/layers/helpers.py: -------------------------------------------------------------------------------- 1 | """ Layer/Module Helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from itertools import repeat 6 | import collections.abc 7 | 8 | 9 | # From PyTorch internals 10 | def _ntuple(n): 11 | def parse(x): 12 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): 13 | return x 14 | return tuple(repeat(x, n)) 15 | return parse 16 | 17 | 18 | to_1tuple = _ntuple(1) 19 | to_2tuple = _ntuple(2) 20 | to_3tuple = _ntuple(3) 21 | to_4tuple = _ntuple(4) 22 | to_ntuple = _ntuple 23 | 24 | 25 | def make_divisible(v, divisor=8, min_value=None, round_limit=.9): 26 | min_value = min_value or divisor 27 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 28 | # Make sure that round down does not go down by more than 10%. 29 | if new_v < round_limit * v: 30 | new_v += divisor 31 | return new_v 32 | 33 | 34 | def extend_tuple(x, n): 35 | # pdas a tuple to specified n by padding with last value 36 | if not isinstance(x, (tuple, list)): 37 | x = (x,) 38 | else: 39 | x = tuple(x) 40 | pad_n = n - len(x) 41 | if pad_n <= 0: 42 | return x[:n] 43 | return x + (x[-1],) * pad_n 44 | -------------------------------------------------------------------------------- /utils/transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | from skimage.filters import gaussian 5 | import torch 6 | from PIL import Image, ImageFilter 7 | 8 | 9 | class RandomVerticalFlip(object): 10 | def __call__(self, img): 11 | if random.random() < 0.5: 12 | return img.transpose(Image.FLIP_TOP_BOTTOM) 13 | return img 14 | 15 | 16 | class DeNormalize(object): 17 | def __init__(self, mean, std): 18 | self.mean = mean 19 | self.std = std 20 | 21 | def __call__(self, tensor): 22 | for t, m, s in zip(tensor, self.mean, self.std): 23 | t.mul_(s).add_(m) 24 | return tensor 25 | 26 | 27 | class MaskToTensor(object): 28 | def __call__(self, img): 29 | return torch.from_numpy(np.array(img, dtype=np.int32)).long() 30 | 31 | 32 | class FreeScale(object): 33 | def __init__(self, size, interpolation=Image.BILINEAR): 34 | self.size = tuple(reversed(size)) # size: (h, w) 35 | self.interpolation = interpolation 36 | 37 | def __call__(self, img): 38 | return img.resize(self.size, self.interpolation) 39 | 40 | 41 | class FlipChannels(object): 42 | def __call__(self, img): 43 | img = np.array(img)[:, :, ::-1] 44 | return Image.fromarray(img.astype(np.uint8)) 45 | 46 | 47 | class RandomGaussianBlur(object): 48 | def __call__(self, img): 49 | sigma = 0.15 + random.random() * 1.15 50 | blurred_img = gaussian(np.array(img), sigma=sigma, multichannel=True) 51 | blurred_img *= 255 52 | return Image.fromarray(blurred_img.astype(np.uint8)) 53 | -------------------------------------------------------------------------------- /lib/models_timm/layers/conv2d_same.py: -------------------------------------------------------------------------------- 1 | """ Conv2d w/ Same Padding 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from typing import Tuple, Optional 9 | 10 | from .padding import pad_same, get_padding_value 11 | 12 | 13 | def conv2d_same( 14 | x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1), 15 | padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1): 16 | x = pad_same(x, weight.shape[-2:], stride, dilation) 17 | return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) 18 | 19 | 20 | class Conv2dSame(nn.Conv2d): 21 | """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions 22 | """ 23 | 24 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 25 | padding=0, dilation=1, groups=1, bias=True): 26 | super(Conv2dSame, self).__init__( 27 | in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) 28 | 29 | def forward(self, x): 30 | return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 31 | 32 | 33 | def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): 34 | padding = kwargs.pop('padding', '') 35 | kwargs.setdefault('bias', False) 36 | padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) 37 | if is_dynamic: 38 | return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) 39 | else: 40 | return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) 41 | 42 | 43 | -------------------------------------------------------------------------------- /lib/models_timm/layers/blur_pool.py: -------------------------------------------------------------------------------- 1 | """ 2 | BlurPool layer inspired by 3 | - Kornia's Max_BlurPool2d 4 | - Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar` 5 | 6 | Hacked together by Chris Ha and Ross Wightman 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import numpy as np 13 | from .padding import get_padding 14 | 15 | 16 | class BlurPool2d(nn.Module): 17 | r"""Creates a module that computes blurs and downsample a given feature map. 18 | See :cite:`zhang2019shiftinvar` for more details. 19 | Corresponds to the Downsample class, which does blurring and subsampling 20 | 21 | Args: 22 | channels = Number of input channels 23 | filt_size (int): binomial filter size for blurring. currently supports 3 (default) and 5. 24 | stride (int): downsampling filter stride 25 | 26 | Returns: 27 | torch.Tensor: the transformed tensor. 28 | """ 29 | def __init__(self, channels, filt_size=3, stride=2) -> None: 30 | super(BlurPool2d, self).__init__() 31 | assert filt_size > 1 32 | self.channels = channels 33 | self.filt_size = filt_size 34 | self.stride = stride 35 | self.padding = [get_padding(filt_size, stride, dilation=1)] * 4 36 | coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs.astype(np.float32)) 37 | blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :].repeat(self.channels, 1, 1, 1) 38 | self.register_buffer('filt', blur_filter, persistent=False) 39 | 40 | def forward(self, x: torch.Tensor) -> torch.Tensor: 41 | x = F.pad(x, self.padding, 'reflect') 42 | return F.conv2d(x, self.filt, stride=self.stride, groups=self.channels) 43 | -------------------------------------------------------------------------------- /lib/models_timm/layers/create_conv2d.py: -------------------------------------------------------------------------------- 1 | """ Create Conv2d Factory Method 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | 6 | from .mixed_conv2d import MixedConv2d 7 | from .cond_conv2d import CondConv2d 8 | from .conv2d_same import create_conv2d_pad 9 | 10 | 11 | def create_conv2d(in_channels, out_channels, kernel_size, **kwargs): 12 | """ Select a 2d convolution implementation based on arguments 13 | Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d. 14 | 15 | Used extensively by EfficientNet, MobileNetv3 and related networks. 16 | """ 17 | if isinstance(kernel_size, list): 18 | assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently 19 | if 'groups' in kwargs: 20 | groups = kwargs.pop('groups') 21 | if groups == in_channels: 22 | kwargs['depthwise'] = True 23 | else: 24 | assert groups == 1 25 | # We're going to use only lists for defining the MixedConv2d kernel groups, 26 | # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. 27 | m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs) 28 | else: 29 | depthwise = kwargs.pop('depthwise', False) 30 | # for DW out_channels must be multiple of in_channels as must have out_channels % groups == 0 31 | groups = in_channels if depthwise else kwargs.pop('groups', 1) 32 | if 'num_experts' in kwargs and kwargs['num_experts'] > 0: 33 | m = CondConv2d(in_channels, out_channels, kernel_size, groups=groups, **kwargs) 34 | else: 35 | m = create_conv2d_pad(in_channels, out_channels, kernel_size, groups=groups, **kwargs) 36 | return m 37 | -------------------------------------------------------------------------------- /lib/models_timm/layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | """ Image to Patch Embedding using Conv2d 2 | 3 | A convolution based approach to patchifying a 2D image w/ embedding projection. 4 | 5 | Based on the impl in https://github.com/google-research/vision_transformer 6 | 7 | Hacked together by / Copyright 2020 Ross Wightman 8 | """ 9 | from torch import nn as nn 10 | 11 | from .helpers import to_2tuple 12 | from .trace_utils import _assert 13 | 14 | 15 | class PatchEmbed(nn.Module): 16 | """ 2D Image to Patch Embedding 17 | """ 18 | def __init__( 19 | self, 20 | img_size=224, 21 | patch_size=16, 22 | in_chans=3, 23 | embed_dim=768, 24 | norm_layer=None, 25 | flatten=True, 26 | bias=True, 27 | ): 28 | super().__init__() 29 | img_size = to_2tuple(img_size) 30 | patch_size = to_2tuple(patch_size) 31 | self.img_size = img_size 32 | self.patch_size = patch_size 33 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 34 | self.num_patches = self.grid_size[0] * self.grid_size[1] 35 | self.flatten = flatten 36 | 37 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) 38 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 39 | 40 | def forward(self, x): 41 | B, C, H, W = x.shape 42 | _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") 43 | _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") 44 | x = self.proj(x) 45 | if self.flatten: 46 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 47 | x = self.norm(x) 48 | return x 49 | -------------------------------------------------------------------------------- /lib/models_timm/layers/median_pool.py: -------------------------------------------------------------------------------- 1 | """ Median Pool 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from .helpers import to_2tuple, to_4tuple 7 | 8 | 9 | class MedianPool2d(nn.Module): 10 | """ Median pool (usable as median filter when stride=1) module. 11 | 12 | Args: 13 | kernel_size: size of pooling kernel, int or 2-tuple 14 | stride: pool stride, int or 2-tuple 15 | padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad 16 | same: override padding and enforce same padding, boolean 17 | """ 18 | def __init__(self, kernel_size=3, stride=1, padding=0, same=False): 19 | super(MedianPool2d, self).__init__() 20 | self.k = to_2tuple(kernel_size) 21 | self.stride = to_2tuple(stride) 22 | self.padding = to_4tuple(padding) # convert to l, r, t, b 23 | self.same = same 24 | 25 | def _padding(self, x): 26 | if self.same: 27 | ih, iw = x.size()[2:] 28 | if ih % self.stride[0] == 0: 29 | ph = max(self.k[0] - self.stride[0], 0) 30 | else: 31 | ph = max(self.k[0] - (ih % self.stride[0]), 0) 32 | if iw % self.stride[1] == 0: 33 | pw = max(self.k[1] - self.stride[1], 0) 34 | else: 35 | pw = max(self.k[1] - (iw % self.stride[1]), 0) 36 | pl = pw // 2 37 | pr = pw - pl 38 | pt = ph // 2 39 | pb = ph - pt 40 | padding = (pl, pr, pt, pb) 41 | else: 42 | padding = self.padding 43 | return padding 44 | 45 | def forward(self, x): 46 | x = F.pad(x, self._padding(x), mode='reflect') 47 | x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1]) 48 | x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0] 49 | return x 50 | -------------------------------------------------------------------------------- /lib/models_timm/layers/space_to_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SpaceToDepth(nn.Module): 6 | def __init__(self, block_size=4): 7 | super().__init__() 8 | assert block_size == 4 9 | self.bs = block_size 10 | 11 | def forward(self, x): 12 | N, C, H, W = x.size() 13 | x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs) 14 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) 15 | x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs) 16 | return x 17 | 18 | 19 | @torch.jit.script 20 | class SpaceToDepthJit(object): 21 | def __call__(self, x: torch.Tensor): 22 | # assuming hard-coded that block_size==4 for acceleration 23 | N, C, H, W = x.size() 24 | x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs) 25 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) 26 | x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs) 27 | return x 28 | 29 | 30 | class SpaceToDepthModule(nn.Module): 31 | def __init__(self, no_jit=False): 32 | super().__init__() 33 | if not no_jit: 34 | self.op = SpaceToDepthJit() 35 | else: 36 | self.op = SpaceToDepth() 37 | 38 | def forward(self, x): 39 | return self.op(x) 40 | 41 | 42 | class DepthToSpace(nn.Module): 43 | 44 | def __init__(self, block_size): 45 | super().__init__() 46 | self.bs = block_size 47 | 48 | def forward(self, x): 49 | N, C, H, W = x.size() 50 | x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W) 51 | x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs) 52 | x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs) 53 | return x 54 | -------------------------------------------------------------------------------- /lib/models_timm/layers/create_norm.py: -------------------------------------------------------------------------------- 1 | """ Norm Layer Factory 2 | 3 | Create norm modules by string (to mirror create_act and creat_norm-act fns) 4 | 5 | Copyright 2022 Ross Wightman 6 | """ 7 | import types 8 | import functools 9 | 10 | import torch.nn as nn 11 | 12 | from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d 13 | 14 | _NORM_MAP = dict( 15 | batchnorm=nn.BatchNorm2d, 16 | batchnorm2d=nn.BatchNorm2d, 17 | batchnorm1d=nn.BatchNorm1d, 18 | groupnorm=GroupNorm, 19 | groupnorm1=GroupNorm1, 20 | layernorm=LayerNorm, 21 | layernorm2d=LayerNorm2d, 22 | ) 23 | _NORM_TYPES = {m for n, m in _NORM_MAP.items()} 24 | 25 | 26 | def create_norm_layer(layer_name, num_features, act_layer=None, apply_act=True, **kwargs): 27 | layer = get_norm_layer(layer_name, act_layer=act_layer) 28 | layer_instance = layer(num_features, apply_act=apply_act, **kwargs) 29 | return layer_instance 30 | 31 | 32 | def get_norm_layer(norm_layer): 33 | assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) 34 | norm_kwargs = {} 35 | 36 | # unbind partial fn, so args can be rebound later 37 | if isinstance(norm_layer, functools.partial): 38 | norm_kwargs.update(norm_layer.keywords) 39 | norm_layer = norm_layer.func 40 | 41 | if isinstance(norm_layer, str): 42 | layer_name = norm_layer.replace('_', '') 43 | norm_layer = _NORM_MAP.get(layer_name, None) 44 | elif norm_layer in _NORM_TYPES: 45 | norm_layer = norm_layer 46 | elif isinstance(norm_layer, types.FunctionType): 47 | # if function type, assume it is a lambda/fn that creates a norm layer 48 | norm_layer = norm_layer 49 | else: 50 | type_name = norm_layer.__name__.lower().replace('_', '') 51 | norm_layer = _NORM_MAP.get(type_name, None) 52 | assert norm_layer is not None, f"No equivalent norm layer for {type_name}" 53 | 54 | if norm_kwargs: 55 | norm_layer = functools.partial(norm_layer, **norm_kwargs) # bind/rebind args 56 | return norm_layer 57 | -------------------------------------------------------------------------------- /lib/models_timm/layers/mixed_conv2d.py: -------------------------------------------------------------------------------- 1 | """ PyTorch Mixed Convolution 2 | 3 | Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595) 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | 8 | import torch 9 | from torch import nn as nn 10 | 11 | from .conv2d_same import create_conv2d_pad 12 | 13 | 14 | def _split_channels(num_chan, num_groups): 15 | split = [num_chan // num_groups for _ in range(num_groups)] 16 | split[0] += num_chan - sum(split) 17 | return split 18 | 19 | 20 | class MixedConv2d(nn.ModuleDict): 21 | """ Mixed Grouped Convolution 22 | 23 | Based on MDConv and GroupedConv in MixNet impl: 24 | https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py 25 | """ 26 | def __init__(self, in_channels, out_channels, kernel_size=3, 27 | stride=1, padding='', dilation=1, depthwise=False, **kwargs): 28 | super(MixedConv2d, self).__init__() 29 | 30 | kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] 31 | num_groups = len(kernel_size) 32 | in_splits = _split_channels(in_channels, num_groups) 33 | out_splits = _split_channels(out_channels, num_groups) 34 | self.in_channels = sum(in_splits) 35 | self.out_channels = sum(out_splits) 36 | for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): 37 | conv_groups = in_ch if depthwise else 1 38 | # use add_module to keep key space clean 39 | self.add_module( 40 | str(idx), 41 | create_conv2d_pad( 42 | in_ch, out_ch, k, stride=stride, 43 | padding=padding, dilation=dilation, groups=conv_groups, **kwargs) 44 | ) 45 | self.splits = in_splits 46 | 47 | def forward(self, x): 48 | x_split = torch.split(x, self.splits, 1) 49 | x_out = [c(x_split[i]) for i, c in enumerate(self.values())] 50 | x = torch.cat(x_out, 1) 51 | return x 52 | -------------------------------------------------------------------------------- /lib/models_timm/layers/test_time_pool.py: -------------------------------------------------------------------------------- 1 | """ Test Time Pooling (Average-Max Pool) 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | 6 | import logging 7 | from torch import nn 8 | import torch.nn.functional as F 9 | 10 | from .adaptive_avgmax_pool import adaptive_avgmax_pool2d 11 | 12 | 13 | _logger = logging.getLogger(__name__) 14 | 15 | 16 | class TestTimePoolHead(nn.Module): 17 | def __init__(self, base, original_pool=7): 18 | super(TestTimePoolHead, self).__init__() 19 | self.base = base 20 | self.original_pool = original_pool 21 | base_fc = self.base.get_classifier() 22 | if isinstance(base_fc, nn.Conv2d): 23 | self.fc = base_fc 24 | else: 25 | self.fc = nn.Conv2d( 26 | self.base.num_features, self.base.num_classes, kernel_size=1, bias=True) 27 | self.fc.weight.data.copy_(base_fc.weight.data.view(self.fc.weight.size())) 28 | self.fc.bias.data.copy_(base_fc.bias.data.view(self.fc.bias.size())) 29 | self.base.reset_classifier(0) # delete original fc layer 30 | 31 | def forward(self, x): 32 | x = self.base.forward_features(x) 33 | x = F.avg_pool2d(x, kernel_size=self.original_pool, stride=1) 34 | x = self.fc(x) 35 | x = adaptive_avgmax_pool2d(x, 1) 36 | return x.view(x.size(0), -1) 37 | 38 | 39 | def apply_test_time_pool(model, config, use_test_size=False): 40 | test_time_pool = False 41 | if not hasattr(model, 'default_cfg') or not model.default_cfg: 42 | return model, False 43 | if use_test_size and 'test_input_size' in model.default_cfg: 44 | df_input_size = model.default_cfg['test_input_size'] 45 | else: 46 | df_input_size = model.default_cfg['input_size'] 47 | if config['input_size'][-1] > df_input_size[-1] and config['input_size'][-2] > df_input_size[-2]: 48 | _logger.info('Target input size %s > pretrained default %s, using test time pooling' % 49 | (str(config['input_size'][-2:]), str(df_input_size[-2:]))) 50 | model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size']) 51 | test_time_pool = True 52 | return model, test_time_pool 53 | -------------------------------------------------------------------------------- /lib/models_timm/layers/padding.py: -------------------------------------------------------------------------------- 1 | """ Padding Helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import math 6 | from typing import List, Tuple 7 | 8 | import torch.nn.functional as F 9 | 10 | 11 | # Calculate symmetric padding for a convolution 12 | def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int: 13 | padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 14 | return padding 15 | 16 | 17 | # Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution 18 | def get_same_padding(x: int, k: int, s: int, d: int): 19 | return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0) 20 | 21 | 22 | # Can SAME padding for given args be done statically? 23 | def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_): 24 | return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 25 | 26 | 27 | # Dynamically pad input x with 'SAME' padding for conv with specified args 28 | def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0): 29 | ih, iw = x.size()[-2:] 30 | pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1]) 31 | if pad_h > 0 or pad_w > 0: 32 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value) 33 | return x 34 | 35 | 36 | def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]: 37 | dynamic = False 38 | if isinstance(padding, str): 39 | # for any string padding, the padding will be calculated for you, one of three ways 40 | padding = padding.lower() 41 | if padding == 'same': 42 | # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact 43 | if is_static_pad(kernel_size, **kwargs): 44 | # static case, no extra overhead 45 | padding = get_padding(kernel_size, **kwargs) 46 | else: 47 | # dynamic 'SAME' padding, has runtime/GPU memory overhead 48 | padding = 0 49 | dynamic = True 50 | elif padding == 'valid': 51 | # 'VALID' padding, same as padding=0 52 | padding = 0 53 | else: 54 | # Default to PyTorch style 'same'-ish symmetric padding 55 | padding = get_padding(kernel_size, **kwargs) 56 | return padding, dynamic 57 | -------------------------------------------------------------------------------- /lib/models_timm/__init__.py: -------------------------------------------------------------------------------- 1 | from .beit import * 2 | from .byoanet import * 3 | from .byobnet import * 4 | from .cait import * 5 | from .coat import * 6 | from .convit import * 7 | from .convmixer import * 8 | from .convnext import * 9 | from .crossvit import * 10 | from .cspnet import * 11 | from .deit import * 12 | from .densenet import * 13 | from .dla import * 14 | from .dpn import * 15 | from .edgenext import * 16 | from .efficientformer import * 17 | from .efficientnet import * 18 | from .gcvit import * 19 | from .ghostnet import * 20 | from .gluon_resnet import * 21 | from .gluon_xception import * 22 | from .hardcorenas import * 23 | from .hrnet import * 24 | from .inception_resnet_v2 import * 25 | from .inception_v3 import * 26 | from .inception_v4 import * 27 | from .levit import * 28 | from .maxxvit import * 29 | from .mlp_mixer import * 30 | from .mobilenetv3 import * 31 | from .mobilevit import * 32 | from .mvitv2 import * 33 | from .nasnet import * 34 | from .nest import * 35 | from .nfnet import * 36 | from .pit import * 37 | from .pnasnet import * 38 | from .poolformer import * 39 | from .pvt_v2 import * 40 | from .regnet import * 41 | from .res2net import * 42 | from .resnest import * 43 | from .resnet import * 44 | from .resnetv2 import * 45 | from .rexnet import * 46 | from .selecsls import * 47 | from .senet import * 48 | from .sequencer import * 49 | from .sknet import * 50 | from .swin_transformer import * 51 | from .swin_transformer_v2 import * 52 | from .swin_transformer_v2_cr import * 53 | from .tnt import * 54 | from .tresnet import * 55 | from .twins import * 56 | from .vgg import * 57 | from .visformer import * 58 | from .vision_transformer import * 59 | from .vision_transformer_hybrid import * 60 | from .vision_transformer_relpos import * 61 | from .volo import * 62 | from .vovnet import * 63 | from .xception import * 64 | from .xception_aligned import * 65 | from .xcit import * 66 | 67 | from .factory import create_model, parse_model_name, safe_model_name 68 | from .helpers import load_checkpoint, resume_checkpoint, model_parameters 69 | from .layers import TestTimePoolHead, apply_test_time_pool 70 | from .layers import convert_splitbn_model, convert_sync_batchnorm 71 | from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit 72 | from .layers import set_fast_norm 73 | from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\ 74 | is_model_pretrained, get_pretrained_cfg, has_pretrained_cfg_key, is_pretrained_cfg_key, get_pretrained_cfg_value 75 | -------------------------------------------------------------------------------- /lib/models_timm/layers/classifier.py: -------------------------------------------------------------------------------- 1 | """ Classifier head and layer factory 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from torch import nn as nn 6 | from torch.nn import functional as F 7 | 8 | from .adaptive_avgmax_pool import SelectAdaptivePool2d 9 | 10 | 11 | def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False): 12 | flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling 13 | if not pool_type: 14 | assert num_classes == 0 or use_conv,\ 15 | 'Pooling can only be disabled if classifier is also removed or conv classifier is used' 16 | flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling) 17 | global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten_in_pool) 18 | num_pooled_features = num_features * global_pool.feat_mult() 19 | return global_pool, num_pooled_features 20 | 21 | 22 | def _create_fc(num_features, num_classes, use_conv=False): 23 | if num_classes <= 0: 24 | fc = nn.Identity() # pass-through (no classifier) 25 | elif use_conv: 26 | fc = nn.Conv2d(num_features, num_classes, 1, bias=True) 27 | else: 28 | fc = nn.Linear(num_features, num_classes, bias=True) 29 | return fc 30 | 31 | 32 | def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False): 33 | global_pool, num_pooled_features = _create_pool(num_features, num_classes, pool_type, use_conv=use_conv) 34 | fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) 35 | return global_pool, fc 36 | 37 | 38 | class ClassifierHead(nn.Module): 39 | """Classifier head w/ configurable global pooling and dropout.""" 40 | 41 | def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0., use_conv=False): 42 | super(ClassifierHead, self).__init__() 43 | self.drop_rate = drop_rate 44 | self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv) 45 | self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) 46 | self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity() 47 | 48 | def forward(self, x, pre_logits: bool = False): 49 | x = self.global_pool(x) 50 | if self.drop_rate: 51 | x = F.dropout(x, p=float(self.drop_rate), training=self.training) 52 | if pre_logits: 53 | return x.flatten(1) 54 | else: 55 | x = self.fc(x) 56 | return self.flatten(x) 57 | -------------------------------------------------------------------------------- /utils/lesion/make_dataset.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import numpy as np 4 | from shutil import copyfile 5 | 6 | from torch.utils.data import random_split 7 | import albumentations.augmentations.functional as F 8 | import cv2 as cv 9 | 10 | #sys.path.append('../../') 11 | import helpers as h 12 | 13 | VALID_INPUT_FOLDER = '../../data/ISIC2018/ISIC2018_Task1-2_Validation_Input' 14 | TRAIN_INPUT_FOLDER = '../../data/ISIC2018/ISIC2018_Task1-2_Training_Input' 15 | VALID_GT_FOLDER = '../../data/ISIC2018/ISIC2018_Task1_Validation_GroundTruth' 16 | TRAIN_GT_FOLDER = '../../data/ISIC2018/ISIC2018_Task1_Training_GroundTruth' 17 | 18 | def get_files(folder): 19 | files = h.listdir(folder) 20 | files.sort() 21 | files = [f for f in files if not '.txt' in f] 22 | files = [os.path.join(folder, f) for f in files] 23 | return files 24 | 25 | valid_input = get_files(VALID_INPUT_FOLDER) 26 | train_input = get_files(TRAIN_INPUT_FOLDER) 27 | 28 | valid_gt = get_files(VALID_GT_FOLDER) 29 | train_gt = get_files(TRAIN_GT_FOLDER) 30 | 31 | inputs = valid_input + train_input 32 | gts = valid_gt + train_gt 33 | 34 | # split same as in Double U-Net paper: https://arxiv.org/pdf/2006.04868v2.pdf 35 | train_valid_test_split = (0.8, 0.1, 0.1) 36 | 37 | test_count = int(train_valid_test_split[2] * len(inputs)) 38 | valid_count = test_count 39 | train_count = len(inputs) - test_count * 2 40 | 41 | print(train_count, valid_count, test_count) 42 | assert(test_count + valid_count + train_count == len(inputs)) 43 | 44 | all_files = np.array(list(zip(inputs, gts))) 45 | np.random.seed(42) 46 | np.random.shuffle(all_files) 47 | 48 | train_files = all_files[:train_count] 49 | valid_files = all_files[train_count : train_count + valid_count] 50 | test_files = all_files[-test_count:] 51 | 52 | def save_files(files, folder): 53 | h.mkdir(folder) 54 | h.mkdir(os.path.join(folder, 'images')) 55 | h.mkdir(os.path.join(folder, 'masks')) 56 | 57 | for input_file, gt_file in files: 58 | file_name = input_file.split('/')[-1] 59 | input_destination = os.path.join(folder, 'images', file_name) 60 | 61 | input_img = cv.imread(input_file) 62 | input_img = F.resize(input_img, 384, 512) 63 | 64 | gt_img = cv.imread(gt_file, cv.IMREAD_GRAYSCALE) 65 | gt_img = F.resize(gt_img, 384, 512) 66 | 67 | cv.imwrite(input_destination, input_img) 68 | 69 | gt_destionation = os.path.join(folder, 'masks', file_name.replace('.jpg', '.png')) 70 | cv.imwrite(gt_destionation, gt_img) 71 | 72 | save_files(train_files, '../../data/ISIC2018/train') 73 | save_files(valid_files, '../../data/ISIC2018/valid') 74 | save_files(test_files, '../../data/ISIC2018/test') 75 | -------------------------------------------------------------------------------- /utils/lesion/lesion_dataset.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os.path as p 3 | 4 | import torch 5 | from torch.utils.data import Dataset 6 | import numpy as np 7 | import cv2 as cv 8 | import matplotlib.pyplot as plt 9 | import albumentations.augmentations.functional as F 10 | 11 | sys.path.append('..') 12 | import helpers as h 13 | import polar_transformations 14 | 15 | class LesionDataset(Dataset): 16 | 17 | width = 512 18 | height = 384 19 | 20 | in_channels = 3 21 | out_channels = 1 22 | 23 | def __init__(self, directory, polar=True, manual_centers=None, center_augmentation=False, percent=None): 24 | self.directory = p.join('datasets/lesion', directory) 25 | self.polar = polar 26 | self.manual_centers = manual_centers 27 | self.center_augmentation = center_augmentation 28 | self.percent = percent 29 | 30 | self.file_names = h.listdir(p.join(self.directory, 'label')) 31 | self.file_names.sort() 32 | 33 | def __len__(self): 34 | length = len(self.file_names) 35 | if self.percent is not None: 36 | length = int(length * self.percent) 37 | return length 38 | 39 | def __getitem__(self, idx): 40 | file_name = self.file_names[idx] 41 | label_file = p.join(self.directory, 'label', file_name) 42 | input_file = p.join(self.directory, 'input', file_name.replace('.png', '.jpg')) 43 | 44 | label = cv.imread(label_file, cv.IMREAD_GRAYSCALE) 45 | label = label.astype(np.float32) 46 | label /= 255.0 47 | 48 | input = cv.imread(input_file) 49 | input = cv.cvtColor(input, cv.COLOR_BGR2RGB) 50 | input = input.astype(np.float32) 51 | input /= 255.0 52 | input -= 0.5 53 | 54 | # convert to polar 55 | if self.polar: 56 | if self.manual_centers is not None: 57 | center = self.manual_centers[idx] 58 | else: 59 | center = polar_transformations.centroid(label) 60 | 61 | if self.center_augmentation and np.random.uniform() < 0.3: 62 | center_max_shift = 0.05 * LesionDataset.height 63 | center = np.array(center) 64 | center = ( 65 | center[0] + np.random.uniform(-center_max_shift, center_max_shift), 66 | center[1] + np.random.uniform(-center_max_shift, center_max_shift)) 67 | 68 | input = polar_transformations.to_polar(input, center) 69 | label = polar_transformations.to_polar(label, center) 70 | 71 | # to PyTorch expected format 72 | input = input.transpose(2, 0, 1) 73 | label = np.expand_dims(label, axis=-1) 74 | label = label.transpose(2, 0, 1) 75 | 76 | input_tensor = torch.from_numpy(input) 77 | label_tensor = torch.from_numpy(label) 78 | 79 | return input_tensor, label_tensor 80 | -------------------------------------------------------------------------------- /lib/models_timm/layers/global_context.py: -------------------------------------------------------------------------------- 1 | """ Global Context Attention Block 2 | 3 | Paper: `GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond` 4 | - https://arxiv.org/abs/1904.11492 5 | 6 | Official code consulted as reference: https://github.com/xvjiarui/GCNet 7 | 8 | Hacked together by / Copyright 2021 Ross Wightman 9 | """ 10 | from torch import nn as nn 11 | import torch.nn.functional as F 12 | 13 | from .create_act import create_act_layer, get_act_layer 14 | from .helpers import make_divisible 15 | from .mlp import ConvMlp 16 | from .norm import LayerNorm2d 17 | 18 | 19 | class GlobalContext(nn.Module): 20 | 21 | def __init__(self, channels, use_attn=True, fuse_add=False, fuse_scale=True, init_last_zero=False, 22 | rd_ratio=1./8, rd_channels=None, rd_divisor=1, act_layer=nn.ReLU, gate_layer='sigmoid'): 23 | super(GlobalContext, self).__init__() 24 | act_layer = get_act_layer(act_layer) 25 | 26 | self.conv_attn = nn.Conv2d(channels, 1, kernel_size=1, bias=True) if use_attn else None 27 | 28 | if rd_channels is None: 29 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) 30 | if fuse_add: 31 | self.mlp_add = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d) 32 | else: 33 | self.mlp_add = None 34 | if fuse_scale: 35 | self.mlp_scale = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d) 36 | else: 37 | self.mlp_scale = None 38 | 39 | self.gate = create_act_layer(gate_layer) 40 | self.init_last_zero = init_last_zero 41 | self.reset_parameters() 42 | 43 | def reset_parameters(self): 44 | if self.conv_attn is not None: 45 | nn.init.kaiming_normal_(self.conv_attn.weight, mode='fan_in', nonlinearity='relu') 46 | if self.mlp_add is not None: 47 | nn.init.zeros_(self.mlp_add.fc2.weight) 48 | 49 | def forward(self, x): 50 | B, C, H, W = x.shape 51 | 52 | if self.conv_attn is not None: 53 | attn = self.conv_attn(x).reshape(B, 1, H * W) # (B, 1, H * W) 54 | attn = F.softmax(attn, dim=-1).unsqueeze(3) # (B, 1, H * W, 1) 55 | context = x.reshape(B, C, H * W).unsqueeze(1) @ attn 56 | context = context.view(B, C, 1, 1) 57 | else: 58 | context = x.mean(dim=(2, 3), keepdim=True) 59 | 60 | if self.mlp_scale is not None: 61 | mlp_x = self.mlp_scale(context) 62 | x = x * self.gate(mlp_x) 63 | if self.mlp_add is not None: 64 | mlp_x = self.mlp_add(context) 65 | x = x + mlp_x 66 | 67 | return x 68 | -------------------------------------------------------------------------------- /lib/models_timm/layers/fast_norm.py: -------------------------------------------------------------------------------- 1 | """ 'Fast' Normalization Functions 2 | 3 | For GroupNorm and LayerNorm these functions bypass typical AMP upcast to float32. 4 | 5 | Additionally, for LayerNorm, the APEX fused LN is used if available (which also does not upcast) 6 | 7 | Hacked together by / Copyright 2022 Ross Wightman 8 | """ 9 | from typing import List, Optional 10 | 11 | import torch 12 | from torch.nn import functional as F 13 | 14 | try: 15 | from apex.normalization.fused_layer_norm import fused_layer_norm_affine 16 | has_apex = True 17 | except ImportError: 18 | has_apex = False 19 | 20 | 21 | # fast (ie lower precision LN) can be disabled with this flag if issues crop up 22 | _USE_FAST_NORM = False # defaulting to False for now 23 | 24 | 25 | def is_fast_norm(): 26 | return _USE_FAST_NORM 27 | 28 | 29 | def set_fast_norm(enable=True): 30 | global _USE_FAST_NORM 31 | _USE_FAST_NORM = enable 32 | 33 | 34 | def fast_group_norm( 35 | x: torch.Tensor, 36 | num_groups: int, 37 | weight: Optional[torch.Tensor] = None, 38 | bias: Optional[torch.Tensor] = None, 39 | eps: float = 1e-5 40 | ) -> torch.Tensor: 41 | if torch.jit.is_scripting(): 42 | # currently cannot use is_autocast_enabled within torchscript 43 | return F.group_norm(x, num_groups, weight, bias, eps) 44 | 45 | if torch.is_autocast_enabled(): 46 | # normally native AMP casts GN inputs to float32 47 | # here we use the low precision autocast dtype 48 | # FIXME what to do re CPU autocast? 49 | dt = torch.get_autocast_gpu_dtype() 50 | x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) 51 | 52 | with torch.cuda.amp.autocast(enabled=False): 53 | return F.group_norm(x, num_groups, weight, bias, eps) 54 | 55 | 56 | def fast_layer_norm( 57 | x: torch.Tensor, 58 | normalized_shape: List[int], 59 | weight: Optional[torch.Tensor] = None, 60 | bias: Optional[torch.Tensor] = None, 61 | eps: float = 1e-5 62 | ) -> torch.Tensor: 63 | if torch.jit.is_scripting(): 64 | # currently cannot use is_autocast_enabled within torchscript 65 | return F.layer_norm(x, normalized_shape, weight, bias, eps) 66 | 67 | if has_apex: 68 | return fused_layer_norm_affine(x, weight, bias, normalized_shape, eps) 69 | 70 | if torch.is_autocast_enabled(): 71 | # normally native AMP casts LN inputs to float32 72 | # apex LN does not, this is behaving like Apex 73 | dt = torch.get_autocast_gpu_dtype() 74 | # FIXME what to do re CPU autocast? 75 | x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) 76 | 77 | with torch.cuda.amp.autocast(enabled=False): 78 | return F.layer_norm(x, normalized_shape, weight, bias, eps) 79 | -------------------------------------------------------------------------------- /utils/lesion/helpers.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import nibabel as nib 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import pydicom 7 | from skimage.metrics import adapted_rand_error 8 | from medpy.metric.binary import precision as mp_precision 9 | from medpy.metric.binary import recall as mp_recall 10 | from medpy.metric.binary import dc 11 | 12 | def _thresh(img): 13 | img[img > 0.5] = 1 14 | img[img <= 0.5] = 0 15 | return img 16 | 17 | def dsc(y_pred, y_true): 18 | y_pred = _thresh(y_pred) 19 | y_true = _thresh(y_true) 20 | 21 | return dc(y_pred, y_true) 22 | 23 | def iou(y_pred, y_true): 24 | y_pred = _thresh(y_pred) 25 | y_true = _thresh(y_true) 26 | 27 | intersection = np.logical_and(y_pred, y_true) 28 | union = np.logical_or(y_pred, y_true) 29 | if not np.any(union): 30 | return 0 if np.any(y_pred) else 1 31 | 32 | return intersection.sum() / float(union.sum()) 33 | 34 | def precision(y_pred, y_true): 35 | y_pred = _thresh(y_pred).astype(np.int) 36 | y_true = _thresh(y_true).astype(np.int) 37 | 38 | if y_true.sum() <= 5: 39 | # when the example is nearly empty, avoid division by 0 40 | # if the prediction is also empty, precision is 1 41 | # otherwise it's 0 42 | return 1 if y_pred.sum() <= 5 else 0 43 | 44 | if y_pred.sum() <= 5: 45 | return 0. 46 | 47 | return mp_precision(y_pred, y_true) 48 | 49 | def recall(y_pred, y_true): 50 | y_pred = _thresh(y_pred).astype(np.int) 51 | y_true = _thresh(y_true).astype(np.int) 52 | 53 | if y_true.sum() <= 5: 54 | # when the example is nearly empty, avoid division by 0 55 | # if the prediction is also empty, recall is 1 56 | # otherwise it's 0 57 | return 1 if y_pred.sum() <= 5 else 0 58 | 59 | if y_pred.sum() <= 5: 60 | return 0. 61 | 62 | r = mp_recall(y_pred, y_true) 63 | return r 64 | 65 | def listdir(path): 66 | """ List files but remove hidden files from list """ 67 | return [item for item in os.listdir(path) if item[0] != '.'] 68 | 69 | def mkdir(path): 70 | if not os.path.exists(path): 71 | os.makedirs(path) 72 | 73 | def show_images_row(imgs, titles=None, rows=1, figsize=(6.4, 4.8), **kwargs): 74 | ''' 75 | Display grid of cv2 images 76 | :param img: list [cv::mat] 77 | :param title: titles 78 | :return: None 79 | ''' 80 | assert ((titles is None) or (len(imgs) == len(titles))) 81 | num_images = len(imgs) 82 | 83 | if titles is None: 84 | titles = ['Image (%d)' % i for i in range(1, num_images + 1)] 85 | 86 | fig = plt.figure(figsize=figsize) 87 | for n, (image, title) in enumerate(zip(imgs, titles)): 88 | ax = fig.add_subplot(rows, np.ceil(num_images / float(rows)), n + 1) 89 | plt.imshow(image, **kwargs) 90 | ax.set_title(title) 91 | plt.axis('off') 92 | -------------------------------------------------------------------------------- /utils/preprocess_synapse_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from time import time 4 | 5 | import numpy as np 6 | import SimpleITK as sitk 7 | import nibabel as nib 8 | import scipy.ndimage as ndimage 9 | import h5py 10 | 11 | splits = ['train', 'test'] 12 | #train = True # Set True to process training set and set False for testset 13 | 14 | for split in splits: 15 | if(split == 'train'): 16 | ct_path = './data/synapse/Abdomen/RawData/TrainSet/img' # set your path to your trainset directory 17 | seg_path = './data/synapse/Abdomen/RawData/TrainSet/label' 18 | save_path = './data/synapse/train_npz_new/' 19 | else: 20 | ct_path = './data/synapse/Abdomen/RawData/TestSet/img' # set your path to your testset directory 21 | seg_path = './data/synapse/Abdomen/RawData/TestSet/label' 22 | save_path = './data/synapse/test_vol_h5_new/' 23 | 24 | if os.path.exists(save_path) is False: 25 | os.mkdir(save_path) 26 | 27 | upper = 275 28 | lower = -125 29 | 30 | start_time = time() 31 | 32 | for ct_file in os.listdir(ct_path): 33 | 34 | ct = nib.load(os.path.join(ct_path, ct_file)) 35 | seg = nib.load(os.path.join(seg_path, ct_file.replace('img', 'label'))) 36 | 37 | #Convert them to numpy format, 38 | ct_array = ct.get_fdata() 39 | seg_array = seg.get_fdata() 40 | 41 | ct_array = np.clip(ct_array, lower, upper) 42 | 43 | #print([np.min(ct_array), np.max(ct_array)]) 44 | 45 | #normalize each 3D image to [0, 1] 46 | ct_array = (ct_array - lower) / (upper - lower) 47 | 48 | #print([np.min(ct_array), np.max(ct_array)]) 49 | 50 | ct_array = np.transpose(ct_array, (2, 0, 1)) 51 | seg_array = np.transpose(seg_array, (2, 0, 1)) 52 | 53 | print('file name:', ct_file) 54 | print('shape:', ct_array.shape) 55 | 56 | ct_number = ct_file.split('.')[0] 57 | if(split == 'test'): 58 | new_ct_name = ct_number.replace('img', 'case')+'.npy.h5' 59 | hf = h5py.File(os.path.join(save_path, new_ct_name), 'w') 60 | hf.create_dataset('image', data=ct_array) 61 | hf.create_dataset('label', data=seg_array) 62 | hf.close() 63 | continue 64 | 65 | for s_idx in range(ct_array.shape[0]): 66 | ct_array_s = ct_array[s_idx, :, :] 67 | seg_array_s = seg_array[s_idx, :, :] 68 | slice_no = "{:03d}".format(s_idx) 69 | new_ct_name = ct_number.replace('img', 'case') + '_slice' + slice_no 70 | np.savez(os.path.join(save_path, new_ct_name), image=ct_array_s, label=seg_array_s) 71 | 72 | print('already use {:.3f} min'.format((time() - start_time) / 60)) 73 | print('-----------') 74 | 75 | -------------------------------------------------------------------------------- /lib/models_timm/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .activations import * 2 | from .adaptive_avgmax_pool import \ 3 | adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d 4 | from .blur_pool import BlurPool2d 5 | from .classifier import ClassifierHead, create_classifier 6 | from .cond_conv2d import CondConv2d, get_condconv_initializer 7 | from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ 8 | set_layer_config 9 | from .conv2d_same import Conv2dSame, conv2d_same 10 | from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct 11 | from .create_act import create_act_layer, get_act_layer, get_act_fn 12 | from .create_attn import get_attn, create_attn 13 | from .create_conv2d import create_conv2d 14 | from .create_norm import get_norm_layer, create_norm_layer 15 | from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer 16 | from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path 17 | from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn 18 | from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\ 19 | EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a 20 | from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm 21 | from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d 22 | from .gather_excite import GatherExcite 23 | from .global_context import GlobalContext 24 | from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple 25 | from .inplace_abn import InplaceAbn 26 | from .linear import Linear 27 | from .mixed_conv2d import MixedConv2d 28 | from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp 29 | from .non_local_attn import NonLocalAttn, BatNonLocalAttn 30 | from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d 31 | from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm 32 | from .padding import get_padding, get_same_padding, pad_same 33 | from .patch_embed import PatchEmbed 34 | from .pool2d_same import AvgPool2dSame, create_pool2d 35 | from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite 36 | from .selective_kernel import SelectiveKernel 37 | from .separable_conv import SeparableConv2d, SeparableConvNormAct 38 | from .space_to_depth import SpaceToDepthModule 39 | from .split_attn import SplitAttn 40 | from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model 41 | from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame 42 | from .test_time_pool import TestTimePoolHead, apply_test_time_pool 43 | from .trace_utils import _assert, _float_to_int 44 | from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_ 45 | -------------------------------------------------------------------------------- /lib/models_timm/layers/filter_response_norm.py: -------------------------------------------------------------------------------- 1 | """ Filter Response Norm in PyTorch 2 | 3 | Based on `Filter Response Normalization Layer` - https://arxiv.org/abs/1911.09737 4 | 5 | Hacked together by / Copyright 2021 Ross Wightman 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | 10 | from .create_act import create_act_layer 11 | from .trace_utils import _assert 12 | 13 | 14 | def inv_instance_rms(x, eps: float = 1e-5): 15 | rms = x.square().float().mean(dim=(2, 3), keepdim=True).add(eps).rsqrt().to(x.dtype) 16 | return rms.expand(x.shape) 17 | 18 | 19 | class FilterResponseNormTlu2d(nn.Module): 20 | def __init__(self, num_features, apply_act=True, eps=1e-5, rms=True, **_): 21 | super(FilterResponseNormTlu2d, self).__init__() 22 | self.apply_act = apply_act # apply activation (non-linearity) 23 | self.rms = rms 24 | self.eps = eps 25 | self.weight = nn.Parameter(torch.ones(num_features)) 26 | self.bias = nn.Parameter(torch.zeros(num_features)) 27 | self.tau = nn.Parameter(torch.zeros(num_features)) if apply_act else None 28 | self.reset_parameters() 29 | 30 | def reset_parameters(self): 31 | nn.init.ones_(self.weight) 32 | nn.init.zeros_(self.bias) 33 | if self.tau is not None: 34 | nn.init.zeros_(self.tau) 35 | 36 | def forward(self, x): 37 | _assert(x.dim() == 4, 'expected 4D input') 38 | x_dtype = x.dtype 39 | v_shape = (1, -1, 1, 1) 40 | x = x * inv_instance_rms(x, self.eps) 41 | x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) 42 | return torch.maximum(x, self.tau.reshape(v_shape).to(dtype=x_dtype)) if self.tau is not None else x 43 | 44 | 45 | class FilterResponseNormAct2d(nn.Module): 46 | def __init__(self, num_features, apply_act=True, act_layer=nn.ReLU, inplace=None, rms=True, eps=1e-5, **_): 47 | super(FilterResponseNormAct2d, self).__init__() 48 | if act_layer is not None and apply_act: 49 | self.act = create_act_layer(act_layer, inplace=inplace) 50 | else: 51 | self.act = nn.Identity() 52 | self.rms = rms 53 | self.eps = eps 54 | self.weight = nn.Parameter(torch.ones(num_features)) 55 | self.bias = nn.Parameter(torch.zeros(num_features)) 56 | self.reset_parameters() 57 | 58 | def reset_parameters(self): 59 | nn.init.ones_(self.weight) 60 | nn.init.zeros_(self.bias) 61 | 62 | def forward(self, x): 63 | _assert(x.dim() == 4, 'expected 4D input') 64 | x_dtype = x.dtype 65 | v_shape = (1, -1, 1, 1) 66 | x = x * inv_instance_rms(x, self.eps) 67 | x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) 68 | return self.act(x) 69 | -------------------------------------------------------------------------------- /lib/models_timm/layers/activations_jit.py: -------------------------------------------------------------------------------- 1 | """ Activations 2 | 3 | A collection of jit-scripted activations fn and modules with a common interface so that they can 4 | easily be swapped. All have an `inplace` arg even if not used. 5 | 6 | All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not 7 | currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted 8 | versions if they contain in-place ops. 9 | 10 | Hacked together by / Copyright 2020 Ross Wightman 11 | """ 12 | 13 | import torch 14 | from torch import nn as nn 15 | from torch.nn import functional as F 16 | 17 | 18 | @torch.jit.script 19 | def swish_jit(x, inplace: bool = False): 20 | """Swish - Described in: https://arxiv.org/abs/1710.05941 21 | """ 22 | return x.mul(x.sigmoid()) 23 | 24 | 25 | @torch.jit.script 26 | def mish_jit(x, _inplace: bool = False): 27 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 28 | """ 29 | return x.mul(F.softplus(x).tanh()) 30 | 31 | 32 | class SwishJit(nn.Module): 33 | def __init__(self, inplace: bool = False): 34 | super(SwishJit, self).__init__() 35 | 36 | def forward(self, x): 37 | return swish_jit(x) 38 | 39 | 40 | class MishJit(nn.Module): 41 | def __init__(self, inplace: bool = False): 42 | super(MishJit, self).__init__() 43 | 44 | def forward(self, x): 45 | return mish_jit(x) 46 | 47 | 48 | @torch.jit.script 49 | def hard_sigmoid_jit(x, inplace: bool = False): 50 | # return F.relu6(x + 3.) / 6. 51 | return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? 52 | 53 | 54 | class HardSigmoidJit(nn.Module): 55 | def __init__(self, inplace: bool = False): 56 | super(HardSigmoidJit, self).__init__() 57 | 58 | def forward(self, x): 59 | return hard_sigmoid_jit(x) 60 | 61 | 62 | @torch.jit.script 63 | def hard_swish_jit(x, inplace: bool = False): 64 | # return x * (F.relu6(x + 3.) / 6) 65 | return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? 66 | 67 | 68 | class HardSwishJit(nn.Module): 69 | def __init__(self, inplace: bool = False): 70 | super(HardSwishJit, self).__init__() 71 | 72 | def forward(self, x): 73 | return hard_swish_jit(x) 74 | 75 | 76 | @torch.jit.script 77 | def hard_mish_jit(x, inplace: bool = False): 78 | """ Hard Mish 79 | Experimental, based on notes by Mish author Diganta Misra at 80 | https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md 81 | """ 82 | return 0.5 * x * (x + 2).clamp(min=0, max=2) 83 | 84 | 85 | class HardMishJit(nn.Module): 86 | def __init__(self, inplace: bool = False): 87 | super(HardMishJit, self).__init__() 88 | 89 | def forward(self, x): 90 | return hard_mish_jit(x) 91 | -------------------------------------------------------------------------------- /lib/models_timm/layers/separable_conv.py: -------------------------------------------------------------------------------- 1 | """ Depthwise Separable Conv Modules 2 | 3 | Basic DWS convs. Other variations of DWS exist with batch norm or activations between the 4 | DW and PW convs such as the Depthwise modules in MobileNetV2 / EfficientNet and Xception. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | from torch import nn as nn 9 | 10 | from .create_conv2d import create_conv2d 11 | from .create_norm_act import get_norm_act_layer 12 | 13 | 14 | class SeparableConvNormAct(nn.Module): 15 | """ Separable Conv w/ trailing Norm and Activation 16 | """ 17 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, 18 | channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, 19 | apply_act=True, drop_layer=None): 20 | super(SeparableConvNormAct, self).__init__() 21 | 22 | self.conv_dw = create_conv2d( 23 | in_channels, int(in_channels * channel_multiplier), kernel_size, 24 | stride=stride, dilation=dilation, padding=padding, depthwise=True) 25 | 26 | self.conv_pw = create_conv2d( 27 | int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) 28 | 29 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer) 30 | norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {} 31 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs) 32 | 33 | @property 34 | def in_channels(self): 35 | return self.conv_dw.in_channels 36 | 37 | @property 38 | def out_channels(self): 39 | return self.conv_pw.out_channels 40 | 41 | def forward(self, x): 42 | x = self.conv_dw(x) 43 | x = self.conv_pw(x) 44 | x = self.bn(x) 45 | return x 46 | 47 | 48 | SeparableConvBnAct = SeparableConvNormAct 49 | 50 | 51 | class SeparableConv2d(nn.Module): 52 | """ Separable Conv 53 | """ 54 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, 55 | channel_multiplier=1.0, pw_kernel_size=1): 56 | super(SeparableConv2d, self).__init__() 57 | 58 | self.conv_dw = create_conv2d( 59 | in_channels, int(in_channels * channel_multiplier), kernel_size, 60 | stride=stride, dilation=dilation, padding=padding, depthwise=True) 61 | 62 | self.conv_pw = create_conv2d( 63 | int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) 64 | 65 | @property 66 | def in_channels(self): 67 | return self.conv_dw.in_channels 68 | 69 | @property 70 | def out_channels(self): 71 | return self.conv_pw.out_channels 72 | 73 | def forward(self, x): 74 | x = self.conv_dw(x) 75 | x = self.conv_pw(x) 76 | return x 77 | -------------------------------------------------------------------------------- /utils/preprocess_synapse_data_3d.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from time import time 4 | 5 | import numpy as np 6 | import SimpleITK as sitk 7 | import nibabel as nib 8 | import scipy.ndimage as ndimage 9 | import h5py 10 | 11 | splits = ['train', 'test'] 12 | #train = True # Set True to process training set and set False for testset 13 | 14 | for split in splits: 15 | if(split == 'train'): 16 | ct_path = './data/synapse/Abdomen/RawData/TrainSet/img' # set your path to your trainset directory 17 | seg_path = './data/synapse/Abdomen/RawData/TrainSet/label' 18 | save_path = './data/synapse/train_npz_mframes/' 19 | else: 20 | ct_path = './data/synapse/Abdomen/RawData/TestSet/img' # set your path to your testset directory 21 | seg_path = './data/synapse/Abdomen/RawData/TestSet/label' 22 | save_path = './data/synapse/test_vol_h5_mframes/' 23 | 24 | if os.path.exists(save_path) is False: 25 | os.mkdir(save_path) 26 | 27 | upper = 275 28 | lower = -125 29 | 30 | start_time = time() 31 | min_size= 10000 32 | for ct_file in os.listdir(ct_path): 33 | 34 | ct = nib.load(os.path.join(ct_path, ct_file)) 35 | seg = nib.load(os.path.join(seg_path, ct_file.replace('img', 'label'))) 36 | 37 | #Convert them to numpy format, 38 | ct_array = ct.get_fdata() 39 | seg_array = seg.get_fdata() 40 | 41 | ct_array = np.clip(ct_array, lower, upper) 42 | 43 | #print([np.min(ct_array), np.max(ct_array)]) 44 | 45 | #normalize each 3D image to [0, 1] 46 | ct_array = (ct_array - lower) / (upper - lower) 47 | 48 | #print([np.min(ct_array), np.max(ct_array)]) 49 | 50 | ct_array = np.transpose(ct_array, (2, 0, 1)) 51 | seg_array = np.transpose(seg_array, (2, 0, 1)) 52 | 53 | print('file name:', ct_file) 54 | print('shape:', ct_array.shape) 55 | 56 | if(ct_array.shape[0] < min_size): 57 | min_size = ct_array.shape[0] 58 | 59 | ct_number = ct_file.split('.')[0] 60 | if(split == 'test'): 61 | new_ct_name = ct_number.replace('img', 'case')+'.npy.h5' 62 | hf = h5py.File(os.path.join(save_path, new_ct_name), 'w') 63 | hf.create_dataset('image', data=ct_array) 64 | hf.create_dataset('label', data=seg_array) 65 | hf.close() 66 | continue 67 | 68 | for s_idx in range(ct_array.shape[0]-2): 69 | #ct_array_s = np.zeros() 70 | ct_array_s = np.transpose(ct_array, (1, 2, 0))[:, :, s_idx:s_idx+3] 71 | print(ct_array_s.shape) 72 | seg_array_s = seg_array[s_idx+1, :, :] 73 | slice_no = "{:03d}".format(s_idx) 74 | new_ct_name = ct_number.replace('img', 'case') + '_slice' + slice_no 75 | np.savez(os.path.join(save_path, new_ct_name), image=ct_array_s, label=seg_array_s) 76 | 77 | 78 | print('already use {:.3f} min'.format((time() - start_time) / 60)) 79 | print('-----------') 80 | print('max_size '+str(min_size)) 81 | -------------------------------------------------------------------------------- /utils/dataset_ACDC.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | 4 | import os 5 | import random 6 | import numpy as np 7 | import torch 8 | from scipy import ndimage 9 | from scipy.ndimage.interpolation import zoom 10 | from torch.utils.data import Dataset 11 | 12 | 13 | def random_rot_flip(image, label): 14 | k = np.random.randint(0, 4) 15 | image = np.rot90(image, k) 16 | label = np.rot90(label, k) 17 | axis = np.random.randint(0, 2) 18 | image = np.flip(image, axis=axis).copy() 19 | label = np.flip(label, axis=axis).copy() 20 | return image, label 21 | 22 | 23 | def random_rotate(image, label): 24 | angle = np.random.randint(-20, 20) 25 | image = ndimage.rotate(image, angle, order=0, reshape=False) 26 | label = ndimage.rotate(label, angle, order=0, reshape=False) 27 | return image, label 28 | 29 | 30 | class RandomGenerator(object): 31 | def __init__(self, output_size): 32 | self.output_size = output_size 33 | 34 | def __call__(self, sample): 35 | image, label = sample['image'], sample['label'] 36 | 37 | if random.random() > 0.5: 38 | image, label = random_rot_flip(image, label) 39 | elif random.random() > 0.5: 40 | image, label = random_rotate(image, label) 41 | x, y = image.shape 42 | if x != self.output_size[0] or y != self.output_size[1]: 43 | image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y), order=3) # why not 3? 44 | label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0) 45 | image = torch.from_numpy(image.astype(np.float32)).unsqueeze(0) 46 | label = torch.from_numpy(label.astype(np.float32)) 47 | sample = {'image': image, 'label': label.long()} 48 | return sample 49 | 50 | 51 | class ACDCdataset(Dataset): 52 | def __init__(self, base_dir, list_dir, split, transform=None): 53 | self.transform = transform # using transform in torch! 54 | self.split = split 55 | self.sample_list = open(os.path.join(list_dir, self.split+'.txt')).readlines() 56 | self.data_dir = base_dir 57 | 58 | def __len__(self): 59 | return len(self.sample_list) 60 | 61 | def __getitem__(self, idx): 62 | if self.split == "train" or self.split == "valid": 63 | slice_name = self.sample_list[idx].strip('\n') 64 | data_path = os.path.join(self.data_dir, self.split, slice_name) 65 | data = np.load(data_path) 66 | image, label = data['img'], data['label'] 67 | else: 68 | vol_name = self.sample_list[idx].strip('\n') 69 | filepath = self.data_dir + "/{}".format(vol_name) 70 | data = np.load(filepath) 71 | image, label = data['img'], data['label'] 72 | 73 | sample = {'image': image, 'label': label} 74 | if self.transform and self.split == "train": 75 | sample = self.transform(sample) 76 | sample['case_name'] = self.sample_list[idx].strip('\n') 77 | return sample 78 | -------------------------------------------------------------------------------- /lib/models_timm/layers/pool2d_same.py: -------------------------------------------------------------------------------- 1 | """ AvgPool2d w/ Same Padding 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from typing import List, Tuple, Optional 9 | 10 | from .helpers import to_2tuple 11 | from .padding import pad_same, get_padding_value 12 | 13 | 14 | def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), 15 | ceil_mode: bool = False, count_include_pad: bool = True): 16 | # FIXME how to deal with count_include_pad vs not for external padding? 17 | x = pad_same(x, kernel_size, stride) 18 | return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad) 19 | 20 | 21 | class AvgPool2dSame(nn.AvgPool2d): 22 | """ Tensorflow like 'SAME' wrapper for 2D average pooling 23 | """ 24 | def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True): 25 | kernel_size = to_2tuple(kernel_size) 26 | stride = to_2tuple(stride) 27 | super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad) 28 | 29 | def forward(self, x): 30 | x = pad_same(x, self.kernel_size, self.stride) 31 | return F.avg_pool2d( 32 | x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad) 33 | 34 | 35 | def max_pool2d_same( 36 | x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), 37 | dilation: List[int] = (1, 1), ceil_mode: bool = False): 38 | x = pad_same(x, kernel_size, stride, value=-float('inf')) 39 | return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode) 40 | 41 | 42 | class MaxPool2dSame(nn.MaxPool2d): 43 | """ Tensorflow like 'SAME' wrapper for 2D max pooling 44 | """ 45 | def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False): 46 | kernel_size = to_2tuple(kernel_size) 47 | stride = to_2tuple(stride) 48 | dilation = to_2tuple(dilation) 49 | super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode) 50 | 51 | def forward(self, x): 52 | x = pad_same(x, self.kernel_size, self.stride, value=-float('inf')) 53 | return F.max_pool2d(x, self.kernel_size, self.stride, (0, 0), self.dilation, self.ceil_mode) 54 | 55 | 56 | def create_pool2d(pool_type, kernel_size, stride=None, **kwargs): 57 | stride = stride or kernel_size 58 | padding = kwargs.pop('padding', '') 59 | padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs) 60 | if is_dynamic: 61 | if pool_type == 'avg': 62 | return AvgPool2dSame(kernel_size, stride=stride, **kwargs) 63 | elif pool_type == 'max': 64 | return MaxPool2dSame(kernel_size, stride=stride, **kwargs) 65 | else: 66 | assert False, f'Unsupported pool type {pool_type}' 67 | else: 68 | if pool_type == 'avg': 69 | return nn.AvgPool2d(kernel_size, stride=stride, padding=padding, **kwargs) 70 | elif pool_type == 'max': 71 | return nn.MaxPool2d(kernel_size, stride=stride, padding=padding, **kwargs) 72 | else: 73 | assert False, f'Unsupported pool type {pool_type}' 74 | -------------------------------------------------------------------------------- /lib/models_timm/layers/split_attn.py: -------------------------------------------------------------------------------- 1 | """ Split Attention Conv2d (for ResNeSt Models) 2 | 3 | Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955 4 | 5 | Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt 6 | 7 | Modified for torchscript compat, performance, and consistency with timm by Ross Wightman 8 | """ 9 | import torch 10 | import torch.nn.functional as F 11 | from torch import nn 12 | 13 | from .helpers import make_divisible 14 | 15 | 16 | class RadixSoftmax(nn.Module): 17 | def __init__(self, radix, cardinality): 18 | super(RadixSoftmax, self).__init__() 19 | self.radix = radix 20 | self.cardinality = cardinality 21 | 22 | def forward(self, x): 23 | batch = x.size(0) 24 | if self.radix > 1: 25 | x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) 26 | x = F.softmax(x, dim=1) 27 | x = x.reshape(batch, -1) 28 | else: 29 | x = torch.sigmoid(x) 30 | return x 31 | 32 | 33 | class SplitAttn(nn.Module): 34 | """Split-Attention (aka Splat) 35 | """ 36 | def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=None, 37 | dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8, 38 | act_layer=nn.ReLU, norm_layer=None, drop_layer=None, **kwargs): 39 | super(SplitAttn, self).__init__() 40 | out_channels = out_channels or in_channels 41 | self.radix = radix 42 | mid_chs = out_channels * radix 43 | if rd_channels is None: 44 | attn_chs = make_divisible(in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor) 45 | else: 46 | attn_chs = rd_channels * radix 47 | 48 | padding = kernel_size // 2 if padding is None else padding 49 | self.conv = nn.Conv2d( 50 | in_channels, mid_chs, kernel_size, stride, padding, dilation, 51 | groups=groups * radix, bias=bias, **kwargs) 52 | self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity() 53 | self.drop = drop_layer() if drop_layer is not None else nn.Identity() 54 | self.act0 = act_layer(inplace=True) 55 | self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups) 56 | self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity() 57 | self.act1 = act_layer(inplace=True) 58 | self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups) 59 | self.rsoftmax = RadixSoftmax(radix, groups) 60 | 61 | def forward(self, x): 62 | x = self.conv(x) 63 | x = self.bn0(x) 64 | x = self.drop(x) 65 | x = self.act0(x) 66 | 67 | B, RC, H, W = x.shape 68 | if self.radix > 1: 69 | x = x.reshape((B, self.radix, RC // self.radix, H, W)) 70 | x_gap = x.sum(dim=1) 71 | else: 72 | x_gap = x 73 | x_gap = x_gap.mean((2, 3), keepdim=True) 74 | x_gap = self.fc1(x_gap) 75 | x_gap = self.bn1(x_gap) 76 | x_gap = self.act1(x_gap) 77 | x_attn = self.fc2(x_gap) 78 | 79 | x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1) 80 | if self.radix > 1: 81 | out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1) 82 | else: 83 | out = x * x_attn 84 | return out.contiguous() 85 | -------------------------------------------------------------------------------- /lib/models_timm/factory.py: -------------------------------------------------------------------------------- 1 | from urllib.parse import urlsplit, urlunsplit 2 | import os 3 | 4 | from .registry import is_model, is_model_in_modules, model_entrypoint 5 | from .helpers import load_checkpoint 6 | from .layers import set_layer_config 7 | from .hub import load_model_config_from_hf 8 | 9 | 10 | def parse_model_name(model_name): 11 | model_name = model_name.replace('hf_hub', 'hf-hub') # NOTE for backwards compat, to deprecate hf_hub use 12 | parsed = urlsplit(model_name) 13 | assert parsed.scheme in ('', 'timm', 'hf-hub') 14 | if parsed.scheme == 'hf-hub': 15 | # FIXME may use fragment as revision, currently `@` in URI path 16 | return parsed.scheme, parsed.path 17 | else: 18 | model_name = os.path.split(parsed.path)[-1] 19 | return 'timm', model_name 20 | 21 | 22 | def safe_model_name(model_name, remove_source=True): 23 | def make_safe(name): 24 | return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_') 25 | if remove_source: 26 | model_name = parse_model_name(model_name)[-1] 27 | return make_safe(model_name) 28 | 29 | 30 | def create_model( 31 | model_name, 32 | pretrained=False, 33 | pretrained_cfg=None, 34 | checkpoint_path='', 35 | scriptable=None, 36 | exportable=None, 37 | no_jit=None, 38 | **kwargs): 39 | """Create a model 40 | 41 | Args: 42 | model_name (str): name of model to instantiate 43 | pretrained (bool): load pretrained ImageNet-1k weights if true 44 | checkpoint_path (str): path of checkpoint to load after model is initialized 45 | scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet) 46 | exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet) 47 | no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only) 48 | 49 | Keyword Args: 50 | drop_rate (float): dropout rate for training (default: 0.0) 51 | global_pool (str): global pool type (default: 'avg') 52 | **: other kwargs are model specific 53 | """ 54 | # Parameters that aren't supported by all models or are intended to only override model defaults if set 55 | # should default to None in command line args/cfg. Remove them if they are present and not set so that 56 | # non-supporting models don't break and default args remain in effect. 57 | kwargs = {k: v for k, v in kwargs.items() if v is not None} 58 | 59 | model_source, model_name = parse_model_name(model_name) 60 | if model_source == 'hf-hub': 61 | # FIXME hf-hub source overrides any passed in pretrained_cfg, warn? 62 | # For model names specified in the form `hf-hub:path/architecture_name@revision`, 63 | # load model weights + pretrained_cfg from Hugging Face hub. 64 | pretrained_cfg, model_name = load_model_config_from_hf(model_name) 65 | 66 | if not is_model(model_name): 67 | raise RuntimeError('Unknown model (%s)' % model_name) 68 | 69 | create_fn = model_entrypoint(model_name) 70 | with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit): 71 | model = create_fn(pretrained=pretrained, pretrained_cfg=pretrained_cfg, **kwargs) 72 | 73 | if checkpoint_path: 74 | load_checkpoint(model, checkpoint_path) 75 | 76 | return model 77 | -------------------------------------------------------------------------------- /lib/gcn_lib/pos_embed.py: -------------------------------------------------------------------------------- 1 | # 2022.06.17-Changed for building ViG model 2 | # Huawei Technologies Co., Ltd. 3 | # modified from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py 4 | # Copyright (c) Meta Platforms, Inc. and affiliates. 5 | # All rights reserved. 6 | 7 | # This source code is licensed under the license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | # -------------------------------------------------------- 10 | # Position embedding utils 11 | # -------------------------------------------------------- 12 | 13 | import numpy as np 14 | 15 | import torch 16 | 17 | # -------------------------------------------------------- 18 | # relative position embedding 19 | # References: https://arxiv.org/abs/2009.13658 20 | # -------------------------------------------------------- 21 | def get_2d_relative_pos_embed(embed_dim, grid_size): 22 | """ 23 | grid_size: int of the grid height and width 24 | return: 25 | pos_embed: [grid_size*grid_size, grid_size*grid_size] 26 | """ 27 | pos_embed = get_2d_sincos_pos_embed(embed_dim, grid_size) 28 | relative_pos = 2 * np.matmul(pos_embed, pos_embed.transpose()) / pos_embed.shape[1] 29 | return relative_pos 30 | 31 | 32 | # -------------------------------------------------------- 33 | # 2D sine-cosine position embedding 34 | # References: 35 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 36 | # MoCo v3: https://github.com/facebookresearch/moco-v3 37 | # -------------------------------------------------------- 38 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 39 | """ 40 | grid_size: int of the grid height and width 41 | return: 42 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 43 | """ 44 | grid_h = np.arange(grid_size, dtype=np.float32) 45 | grid_w = np.arange(grid_size, dtype=np.float32) 46 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 47 | grid = np.stack(grid, axis=0) 48 | 49 | grid = grid.reshape([2, 1, grid_size, grid_size]) 50 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 51 | if cls_token: 52 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 53 | return pos_embed 54 | 55 | 56 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 57 | assert embed_dim % 2 == 0 58 | 59 | # use half of dimensions to encode grid_h 60 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 61 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 62 | 63 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 64 | return emb 65 | 66 | 67 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 68 | """ 69 | embed_dim: output dimension for each position 70 | pos: a list of positions to be encoded: size (M,) 71 | out: (M, D) 72 | """ 73 | assert embed_dim % 2 == 0 74 | omega = np.arange(embed_dim // 2, dtype=np.float) 75 | omega /= embed_dim / 2. 76 | omega = 1. / 10000**omega # (D/2,) 77 | 78 | pos = pos.reshape(-1) # (M,) 79 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 80 | 81 | emb_sin = np.sin(out) # (M, D/2) 82 | emb_cos = np.cos(out) # (M, D/2) 83 | 84 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 85 | return emb 86 | -------------------------------------------------------------------------------- /lib/models_timm/layers/conv_bn_act.py: -------------------------------------------------------------------------------- 1 | """ Conv2d + BN + Act 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import functools 6 | from torch import nn as nn 7 | 8 | from .create_conv2d import create_conv2d 9 | from .create_norm_act import get_norm_act_layer 10 | 11 | 12 | class ConvNormAct(nn.Module): 13 | def __init__( 14 | self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1, 15 | bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, drop_layer=None): 16 | super(ConvNormAct, self).__init__() 17 | self.conv = create_conv2d( 18 | in_channels, out_channels, kernel_size, stride=stride, 19 | padding=padding, dilation=dilation, groups=groups, bias=bias) 20 | 21 | # NOTE for backwards compatibility with models that use separate norm and act layer definitions 22 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer) 23 | # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` 24 | norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {} 25 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs) 26 | 27 | @property 28 | def in_channels(self): 29 | return self.conv.in_channels 30 | 31 | @property 32 | def out_channels(self): 33 | return self.conv.out_channels 34 | 35 | def forward(self, x): 36 | x = self.conv(x) 37 | x = self.bn(x) 38 | return x 39 | 40 | 41 | ConvBnAct = ConvNormAct 42 | 43 | 44 | def create_aa(aa_layer, channels, stride=2, enable=True): 45 | if not aa_layer or not enable: 46 | return nn.Identity() 47 | if isinstance(aa_layer, functools.partial): 48 | if issubclass(aa_layer.func, nn.AvgPool2d): 49 | return aa_layer() 50 | else: 51 | return aa_layer(channels) 52 | elif issubclass(aa_layer, nn.AvgPool2d): 53 | return aa_layer(stride) 54 | else: 55 | return aa_layer(channels=channels, stride=stride) 56 | 57 | 58 | class ConvNormActAa(nn.Module): 59 | def __init__( 60 | self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1, 61 | bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None, drop_layer=None): 62 | super(ConvNormActAa, self).__init__() 63 | use_aa = aa_layer is not None and stride == 2 64 | 65 | self.conv = create_conv2d( 66 | in_channels, out_channels, kernel_size, stride=1 if use_aa else stride, 67 | padding=padding, dilation=dilation, groups=groups, bias=bias) 68 | 69 | # NOTE for backwards compatibility with models that use separate norm and act layer definitions 70 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer) 71 | # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` 72 | norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {} 73 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs) 74 | self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa) 75 | 76 | @property 77 | def in_channels(self): 78 | return self.conv.in_channels 79 | 80 | @property 81 | def out_channels(self): 82 | return self.conv.out_channels 83 | 84 | def forward(self, x): 85 | x = self.conv(x) 86 | x = self.bn(x) 87 | x = self.aa(x) 88 | return x 89 | -------------------------------------------------------------------------------- /lib/models_timm/layers/config.py: -------------------------------------------------------------------------------- 1 | """ Model / Layer Config singleton state 2 | """ 3 | from typing import Any, Optional 4 | 5 | __all__ = [ 6 | 'is_exportable', 'is_scriptable', 'is_no_jit', 7 | 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config' 8 | ] 9 | 10 | # Set to True if prefer to have layers with no jit optimization (includes activations) 11 | _NO_JIT = False 12 | 13 | # Set to True if prefer to have activation layers with no jit optimization 14 | # NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying 15 | # the jit flags so far are activations. This will change as more layers are updated and/or added. 16 | _NO_ACTIVATION_JIT = False 17 | 18 | # Set to True if exporting a model with Same padding via ONNX 19 | _EXPORTABLE = False 20 | 21 | # Set to True if wanting to use torch.jit.script on a model 22 | _SCRIPTABLE = False 23 | 24 | 25 | def is_no_jit(): 26 | return _NO_JIT 27 | 28 | 29 | class set_no_jit: 30 | def __init__(self, mode: bool) -> None: 31 | global _NO_JIT 32 | self.prev = _NO_JIT 33 | _NO_JIT = mode 34 | 35 | def __enter__(self) -> None: 36 | pass 37 | 38 | def __exit__(self, *args: Any) -> bool: 39 | global _NO_JIT 40 | _NO_JIT = self.prev 41 | return False 42 | 43 | 44 | def is_exportable(): 45 | return _EXPORTABLE 46 | 47 | 48 | class set_exportable: 49 | def __init__(self, mode: bool) -> None: 50 | global _EXPORTABLE 51 | self.prev = _EXPORTABLE 52 | _EXPORTABLE = mode 53 | 54 | def __enter__(self) -> None: 55 | pass 56 | 57 | def __exit__(self, *args: Any) -> bool: 58 | global _EXPORTABLE 59 | _EXPORTABLE = self.prev 60 | return False 61 | 62 | 63 | def is_scriptable(): 64 | return _SCRIPTABLE 65 | 66 | 67 | class set_scriptable: 68 | def __init__(self, mode: bool) -> None: 69 | global _SCRIPTABLE 70 | self.prev = _SCRIPTABLE 71 | _SCRIPTABLE = mode 72 | 73 | def __enter__(self) -> None: 74 | pass 75 | 76 | def __exit__(self, *args: Any) -> bool: 77 | global _SCRIPTABLE 78 | _SCRIPTABLE = self.prev 79 | return False 80 | 81 | 82 | class set_layer_config: 83 | """ Layer config context manager that allows setting all layer config flags at once. 84 | If a flag arg is None, it will not change the current value. 85 | """ 86 | def __init__( 87 | self, 88 | scriptable: Optional[bool] = None, 89 | exportable: Optional[bool] = None, 90 | no_jit: Optional[bool] = None, 91 | no_activation_jit: Optional[bool] = None): 92 | global _SCRIPTABLE 93 | global _EXPORTABLE 94 | global _NO_JIT 95 | global _NO_ACTIVATION_JIT 96 | self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT 97 | if scriptable is not None: 98 | _SCRIPTABLE = scriptable 99 | if exportable is not None: 100 | _EXPORTABLE = exportable 101 | if no_jit is not None: 102 | _NO_JIT = no_jit 103 | if no_activation_jit is not None: 104 | _NO_ACTIVATION_JIT = no_activation_jit 105 | 106 | def __enter__(self) -> None: 107 | pass 108 | 109 | def __exit__(self, *args: Any) -> bool: 110 | global _SCRIPTABLE 111 | global _EXPORTABLE 112 | global _NO_JIT 113 | global _NO_ACTIVATION_JIT 114 | _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev 115 | return False 116 | -------------------------------------------------------------------------------- /lib/models_timm/layers/inplace_abn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | 4 | try: 5 | from inplace_abn.functions import inplace_abn, inplace_abn_sync 6 | has_iabn = True 7 | except ImportError: 8 | has_iabn = False 9 | 10 | def inplace_abn(x, weight, bias, running_mean, running_var, 11 | training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01): 12 | raise ImportError( 13 | "Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.12'") 14 | 15 | def inplace_abn_sync(**kwargs): 16 | inplace_abn(**kwargs) 17 | 18 | 19 | class InplaceAbn(nn.Module): 20 | """Activated Batch Normalization 21 | 22 | This gathers a BatchNorm and an activation function in a single module 23 | 24 | Parameters 25 | ---------- 26 | num_features : int 27 | Number of feature channels in the input and output. 28 | eps : float 29 | Small constant to prevent numerical issues. 30 | momentum : float 31 | Momentum factor applied to compute running statistics. 32 | affine : bool 33 | If `True` apply learned scale and shift transformation after normalization. 34 | act_layer : str or nn.Module type 35 | Name or type of the activation functions, one of: `leaky_relu`, `elu` 36 | act_param : float 37 | Negative slope for the `leaky_relu` activation. 38 | """ 39 | 40 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, apply_act=True, 41 | act_layer="leaky_relu", act_param=0.01, drop_layer=None): 42 | super(InplaceAbn, self).__init__() 43 | self.num_features = num_features 44 | self.affine = affine 45 | self.eps = eps 46 | self.momentum = momentum 47 | if apply_act: 48 | if isinstance(act_layer, str): 49 | assert act_layer in ('leaky_relu', 'elu', 'identity', '') 50 | self.act_name = act_layer if act_layer else 'identity' 51 | else: 52 | # convert act layer passed as type to string 53 | if act_layer == nn.ELU: 54 | self.act_name = 'elu' 55 | elif act_layer == nn.LeakyReLU: 56 | self.act_name = 'leaky_relu' 57 | elif act_layer is None or act_layer == nn.Identity: 58 | self.act_name = 'identity' 59 | else: 60 | assert False, f'Invalid act layer {act_layer.__name__} for IABN' 61 | else: 62 | self.act_name = 'identity' 63 | self.act_param = act_param 64 | if self.affine: 65 | self.weight = nn.Parameter(torch.ones(num_features)) 66 | self.bias = nn.Parameter(torch.zeros(num_features)) 67 | else: 68 | self.register_parameter('weight', None) 69 | self.register_parameter('bias', None) 70 | self.register_buffer('running_mean', torch.zeros(num_features)) 71 | self.register_buffer('running_var', torch.ones(num_features)) 72 | self.reset_parameters() 73 | 74 | def reset_parameters(self): 75 | nn.init.constant_(self.running_mean, 0) 76 | nn.init.constant_(self.running_var, 1) 77 | if self.affine: 78 | nn.init.constant_(self.weight, 1) 79 | nn.init.constant_(self.bias, 0) 80 | 81 | def forward(self, x): 82 | output = inplace_abn( 83 | x, self.weight, self.bias, self.running_mean, self.running_var, 84 | self.training, self.momentum, self.eps, self.act_name, self.act_param) 85 | if isinstance(output, tuple): 86 | output = output[0] 87 | return output 88 | -------------------------------------------------------------------------------- /lib/models_timm/layers/split_batchnorm.py: -------------------------------------------------------------------------------- 1 | """ Split BatchNorm 2 | 3 | A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through 4 | a separate BN layer. The first split is passed through the parent BN layers with weight/bias 5 | keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn' 6 | namespace. 7 | 8 | This allows easily removing the auxiliary BN layers after training to efficiently 9 | achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2, 10 | 'Disentangled Learning via An Auxiliary BN' 11 | 12 | Hacked together by / Copyright 2020 Ross Wightman 13 | """ 14 | import torch 15 | import torch.nn as nn 16 | 17 | 18 | class SplitBatchNorm2d(torch.nn.BatchNorm2d): 19 | 20 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 21 | track_running_stats=True, num_splits=2): 22 | super().__init__(num_features, eps, momentum, affine, track_running_stats) 23 | assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)' 24 | self.num_splits = num_splits 25 | self.aux_bn = nn.ModuleList([ 26 | nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)]) 27 | 28 | def forward(self, input: torch.Tensor): 29 | if self.training: # aux BN only relevant while training 30 | split_size = input.shape[0] // self.num_splits 31 | assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits" 32 | split_input = input.split(split_size) 33 | x = [super().forward(split_input[0])] 34 | for i, a in enumerate(self.aux_bn): 35 | x.append(a(split_input[i + 1])) 36 | return torch.cat(x, dim=0) 37 | else: 38 | return super().forward(input) 39 | 40 | 41 | def convert_splitbn_model(module, num_splits=2): 42 | """ 43 | Recursively traverse module and its children to replace all instances of 44 | ``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`. 45 | Args: 46 | module (torch.nn.Module): input module 47 | num_splits: number of separate batchnorm layers to split input across 48 | Example:: 49 | >>> # model is an instance of torch.nn.Module 50 | >>> model = timm.models.convert_splitbn_model(model, num_splits=2) 51 | """ 52 | mod = module 53 | if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm): 54 | return module 55 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): 56 | mod = SplitBatchNorm2d( 57 | module.num_features, module.eps, module.momentum, module.affine, 58 | module.track_running_stats, num_splits=num_splits) 59 | mod.running_mean = module.running_mean 60 | mod.running_var = module.running_var 61 | mod.num_batches_tracked = module.num_batches_tracked 62 | if module.affine: 63 | mod.weight.data = module.weight.data.clone().detach() 64 | mod.bias.data = module.bias.data.clone().detach() 65 | for aux in mod.aux_bn: 66 | aux.running_mean = module.running_mean.clone() 67 | aux.running_var = module.running_var.clone() 68 | aux.num_batches_tracked = module.num_batches_tracked.clone() 69 | if module.affine: 70 | aux.weight.data = module.weight.data.clone().detach() 71 | aux.bias.data = module.bias.data.clone().detach() 72 | for name, child in module.named_children(): 73 | mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits)) 74 | del module 75 | return mod 76 | -------------------------------------------------------------------------------- /lib/models_timm/layers/create_attn.py: -------------------------------------------------------------------------------- 1 | """ Attention Factory 2 | 3 | Hacked together by / Copyright 2021 Ross Wightman 4 | """ 5 | import torch 6 | from functools import partial 7 | 8 | from .bottleneck_attn import BottleneckAttn 9 | from .cbam import CbamModule, LightCbamModule 10 | from .eca import EcaModule, CecaModule 11 | from .gather_excite import GatherExcite 12 | from .global_context import GlobalContext 13 | from .halo_attn import HaloAttn 14 | from .lambda_layer import LambdaLayer 15 | from .non_local_attn import NonLocalAttn, BatNonLocalAttn 16 | from .selective_kernel import SelectiveKernel 17 | from .split_attn import SplitAttn 18 | from .squeeze_excite import SEModule, EffectiveSEModule 19 | 20 | 21 | def get_attn(attn_type): 22 | if isinstance(attn_type, torch.nn.Module): 23 | return attn_type 24 | module_cls = None 25 | if attn_type: 26 | if isinstance(attn_type, str): 27 | attn_type = attn_type.lower() 28 | # Lightweight attention modules (channel and/or coarse spatial). 29 | # Typically added to existing network architecture blocks in addition to existing convolutions. 30 | if attn_type == 'se': 31 | module_cls = SEModule 32 | elif attn_type == 'ese': 33 | module_cls = EffectiveSEModule 34 | elif attn_type == 'eca': 35 | module_cls = EcaModule 36 | elif attn_type == 'ecam': 37 | module_cls = partial(EcaModule, use_mlp=True) 38 | elif attn_type == 'ceca': 39 | module_cls = CecaModule 40 | elif attn_type == 'ge': 41 | module_cls = GatherExcite 42 | elif attn_type == 'gc': 43 | module_cls = GlobalContext 44 | elif attn_type == 'gca': 45 | module_cls = partial(GlobalContext, fuse_add=True, fuse_scale=False) 46 | elif attn_type == 'cbam': 47 | module_cls = CbamModule 48 | elif attn_type == 'lcbam': 49 | module_cls = LightCbamModule 50 | 51 | # Attention / attention-like modules w/ significant params 52 | # Typically replace some of the existing workhorse convs in a network architecture. 53 | # All of these accept a stride argument and can spatially downsample the input. 54 | elif attn_type == 'sk': 55 | module_cls = SelectiveKernel 56 | elif attn_type == 'splat': 57 | module_cls = SplitAttn 58 | 59 | # Self-attention / attention-like modules w/ significant compute and/or params 60 | # Typically replace some of the existing workhorse convs in a network architecture. 61 | # All of these accept a stride argument and can spatially downsample the input. 62 | elif attn_type == 'lambda': 63 | return LambdaLayer 64 | elif attn_type == 'bottleneck': 65 | return BottleneckAttn 66 | elif attn_type == 'halo': 67 | return HaloAttn 68 | elif attn_type == 'nl': 69 | module_cls = NonLocalAttn 70 | elif attn_type == 'bat': 71 | module_cls = BatNonLocalAttn 72 | 73 | # Woops! 74 | else: 75 | assert False, "Invalid attn module (%s)" % attn_type 76 | elif isinstance(attn_type, bool): 77 | if attn_type: 78 | module_cls = SEModule 79 | else: 80 | module_cls = attn_type 81 | return module_cls 82 | 83 | 84 | def create_attn(attn_type, channels, **kwargs): 85 | module_cls = get_attn(attn_type) 86 | if module_cls is not None: 87 | # NOTE: it's expected the first (positional) argument of all attention layers is the # input channels 88 | return module_cls(channels, **kwargs) 89 | return None 90 | -------------------------------------------------------------------------------- /utils/dataset_synapse.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import h5py 4 | import numpy as np 5 | import torch 6 | import cv2 7 | from scipy import ndimage 8 | from scipy.ndimage.interpolation import zoom 9 | from torch.utils.data import Dataset 10 | 11 | 12 | def random_rot_flip(image, label): 13 | k = np.random.randint(0, 4) 14 | image = np.rot90(image, k) 15 | label = np.rot90(label, k) 16 | axis = np.random.randint(0, 2) 17 | image = np.flip(image, axis=axis).copy() 18 | label = np.flip(label, axis=axis).copy() 19 | return image, label 20 | 21 | 22 | def random_rotate(image, label): 23 | angle = np.random.randint(-20, 20) 24 | image = ndimage.rotate(image, angle, order=0, reshape=False) 25 | label = ndimage.rotate(label, angle, order=0, reshape=False) 26 | return image, label 27 | 28 | 29 | class RandomGenerator(object): 30 | def __init__(self, output_size): 31 | self.output_size = output_size 32 | 33 | def __call__(self, sample): 34 | image, label = sample['image'], sample['label'] 35 | 36 | if random.random() > 0.5: 37 | image, label = random_rot_flip(image, label) 38 | elif random.random() > 0.5: 39 | image, label = random_rotate(image, label) 40 | x, y = image.shape 41 | if x != self.output_size[0] or y != self.output_size[1]: 42 | image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y), order=3) # why not 3? 43 | label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0) 44 | image = torch.from_numpy(image.astype(np.float32)).unsqueeze(0) 45 | label = torch.from_numpy(label.astype(np.float32)) 46 | sample = {'image': image, 'label': label.long()} 47 | return sample 48 | 49 | 50 | class Synapse_dataset(Dataset): 51 | def __init__(self, base_dir, list_dir, split, nclass=9, transform=None): 52 | self.transform = transform # using transform in torch! 53 | self.split = split 54 | self.sample_list = open(os.path.join(list_dir, self.split+'.txt')).readlines() 55 | self.data_dir = base_dir 56 | self.nclass = nclass 57 | 58 | def __len__(self): 59 | return len(self.sample_list) 60 | 61 | def __getitem__(self, idx): 62 | if self.split == "train": 63 | slice_name = self.sample_list[idx].strip('\n') 64 | data_path = os.path.join(self.data_dir, slice_name+'.npz') 65 | data = np.load(data_path) 66 | image, label = data['image'], data['label'] 67 | #print(image.shape) 68 | #image = np.reshape(image, (512, 512)) 69 | #image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) 70 | 71 | #label = np.reshape(label, (512, 512)) 72 | 73 | 74 | else: 75 | vol_name = self.sample_list[idx].strip('\n') 76 | filepath = self.data_dir + "/{}.npy.h5".format(vol_name) 77 | data = h5py.File(filepath) 78 | image, label = data['image'][:], data['label'][:] 79 | #image = np.reshape(image, (image.shape[2], 512, 512)) 80 | #label = np.reshape(label, (label.shape[2], 512, 512)) 81 | #label[label==5]= 0 82 | #label[label==9]= 0 83 | #label[label==10]= 0 84 | #label[label==12]= 0 85 | #label[label==13]= 0 86 | #label[label==11]= 5 87 | 88 | if self.nclass == 9: 89 | label[label==5]= 0 90 | label[label==9]= 0 91 | label[label==10]= 0 92 | label[label==12]= 0 93 | label[label==13]= 0 94 | label[label==11]= 5 95 | 96 | sample = {'image': image, 'label': label} 97 | if self.transform: 98 | sample = self.transform(sample) 99 | sample['case_name'] = self.sample_list[idx].strip('\n') 100 | return sample 101 | -------------------------------------------------------------------------------- /lib/models_timm/layers/create_norm_act.py: -------------------------------------------------------------------------------- 1 | """ NormAct (Normalizaiton + Activation Layer) Factory 2 | 3 | Create norm + act combo modules that attempt to be backwards compatible with separate norm + act 4 | isntances in models. Where these are used it will be possible to swap separate BN + act layers with 5 | combined modules like IABN or EvoNorms. 6 | 7 | Hacked together by / Copyright 2020 Ross Wightman 8 | """ 9 | import types 10 | import functools 11 | 12 | from .evo_norm import * 13 | from .filter_response_norm import FilterResponseNormAct2d, FilterResponseNormTlu2d 14 | from .norm_act import BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d 15 | from .inplace_abn import InplaceAbn 16 | 17 | _NORM_ACT_MAP = dict( 18 | batchnorm=BatchNormAct2d, 19 | batchnorm2d=BatchNormAct2d, 20 | groupnorm=GroupNormAct, 21 | groupnorm1=functools.partial(GroupNormAct, num_groups=1), 22 | layernorm=LayerNormAct, 23 | layernorm2d=LayerNormAct2d, 24 | evonormb0=EvoNorm2dB0, 25 | evonormb1=EvoNorm2dB1, 26 | evonormb2=EvoNorm2dB2, 27 | evonorms0=EvoNorm2dS0, 28 | evonorms0a=EvoNorm2dS0a, 29 | evonorms1=EvoNorm2dS1, 30 | evonorms1a=EvoNorm2dS1a, 31 | evonorms2=EvoNorm2dS2, 32 | evonorms2a=EvoNorm2dS2a, 33 | frn=FilterResponseNormAct2d, 34 | frntlu=FilterResponseNormTlu2d, 35 | inplaceabn=InplaceAbn, 36 | iabn=InplaceAbn, 37 | ) 38 | _NORM_ACT_TYPES = {m for n, m in _NORM_ACT_MAP.items()} 39 | # has act_layer arg to define act type 40 | _NORM_ACT_REQUIRES_ARG = { 41 | BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d, FilterResponseNormAct2d, InplaceAbn} 42 | 43 | 44 | def create_norm_act_layer(layer_name, num_features, act_layer=None, apply_act=True, jit=False, **kwargs): 45 | layer = get_norm_act_layer(layer_name, act_layer=act_layer) 46 | layer_instance = layer(num_features, apply_act=apply_act, **kwargs) 47 | if jit: 48 | layer_instance = torch.jit.script(layer_instance) 49 | return layer_instance 50 | 51 | 52 | def get_norm_act_layer(norm_layer, act_layer=None): 53 | assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) 54 | assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial)) 55 | norm_act_kwargs = {} 56 | 57 | # unbind partial fn, so args can be rebound later 58 | if isinstance(norm_layer, functools.partial): 59 | norm_act_kwargs.update(norm_layer.keywords) 60 | norm_layer = norm_layer.func 61 | 62 | if isinstance(norm_layer, str): 63 | layer_name = norm_layer.replace('_', '').lower().split('-')[0] 64 | norm_act_layer = _NORM_ACT_MAP.get(layer_name, None) 65 | elif norm_layer in _NORM_ACT_TYPES: 66 | norm_act_layer = norm_layer 67 | elif isinstance(norm_layer, types.FunctionType): 68 | # if function type, must be a lambda/fn that creates a norm_act layer 69 | norm_act_layer = norm_layer 70 | else: 71 | type_name = norm_layer.__name__.lower() 72 | if type_name.startswith('batchnorm'): 73 | norm_act_layer = BatchNormAct2d 74 | elif type_name.startswith('groupnorm'): 75 | norm_act_layer = GroupNormAct 76 | elif type_name.startswith('groupnorm1'): 77 | norm_act_layer = functools.partial(GroupNormAct, num_groups=1) 78 | elif type_name.startswith('layernorm2d'): 79 | norm_act_layer = LayerNormAct2d 80 | elif type_name.startswith('layernorm'): 81 | norm_act_layer = LayerNormAct 82 | else: 83 | assert False, f"No equivalent norm_act layer for {type_name}" 84 | 85 | if norm_act_layer in _NORM_ACT_REQUIRES_ARG: 86 | # pass `act_layer` through for backwards compat where `act_layer=None` implies no activation. 87 | # In the future, may force use of `apply_act` with `act_layer` arg bound to relevant NormAct types 88 | norm_act_kwargs.setdefault('act_layer', act_layer) 89 | if norm_act_kwargs: 90 | norm_act_layer = functools.partial(norm_act_layer, **norm_act_kwargs) # bind/rebind args 91 | return norm_act_layer 92 | -------------------------------------------------------------------------------- /lib/models_timm/layers/gather_excite.py: -------------------------------------------------------------------------------- 1 | """ Gather-Excite Attention Block 2 | 3 | Paper: `Gather-Excite: Exploiting Feature Context in CNNs` - https://arxiv.org/abs/1810.12348 4 | 5 | Official code here, but it's only partial impl in Caffe: https://github.com/hujie-frank/GENet 6 | 7 | I've tried to support all of the extent both w/ and w/o params. I don't believe I've seen another 8 | impl that covers all of the cases. 9 | 10 | NOTE: extent=0 + extra_params=False is equivalent to Squeeze-and-Excitation 11 | 12 | Hacked together by / Copyright 2021 Ross Wightman 13 | """ 14 | import math 15 | 16 | from torch import nn as nn 17 | import torch.nn.functional as F 18 | 19 | from .create_act import create_act_layer, get_act_layer 20 | from .create_conv2d import create_conv2d 21 | from .helpers import make_divisible 22 | from .mlp import ConvMlp 23 | 24 | 25 | class GatherExcite(nn.Module): 26 | """ Gather-Excite Attention Module 27 | """ 28 | def __init__( 29 | self, channels, feat_size=None, extra_params=False, extent=0, use_mlp=True, 30 | rd_ratio=1./16, rd_channels=None, rd_divisor=1, add_maxpool=False, 31 | act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, gate_layer='sigmoid'): 32 | super(GatherExcite, self).__init__() 33 | self.add_maxpool = add_maxpool 34 | act_layer = get_act_layer(act_layer) 35 | self.extent = extent 36 | if extra_params: 37 | self.gather = nn.Sequential() 38 | if extent == 0: 39 | assert feat_size is not None, 'spatial feature size must be specified for global extent w/ params' 40 | self.gather.add_module( 41 | 'conv1', create_conv2d(channels, channels, kernel_size=feat_size, stride=1, depthwise=True)) 42 | if norm_layer: 43 | self.gather.add_module(f'norm1', nn.BatchNorm2d(channels)) 44 | else: 45 | assert extent % 2 == 0 46 | num_conv = int(math.log2(extent)) 47 | for i in range(num_conv): 48 | self.gather.add_module( 49 | f'conv{i + 1}', 50 | create_conv2d(channels, channels, kernel_size=3, stride=2, depthwise=True)) 51 | if norm_layer: 52 | self.gather.add_module(f'norm{i + 1}', nn.BatchNorm2d(channels)) 53 | if i != num_conv - 1: 54 | self.gather.add_module(f'act{i + 1}', act_layer(inplace=True)) 55 | else: 56 | self.gather = None 57 | if self.extent == 0: 58 | self.gk = 0 59 | self.gs = 0 60 | else: 61 | assert extent % 2 == 0 62 | self.gk = self.extent * 2 - 1 63 | self.gs = self.extent 64 | 65 | if not rd_channels: 66 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) 67 | self.mlp = ConvMlp(channels, rd_channels, act_layer=act_layer) if use_mlp else nn.Identity() 68 | self.gate = create_act_layer(gate_layer) 69 | 70 | def forward(self, x): 71 | size = x.shape[-2:] 72 | if self.gather is not None: 73 | x_ge = self.gather(x) 74 | else: 75 | if self.extent == 0: 76 | # global extent 77 | x_ge = x.mean(dim=(2, 3), keepdims=True) 78 | if self.add_maxpool: 79 | # experimental codepath, may remove or change 80 | x_ge = 0.5 * x_ge + 0.5 * x.amax((2, 3), keepdim=True) 81 | else: 82 | x_ge = F.avg_pool2d( 83 | x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2, count_include_pad=False) 84 | if self.add_maxpool: 85 | # experimental codepath, may remove or change 86 | x_ge = 0.5 * x_ge + 0.5 * F.max_pool2d(x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2) 87 | x_ge = self.mlp(x_ge) 88 | if x_ge.shape[-1] != 1 or x_ge.shape[-2] != 1: 89 | x_ge = F.interpolate(x_ge, size=size) 90 | return x * self.gate(x_ge) 91 | -------------------------------------------------------------------------------- /lib/models_timm/pruned/ecaresnet50d_pruned.txt: -------------------------------------------------------------------------------- 1 | conv1.0.weight:[32, 3, 3, 3]***conv1.1.weight:[32]***conv1.3.weight:[32, 32, 3, 3]***conv1.4.weight:[32]***conv1.6.weight:[64, 32, 3, 3]***bn1.weight:[64]***layer1.0.conv1.weight:[47, 64, 1, 1]***layer1.0.bn1.weight:[47]***layer1.0.conv2.weight:[18, 47, 3, 3]***layer1.0.bn2.weight:[18]***layer1.0.conv3.weight:[19, 18, 1, 1]***layer1.0.bn3.weight:[19]***layer1.0.se.conv.weight:[1, 1, 5]***layer1.0.downsample.1.weight:[19, 64, 1, 1]***layer1.0.downsample.2.weight:[19]***layer1.1.conv1.weight:[52, 19, 1, 1]***layer1.1.bn1.weight:[52]***layer1.1.conv2.weight:[22, 52, 3, 3]***layer1.1.bn2.weight:[22]***layer1.1.conv3.weight:[19, 22, 1, 1]***layer1.1.bn3.weight:[19]***layer1.1.se.conv.weight:[1, 1, 5]***layer1.2.conv1.weight:[64, 19, 1, 1]***layer1.2.bn1.weight:[64]***layer1.2.conv2.weight:[35, 64, 3, 3]***layer1.2.bn2.weight:[35]***layer1.2.conv3.weight:[19, 35, 1, 1]***layer1.2.bn3.weight:[19]***layer1.2.se.conv.weight:[1, 1, 5]***layer2.0.conv1.weight:[85, 19, 1, 1]***layer2.0.bn1.weight:[85]***layer2.0.conv2.weight:[37, 85, 3, 3]***layer2.0.bn2.weight:[37]***layer2.0.conv3.weight:[171, 37, 1, 1]***layer2.0.bn3.weight:[171]***layer2.0.se.conv.weight:[1, 1, 5]***layer2.0.downsample.1.weight:[171, 19, 1, 1]***layer2.0.downsample.2.weight:[171]***layer2.1.conv1.weight:[107, 171, 1, 1]***layer2.1.bn1.weight:[107]***layer2.1.conv2.weight:[80, 107, 3, 3]***layer2.1.bn2.weight:[80]***layer2.1.conv3.weight:[171, 80, 1, 1]***layer2.1.bn3.weight:[171]***layer2.1.se.conv.weight:[1, 1, 5]***layer2.2.conv1.weight:[120, 171, 1, 1]***layer2.2.bn1.weight:[120]***layer2.2.conv2.weight:[85, 120, 3, 3]***layer2.2.bn2.weight:[85]***layer2.2.conv3.weight:[171, 85, 1, 1]***layer2.2.bn3.weight:[171]***layer2.2.se.conv.weight:[1, 1, 5]***layer2.3.conv1.weight:[125, 171, 1, 1]***layer2.3.bn1.weight:[125]***layer2.3.conv2.weight:[87, 125, 3, 3]***layer2.3.bn2.weight:[87]***layer2.3.conv3.weight:[171, 87, 1, 1]***layer2.3.bn3.weight:[171]***layer2.3.se.conv.weight:[1, 1, 5]***layer3.0.conv1.weight:[198, 171, 1, 1]***layer3.0.bn1.weight:[198]***layer3.0.conv2.weight:[126, 198, 3, 3]***layer3.0.bn2.weight:[126]***layer3.0.conv3.weight:[818, 126, 1, 1]***layer3.0.bn3.weight:[818]***layer3.0.se.conv.weight:[1, 1, 5]***layer3.0.downsample.1.weight:[818, 171, 1, 1]***layer3.0.downsample.2.weight:[818]***layer3.1.conv1.weight:[255, 818, 1, 1]***layer3.1.bn1.weight:[255]***layer3.1.conv2.weight:[232, 255, 3, 3]***layer3.1.bn2.weight:[232]***layer3.1.conv3.weight:[818, 232, 1, 1]***layer3.1.bn3.weight:[818]***layer3.1.se.conv.weight:[1, 1, 5]***layer3.2.conv1.weight:[256, 818, 1, 1]***layer3.2.bn1.weight:[256]***layer3.2.conv2.weight:[233, 256, 3, 3]***layer3.2.bn2.weight:[233]***layer3.2.conv3.weight:[818, 233, 1, 1]***layer3.2.bn3.weight:[818]***layer3.2.se.conv.weight:[1, 1, 5]***layer3.3.conv1.weight:[253, 818, 1, 1]***layer3.3.bn1.weight:[253]***layer3.3.conv2.weight:[235, 253, 3, 3]***layer3.3.bn2.weight:[235]***layer3.3.conv3.weight:[818, 235, 1, 1]***layer3.3.bn3.weight:[818]***layer3.3.se.conv.weight:[1, 1, 5]***layer3.4.conv1.weight:[256, 818, 1, 1]***layer3.4.bn1.weight:[256]***layer3.4.conv2.weight:[225, 256, 3, 3]***layer3.4.bn2.weight:[225]***layer3.4.conv3.weight:[818, 225, 1, 1]***layer3.4.bn3.weight:[818]***layer3.4.se.conv.weight:[1, 1, 5]***layer3.5.conv1.weight:[256, 818, 1, 1]***layer3.5.bn1.weight:[256]***layer3.5.conv2.weight:[239, 256, 3, 3]***layer3.5.bn2.weight:[239]***layer3.5.conv3.weight:[818, 239, 1, 1]***layer3.5.bn3.weight:[818]***layer3.5.se.conv.weight:[1, 1, 5]***layer4.0.conv1.weight:[492, 818, 1, 1]***layer4.0.bn1.weight:[492]***layer4.0.conv2.weight:[237, 492, 3, 3]***layer4.0.bn2.weight:[237]***layer4.0.conv3.weight:[2022, 237, 1, 1]***layer4.0.bn3.weight:[2022]***layer4.0.se.conv.weight:[1, 1, 7]***layer4.0.downsample.1.weight:[2022, 818, 1, 1]***layer4.0.downsample.2.weight:[2022]***layer4.1.conv1.weight:[512, 2022, 1, 1]***layer4.1.bn1.weight:[512]***layer4.1.conv2.weight:[500, 512, 3, 3]***layer4.1.bn2.weight:[500]***layer4.1.conv3.weight:[2022, 500, 1, 1]***layer4.1.bn3.weight:[2022]***layer4.1.se.conv.weight:[1, 1, 7]***layer4.2.conv1.weight:[512, 2022, 1, 1]***layer4.2.bn1.weight:[512]***layer4.2.conv2.weight:[490, 512, 3, 3]***layer4.2.bn2.weight:[490]***layer4.2.conv3.weight:[2022, 490, 1, 1]***layer4.2.bn3.weight:[2022]***layer4.2.se.conv.weight:[1, 1, 7]***fc.weight:[1000, 2022]***layer1_2_conv3_M.weight:[256, 19]***layer2_3_conv3_M.weight:[512, 171]***layer3_5_conv3_M.weight:[1024, 818]***layer4_2_conv3_M.weight:[2048, 2022] -------------------------------------------------------------------------------- /lib/gcn_lib/torch_nn.py: -------------------------------------------------------------------------------- 1 | # 2022.06.17-Changed for building ViG model 2 | # Huawei Technologies Co., Ltd. 3 | import torch 4 | from torch import nn 5 | from torch.nn import Sequential as Seq, Linear as Lin, Conv2d 6 | 7 | 8 | ############################## 9 | # Basic layers 10 | ############################## 11 | def act_layer(act, inplace=False, neg_slope=0.2, n_prelu=1): 12 | # activation layer 13 | 14 | act = act.lower() 15 | if act == 'relu': 16 | layer = nn.ReLU(inplace) 17 | elif act == 'leakyrelu': 18 | layer = nn.LeakyReLU(neg_slope, inplace) 19 | elif act == 'prelu': 20 | layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) 21 | elif act == 'gelu': 22 | layer = nn.GELU() 23 | elif act == 'hswish': 24 | layer = nn.Hardswish(inplace) 25 | else: 26 | raise NotImplementedError('activation layer [%s] is not found' % act) 27 | return layer 28 | 29 | 30 | def norm_layer(norm, nc): 31 | # normalization layer 2d 32 | norm = norm.lower() 33 | if norm == 'batch': 34 | layer = nn.BatchNorm2d(nc, affine=True) 35 | elif norm == 'instance': 36 | layer = nn.InstanceNorm2d(nc, affine=False) 37 | else: 38 | raise NotImplementedError('normalization layer [%s] is not found' % norm) 39 | return layer 40 | 41 | 42 | class MLP(Seq): 43 | def __init__(self, channels, act='relu', norm=None, bias=True): 44 | m = [] 45 | for i in range(1, len(channels)): 46 | m.append(Lin(channels[i - 1], channels[i], bias)) 47 | if act is not None and act.lower() != 'none': 48 | m.append(act_layer(act)) 49 | if norm is not None and norm.lower() != 'none': 50 | m.append(norm_layer(norm, channels[-1])) 51 | super(MLP, self).__init__(*m) 52 | 53 | 54 | class BasicConv(Seq): 55 | def __init__(self, channels, act='relu', norm=None, bias=True, drop=0., kernel_size=1, padding=0, groups=4): 56 | m = [] 57 | for i in range(1, len(channels)): 58 | m.append(Conv2d(channels[i - 1], channels[i], kernel_size, padding=padding, bias=bias, groups=groups)) 59 | if norm is not None and norm.lower() != 'none': 60 | m.append(norm_layer(norm, channels[-1])) 61 | if act is not None and act.lower() != 'none': 62 | m.append(act_layer(act)) 63 | if drop > 0: 64 | m.append(nn.Dropout2d(drop)) 65 | 66 | super(BasicConv, self).__init__(*m) 67 | 68 | self.reset_parameters() 69 | 70 | def reset_parameters(self): 71 | for m in self.modules(): 72 | if isinstance(m, nn.Conv2d): 73 | nn.init.kaiming_normal_(m.weight) 74 | if m.bias is not None: 75 | nn.init.zeros_(m.bias) 76 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d): 77 | m.weight.data.fill_(1) 78 | m.bias.data.zero_() 79 | 80 | 81 | def batched_index_select(x, idx): 82 | r"""fetches neighbors features from a given neighbor idx 83 | 84 | Args: 85 | x (Tensor): input feature Tensor 86 | :math:`\mathbf{X} \in \mathbb{R}^{B \times C \times N \times 1}`. 87 | idx (Tensor): edge_idx 88 | :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times l}`. 89 | Returns: 90 | Tensor: output neighbors features 91 | :math:`\mathbf{X} \in \mathbb{R}^{B \times C \times N \times k}`. 92 | """ 93 | batch_size, num_dims, num_vertices_reduced = x.shape[:3] 94 | _, num_vertices, k = idx.shape 95 | #print([batch_size,num_dims,num_vertices_reduced, num_vertices, k, x.shape]) 96 | idx_base = torch.arange(0, batch_size, device=idx.device).view(-1, 1, 1) * num_vertices_reduced 97 | #print(idx_base.shape) 98 | idx = idx + idx_base 99 | idx = idx.contiguous().view(-1) 100 | #print(x.shape) 101 | x = x.transpose(2, 1) 102 | #print(x.shape) 103 | x = x.contiguous().view(batch_size * num_vertices_reduced, -1) 104 | #print(x.shape) 105 | feature = x[idx, :] 106 | #print(feature.shape) 107 | feature = feature.view(batch_size, num_vertices, k, num_dims) 108 | #print(feature.shape) 109 | feature = feature.permute(0, 3, 1, 2).contiguous() 110 | #print(feature.shape) 111 | return feature 112 | -------------------------------------------------------------------------------- /lib/models_timm/layers/adaptive_avgmax_pool.py: -------------------------------------------------------------------------------- 1 | """ PyTorch selectable adaptive pooling 2 | Adaptive pooling with the ability to select the type of pooling from: 3 | * 'avg' - Average pooling 4 | * 'max' - Max pooling 5 | * 'avgmax' - Sum of average and max pooling re-scaled by 0.5 6 | * 'avgmaxc' - Concatenation of average and max pooling along feature dim, doubles feature dim 7 | 8 | Both a functional and a nn.Module version of the pooling is provided. 9 | 10 | Hacked together by / Copyright 2020 Ross Wightman 11 | """ 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | 17 | def adaptive_pool_feat_mult(pool_type='avg'): 18 | if pool_type == 'catavgmax': 19 | return 2 20 | else: 21 | return 1 22 | 23 | 24 | def adaptive_avgmax_pool2d(x, output_size=1): 25 | x_avg = F.adaptive_avg_pool2d(x, output_size) 26 | x_max = F.adaptive_max_pool2d(x, output_size) 27 | return 0.5 * (x_avg + x_max) 28 | 29 | 30 | def adaptive_catavgmax_pool2d(x, output_size=1): 31 | x_avg = F.adaptive_avg_pool2d(x, output_size) 32 | x_max = F.adaptive_max_pool2d(x, output_size) 33 | return torch.cat((x_avg, x_max), 1) 34 | 35 | 36 | def select_adaptive_pool2d(x, pool_type='avg', output_size=1): 37 | """Selectable global pooling function with dynamic input kernel size 38 | """ 39 | if pool_type == 'avg': 40 | x = F.adaptive_avg_pool2d(x, output_size) 41 | elif pool_type == 'avgmax': 42 | x = adaptive_avgmax_pool2d(x, output_size) 43 | elif pool_type == 'catavgmax': 44 | x = adaptive_catavgmax_pool2d(x, output_size) 45 | elif pool_type == 'max': 46 | x = F.adaptive_max_pool2d(x, output_size) 47 | else: 48 | assert False, 'Invalid pool type: %s' % pool_type 49 | return x 50 | 51 | 52 | class FastAdaptiveAvgPool2d(nn.Module): 53 | def __init__(self, flatten=False): 54 | super(FastAdaptiveAvgPool2d, self).__init__() 55 | self.flatten = flatten 56 | 57 | def forward(self, x): 58 | return x.mean((2, 3), keepdim=not self.flatten) 59 | 60 | 61 | class AdaptiveAvgMaxPool2d(nn.Module): 62 | def __init__(self, output_size=1): 63 | super(AdaptiveAvgMaxPool2d, self).__init__() 64 | self.output_size = output_size 65 | 66 | def forward(self, x): 67 | return adaptive_avgmax_pool2d(x, self.output_size) 68 | 69 | 70 | class AdaptiveCatAvgMaxPool2d(nn.Module): 71 | def __init__(self, output_size=1): 72 | super(AdaptiveCatAvgMaxPool2d, self).__init__() 73 | self.output_size = output_size 74 | 75 | def forward(self, x): 76 | return adaptive_catavgmax_pool2d(x, self.output_size) 77 | 78 | 79 | class SelectAdaptivePool2d(nn.Module): 80 | """Selectable global pooling layer with dynamic input kernel size 81 | """ 82 | def __init__(self, output_size=1, pool_type='fast', flatten=False): 83 | super(SelectAdaptivePool2d, self).__init__() 84 | self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing 85 | self.flatten = nn.Flatten(1) if flatten else nn.Identity() 86 | if pool_type == '': 87 | self.pool = nn.Identity() # pass through 88 | elif pool_type == 'fast': 89 | assert output_size == 1 90 | self.pool = FastAdaptiveAvgPool2d(flatten) 91 | self.flatten = nn.Identity() 92 | elif pool_type == 'avg': 93 | self.pool = nn.AdaptiveAvgPool2d(output_size) 94 | elif pool_type == 'avgmax': 95 | self.pool = AdaptiveAvgMaxPool2d(output_size) 96 | elif pool_type == 'catavgmax': 97 | self.pool = AdaptiveCatAvgMaxPool2d(output_size) 98 | elif pool_type == 'max': 99 | self.pool = nn.AdaptiveMaxPool2d(output_size) 100 | else: 101 | assert False, 'Invalid pool type: %s' % pool_type 102 | 103 | def is_identity(self): 104 | return not self.pool_type 105 | 106 | def forward(self, x): 107 | x = self.pool(x) 108 | x = self.flatten(x) 109 | return x 110 | 111 | def feat_mult(self): 112 | return adaptive_pool_feat_mult(self.pool_type) 113 | 114 | def __repr__(self): 115 | return self.__class__.__name__ + ' (' \ 116 | + 'pool_type=' + self.pool_type \ 117 | + ', flatten=' + str(self.flatten) + ')' 118 | 119 | -------------------------------------------------------------------------------- /lib/models_timm/fx_features.py: -------------------------------------------------------------------------------- 1 | """ PyTorch FX Based Feature Extraction Helpers 2 | Using https://pytorch.org/vision/stable/feature_extraction.html 3 | """ 4 | from typing import Callable, List, Dict, Union, Type 5 | 6 | import torch 7 | from torch import nn 8 | 9 | from .features import _get_feature_info 10 | 11 | try: 12 | from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor 13 | has_fx_feature_extraction = True 14 | except ImportError: 15 | has_fx_feature_extraction = False 16 | 17 | # Layers we went to treat as leaf modules 18 | from .layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame 19 | from .layers.non_local_attn import BilinearAttnTransform 20 | from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame 21 | 22 | # NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here 23 | # BUT modules from timm.models should use the registration mechanism below 24 | _leaf_modules = { 25 | BilinearAttnTransform, # reason: flow control t <= 1 26 | # Reason: get_same_padding has a max which raises a control flow error 27 | Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame, 28 | CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0]) 29 | } 30 | 31 | try: 32 | from .layers import InplaceAbn 33 | _leaf_modules.add(InplaceAbn) 34 | except ImportError: 35 | pass 36 | 37 | 38 | def register_notrace_module(module: Type[nn.Module]): 39 | """ 40 | Any module not under timm.models.layers should get this decorator if we don't want to trace through it. 41 | """ 42 | _leaf_modules.add(module) 43 | return module 44 | 45 | 46 | # Functions we want to autowrap (treat them as leaves) 47 | _autowrap_functions = set() 48 | 49 | 50 | def register_notrace_function(func: Callable): 51 | """ 52 | Decorator for functions which ought not to be traced through 53 | """ 54 | _autowrap_functions.add(func) 55 | return func 56 | 57 | 58 | def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]): 59 | assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction' 60 | return _create_feature_extractor( 61 | model, return_nodes, 62 | tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)} 63 | ) 64 | 65 | 66 | class FeatureGraphNet(nn.Module): 67 | """ A FX Graph based feature extractor that works with the model feature_info metadata 68 | """ 69 | def __init__(self, model, out_indices, out_map=None): 70 | super().__init__() 71 | assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction' 72 | self.feature_info = _get_feature_info(model, out_indices) 73 | if out_map is not None: 74 | assert len(out_map) == len(out_indices) 75 | return_nodes = { 76 | info['module']: out_map[i] if out_map is not None else info['module'] 77 | for i, info in enumerate(self.feature_info) if i in out_indices} 78 | self.graph_module = create_feature_extractor(model, return_nodes) 79 | 80 | def forward(self, x): 81 | return list(self.graph_module(x).values()) 82 | 83 | 84 | class GraphExtractNet(nn.Module): 85 | """ A standalone feature extraction wrapper that maps dict -> list or single tensor 86 | NOTE: 87 | * one can use feature_extractor directly if dictionary output is desired 88 | * unlike FeatureGraphNet, this is intended to be used standalone and not with model feature_info 89 | metadata for builtin feature extraction mode 90 | * create_feature_extractor can be used directly if dictionary output is acceptable 91 | 92 | Args: 93 | model: model to extract features from 94 | return_nodes: node names to return features from (dict or list) 95 | squeeze_out: if only one output, and output in list format, flatten to single tensor 96 | """ 97 | def __init__(self, model, return_nodes: Union[Dict[str, str], List[str]], squeeze_out: bool = True): 98 | super().__init__() 99 | self.squeeze_out = squeeze_out 100 | self.graph_module = create_feature_extractor(model, return_nodes) 101 | 102 | def forward(self, x) -> Union[List[torch.Tensor], torch.Tensor]: 103 | out = list(self.graph_module(x).values()) 104 | if self.squeeze_out and len(out) == 1: 105 | return out[0] 106 | return out 107 | -------------------------------------------------------------------------------- /lib/models_timm/layers/activations.py: -------------------------------------------------------------------------------- 1 | """ Activations 2 | 3 | A collection of activations fn and modules with a common interface so that they can 4 | easily be swapped. All have an `inplace` arg even if not used. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | 9 | import torch 10 | from torch import nn as nn 11 | from torch.nn import functional as F 12 | 13 | 14 | def swish(x, inplace: bool = False): 15 | """Swish - Described in: https://arxiv.org/abs/1710.05941 16 | """ 17 | return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) 18 | 19 | 20 | class Swish(nn.Module): 21 | def __init__(self, inplace: bool = False): 22 | super(Swish, self).__init__() 23 | self.inplace = inplace 24 | 25 | def forward(self, x): 26 | return swish(x, self.inplace) 27 | 28 | 29 | def mish(x, inplace: bool = False): 30 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 31 | NOTE: I don't have a working inplace variant 32 | """ 33 | return x.mul(F.softplus(x).tanh()) 34 | 35 | 36 | class Mish(nn.Module): 37 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 38 | """ 39 | def __init__(self, inplace: bool = False): 40 | super(Mish, self).__init__() 41 | 42 | def forward(self, x): 43 | return mish(x) 44 | 45 | 46 | def sigmoid(x, inplace: bool = False): 47 | return x.sigmoid_() if inplace else x.sigmoid() 48 | 49 | 50 | # PyTorch has this, but not with a consistent inplace argmument interface 51 | class Sigmoid(nn.Module): 52 | def __init__(self, inplace: bool = False): 53 | super(Sigmoid, self).__init__() 54 | self.inplace = inplace 55 | 56 | def forward(self, x): 57 | return x.sigmoid_() if self.inplace else x.sigmoid() 58 | 59 | 60 | def tanh(x, inplace: bool = False): 61 | return x.tanh_() if inplace else x.tanh() 62 | 63 | 64 | # PyTorch has this, but not with a consistent inplace argmument interface 65 | class Tanh(nn.Module): 66 | def __init__(self, inplace: bool = False): 67 | super(Tanh, self).__init__() 68 | self.inplace = inplace 69 | 70 | def forward(self, x): 71 | return x.tanh_() if self.inplace else x.tanh() 72 | 73 | 74 | def hard_swish(x, inplace: bool = False): 75 | inner = F.relu6(x + 3.).div_(6.) 76 | return x.mul_(inner) if inplace else x.mul(inner) 77 | 78 | 79 | class HardSwish(nn.Module): 80 | def __init__(self, inplace: bool = False): 81 | super(HardSwish, self).__init__() 82 | self.inplace = inplace 83 | 84 | def forward(self, x): 85 | return hard_swish(x, self.inplace) 86 | 87 | 88 | def hard_sigmoid(x, inplace: bool = False): 89 | if inplace: 90 | return x.add_(3.).clamp_(0., 6.).div_(6.) 91 | else: 92 | return F.relu6(x + 3.) / 6. 93 | 94 | 95 | class HardSigmoid(nn.Module): 96 | def __init__(self, inplace: bool = False): 97 | super(HardSigmoid, self).__init__() 98 | self.inplace = inplace 99 | 100 | def forward(self, x): 101 | return hard_sigmoid(x, self.inplace) 102 | 103 | 104 | def hard_mish(x, inplace: bool = False): 105 | """ Hard Mish 106 | Experimental, based on notes by Mish author Diganta Misra at 107 | https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md 108 | """ 109 | if inplace: 110 | return x.mul_(0.5 * (x + 2).clamp(min=0, max=2)) 111 | else: 112 | return 0.5 * x * (x + 2).clamp(min=0, max=2) 113 | 114 | 115 | class HardMish(nn.Module): 116 | def __init__(self, inplace: bool = False): 117 | super(HardMish, self).__init__() 118 | self.inplace = inplace 119 | 120 | def forward(self, x): 121 | return hard_mish(x, self.inplace) 122 | 123 | 124 | class PReLU(nn.PReLU): 125 | """Applies PReLU (w/ dummy inplace arg) 126 | """ 127 | def __init__(self, num_parameters: int = 1, init: float = 0.25, inplace: bool = False) -> None: 128 | super(PReLU, self).__init__(num_parameters=num_parameters, init=init) 129 | 130 | def forward(self, input: torch.Tensor) -> torch.Tensor: 131 | return F.prelu(input, self.weight) 132 | 133 | 134 | def gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor: 135 | return F.gelu(x) 136 | 137 | 138 | class GELU(nn.Module): 139 | """Applies the Gaussian Error Linear Units function (w/ dummy inplace arg) 140 | """ 141 | def __init__(self, inplace: bool = False): 142 | super(GELU, self).__init__() 143 | 144 | def forward(self, input: torch.Tensor) -> torch.Tensor: 145 | return F.gelu(input) 146 | -------------------------------------------------------------------------------- /lib/models_timm/layers/squeeze_excite.py: -------------------------------------------------------------------------------- 1 | """ Squeeze-and-Excitation Channel Attention 2 | 3 | An SE implementation originally based on PyTorch SE-Net impl. 4 | Has since evolved with additional functionality / configuration. 5 | 6 | Paper: `Squeeze-and-Excitation Networks` - https://arxiv.org/abs/1709.01507 7 | 8 | Also included is Effective Squeeze-Excitation (ESE). 9 | Paper: `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 10 | 11 | Hacked together by / Copyright 2021 Ross Wightman 12 | """ 13 | from torch import nn as nn 14 | 15 | from .create_act import create_act_layer 16 | from .helpers import make_divisible 17 | 18 | 19 | class SEModule(nn.Module): 20 | """ SE Module as defined in original SE-Nets with a few additions 21 | Additions include: 22 | * divisor can be specified to keep channels % div == 0 (default: 8) 23 | * reduction channels can be specified directly by arg (if rd_channels is set) 24 | * reduction channels can be specified by float rd_ratio (default: 1/16) 25 | * global max pooling can be added to the squeeze aggregation 26 | * customizable activation, normalization, and gate layer 27 | """ 28 | def __init__( 29 | self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False, 30 | bias=True, act_layer=nn.ReLU, norm_layer=None, gate_layer='sigmoid'): 31 | super(SEModule, self).__init__() 32 | self.add_maxpool = add_maxpool 33 | if not rd_channels: 34 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) 35 | self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=bias) 36 | self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity() 37 | self.act = create_act_layer(act_layer, inplace=True) 38 | self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=bias) 39 | self.gate = create_act_layer(gate_layer) 40 | 41 | def forward(self, x): 42 | x_se = x.mean((2, 3), keepdim=True) 43 | if self.add_maxpool: 44 | # experimental codepath, may remove or change 45 | x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True) 46 | x_se = self.fc1(x_se) 47 | x_se = self.act(self.bn(x_se)) 48 | x_se = self.fc2(x_se) 49 | return x * self.gate(x_se) 50 | 51 | 52 | SqueezeExcite = SEModule # alias 53 | 54 | 55 | class EffectiveSEModule(nn.Module): 56 | """ 'Effective Squeeze-Excitation 57 | From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 58 | """ 59 | def __init__(self, channels, add_maxpool=False, gate_layer='hard_sigmoid', **_): 60 | super(EffectiveSEModule, self).__init__() 61 | self.add_maxpool = add_maxpool 62 | self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0) 63 | self.gate = create_act_layer(gate_layer) 64 | 65 | def forward(self, x): 66 | x_se = x.mean((2, 3), keepdim=True) 67 | if self.add_maxpool: 68 | # experimental codepath, may remove or change 69 | x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True) 70 | x_se = self.fc(x_se) 71 | return x * self.gate(x_se) 72 | 73 | 74 | EffectiveSqueezeExcite = EffectiveSEModule # alias 75 | 76 | 77 | class SqueezeExciteCl(nn.Module): 78 | """ SE Module as defined in original SE-Nets with a few additions 79 | Additions include: 80 | * divisor can be specified to keep channels % div == 0 (default: 8) 81 | * reduction channels can be specified directly by arg (if rd_channels is set) 82 | * reduction channels can be specified by float rd_ratio (default: 1/16) 83 | * global max pooling can be added to the squeeze aggregation 84 | * customizable activation, normalization, and gate layer 85 | """ 86 | def __init__( 87 | self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, 88 | bias=True, act_layer=nn.ReLU, gate_layer='sigmoid'): 89 | super().__init__() 90 | if not rd_channels: 91 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) 92 | self.fc1 = nn.Linear(channels, rd_channels, bias=bias) 93 | self.act = create_act_layer(act_layer, inplace=True) 94 | self.fc2 = nn.Linear(rd_channels, channels, bias=bias) 95 | self.gate = create_act_layer(gate_layer) 96 | 97 | def forward(self, x): 98 | x_se = x.mean((1, 2), keepdims=True) # FIXME avg dim [1:n-1], don't assume 2D NHWC 99 | x_se = self.fc1(x_se) 100 | x_se = self.act(x_se) 101 | x_se = self.fc2(x_se) 102 | return x * self.gate(x_se) -------------------------------------------------------------------------------- /lib/models_timm/layers/cbam.py: -------------------------------------------------------------------------------- 1 | """ CBAM (sort-of) Attention 2 | 3 | Experimental impl of CBAM: Convolutional Block Attention Module: https://arxiv.org/abs/1807.06521 4 | 5 | WARNING: Results with these attention layers have been mixed. They can significantly reduce performance on 6 | some tasks, especially fine-grained it seems. I may end up removing this impl. 7 | 8 | Hacked together by / Copyright 2020 Ross Wightman 9 | """ 10 | import torch 11 | from torch import nn as nn 12 | import torch.nn.functional as F 13 | 14 | from .conv_bn_act import ConvNormAct 15 | from .create_act import create_act_layer, get_act_layer 16 | from .helpers import make_divisible 17 | 18 | 19 | class ChannelAttn(nn.Module): 20 | """ Original CBAM channel attention module, currently avg + max pool variant only. 21 | """ 22 | def __init__( 23 | self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, 24 | act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): 25 | super(ChannelAttn, self).__init__() 26 | if not rd_channels: 27 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) 28 | self.fc1 = nn.Conv2d(channels, rd_channels, 1, bias=mlp_bias) 29 | self.act = act_layer(inplace=True) 30 | self.fc2 = nn.Conv2d(rd_channels, channels, 1, bias=mlp_bias) 31 | self.gate = create_act_layer(gate_layer) 32 | 33 | def forward(self, x): 34 | x_avg = self.fc2(self.act(self.fc1(x.mean((2, 3), keepdim=True)))) 35 | x_max = self.fc2(self.act(self.fc1(x.amax((2, 3), keepdim=True)))) 36 | return x * self.gate(x_avg + x_max) 37 | 38 | 39 | class LightChannelAttn(ChannelAttn): 40 | """An experimental 'lightweight' that sums avg + max pool first 41 | """ 42 | def __init__( 43 | self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, 44 | act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): 45 | super(LightChannelAttn, self).__init__( 46 | channels, rd_ratio, rd_channels, rd_divisor, act_layer, gate_layer, mlp_bias) 47 | 48 | def forward(self, x): 49 | x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * x.amax((2, 3), keepdim=True) 50 | x_attn = self.fc2(self.act(self.fc1(x_pool))) 51 | return x * F.sigmoid(x_attn) 52 | 53 | 54 | class SpatialAttn(nn.Module): 55 | """ Original CBAM spatial attention module 56 | """ 57 | def __init__(self, kernel_size=7, gate_layer='sigmoid'): 58 | super(SpatialAttn, self).__init__() 59 | self.conv = ConvNormAct(2, 1, kernel_size, apply_act=False) 60 | self.gate = create_act_layer(gate_layer) 61 | 62 | def forward(self, x): 63 | x_attn = torch.cat([x.mean(dim=1, keepdim=True), x.amax(dim=1, keepdim=True)], dim=1) 64 | x_attn = self.conv(x_attn) 65 | return x * self.gate(x_attn) 66 | 67 | 68 | class LightSpatialAttn(nn.Module): 69 | """An experimental 'lightweight' variant that sums avg_pool and max_pool results. 70 | """ 71 | def __init__(self, kernel_size=7, gate_layer='sigmoid'): 72 | super(LightSpatialAttn, self).__init__() 73 | self.conv = ConvNormAct(1, 1, kernel_size, apply_act=False) 74 | self.gate = create_act_layer(gate_layer) 75 | 76 | def forward(self, x): 77 | x_attn = 0.5 * x.mean(dim=1, keepdim=True) + 0.5 * x.amax(dim=1, keepdim=True) 78 | x_attn = self.conv(x_attn) 79 | return x * self.gate(x_attn) 80 | 81 | 82 | class CbamModule(nn.Module): 83 | def __init__( 84 | self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, 85 | spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): 86 | super(CbamModule, self).__init__() 87 | self.channel = ChannelAttn( 88 | channels, rd_ratio=rd_ratio, rd_channels=rd_channels, 89 | rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias) 90 | self.spatial = SpatialAttn(spatial_kernel_size, gate_layer=gate_layer) 91 | 92 | def forward(self, x): 93 | x = self.channel(x) 94 | x = self.spatial(x) 95 | return x 96 | 97 | 98 | class LightCbamModule(nn.Module): 99 | def __init__( 100 | self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, 101 | spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): 102 | super(LightCbamModule, self).__init__() 103 | self.channel = LightChannelAttn( 104 | channels, rd_ratio=rd_ratio, rd_channels=rd_channels, 105 | rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias) 106 | self.spatial = LightSpatialAttn(spatial_kernel_size) 107 | 108 | def forward(self, x): 109 | x = self.channel(x) 110 | x = self.spatial(x) 111 | return x 112 | 113 | -------------------------------------------------------------------------------- /lib/models_timm/convmixer.py: -------------------------------------------------------------------------------- 1 | """ ConvMixer 2 | 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | 7 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 8 | from timm.models.registry import register_model 9 | from .helpers import build_model_with_cfg, checkpoint_seq 10 | from .layers import SelectAdaptivePool2d 11 | 12 | 13 | def _cfg(url='', **kwargs): 14 | return { 15 | 'url': url, 16 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 17 | 'crop_pct': .96, 'interpolation': 'bicubic', 18 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head', 19 | 'first_conv': 'stem.0', 20 | **kwargs 21 | } 22 | 23 | 24 | default_cfgs = { 25 | 'convmixer_1536_20': _cfg(url='https://github.com/tmp-iclr/convmixer/releases/download/timm-v1.0/convmixer_1536_20_ks9_p7.pth.tar'), 26 | 'convmixer_768_32': _cfg(url='https://github.com/tmp-iclr/convmixer/releases/download/timm-v1.0/convmixer_768_32_ks7_p7_relu.pth.tar'), 27 | 'convmixer_1024_20_ks9_p14': _cfg(url='https://github.com/tmp-iclr/convmixer/releases/download/timm-v1.0/convmixer_1024_20_ks9_p14.pth.tar') 28 | } 29 | 30 | 31 | class Residual(nn.Module): 32 | def __init__(self, fn): 33 | super().__init__() 34 | self.fn = fn 35 | 36 | def forward(self, x): 37 | return self.fn(x) + x 38 | 39 | 40 | class ConvMixer(nn.Module): 41 | def __init__( 42 | self, dim, depth, kernel_size=9, patch_size=7, in_chans=3, num_classes=1000, global_pool='avg', 43 | act_layer=nn.GELU, **kwargs): 44 | super().__init__() 45 | self.num_classes = num_classes 46 | self.num_features = dim 47 | self.grad_checkpointing = False 48 | 49 | self.stem = nn.Sequential( 50 | nn.Conv2d(in_chans, dim, kernel_size=patch_size, stride=patch_size), 51 | act_layer(), 52 | nn.BatchNorm2d(dim) 53 | ) 54 | self.blocks = nn.Sequential( 55 | *[nn.Sequential( 56 | Residual(nn.Sequential( 57 | nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"), 58 | act_layer(), 59 | nn.BatchNorm2d(dim) 60 | )), 61 | nn.Conv2d(dim, dim, kernel_size=1), 62 | act_layer(), 63 | nn.BatchNorm2d(dim) 64 | ) for i in range(depth)] 65 | ) 66 | self.pooling = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) 67 | self.head = nn.Linear(dim, num_classes) if num_classes > 0 else nn.Identity() 68 | 69 | @torch.jit.ignore 70 | def group_matcher(self, coarse=False): 71 | matcher = dict(stem=r'^stem', blocks=r'^blocks\.(\d+)') 72 | return matcher 73 | 74 | @torch.jit.ignore 75 | def set_grad_checkpointing(self, enable=True): 76 | self.grad_checkpointing = enable 77 | 78 | @torch.jit.ignore 79 | def get_classifier(self): 80 | return self.head 81 | 82 | def reset_classifier(self, num_classes, global_pool=None): 83 | self.num_classes = num_classes 84 | if global_pool is not None: 85 | self.pooling = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) 86 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 87 | 88 | def forward_features(self, x): 89 | x = self.stem(x) 90 | if self.grad_checkpointing and not torch.jit.is_scripting(): 91 | x = checkpoint_seq(self.blocks, x) 92 | else: 93 | x = self.blocks(x) 94 | return x 95 | 96 | def forward_head(self, x, pre_logits: bool = False): 97 | x = self.pooling(x) 98 | return x if pre_logits else self.head(x) 99 | 100 | def forward(self, x): 101 | x = self.forward_features(x) 102 | x = self.forward_head(x) 103 | return x 104 | 105 | 106 | def _create_convmixer(variant, pretrained=False, **kwargs): 107 | return build_model_with_cfg(ConvMixer, variant, pretrained, **kwargs) 108 | 109 | 110 | @register_model 111 | def convmixer_1536_20(pretrained=False, **kwargs): 112 | model_args = dict(dim=1536, depth=20, kernel_size=9, patch_size=7, **kwargs) 113 | return _create_convmixer('convmixer_1536_20', pretrained, **model_args) 114 | 115 | 116 | @register_model 117 | def convmixer_768_32(pretrained=False, **kwargs): 118 | model_args = dict(dim=768, depth=32, kernel_size=7, patch_size=7, act_layer=nn.ReLU, **kwargs) 119 | return _create_convmixer('convmixer_768_32', pretrained, **model_args) 120 | 121 | 122 | @register_model 123 | def convmixer_1024_20_ks9_p14(pretrained=False, **kwargs): 124 | model_args = dict(dim=1024, depth=20, kernel_size=9, patch_size=14, **kwargs) 125 | return _create_convmixer('convmixer_1024_20_ks9_p14', pretrained, **model_args) -------------------------------------------------------------------------------- /lib/models_timm/layers/mlp.py: -------------------------------------------------------------------------------- 1 | """ MLP module w/ dropout and configurable activation layer 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from torch import nn as nn 6 | 7 | from .helpers import to_2tuple 8 | 9 | 10 | class Mlp(nn.Module): 11 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks 12 | """ 13 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.): 14 | super().__init__() 15 | out_features = out_features or in_features 16 | hidden_features = hidden_features or in_features 17 | bias = to_2tuple(bias) 18 | drop_probs = to_2tuple(drop) 19 | 20 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) 21 | self.act = act_layer() 22 | self.drop1 = nn.Dropout(drop_probs[0]) 23 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) 24 | self.drop2 = nn.Dropout(drop_probs[1]) 25 | 26 | def forward(self, x): 27 | x = self.fc1(x) 28 | x = self.act(x) 29 | x = self.drop1(x) 30 | x = self.fc2(x) 31 | x = self.drop2(x) 32 | return x 33 | 34 | 35 | class GluMlp(nn.Module): 36 | """ MLP w/ GLU style gating 37 | See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202 38 | """ 39 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, bias=True, drop=0.): 40 | super().__init__() 41 | out_features = out_features or in_features 42 | hidden_features = hidden_features or in_features 43 | assert hidden_features % 2 == 0 44 | bias = to_2tuple(bias) 45 | drop_probs = to_2tuple(drop) 46 | 47 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) 48 | self.act = act_layer() 49 | self.drop1 = nn.Dropout(drop_probs[0]) 50 | self.fc2 = nn.Linear(hidden_features // 2, out_features, bias=bias[1]) 51 | self.drop2 = nn.Dropout(drop_probs[1]) 52 | 53 | def init_weights(self): 54 | # override init of fc1 w/ gate portion set to weight near zero, bias=1 55 | fc1_mid = self.fc1.bias.shape[0] // 2 56 | nn.init.ones_(self.fc1.bias[fc1_mid:]) 57 | nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6) 58 | 59 | def forward(self, x): 60 | x = self.fc1(x) 61 | x, gates = x.chunk(2, dim=-1) 62 | x = x * self.act(gates) 63 | x = self.drop1(x) 64 | x = self.fc2(x) 65 | x = self.drop2(x) 66 | return x 67 | 68 | 69 | class GatedMlp(nn.Module): 70 | """ MLP as used in gMLP 71 | """ 72 | def __init__( 73 | self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, 74 | gate_layer=None, bias=True, drop=0.): 75 | super().__init__() 76 | out_features = out_features or in_features 77 | hidden_features = hidden_features or in_features 78 | bias = to_2tuple(bias) 79 | drop_probs = to_2tuple(drop) 80 | 81 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) 82 | self.act = act_layer() 83 | self.drop1 = nn.Dropout(drop_probs[0]) 84 | if gate_layer is not None: 85 | assert hidden_features % 2 == 0 86 | self.gate = gate_layer(hidden_features) 87 | hidden_features = hidden_features // 2 # FIXME base reduction on gate property? 88 | else: 89 | self.gate = nn.Identity() 90 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) 91 | self.drop2 = nn.Dropout(drop_probs[1]) 92 | 93 | def forward(self, x): 94 | x = self.fc1(x) 95 | x = self.act(x) 96 | x = self.drop1(x) 97 | x = self.gate(x) 98 | x = self.fc2(x) 99 | x = self.drop2(x) 100 | return x 101 | 102 | 103 | class ConvMlp(nn.Module): 104 | """ MLP using 1x1 convs that keeps spatial dims 105 | """ 106 | def __init__( 107 | self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, 108 | norm_layer=None, bias=True, drop=0.): 109 | super().__init__() 110 | out_features = out_features or in_features 111 | hidden_features = hidden_features or in_features 112 | bias = to_2tuple(bias) 113 | 114 | self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=bias[0]) 115 | self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity() 116 | self.act = act_layer() 117 | self.drop = nn.Dropout(drop) 118 | self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=bias[1]) 119 | 120 | def forward(self, x): 121 | x = self.fc1(x) 122 | x = self.norm(x) 123 | x = self.act(x) 124 | x = self.drop(x) 125 | x = self.fc2(x) 126 | return x 127 | -------------------------------------------------------------------------------- /lib/models_timm/layers/norm.py: -------------------------------------------------------------------------------- 1 | """ Normalization layers and wrappers 2 | 3 | Norm layer definitions that support fast norm and consistent channel arg order (always first arg). 4 | 5 | Hacked together by / Copyright 2022 Ross Wightman 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm 13 | 14 | 15 | class GroupNorm(nn.GroupNorm): 16 | def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True): 17 | # NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN 18 | super().__init__(num_groups, num_channels, eps=eps, affine=affine) 19 | self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) 20 | 21 | def forward(self, x): 22 | if self.fast_norm: 23 | return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps) 24 | else: 25 | return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) 26 | 27 | 28 | class GroupNorm1(nn.GroupNorm): 29 | """ Group Normalization with 1 group. 30 | Input: tensor in shape [B, C, *] 31 | """ 32 | 33 | def __init__(self, num_channels, **kwargs): 34 | super().__init__(1, num_channels, **kwargs) 35 | self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) 36 | 37 | def forward(self, x: torch.Tensor) -> torch.Tensor: 38 | if self.fast_norm: 39 | return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps) 40 | else: 41 | return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) 42 | 43 | 44 | class LayerNorm(nn.LayerNorm): 45 | """ LayerNorm w/ fast norm option 46 | """ 47 | def __init__(self, num_channels, eps=1e-6, affine=True): 48 | super().__init__(num_channels, eps=eps, elementwise_affine=affine) 49 | self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) 50 | 51 | def forward(self, x: torch.Tensor) -> torch.Tensor: 52 | if self._fast_norm: 53 | x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 54 | else: 55 | x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 56 | return x 57 | 58 | 59 | class LayerNorm2d(nn.LayerNorm): 60 | """ LayerNorm for channels of '2D' spatial NCHW tensors """ 61 | def __init__(self, num_channels, eps=1e-6, affine=True): 62 | super().__init__(num_channels, eps=eps, elementwise_affine=affine) 63 | self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) 64 | 65 | def forward(self, x: torch.Tensor) -> torch.Tensor: 66 | x = x.permute(0, 2, 3, 1) 67 | if self._fast_norm: 68 | x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 69 | else: 70 | x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 71 | x = x.permute(0, 3, 1, 2) 72 | return x 73 | 74 | 75 | def _is_contiguous(tensor: torch.Tensor) -> bool: 76 | # jit is oh so lovely :/ 77 | if torch.jit.is_scripting(): 78 | return tensor.is_contiguous() 79 | else: 80 | return tensor.is_contiguous(memory_format=torch.contiguous_format) 81 | 82 | 83 | @torch.jit.script 84 | def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float): 85 | s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True) 86 | x = (x - u) * torch.rsqrt(s + eps) 87 | x = x * weight[:, None, None] + bias[:, None, None] 88 | return x 89 | 90 | 91 | def _layer_norm_cf_sqm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float): 92 | u = x.mean(dim=1, keepdim=True) 93 | s = ((x * x).mean(dim=1, keepdim=True) - (u * u)).clamp(0) 94 | x = (x - u) * torch.rsqrt(s + eps) 95 | x = x * weight.view(1, -1, 1, 1) + bias.view(1, -1, 1, 1) 96 | return x 97 | 98 | 99 | class LayerNormExp2d(nn.LayerNorm): 100 | """ LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W). 101 | 102 | Experimental implementation w/ manual norm for tensors non-contiguous tensors. 103 | 104 | This improves throughput in some scenarios (tested on Ampere GPU), esp w/ channels_last 105 | layout. However, benefits are not always clear and can perform worse on other GPUs. 106 | """ 107 | 108 | def __init__(self, num_channels, eps=1e-6): 109 | super().__init__(num_channels, eps=eps) 110 | 111 | def forward(self, x) -> torch.Tensor: 112 | if _is_contiguous(x): 113 | x = F.layer_norm( 114 | x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) 115 | else: 116 | x = _layer_norm_cf(x, self.weight, self.bias, self.eps) 117 | return x 118 | -------------------------------------------------------------------------------- /lib/models_timm/layers/weight_init.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import warnings 4 | 5 | from torch.nn.init import _calculate_fan_in_and_fan_out 6 | 7 | 8 | def _trunc_normal_(tensor, mean, std, a, b): 9 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 10 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 11 | def norm_cdf(x): 12 | # Computes standard normal cumulative distribution function 13 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 14 | 15 | if (mean < a - 2 * std) or (mean > b + 2 * std): 16 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 17 | "The distribution of values may be incorrect.", 18 | stacklevel=2) 19 | 20 | # Values are generated by using a truncated uniform distribution and 21 | # then using the inverse CDF for the normal distribution. 22 | # Get upper and lower cdf values 23 | l = norm_cdf((a - mean) / std) 24 | u = norm_cdf((b - mean) / std) 25 | 26 | # Uniformly fill tensor with values from [l, u], then translate to 27 | # [2l-1, 2u-1]. 28 | tensor.uniform_(2 * l - 1, 2 * u - 1) 29 | 30 | # Use inverse cdf transform for normal distribution to get truncated 31 | # standard normal 32 | tensor.erfinv_() 33 | 34 | # Transform to proper mean, std 35 | tensor.mul_(std * math.sqrt(2.)) 36 | tensor.add_(mean) 37 | 38 | # Clamp to ensure it's in the proper range 39 | tensor.clamp_(min=a, max=b) 40 | return tensor 41 | 42 | 43 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 44 | # type: (Tensor, float, float, float, float) -> Tensor 45 | r"""Fills the input Tensor with values drawn from a truncated 46 | normal distribution. The values are effectively drawn from the 47 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 48 | with values outside :math:`[a, b]` redrawn until they are within 49 | the bounds. The method used for generating the random values works 50 | best when :math:`a \leq \text{mean} \leq b`. 51 | 52 | NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are 53 | applied while sampling the normal with mean/std applied, therefore a, b args 54 | should be adjusted to match the range of mean, std args. 55 | 56 | Args: 57 | tensor: an n-dimensional `torch.Tensor` 58 | mean: the mean of the normal distribution 59 | std: the standard deviation of the normal distribution 60 | a: the minimum cutoff value 61 | b: the maximum cutoff value 62 | Examples: 63 | >>> w = torch.empty(3, 5) 64 | >>> nn.init.trunc_normal_(w) 65 | """ 66 | with torch.no_grad(): 67 | return _trunc_normal_(tensor, mean, std, a, b) 68 | 69 | 70 | def trunc_normal_tf_(tensor, mean=0., std=1., a=-2., b=2.): 71 | # type: (Tensor, float, float, float, float) -> Tensor 72 | r"""Fills the input Tensor with values drawn from a truncated 73 | normal distribution. The values are effectively drawn from the 74 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 75 | with values outside :math:`[a, b]` redrawn until they are within 76 | the bounds. The method used for generating the random values works 77 | best when :math:`a \leq \text{mean} \leq b`. 78 | 79 | NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the 80 | bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 81 | and the result is subsquently scaled and shifted by the mean and std args. 82 | 83 | Args: 84 | tensor: an n-dimensional `torch.Tensor` 85 | mean: the mean of the normal distribution 86 | std: the standard deviation of the normal distribution 87 | a: the minimum cutoff value 88 | b: the maximum cutoff value 89 | Examples: 90 | >>> w = torch.empty(3, 5) 91 | >>> nn.init.trunc_normal_(w) 92 | """ 93 | with torch.no_grad(): 94 | _trunc_normal_(tensor, 0, 1.0, a, b) 95 | tensor.mul_(std).add_(mean) 96 | return tensor 97 | 98 | 99 | def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'): 100 | fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) 101 | if mode == 'fan_in': 102 | denom = fan_in 103 | elif mode == 'fan_out': 104 | denom = fan_out 105 | elif mode == 'fan_avg': 106 | denom = (fan_in + fan_out) / 2 107 | 108 | variance = scale / denom 109 | 110 | if distribution == "truncated_normal": 111 | # constant is stddev of standard normal truncated to (-2, 2) 112 | trunc_normal_tf_(tensor, std=math.sqrt(variance) / .87962566103423978) 113 | elif distribution == "normal": 114 | with torch.no_grad(): 115 | tensor.normal_(std=math.sqrt(variance)) 116 | elif distribution == "uniform": 117 | bound = math.sqrt(3 * variance) 118 | with torch.no_grad(): 119 | tensor.uniform_(-bound, bound) 120 | else: 121 | raise ValueError(f"invalid distribution {distribution}") 122 | 123 | 124 | def lecun_normal_(tensor): 125 | variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal') 126 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # G-CASCADE 2 | 3 | Official Pytorch implementation of [G-CASCADE: Efficient Cascaded Graph Convolutional Decoding for 2D Medical Image Segmentation](https://openaccess.thecvf.com/content/WACV2024/html/Rahman_G-CASCADE_Efficient_Cascaded_Graph_Convolutional_Decoding_for_2D_Medical_Image_WACV_2024_paper.html) WACV 2024. [arxiv](https://arxiv.org/abs/2310.16175) [code](https://github.com/SLDGroup/G-CASCADE) 4 |
5 | [Md Mostafijur Rahman](https://github.com/mostafij-rahman), [Radu Marculescu](https://radum.ece.utexas.edu/) 6 |

The University of Texas at Austin

7 | 8 | ### 🔍 **Check out our CVPR 2024 paper! [EMCAD](https://github.com/SLDGroup/EMCAD)** 9 | ### 🔍 **Check out our CVPRW 2024 paper! [PP-SAM](https://github.com/SLDGroup/PP-SAM)** 10 | ### 🔍 **Check out our MIDL 2023 paper! [MERIT](https://github.com/SLDGroup/MERIT)** 11 | ### 🔍 **Check out our WACV 2023 paper! [CASCADE](https://github.com/SLDGroup/CASCADE)** 12 | 13 | ## Architecture 14 | 15 |

16 | 18 |

19 | 20 | ## Qualitative Results 21 | 22 |

23 | 25 |

26 | 27 | ## Usage: 28 | ### Recommended environment: 29 | ``` 30 | Python 3.8 31 | Pytorch 1.11.0 32 | torchvision 0.12.0 33 | ``` 34 | Please use ```pip install -r requirements.txt``` to install the dependencies. 35 | 36 | ### Data preparation: 37 | - **Synapse Multi-organ dataset:** 38 | Sign up in the [official Synapse website](https://www.synapse.org/#!Synapse:syn3193805/wiki/89480) and download the dataset. Then split the 'RawData' folder into 'TrainSet' (18 scans) and 'TestSet' (12 scans) following the [TransUNet's](https://github.com/Beckschen/TransUNet/blob/main/datasets/README.md) lists and put in the './data/synapse/Abdomen/RawData/' folder. Finally, preprocess using ```python ./utils/preprocess_synapse_data.py``` or download the [preprocessed data](https://drive.google.com/file/d/1wvmw8DVyDKr5sOAFn5zUpfhbK4Vxjze4/view) and save in the './data/synapse/' folder. 39 | Note: If you use the preprocessed data from [TransUNet](https://drive.google.com/drive/folders/1ACJEoTp-uqfFJ73qS3eUObQh52nGuzCd), please make necessary changes (i.e., remove the code segment (line# 88-94) to convert groundtruth labels from 14 to 9 classes) in the utils/dataset_synapse.py. 40 | 41 | - **ACDC dataset:** 42 | Download the preprocessed ACDC dataset from [Google Drive](https://drive.google.com/file/d/1CruCQ-jjvA97BX-LIYwXaRMLmp3DN9zc/view) and move into './data/ACDC/' folder. 43 | 44 | - **Polyp datasets:** 45 | Download the training and testing datasets [Google Drive](https://drive.google.com/file/d/1pFxb9NbM8mj_rlSawTlcXG1OdVGAbRQC/view?usp=sharing) and move them into './data/polyp/' folder. 46 | 47 | - **ISIC2018 dataset:** 48 | Download the training and validation datasets from https://challenge.isic-archive.com/landing/2018/ and merge them together. Afterwards, split the dataset into 80%, 10%, and 10% training, validation, and testing datasets, respectively. Move the splited dataset into './data/ISIC2018/' folder. 49 | 50 | ### Pretrained model: 51 | You should download the pretrained PVTv2 model from [Google Drive](https://drive.google.com/drive/folders/1d5F1VjEF1AtTkNO93JwVBBSivE8zImiF?usp=share), and then put it in the './pretrained_pth/pvt/' folder for initialization. Similarly, you should download the pretrained MaxViT models from [Google Drive](https://drive.google.com/drive/folders/1wuJ8zekQpPfmydqVhfO1LYSKxrik4Ukf?usp=share_link), and then put it in the './pretrained_pth/maxvit/' folder for initialization. 52 | 53 | ### Training: 54 | ``` 55 | cd into G-CASCADE 56 | ``` 57 | 58 | For Synapse Multi-organ dataset training, run ```CUDA_VISIBLE_DEVICES=0 python -W ignore train_synapse.py``` 59 | 60 | For ACDC dataset training, run ```CUDA_VISIBLE_DEVICES=0 python -W ignore train_ACDC.py``` 61 | 62 | For Polyp datasets training, run ```CUDA_VISIBLE_DEVICES=0 python -W ignore train_polyp.py``` 63 | 64 | For ISIC2018 dataset training, run ```CUDA_VISIBLE_DEVICES=0 python -W ignore train_ISIC2018.py``` 65 | 66 | ### Testing: 67 | ``` 68 | cd into G-CASCADE 69 | ``` 70 | 71 | For Synapse Multi-organ dataset testing, run ```CUDA_VISIBLE_DEVICES=0 python -W ignore test_synapse.py``` 72 | 73 | For ACDC dataset testing, run ```CUDA_VISIBLE_DEVICES=0 python -W ignore test_ACDC.py``` 74 | 75 | For Polyp dataset testing, run ```CUDA_VISIBLE_DEVICES=0 python -W ignore test_polyp.py``` 76 | 77 | For ISIC2018 dataset testing, run ```CUDA_VISIBLE_DEVICES=0 python -W ignore test_ISIC2018.py``` 78 | 79 | ## Acknowledgement 80 | We are very grateful for these excellent works [timm](https://github.com/huggingface/pytorch-image-models), [MERIT](https://github.com/SLDGroup/MERIT), [CASCADE](https://github.com/SLDGroup/CASCADE), [PraNet](https://github.com/DengPingFan/PraNet), [Polyp-PVT](https://github.com/DengPingFan/Polyp-PVT) and [TransUNet](https://github.com/Beckschen/TransUNet), which have provided the basis for our framework. 81 | 82 | ## Citations 83 | ``` 84 | @inproceedings{rahman2024g, 85 | title={G-CASCADE: Efficient Cascaded Graph Convolutional Decoding for 2D Medical Image Segmentation}, 86 | author={Rahman, Md Mostafijur and Marculescu, Radu}, 87 | booktitle={Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision}, 88 | pages={7728--7737}, 89 | year={2024} 90 | } 91 | ``` 92 | -------------------------------------------------------------------------------- /lib/models_timm/layers/attention_pool2d.py: -------------------------------------------------------------------------------- 1 | """ Attention Pool 2D 2 | 3 | Implementations of 2D spatial feature pooling using multi-head attention instead of average pool. 4 | 5 | Based on idea in CLIP by OpenAI, licensed Apache 2.0 6 | https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py 7 | 8 | Hacked together by / Copyright 2021 Ross Wightman 9 | """ 10 | from typing import Union, Tuple 11 | 12 | import torch 13 | import torch.nn as nn 14 | 15 | from .helpers import to_2tuple 16 | from .pos_embed import apply_rot_embed, RotaryEmbedding 17 | from .weight_init import trunc_normal_ 18 | 19 | 20 | class RotAttentionPool2d(nn.Module): 21 | """ Attention based 2D feature pooling w/ rotary (relative) pos embedding. 22 | This is a multi-head attention based replacement for (spatial) average pooling in NN architectures. 23 | 24 | Adapted from the AttentionPool2d in CLIP w/ rotary embedding instead of learned embed. 25 | https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py 26 | 27 | NOTE: While this impl does not require a fixed feature size, performance at differeing resolutions from 28 | train varies widely and falls off dramatically. I'm not sure if there is a way around this... -RW 29 | """ 30 | def __init__( 31 | self, 32 | in_features: int, 33 | out_features: int = None, 34 | embed_dim: int = None, 35 | num_heads: int = 4, 36 | qkv_bias: bool = True, 37 | ): 38 | super().__init__() 39 | embed_dim = embed_dim or in_features 40 | out_features = out_features or in_features 41 | self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) 42 | self.proj = nn.Linear(embed_dim, out_features) 43 | self.num_heads = num_heads 44 | assert embed_dim % num_heads == 0 45 | self.head_dim = embed_dim // num_heads 46 | self.scale = self.head_dim ** -0.5 47 | self.pos_embed = RotaryEmbedding(self.head_dim) 48 | 49 | trunc_normal_(self.qkv.weight, std=in_features ** -0.5) 50 | nn.init.zeros_(self.qkv.bias) 51 | 52 | def forward(self, x): 53 | B, _, H, W = x.shape 54 | N = H * W 55 | x = x.reshape(B, -1, N).permute(0, 2, 1) 56 | 57 | x = torch.cat([x.mean(1, keepdim=True), x], dim=1) 58 | 59 | x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 60 | q, k, v = x[0], x[1], x[2] 61 | 62 | qc, q = q[:, :, :1], q[:, :, 1:] 63 | sin_emb, cos_emb = self.pos_embed.get_embed((H, W)) 64 | q = apply_rot_embed(q, sin_emb, cos_emb) 65 | q = torch.cat([qc, q], dim=2) 66 | 67 | kc, k = k[:, :, :1], k[:, :, 1:] 68 | k = apply_rot_embed(k, sin_emb, cos_emb) 69 | k = torch.cat([kc, k], dim=2) 70 | 71 | attn = (q @ k.transpose(-2, -1)) * self.scale 72 | attn = attn.softmax(dim=-1) 73 | 74 | x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1) 75 | x = self.proj(x) 76 | return x[:, 0] 77 | 78 | 79 | class AttentionPool2d(nn.Module): 80 | """ Attention based 2D feature pooling w/ learned (absolute) pos embedding. 81 | This is a multi-head attention based replacement for (spatial) average pooling in NN architectures. 82 | 83 | It was based on impl in CLIP by OpenAI 84 | https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py 85 | 86 | NOTE: This requires feature size upon construction and well prevent adaptive sizing of the network. 87 | """ 88 | def __init__( 89 | self, 90 | in_features: int, 91 | feat_size: Union[int, Tuple[int, int]], 92 | out_features: int = None, 93 | embed_dim: int = None, 94 | num_heads: int = 4, 95 | qkv_bias: bool = True, 96 | ): 97 | super().__init__() 98 | 99 | embed_dim = embed_dim or in_features 100 | out_features = out_features or in_features 101 | assert embed_dim % num_heads == 0 102 | self.feat_size = to_2tuple(feat_size) 103 | self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) 104 | self.proj = nn.Linear(embed_dim, out_features) 105 | self.num_heads = num_heads 106 | self.head_dim = embed_dim // num_heads 107 | self.scale = self.head_dim ** -0.5 108 | 109 | spatial_dim = self.feat_size[0] * self.feat_size[1] 110 | self.pos_embed = nn.Parameter(torch.zeros(spatial_dim + 1, in_features)) 111 | trunc_normal_(self.pos_embed, std=in_features ** -0.5) 112 | trunc_normal_(self.qkv.weight, std=in_features ** -0.5) 113 | nn.init.zeros_(self.qkv.bias) 114 | 115 | def forward(self, x): 116 | B, _, H, W = x.shape 117 | N = H * W 118 | assert self.feat_size[0] == H 119 | assert self.feat_size[1] == W 120 | x = x.reshape(B, -1, N).permute(0, 2, 1) 121 | x = torch.cat([x.mean(1, keepdim=True), x], dim=1) 122 | x = x + self.pos_embed.unsqueeze(0).to(x.dtype) 123 | 124 | x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 125 | q, k, v = x[0], x[1], x[2] 126 | attn = (q @ k.transpose(-2, -1)) * self.scale 127 | attn = attn.softmax(dim=-1) 128 | 129 | x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1) 130 | x = self.proj(x) 131 | return x[:, 0] 132 | -------------------------------------------------------------------------------- /train_synapse.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import random 5 | import numpy as np 6 | import time 7 | 8 | import torch 9 | import torch.backends.cudnn as cudnn 10 | 11 | from lib.networks import PVT_GCASCADE, MERIT_GCASCADE 12 | 13 | from trainer import trainer_synapse 14 | from torchsummaryX import summary 15 | from ptflops import get_model_complexity_info 16 | 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--encoder', type=str, 20 | default='PVT', help='Name of encoder: PVT or MERIT') 21 | parser.add_argument('--skip_aggregation', type=str, 22 | default='additive', help='Type of skip-aggregation: additive or concatenation') 23 | parser.add_argument('--root_path', type=str, 24 | default='./data/synapse/train_npz', help='root dir for data') 25 | parser.add_argument('--volume_path', type=str, 26 | default='./data/synapse/test_vol_h5', help='root dir for validation volume data') 27 | parser.add_argument('--dataset', type=str, 28 | default='Synapse', help='experiment_name') 29 | parser.add_argument('--list_dir', type=str, 30 | default='./lists/lists_Synapse', help='list dir') 31 | parser.add_argument('--num_classes', type=int, 32 | default=9, help='output channel of network') 33 | parser.add_argument('--max_iterations', type=int, 34 | default=30000, help='maximum epoch number to train') 35 | parser.add_argument('--max_epochs', type=int, 36 | default=300, help='maximum epoch number to train') 37 | parser.add_argument('--batch_size', type=int, 38 | default=6, help='batch_size per gpu') #6 39 | parser.add_argument('--n_gpu', type=int, default=1, help='total gpu') 40 | parser.add_argument('--deterministic', type=int, default=1, 41 | help='whether use deterministic training') 42 | parser.add_argument('--base_lr', type=float, default=0.0001, 43 | help='segmentation network learning rate') 44 | parser.add_argument('--img_size', type=int, 45 | default=224, help='input patch size of network input') #256 46 | parser.add_argument('--seed', type=int, 47 | default=2222, help='random seed') 48 | 49 | args = parser.parse_args() 50 | 51 | if __name__ == "__main__": 52 | if not args.deterministic: 53 | cudnn.benchmark = True 54 | cudnn.deterministic = False 55 | else: 56 | cudnn.benchmark = False 57 | cudnn.deterministic = True 58 | 59 | random.seed(args.seed) 60 | np.random.seed(args.seed) 61 | torch.manual_seed(args.seed) 62 | torch.cuda.manual_seed(args.seed) 63 | 64 | 65 | dataset_name = args.dataset 66 | dataset_config = { 67 | 'Synapse': { 68 | 'root_path': args.root_path, 69 | 'volume_path': args.volume_path, 70 | 'list_dir': args.list_dir, 71 | 'num_classes': args.num_classes, 72 | 'z_spacing': 1, 73 | }, 74 | } 75 | args.num_classes = dataset_config[dataset_name]['num_classes'] 76 | args.root_path = dataset_config[dataset_name]['root_path'] 77 | args.volume_path = dataset_config[dataset_name]['volume_path'] 78 | args.z_spacing = dataset_config[dataset_name]['z_spacing'] 79 | args.list_dir = dataset_config[dataset_name]['list_dir'] 80 | args.is_pretrain = True 81 | args.exp = 'PVT_GCASCADE_MUTATION_w3_7_Run1_' + dataset_name + str(args.img_size) 82 | snapshot_path = "model_pth/{}/{}".format(args.exp, 'PVT_GCASCADE_MUTATION_w3_7_Run1') 83 | snapshot_path = snapshot_path + '_pretrain' if args.is_pretrain else snapshot_path 84 | snapshot_path = snapshot_path+'_'+str(args.max_iterations)[0:2]+'k' if args.max_iterations != 30000 else snapshot_path 85 | snapshot_path = snapshot_path + '_epo' +str(args.max_epochs) if args.max_epochs != 30 else snapshot_path 86 | snapshot_path = snapshot_path+'_bs'+str(args.batch_size) 87 | snapshot_path = snapshot_path + '_lr' + str(args.base_lr) if args.base_lr != 0.01 else snapshot_path 88 | snapshot_path = snapshot_path + '_'+str(args.img_size) 89 | snapshot_path = snapshot_path + '_s'+str(args.seed) if args.seed!=1234 else snapshot_path 90 | #current_time = time.strftime("%H%M%S") 91 | #print("The current time is", current_time) 92 | #snapshot_path = snapshot_path +'_t'+current_time 93 | 94 | if not os.path.exists(snapshot_path): 95 | os.makedirs(snapshot_path) 96 | 97 | if args.encoder=='PVT': 98 | net = PVT_GCASCADE(n_class=args.num_classes, img_size=args.img_size, k=11, padding=5, conv='mr', gcb_act='gelu', skip_aggregation=args.skip_aggregation) 99 | elif args.encoder=='MERIT': 100 | net = MERIT_GCASCADE(n_class=args.num_classes, img_size_s1=(args.img_size,args.img_size), img_size_s2=(224,224), k=11, padding=5, conv='mr', gcb_act='gelu', skip_aggregation=args.skip_aggregation) 101 | else: 102 | print('Implementation not found for this encoder. Exiting!') 103 | sys.exit() 104 | 105 | print('Model %s created' % (args.encoder+'-GCASCADE: ')) 106 | 107 | net = net.cuda() 108 | 109 | macs, params = get_model_complexity_info(net, (3, args.img_size, args.img_size), as_strings=True, 110 | print_per_layer_stat=True, verbose=True) 111 | print('{:<30} {:<8}'.format('Computational complexity: ', macs)) 112 | print('{:<30} {:<8}'.format('Number of parameters: ', params)) 113 | 114 | 115 | trainer = {'Synapse': trainer_synapse,} 116 | trainer[dataset_name](args, net, snapshot_path) 117 | -------------------------------------------------------------------------------- /lib/models_timm/layers/cond_conv2d.py: -------------------------------------------------------------------------------- 1 | """ PyTorch Conditionally Parameterized Convolution (CondConv) 2 | 3 | Paper: CondConv: Conditionally Parameterized Convolutions for Efficient Inference 4 | (https://arxiv.org/abs/1904.04971) 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | 9 | import math 10 | from functools import partial 11 | import numpy as np 12 | import torch 13 | from torch import nn as nn 14 | from torch.nn import functional as F 15 | 16 | from .helpers import to_2tuple 17 | from .conv2d_same import conv2d_same 18 | from .padding import get_padding_value 19 | 20 | 21 | def get_condconv_initializer(initializer, num_experts, expert_shape): 22 | def condconv_initializer(weight): 23 | """CondConv initializer function.""" 24 | num_params = np.prod(expert_shape) 25 | if (len(weight.shape) != 2 or weight.shape[0] != num_experts or 26 | weight.shape[1] != num_params): 27 | raise (ValueError( 28 | 'CondConv variables must have shape [num_experts, num_params]')) 29 | for i in range(num_experts): 30 | initializer(weight[i].view(expert_shape)) 31 | return condconv_initializer 32 | 33 | 34 | class CondConv2d(nn.Module): 35 | """ Conditionally Parameterized Convolution 36 | Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py 37 | 38 | Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion: 39 | https://github.com/pytorch/pytorch/issues/17983 40 | """ 41 | __constants__ = ['in_channels', 'out_channels', 'dynamic_padding'] 42 | 43 | def __init__(self, in_channels, out_channels, kernel_size=3, 44 | stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4): 45 | super(CondConv2d, self).__init__() 46 | 47 | self.in_channels = in_channels 48 | self.out_channels = out_channels 49 | self.kernel_size = to_2tuple(kernel_size) 50 | self.stride = to_2tuple(stride) 51 | padding_val, is_padding_dynamic = get_padding_value( 52 | padding, kernel_size, stride=stride, dilation=dilation) 53 | self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript 54 | self.padding = to_2tuple(padding_val) 55 | self.dilation = to_2tuple(dilation) 56 | self.groups = groups 57 | self.num_experts = num_experts 58 | 59 | self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size 60 | weight_num_param = 1 61 | for wd in self.weight_shape: 62 | weight_num_param *= wd 63 | self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param)) 64 | 65 | if bias: 66 | self.bias_shape = (self.out_channels,) 67 | self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels)) 68 | else: 69 | self.register_parameter('bias', None) 70 | 71 | self.reset_parameters() 72 | 73 | def reset_parameters(self): 74 | init_weight = get_condconv_initializer( 75 | partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape) 76 | init_weight(self.weight) 77 | if self.bias is not None: 78 | fan_in = np.prod(self.weight_shape[1:]) 79 | bound = 1 / math.sqrt(fan_in) 80 | init_bias = get_condconv_initializer( 81 | partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape) 82 | init_bias(self.bias) 83 | 84 | def forward(self, x, routing_weights): 85 | B, C, H, W = x.shape 86 | weight = torch.matmul(routing_weights, self.weight) 87 | new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size 88 | weight = weight.view(new_weight_shape) 89 | bias = None 90 | if self.bias is not None: 91 | bias = torch.matmul(routing_weights, self.bias) 92 | bias = bias.view(B * self.out_channels) 93 | # move batch elements with channels so each batch element can be efficiently convolved with separate kernel 94 | # reshape instead of view to work with channels_last input 95 | x = x.reshape(1, B * C, H, W) 96 | if self.dynamic_padding: 97 | out = conv2d_same( 98 | x, weight, bias, stride=self.stride, padding=self.padding, 99 | dilation=self.dilation, groups=self.groups * B) 100 | else: 101 | out = F.conv2d( 102 | x, weight, bias, stride=self.stride, padding=self.padding, 103 | dilation=self.dilation, groups=self.groups * B) 104 | out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1]) 105 | 106 | # Literal port (from TF definition) 107 | # x = torch.split(x, 1, 0) 108 | # weight = torch.split(weight, 1, 0) 109 | # if self.bias is not None: 110 | # bias = torch.matmul(routing_weights, self.bias) 111 | # bias = torch.split(bias, 1, 0) 112 | # else: 113 | # bias = [None] * B 114 | # out = [] 115 | # for xi, wi, bi in zip(x, weight, bias): 116 | # wi = wi.view(*self.weight_shape) 117 | # if bi is not None: 118 | # bi = bi.view(*self.bias_shape) 119 | # out.append(self.conv_fn( 120 | # xi, wi, bi, stride=self.stride, padding=self.padding, 121 | # dilation=self.dilation, groups=self.groups)) 122 | # out = torch.cat(out, 0) 123 | return out 124 | -------------------------------------------------------------------------------- /lib/models_timm/layers/selective_kernel.py: -------------------------------------------------------------------------------- 1 | """ Selective Kernel Convolution/Attention 2 | 3 | Paper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586) 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import torch 8 | from torch import nn as nn 9 | 10 | from .conv_bn_act import ConvNormActAa 11 | from .helpers import make_divisible 12 | from .trace_utils import _assert 13 | 14 | 15 | def _kernel_valid(k): 16 | if isinstance(k, (list, tuple)): 17 | for ki in k: 18 | return _kernel_valid(ki) 19 | assert k >= 3 and k % 2 20 | 21 | 22 | class SelectiveKernelAttn(nn.Module): 23 | def __init__(self, channels, num_paths=2, attn_channels=32, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): 24 | """ Selective Kernel Attention Module 25 | 26 | Selective Kernel attention mechanism factored out into its own module. 27 | 28 | """ 29 | super(SelectiveKernelAttn, self).__init__() 30 | self.num_paths = num_paths 31 | self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False) 32 | self.bn = norm_layer(attn_channels) 33 | self.act = act_layer(inplace=True) 34 | self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False) 35 | 36 | def forward(self, x): 37 | _assert(x.shape[1] == self.num_paths, '') 38 | x = x.sum(1).mean((2, 3), keepdim=True) 39 | x = self.fc_reduce(x) 40 | x = self.bn(x) 41 | x = self.act(x) 42 | x = self.fc_select(x) 43 | B, C, H, W = x.shape 44 | x = x.view(B, self.num_paths, C // self.num_paths, H, W) 45 | x = torch.softmax(x, dim=1) 46 | return x 47 | 48 | 49 | class SelectiveKernel(nn.Module): 50 | 51 | def __init__(self, in_channels, out_channels=None, kernel_size=None, stride=1, dilation=1, groups=1, 52 | rd_ratio=1./16, rd_channels=None, rd_divisor=8, keep_3x3=True, split_input=True, 53 | act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_layer=None): 54 | """ Selective Kernel Convolution Module 55 | 56 | As described in Selective Kernel Networks (https://arxiv.org/abs/1903.06586) with some modifications. 57 | 58 | Largest change is the input split, which divides the input channels across each convolution path, this can 59 | be viewed as a grouping of sorts, but the output channel counts expand to the module level value. This keeps 60 | the parameter count from ballooning when the convolutions themselves don't have groups, but still provides 61 | a noteworthy increase in performance over similar param count models without this attention layer. -Ross W 62 | 63 | Args: 64 | in_channels (int): module input (feature) channel count 65 | out_channels (int): module output (feature) channel count 66 | kernel_size (int, list): kernel size for each convolution branch 67 | stride (int): stride for convolutions 68 | dilation (int): dilation for module as a whole, impacts dilation of each branch 69 | groups (int): number of groups for each branch 70 | rd_ratio (int, float): reduction factor for attention features 71 | keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations 72 | split_input (bool): split input channels evenly across each convolution branch, keeps param count lower, 73 | can be viewed as grouping by path, output expands to module out_channels count 74 | act_layer (nn.Module): activation layer to use 75 | norm_layer (nn.Module): batchnorm/norm layer to use 76 | aa_layer (nn.Module): anti-aliasing module 77 | drop_layer (nn.Module): spatial drop module in convs (drop block, etc) 78 | """ 79 | super(SelectiveKernel, self).__init__() 80 | out_channels = out_channels or in_channels 81 | kernel_size = kernel_size or [3, 5] # default to one 3x3 and one 5x5 branch. 5x5 -> 3x3 + dilation 82 | _kernel_valid(kernel_size) 83 | if not isinstance(kernel_size, list): 84 | kernel_size = [kernel_size] * 2 85 | if keep_3x3: 86 | dilation = [dilation * (k - 1) // 2 for k in kernel_size] 87 | kernel_size = [3] * len(kernel_size) 88 | else: 89 | dilation = [dilation] * len(kernel_size) 90 | self.num_paths = len(kernel_size) 91 | self.in_channels = in_channels 92 | self.out_channels = out_channels 93 | self.split_input = split_input 94 | if self.split_input: 95 | assert in_channels % self.num_paths == 0 96 | in_channels = in_channels // self.num_paths 97 | groups = min(out_channels, groups) 98 | 99 | conv_kwargs = dict( 100 | stride=stride, groups=groups, act_layer=act_layer, norm_layer=norm_layer, 101 | aa_layer=aa_layer, drop_layer=drop_layer) 102 | self.paths = nn.ModuleList([ 103 | ConvNormActAa(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs) 104 | for k, d in zip(kernel_size, dilation)]) 105 | 106 | attn_channels = rd_channels or make_divisible(out_channels * rd_ratio, divisor=rd_divisor) 107 | self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels) 108 | 109 | def forward(self, x): 110 | if self.split_input: 111 | x_split = torch.split(x, self.in_channels // self.num_paths, 1) 112 | x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)] 113 | else: 114 | x_paths = [op(x) for op in self.paths] 115 | x = torch.stack(x_paths, dim=1) 116 | x_attn = self.attn(x) 117 | x = x * x_attn 118 | x = torch.sum(x, dim=1) 119 | return x 120 | -------------------------------------------------------------------------------- /lib/models_timm/layers/create_act.py: -------------------------------------------------------------------------------- 1 | """ Activation Factory 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | from typing import Union, Callable, Type 5 | 6 | from .activations import * 7 | from .activations_jit import * 8 | from .activations_me import * 9 | from .config import is_exportable, is_scriptable, is_no_jit 10 | 11 | # PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7. 12 | # Also hardsigmoid, hardswish, and soon mish. This code will use native version if present. 13 | # Eventually, the custom SiLU, Mish, Hard*, layers will be removed and only native variants will be used. 14 | _has_silu = 'silu' in dir(torch.nn.functional) 15 | _has_hardswish = 'hardswish' in dir(torch.nn.functional) 16 | _has_hardsigmoid = 'hardsigmoid' in dir(torch.nn.functional) 17 | _has_mish = 'mish' in dir(torch.nn.functional) 18 | 19 | 20 | _ACT_FN_DEFAULT = dict( 21 | silu=F.silu if _has_silu else swish, 22 | swish=F.silu if _has_silu else swish, 23 | mish=F.mish if _has_mish else mish, 24 | relu=F.relu, 25 | relu6=F.relu6, 26 | leaky_relu=F.leaky_relu, 27 | elu=F.elu, 28 | celu=F.celu, 29 | selu=F.selu, 30 | gelu=gelu, 31 | sigmoid=sigmoid, 32 | tanh=tanh, 33 | hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid, 34 | hard_swish=F.hardswish if _has_hardswish else hard_swish, 35 | hard_mish=hard_mish, 36 | ) 37 | 38 | _ACT_FN_JIT = dict( 39 | silu=F.silu if _has_silu else swish_jit, 40 | swish=F.silu if _has_silu else swish_jit, 41 | mish=F.mish if _has_mish else mish_jit, 42 | hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_jit, 43 | hard_swish=F.hardswish if _has_hardswish else hard_swish_jit, 44 | hard_mish=hard_mish_jit 45 | ) 46 | 47 | _ACT_FN_ME = dict( 48 | silu=F.silu if _has_silu else swish_me, 49 | swish=F.silu if _has_silu else swish_me, 50 | mish=F.mish if _has_mish else mish_me, 51 | hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_me, 52 | hard_swish=F.hardswish if _has_hardswish else hard_swish_me, 53 | hard_mish=hard_mish_me, 54 | ) 55 | 56 | _ACT_FNS = (_ACT_FN_ME, _ACT_FN_JIT, _ACT_FN_DEFAULT) 57 | for a in _ACT_FNS: 58 | a.setdefault('hardsigmoid', a.get('hard_sigmoid')) 59 | a.setdefault('hardswish', a.get('hard_swish')) 60 | 61 | 62 | _ACT_LAYER_DEFAULT = dict( 63 | silu=nn.SiLU if _has_silu else Swish, 64 | swish=nn.SiLU if _has_silu else Swish, 65 | mish=nn.Mish if _has_mish else Mish, 66 | relu=nn.ReLU, 67 | relu6=nn.ReLU6, 68 | leaky_relu=nn.LeakyReLU, 69 | elu=nn.ELU, 70 | prelu=PReLU, 71 | celu=nn.CELU, 72 | selu=nn.SELU, 73 | gelu=GELU, 74 | sigmoid=Sigmoid, 75 | tanh=Tanh, 76 | hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoid, 77 | hard_swish=nn.Hardswish if _has_hardswish else HardSwish, 78 | hard_mish=HardMish, 79 | ) 80 | 81 | _ACT_LAYER_JIT = dict( 82 | silu=nn.SiLU if _has_silu else SwishJit, 83 | swish=nn.SiLU if _has_silu else SwishJit, 84 | mish=nn.Mish if _has_mish else MishJit, 85 | hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidJit, 86 | hard_swish=nn.Hardswish if _has_hardswish else HardSwishJit, 87 | hard_mish=HardMishJit 88 | ) 89 | 90 | _ACT_LAYER_ME = dict( 91 | silu=nn.SiLU if _has_silu else SwishMe, 92 | swish=nn.SiLU if _has_silu else SwishMe, 93 | mish=nn.Mish if _has_mish else MishMe, 94 | hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidMe, 95 | hard_swish=nn.Hardswish if _has_hardswish else HardSwishMe, 96 | hard_mish=HardMishMe, 97 | ) 98 | 99 | _ACT_LAYERS = (_ACT_LAYER_ME, _ACT_LAYER_JIT, _ACT_LAYER_DEFAULT) 100 | for a in _ACT_LAYERS: 101 | a.setdefault('hardsigmoid', a.get('hard_sigmoid')) 102 | a.setdefault('hardswish', a.get('hard_swish')) 103 | 104 | 105 | def get_act_fn(name: Union[Callable, str] = 'relu'): 106 | """ Activation Function Factory 107 | Fetching activation fns by name with this function allows export or torch script friendly 108 | functions to be returned dynamically based on current config. 109 | """ 110 | if not name: 111 | return None 112 | if isinstance(name, Callable): 113 | return name 114 | if not (is_no_jit() or is_exportable() or is_scriptable()): 115 | # If not exporting or scripting the model, first look for a memory-efficient version with 116 | # custom autograd, then fallback 117 | if name in _ACT_FN_ME: 118 | return _ACT_FN_ME[name] 119 | if not (is_no_jit() or is_exportable()): 120 | if name in _ACT_FN_JIT: 121 | return _ACT_FN_JIT[name] 122 | return _ACT_FN_DEFAULT[name] 123 | 124 | 125 | def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'): 126 | """ Activation Layer Factory 127 | Fetching activation layers by name with this function allows export or torch script friendly 128 | functions to be returned dynamically based on current config. 129 | """ 130 | if not name: 131 | return None 132 | if not isinstance(name, str): 133 | # callable, module, etc 134 | return name 135 | if not (is_no_jit() or is_exportable() or is_scriptable()): 136 | if name in _ACT_LAYER_ME: 137 | return _ACT_LAYER_ME[name] 138 | if not (is_no_jit() or is_exportable()): 139 | if name in _ACT_LAYER_JIT: 140 | return _ACT_LAYER_JIT[name] 141 | return _ACT_LAYER_DEFAULT[name] 142 | 143 | 144 | def create_act_layer(name: Union[nn.Module, str], inplace=None, **kwargs): 145 | act_layer = get_act_layer(name) 146 | if act_layer is None: 147 | return None 148 | if inplace is None: 149 | return act_layer(**kwargs) 150 | try: 151 | return act_layer(inplace=inplace, **kwargs) 152 | except TypeError: 153 | # recover if act layer doesn't have inplace arg 154 | return act_layer(**kwargs) 155 | -------------------------------------------------------------------------------- /lib/models_timm/layers/std_conv.py: -------------------------------------------------------------------------------- 1 | """ Convolution with Weight Standardization (StdConv and ScaledStdConv) 2 | 3 | StdConv: 4 | @article{weightstandardization, 5 | author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Yuille}, 6 | title = {Weight Standardization}, 7 | journal = {arXiv preprint arXiv:1903.10520}, 8 | year = {2019}, 9 | } 10 | Code: https://github.com/joe-siyuan-qiao/WeightStandardization 11 | 12 | ScaledStdConv: 13 | Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` 14 | - https://arxiv.org/abs/2101.08692 15 | Official Deepmind JAX code: https://github.com/deepmind/deepmind-research/tree/master/nfnets 16 | 17 | Hacked together by / copyright Ross Wightman, 2021. 18 | """ 19 | import torch 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | 23 | from .padding import get_padding, get_padding_value, pad_same 24 | 25 | 26 | class StdConv2d(nn.Conv2d): 27 | """Conv2d with Weight Standardization. Used for BiT ResNet-V2 models. 28 | 29 | Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` - 30 | https://arxiv.org/abs/1903.10520v2 31 | """ 32 | def __init__( 33 | self, in_channel, out_channels, kernel_size, stride=1, padding=None, 34 | dilation=1, groups=1, bias=False, eps=1e-6): 35 | if padding is None: 36 | padding = get_padding(kernel_size, stride, dilation) 37 | super().__init__( 38 | in_channel, out_channels, kernel_size, stride=stride, 39 | padding=padding, dilation=dilation, groups=groups, bias=bias) 40 | self.eps = eps 41 | 42 | def forward(self, x): 43 | weight = F.batch_norm( 44 | self.weight.reshape(1, self.out_channels, -1), None, None, 45 | training=True, momentum=0., eps=self.eps).reshape_as(self.weight) 46 | x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 47 | return x 48 | 49 | 50 | class StdConv2dSame(nn.Conv2d): 51 | """Conv2d with Weight Standardization. TF compatible SAME padding. Used for ViT Hybrid model. 52 | 53 | Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` - 54 | https://arxiv.org/abs/1903.10520v2 55 | """ 56 | def __init__( 57 | self, in_channel, out_channels, kernel_size, stride=1, padding='SAME', 58 | dilation=1, groups=1, bias=False, eps=1e-6): 59 | padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation) 60 | super().__init__( 61 | in_channel, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, 62 | groups=groups, bias=bias) 63 | self.same_pad = is_dynamic 64 | self.eps = eps 65 | 66 | def forward(self, x): 67 | if self.same_pad: 68 | x = pad_same(x, self.kernel_size, self.stride, self.dilation) 69 | weight = F.batch_norm( 70 | self.weight.reshape(1, self.out_channels, -1), None, None, 71 | training=True, momentum=0., eps=self.eps).reshape_as(self.weight) 72 | x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 73 | return x 74 | 75 | 76 | class ScaledStdConv2d(nn.Conv2d): 77 | """Conv2d layer with Scaled Weight Standardization. 78 | 79 | Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - 80 | https://arxiv.org/abs/2101.08692 81 | 82 | NOTE: the operations used in this impl differ slightly from the DeepMind Haiku impl. The impact is minor. 83 | """ 84 | 85 | def __init__( 86 | self, in_channels, out_channels, kernel_size, stride=1, padding=None, 87 | dilation=1, groups=1, bias=True, gamma=1.0, eps=1e-6, gain_init=1.0): 88 | if padding is None: 89 | padding = get_padding(kernel_size, stride, dilation) 90 | super().__init__( 91 | in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, 92 | groups=groups, bias=bias) 93 | self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init)) 94 | self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in) 95 | self.eps = eps 96 | 97 | def forward(self, x): 98 | weight = F.batch_norm( 99 | self.weight.reshape(1, self.out_channels, -1), None, None, 100 | weight=(self.gain * self.scale).view(-1), 101 | training=True, momentum=0., eps=self.eps).reshape_as(self.weight) 102 | return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 103 | 104 | 105 | class ScaledStdConv2dSame(nn.Conv2d): 106 | """Conv2d layer with Scaled Weight Standardization and Tensorflow-like SAME padding support 107 | 108 | Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - 109 | https://arxiv.org/abs/2101.08692 110 | 111 | NOTE: the operations used in this impl differ slightly from the DeepMind Haiku impl. The impact is minor. 112 | """ 113 | 114 | def __init__( 115 | self, in_channels, out_channels, kernel_size, stride=1, padding='SAME', 116 | dilation=1, groups=1, bias=True, gamma=1.0, eps=1e-6, gain_init=1.0): 117 | padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation) 118 | super().__init__( 119 | in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, 120 | groups=groups, bias=bias) 121 | self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init)) 122 | self.scale = gamma * self.weight[0].numel() ** -0.5 123 | self.same_pad = is_dynamic 124 | self.eps = eps 125 | 126 | def forward(self, x): 127 | if self.same_pad: 128 | x = pad_same(x, self.kernel_size, self.stride, self.dilation) 129 | weight = F.batch_norm( 130 | self.weight.reshape(1, self.out_channels, -1), None, None, 131 | weight=(self.gain * self.scale).view(-1), 132 | training=True, momentum=0., eps=self.eps).reshape_as(self.weight) 133 | return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 134 | -------------------------------------------------------------------------------- /lib/gcn_lib/torch_edge.py: -------------------------------------------------------------------------------- 1 | # 2022.06.17-Changed for building ViG model 2 | # Huawei Technologies Co., Ltd. 3 | import math 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | 8 | 9 | def pairwise_distance(x): 10 | """ 11 | Compute pairwise distance of a point cloud. 12 | Args: 13 | x: tensor (batch_size, num_points, num_dims) 14 | Returns: 15 | pairwise distance: (batch_size, num_points, num_points) 16 | """ 17 | with torch.no_grad(): 18 | x_inner = -2*torch.matmul(x, x.transpose(2, 1)) 19 | x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True) 20 | return x_square + x_inner + x_square.transpose(2, 1) 21 | 22 | 23 | def part_pairwise_distance(x, start_idx=0, end_idx=1): 24 | """ 25 | Compute pairwise distance of a point cloud. 26 | Args: 27 | x: tensor (batch_size, num_points, num_dims) 28 | Returns: 29 | pairwise distance: (batch_size, num_points, num_points) 30 | """ 31 | with torch.no_grad(): 32 | x_part = x[:, start_idx:end_idx] 33 | x_square_part = torch.sum(torch.mul(x_part, x_part), dim=-1, keepdim=True) 34 | x_inner = -2*torch.matmul(x_part, x.transpose(2, 1)) 35 | x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True) 36 | return x_square_part + x_inner + x_square.transpose(2, 1) 37 | 38 | 39 | def xy_pairwise_distance(x, y): 40 | """ 41 | Compute pairwise distance of a point cloud. 42 | Args: 43 | x: tensor (batch_size, num_points, num_dims) 44 | Returns: 45 | pairwise distance: (batch_size, num_points, num_points) 46 | """ 47 | with torch.no_grad(): 48 | xy_inner = -2*torch.matmul(x, y.transpose(2, 1)) 49 | x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True) 50 | y_square = torch.sum(torch.mul(y, y), dim=-1, keepdim=True) 51 | return x_square + xy_inner + y_square.transpose(2, 1) 52 | 53 | 54 | def dense_knn_matrix(x, k=16, relative_pos=None): 55 | """Get KNN based on the pairwise distance. 56 | Args: 57 | x: (batch_size, num_dims, num_points, 1) 58 | k: int 59 | Returns: 60 | nearest neighbors: (batch_size, num_points, k) (batch_size, num_points, k) 61 | """ 62 | with torch.no_grad(): 63 | x = x.transpose(2, 1).squeeze(-1) 64 | batch_size, n_points, n_dims = x.shape 65 | ### memory efficient implementation ### 66 | n_part = 10000 67 | if n_points > n_part: 68 | nn_idx_list = [] 69 | groups = math.ceil(n_points / n_part) 70 | for i in range(groups): 71 | start_idx = n_part * i 72 | end_idx = min(n_points, n_part * (i + 1)) 73 | dist = part_pairwise_distance(x.detach(), start_idx, end_idx) 74 | if relative_pos is not None: 75 | dist += relative_pos[:, start_idx:end_idx] 76 | _, nn_idx_part = torch.topk(-dist, k=k) 77 | nn_idx_list += [nn_idx_part] 78 | nn_idx = torch.cat(nn_idx_list, dim=1) 79 | else: 80 | dist = pairwise_distance(x.detach()) 81 | if relative_pos is not None: 82 | dist += relative_pos 83 | _, nn_idx = torch.topk(-dist, k=k) # b, n, k 84 | ###### 85 | center_idx = torch.arange(0, n_points, device=x.device).repeat(batch_size, k, 1).transpose(2, 1) 86 | return torch.stack((nn_idx, center_idx), dim=0) 87 | 88 | 89 | def xy_dense_knn_matrix(x, y, k=16, relative_pos=None): 90 | """Get KNN based on the pairwise distance. 91 | Args: 92 | x: (batch_size, num_dims, num_points, 1) 93 | k: int 94 | Returns: 95 | nearest neighbors: (batch_size, num_points, k) (batch_size, num_points, k) 96 | """ 97 | with torch.no_grad(): 98 | x = x.transpose(2, 1).squeeze(-1) 99 | y = y.transpose(2, 1).squeeze(-1) 100 | batch_size, n_points, n_dims = x.shape 101 | dist = xy_pairwise_distance(x.detach(), y.detach()) 102 | if relative_pos is not None: 103 | dist += relative_pos 104 | #print(dist.shape, k) 105 | _, nn_idx = torch.topk(-dist, k=k) 106 | center_idx = torch.arange(0, n_points, device=x.device).repeat(batch_size, k, 1).transpose(2, 1) 107 | return torch.stack((nn_idx, center_idx), dim=0) 108 | 109 | 110 | class DenseDilated(nn.Module): 111 | """ 112 | Find dilated neighbor from neighbor list 113 | 114 | edge_index: (2, batch_size, num_points, k) 115 | """ 116 | def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0): 117 | super(DenseDilated, self).__init__() 118 | self.dilation = dilation 119 | self.stochastic = stochastic 120 | self.epsilon = epsilon 121 | self.k = k 122 | 123 | def forward(self, edge_index): 124 | if self.stochastic: 125 | if torch.rand(1) < self.epsilon and self.training: 126 | num = self.k * self.dilation 127 | randnum = torch.randperm(num)[:self.k] 128 | edge_index = edge_index[:, :, :, randnum] 129 | else: 130 | edge_index = edge_index[:, :, :, ::self.dilation] 131 | else: 132 | edge_index = edge_index[:, :, :, ::self.dilation] 133 | return edge_index 134 | 135 | 136 | class DenseDilatedKnnGraph(nn.Module): 137 | """ 138 | Find the neighbors' indices based on dilated knn 139 | """ 140 | def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0): 141 | super(DenseDilatedKnnGraph, self).__init__() 142 | self.dilation = dilation 143 | self.stochastic = stochastic 144 | self.epsilon = epsilon 145 | self.k = k 146 | self._dilated = DenseDilated(k, dilation, stochastic, epsilon) 147 | 148 | def forward(self, x, y=None, relative_pos=None): 149 | if y is not None: 150 | #### normalize 151 | x = F.normalize(x, p=2.0, dim=1) 152 | y = F.normalize(y, p=2.0, dim=1) 153 | #### 154 | edge_index = xy_dense_knn_matrix(x, y, self.k * self.dilation, relative_pos) 155 | else: 156 | #### normalize 157 | x = F.normalize(x, p=2.0, dim=1) 158 | #### 159 | edge_index = dense_knn_matrix(x, self.k * self.dilation, relative_pos) 160 | return self._dilated(edge_index) 161 | -------------------------------------------------------------------------------- /lib/models_timm/layers/lambda_layer.py: -------------------------------------------------------------------------------- 1 | """ Lambda Layer 2 | 3 | Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention` 4 | - https://arxiv.org/abs/2102.08602 5 | 6 | @misc{2102.08602, 7 | Author = {Irwan Bello}, 8 | Title = {LambdaNetworks: Modeling Long-Range Interactions Without Attention}, 9 | Year = {2021}, 10 | } 11 | 12 | Status: 13 | This impl is a WIP. Code snippets in the paper were used as reference but 14 | good chance some details are missing/wrong. 15 | 16 | I've only implemented local lambda conv based pos embeddings. 17 | 18 | For a PyTorch impl that includes other embedding options checkout 19 | https://github.com/lucidrains/lambda-networks 20 | 21 | Hacked together by / Copyright 2021 Ross Wightman 22 | """ 23 | import torch 24 | from torch import nn 25 | import torch.nn.functional as F 26 | 27 | from .helpers import to_2tuple, make_divisible 28 | from .weight_init import trunc_normal_ 29 | 30 | 31 | def rel_pos_indices(size): 32 | size = to_2tuple(size) 33 | pos = torch.stack(torch.meshgrid(torch.arange(size[0]), torch.arange(size[1]))).flatten(1) 34 | rel_pos = pos[:, None, :] - pos[:, :, None] 35 | rel_pos[0] += size[0] - 1 36 | rel_pos[1] += size[1] - 1 37 | return rel_pos # 2, H * W, H * W 38 | 39 | 40 | class LambdaLayer(nn.Module): 41 | """Lambda Layer 42 | 43 | Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention` 44 | - https://arxiv.org/abs/2102.08602 45 | 46 | NOTE: intra-depth parameter 'u' is fixed at 1. It did not appear worth the complexity to add. 47 | 48 | The internal dimensions of the lambda module are controlled via the interaction of several arguments. 49 | * the output dimension of the module is specified by dim_out, which falls back to input dim if not set 50 | * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim 51 | * the query (q) and key (k) dimension are determined by 52 | * dim_head = (dim_out * attn_ratio // num_heads) if dim_head is None 53 | * q = num_heads * dim_head, k = dim_head 54 | * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not set 55 | 56 | Args: 57 | dim (int): input dimension to the module 58 | dim_out (int): output dimension of the module, same as dim if not set 59 | feat_size (Tuple[int, int]): size of input feature_map for relative pos variant H, W 60 | stride (int): output stride of the module, avg pool used if stride == 2 61 | num_heads (int): parallel attention heads. 62 | dim_head (int): dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set 63 | r (int): local lambda convolution radius. Use lambda conv if set, else relative pos if not. (default: 9) 64 | qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0) 65 | qkv_bias (bool): add bias to q, k, and v projections 66 | """ 67 | def __init__( 68 | self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=16, r=9, 69 | qk_ratio=1.0, qkv_bias=False): 70 | super().__init__() 71 | dim_out = dim_out or dim 72 | assert dim_out % num_heads == 0, ' should be divided by num_heads' 73 | self.dim_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads 74 | self.num_heads = num_heads 75 | self.dim_v = dim_out // num_heads 76 | 77 | self.qkv = nn.Conv2d( 78 | dim, 79 | num_heads * self.dim_qk + self.dim_qk + self.dim_v, 80 | kernel_size=1, bias=qkv_bias) 81 | self.norm_q = nn.BatchNorm2d(num_heads * self.dim_qk) 82 | self.norm_v = nn.BatchNorm2d(self.dim_v) 83 | 84 | if r is not None: 85 | # local lambda convolution for pos 86 | self.conv_lambda = nn.Conv3d(1, self.dim_qk, (r, r, 1), padding=(r // 2, r // 2, 0)) 87 | self.pos_emb = None 88 | self.rel_pos_indices = None 89 | else: 90 | # relative pos embedding 91 | assert feat_size is not None 92 | feat_size = to_2tuple(feat_size) 93 | rel_size = [2 * s - 1 for s in feat_size] 94 | self.conv_lambda = None 95 | self.pos_emb = nn.Parameter(torch.zeros(rel_size[0], rel_size[1], self.dim_qk)) 96 | self.register_buffer('rel_pos_indices', rel_pos_indices(feat_size), persistent=False) 97 | 98 | self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() 99 | 100 | self.reset_parameters() 101 | 102 | def reset_parameters(self): 103 | trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) # fan-in 104 | if self.conv_lambda is not None: 105 | trunc_normal_(self.conv_lambda.weight, std=self.dim_qk ** -0.5) 106 | if self.pos_emb is not None: 107 | trunc_normal_(self.pos_emb, std=.02) 108 | 109 | def forward(self, x): 110 | B, C, H, W = x.shape 111 | M = H * W 112 | qkv = self.qkv(x) 113 | q, k, v = torch.split(qkv, [ 114 | self.num_heads * self.dim_qk, self.dim_qk, self.dim_v], dim=1) 115 | q = self.norm_q(q).reshape(B, self.num_heads, self.dim_qk, M).transpose(-1, -2) # B, num_heads, M, K 116 | v = self.norm_v(v).reshape(B, self.dim_v, M).transpose(-1, -2) # B, M, V 117 | k = F.softmax(k.reshape(B, self.dim_qk, M), dim=-1) # B, K, M 118 | 119 | content_lam = k @ v # B, K, V 120 | content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V 121 | 122 | if self.pos_emb is None: 123 | position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K 124 | position_lam = position_lam.reshape(B, 1, self.dim_qk, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V 125 | else: 126 | # FIXME relative pos embedding path not fully verified 127 | pos_emb = self.pos_emb[self.rel_pos_indices[0], self.rel_pos_indices[1]].expand(B, -1, -1, -1) 128 | position_lam = (pos_emb.transpose(-1, -2) @ v.unsqueeze(1)).unsqueeze(1) # B, 1, M, K, V 129 | position_out = (q.unsqueeze(-2) @ position_lam).squeeze(-2) # B, num_heads, M, V 130 | 131 | out = (content_out + position_out).transpose(-1, -2).reshape(B, C, H, W) # B, C (num_heads * V), H, W 132 | out = self.pool(out) 133 | return out 134 | -------------------------------------------------------------------------------- /test_ACDC.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import logging 4 | import argparse 5 | import random 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import DataLoader 9 | import torch.backends.cudnn as cudnn 10 | from tqdm import tqdm 11 | 12 | from utils.utils import test_single_volume 13 | from utils.dataset_ACDC import ACDCdataset, RandomGenerator 14 | from lib.networks import PVT_GCASCADE, MERIT_GCASCADE 15 | 16 | def inference(args, model, testloader, test_save_path=None): 17 | logging.info("{} test iterations per epoch".format(len(testloader))) 18 | model.eval() 19 | metric_list = 0.0 20 | with torch.no_grad(): 21 | for i_batch, sampled_batch in tqdm(enumerate(testloader)): 22 | h, w = sampled_batch["image"].size()[2:] 23 | image, label, case_name = sampled_batch["image"], sampled_batch["label"], sampled_batch['case_name'][0] 24 | metric_i = test_single_volume(image, label, model, classes=args.num_classes, patch_size=[args.img_size, args.img_size], 25 | test_save_path=test_save_path, case=case_name, z_spacing=args.z_spacing) 26 | metric_list += np.array(metric_i) 27 | logging.info('idx %d case %s mean_dice %f mean_hd95 %f, mean_jacard %f mean_asd %f' % (i_batch, case_name, np.mean(metric_i, axis=0)[0], np.mean(metric_i, axis=0)[1], np.mean(metric_i, axis=0)[2], np.mean(metric_i, axis=0)[3])) 28 | metric_list = metric_list / len(testloader) 29 | for i in range(1, args.num_classes): 30 | logging.info('Mean class (%d) mean_dice %f mean_hd95 %f, mean_jacard %f mean_asd %f' % (i, metric_list[i-1][0], metric_list[i-1][1], metric_list[i-1][2], metric_list[i-1][3])) 31 | performance = np.mean(metric_list, axis=0)[0] 32 | mean_hd95 = np.mean(metric_list, axis=0)[1] 33 | mean_jacard = np.mean(metric_list, axis=0)[2] 34 | mean_asd = np.mean(metric_list, axis=0)[3] 35 | logging.info('Testing performance in best val model: mean_dice : %f mean_hd95 : %f, mean_jacard : %f mean_asd : %f' % (performance, mean_hd95, mean_jacard, mean_asd)) 36 | logging.info("Testing Finished!") 37 | return performance, mean_hd95, mean_jacard, mean_asd 38 | 39 | if __name__ == "__main__": 40 | 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument('--encoder', default='PVT', help='Name of encoder: PVT or MERIT') 43 | parser.add_argument('--skip_aggregation', default='additive', help='Type of skip-aggregation: additive or concatenation') 44 | parser.add_argument("--batch_size", default=12, help="batch size") 45 | parser.add_argument("--lr", default=0.0001, help="learning rate") 46 | parser.add_argument("--max_epochs", default=400) 47 | parser.add_argument("--img_size", default=224) 48 | parser.add_argument("--save_path", default="./model_pth/ACDC") 49 | parser.add_argument("--n_gpu", default=1) 50 | parser.add_argument("--checkpoint", default=None) 51 | parser.add_argument("--list_dir", default="./data/ACDC/lists_ACDC") 52 | parser.add_argument("--root_dir", default="./data/ACDC/") 53 | parser.add_argument("--volume_path", default="./data/ACDC/test") 54 | parser.add_argument("--z_spacing", default=10) 55 | parser.add_argument("--num_classes", default=4) 56 | parser.add_argument('--test_save_dir', default='./predictions', help='saving prediction as nii!') 57 | parser.add_argument('--deterministic', type=int, default=1, 58 | help='whether use deterministic training') 59 | parser.add_argument('--seed', type=int, 60 | default=2222, help='random seed') 61 | args = parser.parse_args() 62 | 63 | if not args.deterministic: 64 | cudnn.benchmark = True 65 | cudnn.deterministic = False 66 | else: 67 | cudnn.benchmark = False 68 | cudnn.deterministic = True 69 | 70 | random.seed(args.seed) 71 | np.random.seed(args.seed) 72 | torch.manual_seed(args.seed) 73 | torch.cuda.manual_seed(args.seed) 74 | 75 | args.is_pretrain = True 76 | args.exp = 'PVT_GCASCADE_MUTATION_w3_7_Run1_' + str(args.img_size) 77 | snapshot_path = "{}/{}/{}".format(args.save_path, args.exp, 'PVT_GCASCADE_MUTATION_w3_7_Run1') 78 | snapshot_path = snapshot_path + '_pretrain' if args.is_pretrain else snapshot_path 79 | snapshot_path = snapshot_path + '_epo' +str(args.max_epochs) if args.max_epochs != 30 else snapshot_path 80 | snapshot_path = snapshot_path+'_bs'+str(args.batch_size) 81 | snapshot_path = snapshot_path + '_lr' + str(args.lr) if args.lr != 0.01 else snapshot_path 82 | snapshot_path = snapshot_path + '_'+str(args.img_size) 83 | snapshot_path = snapshot_path + '_s'+str(args.seed) if args.seed!=1234 else snapshot_path 84 | 85 | if args.encoder=='PVT': 86 | net = PVT_GCASCADE(n_class=args.num_classes, img_size=args.img_size, k=11, padding=5, conv='mr', gcb_act='gelu', skip_aggregation=args.skip_aggregation).cuda() 87 | elif args.encoder=='MERIT': 88 | net = MERIT_GCASCADE(n_class=args.num_classes, img_size_s1=(args.img_size,args.img_size), img_size_s2=(224,224), k=11, padding=5, conv='mr', gcb_act='gelu', skip_aggregation=args.skip_aggregation).cuda() 89 | else: 90 | print('Implementation not found for this encoder. Exiting!') 91 | sys.exit() 92 | 93 | snapshot = os.path.join(snapshot_path, 'best.pth') 94 | if not os.path.exists(snapshot): snapshot = snapshot.replace('best', 'epoch_'+str(args.max_epochs-1)) 95 | net.load_state_dict(torch.load(snapshot)) 96 | snapshot_name = snapshot_path.split('/')[-1] 97 | 98 | log_folder = 'test_log/test_log_' + args.exp 99 | os.makedirs(log_folder, exist_ok=True) 100 | logging.basicConfig(filename=log_folder + '/'+snapshot_name+".txt", level=logging.INFO, format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 101 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 102 | logging.info(str(args)) 103 | logging.info(snapshot_name) 104 | 105 | args.test_save_dir = os.path.join(snapshot_path, args.test_save_dir) 106 | test_save_path = os.path.join(args.test_save_dir, args.exp, snapshot_name) 107 | os.makedirs(test_save_path, exist_ok=True) 108 | 109 | 110 | db_test =ACDCdataset(base_dir=args.volume_path,list_dir=args.list_dir, split="test") 111 | testloader = DataLoader(db_test, batch_size=1, shuffle=False) 112 | 113 | results = inference(args, net, testloader, test_save_path) 114 | 115 | 116 | --------------------------------------------------------------------------------