├── LICENSE.md ├── README.md ├── datasets ├── __pycache__ │ └── datasets_synapse.cpython-38.pyc └── datasets_synapse.py ├── images ├── MSA2Net-1.png ├── MSA2Net-git.png ├── MSA2Net.pdf ├── isic-git.PNG ├── isic-results-git.png ├── isic2018-1.png ├── isic2018-refcs.PNG ├── isic2018.pdf ├── isic2018_results.PNG ├── synapse-1.png ├── synapse-git.PNG ├── synapse-main-results-git.png ├── synapse-refcs.PNG ├── synapse.pdf └── synapse_results.PNG ├── lists └── lists_Synapse │ ├── all.lst │ ├── test_vol.txt │ └── train.txt ├── networks ├── masag.py ├── merit_lib │ ├── decoders.py │ ├── maxxvit_4out.py │ ├── models_timm │ │ ├── convnext.py │ │ ├── factory.py │ │ ├── features.py │ │ ├── fx_features.py │ │ ├── helpers.py │ │ ├── hub.py │ │ ├── layers │ │ │ ├── __init__.py │ │ │ ├── activations.py │ │ │ ├── activations_jit.py │ │ │ ├── activations_me.py │ │ │ ├── adaptive_avgmax_pool.py │ │ │ ├── attention_pool2d.py │ │ │ ├── blur_pool.py │ │ │ ├── bottleneck_attn.py │ │ │ ├── cbam.py │ │ │ ├── classifier.py │ │ │ ├── cond_conv2d.py │ │ │ ├── config.py │ │ │ ├── conv2d_same.py │ │ │ ├── conv_bn_act.py │ │ │ ├── create_act.py │ │ │ ├── create_attn.py │ │ │ ├── create_conv2d.py │ │ │ ├── create_norm.py │ │ │ ├── create_norm_act.py │ │ │ ├── drop.py │ │ │ ├── eca.py │ │ │ ├── evo_norm.py │ │ │ ├── fast_norm.py │ │ │ ├── filter_response_norm.py │ │ │ ├── gather_excite.py │ │ │ ├── global_context.py │ │ │ ├── halo_attn.py │ │ │ ├── helpers.py │ │ │ ├── inplace_abn.py │ │ │ ├── lambda_layer.py │ │ │ ├── linear.py │ │ │ ├── median_pool.py │ │ │ ├── mixed_conv2d.py │ │ │ ├── ml_decoder.py │ │ │ ├── mlp.py │ │ │ ├── non_local_attn.py │ │ │ ├── norm.py │ │ │ ├── norm_act.py │ │ │ ├── padding.py │ │ │ ├── patch_embed.py │ │ │ ├── pool2d_same.py │ │ │ ├── pos_embed.py │ │ │ ├── selective_kernel.py │ │ │ ├── separable_conv.py │ │ │ ├── space_to_depth.py │ │ │ ├── split_attn.py │ │ │ ├── split_batchnorm.py │ │ │ ├── squeeze_excite.py │ │ │ ├── std_conv.py │ │ │ ├── test_time_pool.py │ │ │ ├── trace_utils.py │ │ │ └── weight_init.py │ │ ├── levit.py │ │ ├── maxxvit.py │ │ ├── mlp_mixer.py │ │ ├── registry.py │ │ └── vision_transformer_relpos.py │ └── networks.py ├── msa2net.py └── segformer.py ├── requirements.txt ├── test.py ├── train.py ├── trainer.py └── utils.py /LICENSE.md: -------------------------------------------------------------------------------- 1 | 2 | The MIT License (MIT) 3 | 4 | Copyright (c) 2024 Colorless Tsukuru Tazaki 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MSA2Net: Multi-scale Adaptive Attention-guided Network for Medical Image Segmentation
BMVC 2024 2 | [![arXiv](https://img.shields.io/badge/arXiv-2407.21640-b31b1b.svg)](https://arxiv.org/abs/2407.21640) 3 | 4 | Medical image segmentation involves identifying and separating object instances in a medical image to delineate various tissues and structures, a task complicated by the significant variations in size, shape, and density of these features. Convolutional neural networks (CNNs) have traditionally been used for this task but have limitations in capturing long-range dependencies. Transformers, equipped with self-attention mechanisms, aim to address this problem. However, in medical image segmentation it is beneficial to merge both local and global features to effectively integrate feature maps across various scales, capturing both detailed features and broader semantic elements for dealing with variations in structures. In this paper, we introduce MSA2Net, a new deep segmentation framework featuring an expedient design of skip-connections. These connections facilitate feature fusion by dynamically weighting and combining coarse-grained encoder features with fine-grained decoder feature maps. Specifically, we propose a Multi-Scale Adaptive Spatial Attention Gate (MASAG), which dynamically adjusts the receptive field (Local and Global contextual information) to ensure that spatially relevant features are selectively highlighted while minimizing background distractions. Extensive evaluations involving dermatology, and radiological datasets demonstrate that our MSA2Net outperforms state-of-the-art (SOTA) works or matches their performance. 5 | 6 | ![Proposed Model](https://github.com/xmindflow/MSA-2Net/blob/main/images/MSA2Net-git.png?raw=true) 7 | 8 | ## Updates 9 | - **`14.10.2024`** | Accepted as Oral Presentation 🎉 10 | - **`20.07.2024`** | Accepted in BMVC 2024! 🥳 11 | 12 | ## Citation 13 | ``` 14 | @article{kolahi2024msa2net, 15 | title={MSA2Net: Multi-scale Adaptive Attention-guided Network for Medical Image Segmentation}, 16 | author={Kolahi, Sina Ghorbani and Chaharsooghi, Seyed Kamal and Khatibi, Toktam and Bozorgpour, Afshin and Azad, Reza and Heidari, Moein and Hacihaliloglu, Ilker and Merhof, Dorit}, 17 | journal={arXiv preprint arXiv:2407.21640}, 18 | year={2024} 19 | } 20 | ``` 21 | 22 | ## How to use 23 | 24 | ### Requirements 25 | 26 | - Ubuntu 16.04 or higher 27 | - CUDA 11.1 or higher 28 | - Python v3.7 or higher 29 | - Pytorch v1.7 or higher 30 | - Hardware Spec 31 | - A single GPU with 12GB memory or larger capacity (_we used RTX 3090_) 32 | 33 | ``` 34 | einops 35 | h5py 36 | imgaug 37 | fvcore 38 | MedPy 39 | numpy 40 | opencv_python 41 | pandas 42 | PyWavelets 43 | scipy 44 | SimpleITK 45 | tensorboardX 46 | timm 47 | torch 48 | torchvision 49 | tqdm 50 | ``` 51 | 52 | ### Model weights 53 | You can download the pretrained and learned weights in the following. 54 | Dataset | Model | download link 55 | -----------|-------|---------------- 56 | ImageNet | MaxViT small 224 | [Download](https://drive.google.com/file/d/1MaWFYadsYFEROLNvYG8hZAYnlGCkPLaN/view?usp=sharing) 57 | Synapse | MSA2Net | [Download](https://drive.google.com/file/d/19CwKKw18KYNNohFb7dFOOspQqmoQQ679/view?usp=sharing) 58 | 59 | ### Training and Testing 60 | 61 | 1) Download the Synapse dataset from [here](https://drive.google.com/uc?export=download&id=18I9JHH_i0uuEDg-N6d7bfMdf7Ut6bhBi). 62 | 63 | 2) Download the MaxViT small 224x224 pretrained weights [here](https://drive.google.com/file/d/1MaWFYadsYFEROLNvYG8hZAYnlGCkPLaN/view?usp=sharing) and then put it in the 'networks/merit_lib/networks.py' file for initialization. 64 | 3) Run the following code to install the Requirements. 65 | 66 | `pip install -r requirements.txt` 67 | 68 | 4) Run the below code to train the MSA2Net on the synapse dataset. 69 | ```bash 70 | python train.py --root_path ./data/Synapse/train_npz --test_path ./data/Synapse/test_vol_h5 --batch_size 20 --eval_interval 20 --max_epochs 700 71 | ``` 72 | **--root_path** [Train data path] 73 | 74 | **--test_path** [Test data path] 75 | 76 | **--eval_interval** [Evaluation epoch] 77 | 5) Run the below code to test the MSA2Net on the synapse dataset. 78 | ```bash 79 | python test.py --volume_path ./data/Synapse/ --output_dir ./model_out 80 | ``` 81 | **--volume_path** [Root dir of the test data] 82 | 83 | **--output_dir** [Directory of your learned weights] 84 | 85 | ## Experiments 86 | 87 | ### Synapse Dataset 88 |

89 | Synapse images 90 | Synapse results 91 |

92 | 93 | ### ISIC 2018 Dataset 94 |

95 | ISIC images 96 | ISIC results 97 |

98 | 99 | ## References 100 | - DAEFormer [https://github.com/mindflow-institue/DAEFormer] 101 | - D-LKA Net [https://github.com/xmindflow/deformableLKA] 102 | - SKNet [https://github.com/osmr/imgclsmob/tree/68335927ba27f2356093b985bada0bc3989836b1] 103 | - FAT [https://github.com/qhfan/FAT] 104 | -------------------------------------------------------------------------------- /datasets/__pycache__/datasets_synapse.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmindflow/MSA-2Net/209ca72263c903048ae03077dfc023dd14d09c6c/datasets/__pycache__/datasets_synapse.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/datasets_synapse.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import h5py 4 | import numpy as np 5 | import torch 6 | from scipy import ndimage 7 | from scipy.ndimage.interpolation import zoom 8 | from torch.utils.data import Dataset 9 | import imgaug as ia 10 | import imgaug.augmenters as iaa 11 | 12 | 13 | def mask_to_onehot(mask, ): 14 | """ 15 | Converts a segmentation mask (H, W, C) to (H, W, K) where the last dim is a one 16 | hot encoding vector, C is usually 1 or 3, and K is the number of class. 17 | """ 18 | semantic_map = [] 19 | mask = np.expand_dims(mask,-1) 20 | for colour in range (9): 21 | equality = np.equal(mask, colour) 22 | class_map = np.all(equality, axis=-1) 23 | semantic_map.append(class_map) 24 | semantic_map = np.stack(semantic_map, axis=-1).astype(np.int32) 25 | return semantic_map 26 | 27 | def augment_seg(img_aug, img, seg ): 28 | seg = mask_to_onehot(seg) 29 | aug_det = img_aug.to_deterministic() 30 | image_aug = aug_det.augment_image( img ) 31 | 32 | segmap = ia.SegmentationMapOnImage( seg , nb_classes=np.max(seg)+1 , shape=img.shape ) 33 | segmap_aug = aug_det.augment_segmentation_maps( segmap ) 34 | segmap_aug = segmap_aug.get_arr_int() 35 | segmap_aug = np.argmax(segmap_aug, axis=-1).astype(np.float32) 36 | return image_aug , segmap_aug 37 | 38 | def random_rot_flip(image, label): 39 | k = np.random.randint(0, 4) 40 | image = np.rot90(image, k) 41 | label = np.rot90(label, k) 42 | axis = np.random.randint(0, 2) 43 | image = np.flip(image, axis=axis).copy() 44 | label = np.flip(label, axis=axis).copy() 45 | return image, label 46 | 47 | def random_rotate(image, label): 48 | angle = np.random.randint(-20, 20) 49 | image = ndimage.rotate(image, angle, order=0, reshape=False) 50 | label = ndimage.rotate(label, angle, order=0, reshape=False) 51 | return image, label 52 | 53 | 54 | class RandomGenerator(object): 55 | def __init__(self, output_size): 56 | self.output_size = output_size 57 | 58 | def __call__(self, sample): 59 | image, label = sample['image'], sample['label'] 60 | 61 | if random.random() > 0.5: 62 | image, label = random_rot_flip(image, label) 63 | elif random.random() > 0.5: 64 | image, label = random_rotate(image, label) 65 | x, y = image.shape 66 | if x != self.output_size[0] or y != self.output_size[1]: 67 | image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y), order=3) # why not 3? 68 | label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0) 69 | image = torch.from_numpy(image.astype(np.float32)).unsqueeze(0) 70 | label = torch.from_numpy(label.astype(np.float32)) 71 | sample = {'image': image, 'label': label.long()} 72 | return sample 73 | 74 | 75 | class Synapse_dataset(Dataset): 76 | def __init__(self, base_dir, list_dir, split, img_size, norm_x_transform=None, norm_y_transform=None): 77 | self.norm_x_transform = norm_x_transform 78 | self.norm_y_transform = norm_y_transform 79 | self.split = split 80 | self.sample_list = open(os.path.join(list_dir, self.split+'.txt')).readlines() 81 | self.data_dir = base_dir 82 | self.img_size = img_size 83 | 84 | self.img_aug = iaa.SomeOf((0,4),[ 85 | iaa.Flipud(0.5, name="Flipud"), 86 | iaa.Fliplr(0.5, name="Fliplr"), 87 | iaa.AdditiveGaussianNoise(scale=0.005 * 255), 88 | iaa.GaussianBlur(sigma=(1.0)), 89 | iaa.LinearContrast((0.5, 1.5), per_channel=0.5), 90 | iaa.Affine(scale={"x": (0.5, 2), "y": (0.5, 2)}), 91 | iaa.Affine(rotate=(-40, 40)), 92 | iaa.Affine(shear=(-16, 16)), 93 | iaa.PiecewiseAffine(scale=(0.008, 0.03)), 94 | iaa.Affine(translate_percent={"x": (-0.2, 0.2), "y": (-0.2, 0.2)}) 95 | ], random_order=True) 96 | 97 | 98 | 99 | def __len__(self): 100 | return len(self.sample_list) 101 | 102 | def __getitem__(self, idx): 103 | if self.split == "train": 104 | slice_name = self.sample_list[idx].strip('\n') 105 | data_path = os.path.join(self.data_dir, slice_name+'.npz') 106 | data = np.load(data_path) 107 | image, label = data['image'], data['label'] 108 | image,label = augment_seg(self.img_aug, image, label) 109 | x, y = image.shape 110 | if x != self.img_size or y != self.img_size: 111 | image = zoom(image, (self.img_size / x, self.img_size / y), order=3) # why not 3? 112 | label = zoom(label, (self.img_size / x, self.img_size / y), order=0) 113 | 114 | else: 115 | vol_name = self.sample_list[idx].strip('\n') 116 | filepath = self.data_dir + "/{}.npy.h5".format(vol_name) 117 | data = h5py.File(filepath) 118 | image, label = data['image'][:], data['label'][:] 119 | 120 | 121 | 122 | sample = {'image': image, 'label': label} 123 | if self.norm_x_transform is not None: 124 | sample['image'] = self.norm_x_transform(sample['image'].copy()) 125 | if self.norm_y_transform is not None: 126 | sample['label'] = self.norm_y_transform(sample['label'].copy()) 127 | sample['case_name'] = self.sample_list[idx].strip('\n') 128 | return sample -------------------------------------------------------------------------------- /images/MSA2Net-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmindflow/MSA-2Net/209ca72263c903048ae03077dfc023dd14d09c6c/images/MSA2Net-1.png -------------------------------------------------------------------------------- /images/MSA2Net-git.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmindflow/MSA-2Net/209ca72263c903048ae03077dfc023dd14d09c6c/images/MSA2Net-git.png -------------------------------------------------------------------------------- /images/MSA2Net.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmindflow/MSA-2Net/209ca72263c903048ae03077dfc023dd14d09c6c/images/MSA2Net.pdf -------------------------------------------------------------------------------- /images/isic-git.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmindflow/MSA-2Net/209ca72263c903048ae03077dfc023dd14d09c6c/images/isic-git.PNG -------------------------------------------------------------------------------- /images/isic-results-git.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmindflow/MSA-2Net/209ca72263c903048ae03077dfc023dd14d09c6c/images/isic-results-git.png -------------------------------------------------------------------------------- /images/isic2018-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmindflow/MSA-2Net/209ca72263c903048ae03077dfc023dd14d09c6c/images/isic2018-1.png -------------------------------------------------------------------------------- /images/isic2018-refcs.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmindflow/MSA-2Net/209ca72263c903048ae03077dfc023dd14d09c6c/images/isic2018-refcs.PNG -------------------------------------------------------------------------------- /images/isic2018.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmindflow/MSA-2Net/209ca72263c903048ae03077dfc023dd14d09c6c/images/isic2018.pdf -------------------------------------------------------------------------------- /images/isic2018_results.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmindflow/MSA-2Net/209ca72263c903048ae03077dfc023dd14d09c6c/images/isic2018_results.PNG -------------------------------------------------------------------------------- /images/synapse-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmindflow/MSA-2Net/209ca72263c903048ae03077dfc023dd14d09c6c/images/synapse-1.png -------------------------------------------------------------------------------- /images/synapse-git.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmindflow/MSA-2Net/209ca72263c903048ae03077dfc023dd14d09c6c/images/synapse-git.PNG -------------------------------------------------------------------------------- /images/synapse-main-results-git.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmindflow/MSA-2Net/209ca72263c903048ae03077dfc023dd14d09c6c/images/synapse-main-results-git.png -------------------------------------------------------------------------------- /images/synapse-refcs.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmindflow/MSA-2Net/209ca72263c903048ae03077dfc023dd14d09c6c/images/synapse-refcs.PNG -------------------------------------------------------------------------------- /images/synapse.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmindflow/MSA-2Net/209ca72263c903048ae03077dfc023dd14d09c6c/images/synapse.pdf -------------------------------------------------------------------------------- /images/synapse_results.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmindflow/MSA-2Net/209ca72263c903048ae03077dfc023dd14d09c6c/images/synapse_results.PNG -------------------------------------------------------------------------------- /lists/lists_Synapse/all.lst: -------------------------------------------------------------------------------- 1 | case0031.npy.h5 2 | case0007.npy.h5 3 | case0009.npy.h5 4 | case0005.npy.h5 5 | case0026.npy.h5 6 | case0039.npy.h5 7 | case0024.npy.h5 8 | case0034.npy.h5 9 | case0033.npy.h5 10 | case0030.npy.h5 11 | case0023.npy.h5 12 | case0040.npy.h5 13 | case0010.npy.h5 14 | case0021.npy.h5 15 | case0006.npy.h5 16 | case0027.npy.h5 17 | case0028.npy.h5 18 | case0037.npy.h5 19 | case0008.npy.h5 20 | case0022.npy.h5 21 | case0038.npy.h5 22 | case0036.npy.h5 23 | case0032.npy.h5 24 | case0002.npy.h5 25 | case0029.npy.h5 26 | case0003.npy.h5 27 | case0001.npy.h5 28 | case0004.npy.h5 29 | case0025.npy.h5 30 | case0035.npy.h5 -------------------------------------------------------------------------------- /lists/lists_Synapse/test_vol.txt: -------------------------------------------------------------------------------- 1 | case0008 2 | case0022 3 | case0038 4 | case0036 5 | case0032 6 | case0002 7 | case0029 8 | case0003 9 | case0001 10 | case0004 11 | case0025 12 | case0035 -------------------------------------------------------------------------------- /networks/masag.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | import functools 7 | import math 8 | import timm 9 | from timm.models.layers import DropPath, to_2tuple 10 | import einops 11 | from fvcore.nn import FlopCountAnalysis 12 | 13 | 14 | def num_trainable_params(model): 15 | nums = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6 16 | return nums 17 | 18 | class GlobalExtraction(nn.Module): 19 | def __init__(self,dim = None): 20 | super().__init__() 21 | self.avgpool = self.globalavgchannelpool 22 | self.maxpool = self.globalmaxchannelpool 23 | self.proj = nn.Sequential( 24 | nn.Conv2d(2, 1, 1,1), 25 | nn.BatchNorm2d(1) 26 | ) 27 | def globalavgchannelpool(self, x): 28 | x = x.mean(1, keepdim = True) 29 | return x 30 | 31 | def globalmaxchannelpool(self, x): 32 | x = x.max(dim = 1, keepdim=True)[0] 33 | return x 34 | 35 | def forward(self, x): 36 | x_ = x.clone() 37 | x = self.avgpool(x) 38 | x2 = self.maxpool(x_) 39 | 40 | cat = torch.cat((x,x2), dim = 1) 41 | 42 | proj = self.proj(cat) 43 | return proj 44 | 45 | class ContextExtraction(nn.Module): 46 | def __init__(self, dim, reduction = None): 47 | super().__init__() 48 | self.reduction = 1 if reduction == None else 2 49 | 50 | self.dconv = self.DepthWiseConv2dx2(dim) 51 | self.proj = self.Proj(dim) 52 | 53 | def DepthWiseConv2dx2(self, dim): 54 | dconv = nn.Sequential( 55 | nn.Conv2d(in_channels = dim, 56 | out_channels = dim, 57 | kernel_size = 3, 58 | padding = 1, 59 | groups = dim), 60 | nn.BatchNorm2d(num_features = dim), 61 | nn.ReLU(inplace = True), 62 | nn.Conv2d(in_channels = dim, 63 | out_channels = dim, 64 | kernel_size = 3, 65 | padding = 2, 66 | dilation = 2), 67 | nn.BatchNorm2d(num_features = dim), 68 | nn.ReLU(inplace = True) 69 | ) 70 | return dconv 71 | 72 | def Proj(self, dim): 73 | proj = nn.Sequential( 74 | nn.Conv2d(in_channels = dim, 75 | out_channels = dim //self.reduction, 76 | kernel_size = 1 77 | ), 78 | nn.BatchNorm2d(num_features = dim//self.reduction) 79 | ) 80 | return proj 81 | def forward(self,x): 82 | x = self.dconv(x) 83 | x = self.proj(x) 84 | return x 85 | 86 | class MultiscaleFusion(nn.Module): 87 | def __init__(self, dim): 88 | super().__init__() 89 | self.local= ContextExtraction(dim) 90 | self.global_ = GlobalExtraction() 91 | self.bn = nn.BatchNorm2d(num_features=dim) 92 | 93 | def forward(self, x, g,): 94 | x = self.local(x) 95 | g = self.global_(g) 96 | 97 | fuse = self.bn(x + g) 98 | return fuse 99 | 100 | 101 | class MultiScaleGatedAttn(nn.Module): 102 | # Version 1 103 | def __init__(self, dim): 104 | super().__init__() 105 | self.multi = MultiscaleFusion(dim) 106 | self.selection = nn.Conv2d(dim, 2,1) 107 | self.proj = nn.Conv2d(dim, dim,1) 108 | self.bn = nn.BatchNorm2d(dim) 109 | self.bn_2 = nn.BatchNorm2d(dim) 110 | self.conv_block = nn.Sequential( 111 | nn.Conv2d(in_channels=dim, out_channels=dim, 112 | kernel_size=1, stride=1)) 113 | 114 | def forward(self,x,g): 115 | x_ = x.clone() 116 | g_ = g.clone() 117 | 118 | #stacked = torch.stack((x_, g_), dim = 1) # B, 2, C, H, W 119 | 120 | multi = self.multi(x, g) # B, C, H, W 121 | 122 | ### Option 2 ### 123 | multi = self.selection(multi) # B, num_path, H, W 124 | 125 | attention_weights = F.softmax(multi, dim=1) # Shape: [B, 2, H, W] 126 | #attention_weights = torch.sigmoid(multi) 127 | A, B = attention_weights.split(1, dim=1) # Each will have shape [B, 1, H, W] 128 | 129 | x_att = A.expand_as(x_) * x_ # Using expand_as to match the channel dimensions 130 | g_att = B.expand_as(g_) * g_ 131 | 132 | x_att = x_att + x_ 133 | g_att = g_att + g_ 134 | ## Bidirectional Interaction 135 | 136 | x_sig = torch.sigmoid(x_att) 137 | g_att_2 = x_sig * g_att 138 | 139 | 140 | g_sig = torch.sigmoid(g_att) 141 | x_att_2 = g_sig * x_att 142 | 143 | interaction = x_att_2 * g_att_2 144 | 145 | projected = torch.sigmoid(self.bn(self.proj(interaction))) 146 | 147 | weighted = projected * x_ 148 | 149 | y = self.conv_block(weighted) 150 | 151 | #y = self.bn_2(weighted + y) 152 | y = self.bn_2(y) 153 | return y 154 | 155 | if __name__ == "__main__": 156 | xi = torch.randn(1, 192, 28, 28).cuda() 157 | #xi_1 = torch.randn(1, 384, 14, 14) 158 | g = torch.randn(1, 192, 28, 28).cuda() 159 | #ff = ContextBridge(dim=192) 160 | 161 | attn = MultiScaleGatedAttn(dim = xi.shape[1]).cuda() 162 | 163 | print(attn(xi, g).shape) 164 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/factory.py: -------------------------------------------------------------------------------- 1 | from urllib.parse import urlsplit, urlunsplit 2 | import os 3 | 4 | from .registry import is_model, is_model_in_modules, model_entrypoint 5 | from .helpers import load_checkpoint 6 | from .layers import set_layer_config 7 | from .hub import load_model_config_from_hf 8 | 9 | 10 | def parse_model_name(model_name): 11 | model_name = model_name.replace('hf_hub', 'hf-hub') # NOTE for backwards compat, to deprecate hf_hub use 12 | parsed = urlsplit(model_name) 13 | assert parsed.scheme in ('', 'timm', 'hf-hub') 14 | if parsed.scheme == 'hf-hub': 15 | # FIXME may use fragment as revision, currently `@` in URI path 16 | return parsed.scheme, parsed.path 17 | else: 18 | model_name = os.path.split(parsed.path)[-1] 19 | return 'timm', model_name 20 | 21 | 22 | def safe_model_name(model_name, remove_source=True): 23 | def make_safe(name): 24 | return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_') 25 | if remove_source: 26 | model_name = parse_model_name(model_name)[-1] 27 | return make_safe(model_name) 28 | 29 | 30 | def create_model( 31 | model_name, 32 | pretrained=False, 33 | pretrained_cfg=None, 34 | checkpoint_path='', 35 | scriptable=None, 36 | exportable=None, 37 | no_jit=None, 38 | **kwargs): 39 | """Create a model 40 | 41 | Args: 42 | model_name (str): name of model to instantiate 43 | pretrained (bool): load pretrained ImageNet-1k weights if true 44 | checkpoint_path (str): path of checkpoint to load after model is initialized 45 | scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet) 46 | exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet) 47 | no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only) 48 | 49 | Keyword Args: 50 | drop_rate (float): dropout rate for training (default: 0.0) 51 | global_pool (str): global pool type (default: 'avg') 52 | **: other kwargs are model specific 53 | """ 54 | # Parameters that aren't supported by all models or are intended to only override model defaults if set 55 | # should default to None in command line args/cfg. Remove them if they are present and not set so that 56 | # non-supporting models don't break and default args remain in effect. 57 | kwargs = {k: v for k, v in kwargs.items() if v is not None} 58 | 59 | model_source, model_name = parse_model_name(model_name) 60 | if model_source == 'hf-hub': 61 | # FIXME hf-hub source overrides any passed in pretrained_cfg, warn? 62 | # For model names specified in the form `hf-hub:path/architecture_name@revision`, 63 | # load model weights + pretrained_cfg from Hugging Face hub. 64 | pretrained_cfg, model_name = load_model_config_from_hf(model_name) 65 | 66 | if not is_model(model_name): 67 | raise RuntimeError('Unknown model (%s)' % model_name) 68 | 69 | create_fn = model_entrypoint(model_name) 70 | with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit): 71 | model = create_fn(pretrained=pretrained, pretrained_cfg=pretrained_cfg, **kwargs) 72 | 73 | if checkpoint_path: 74 | load_checkpoint(model, checkpoint_path) 75 | 76 | return model 77 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/fx_features.py: -------------------------------------------------------------------------------- 1 | """ PyTorch FX Based Feature Extraction Helpers 2 | Using https://pytorch.org/vision/stable/feature_extraction.html 3 | """ 4 | from typing import Callable, List, Dict, Union, Type 5 | 6 | import torch 7 | from torch import nn 8 | 9 | from .features import _get_feature_info 10 | 11 | try: 12 | from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor 13 | has_fx_feature_extraction = True 14 | except ImportError: 15 | has_fx_feature_extraction = False 16 | 17 | # Layers we went to treat as leaf modules 18 | from .layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame 19 | from .layers.non_local_attn import BilinearAttnTransform 20 | from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame 21 | 22 | # NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here 23 | # BUT modules from timm.models should use the registration mechanism below 24 | _leaf_modules = { 25 | BilinearAttnTransform, # reason: flow control t <= 1 26 | # Reason: get_same_padding has a max which raises a control flow error 27 | Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame, 28 | CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0]) 29 | } 30 | 31 | try: 32 | from .layers import InplaceAbn 33 | _leaf_modules.add(InplaceAbn) 34 | except ImportError: 35 | pass 36 | 37 | 38 | def register_notrace_module(module: Type[nn.Module]): 39 | """ 40 | Any module not under timm.models.layers should get this decorator if we don't want to trace through it. 41 | """ 42 | _leaf_modules.add(module) 43 | return module 44 | 45 | 46 | # Functions we want to autowrap (treat them as leaves) 47 | _autowrap_functions = set() 48 | 49 | 50 | def register_notrace_function(func: Callable): 51 | """ 52 | Decorator for functions which ought not to be traced through 53 | """ 54 | _autowrap_functions.add(func) 55 | return func 56 | 57 | 58 | def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]): 59 | assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction' 60 | return _create_feature_extractor( 61 | model, return_nodes, 62 | tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)} 63 | ) 64 | 65 | 66 | class FeatureGraphNet(nn.Module): 67 | """ A FX Graph based feature extractor that works with the model feature_info metadata 68 | """ 69 | def __init__(self, model, out_indices, out_map=None): 70 | super().__init__() 71 | assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction' 72 | self.feature_info = _get_feature_info(model, out_indices) 73 | if out_map is not None: 74 | assert len(out_map) == len(out_indices) 75 | return_nodes = { 76 | info['module']: out_map[i] if out_map is not None else info['module'] 77 | for i, info in enumerate(self.feature_info) if i in out_indices} 78 | self.graph_module = create_feature_extractor(model, return_nodes) 79 | 80 | def forward(self, x): 81 | return list(self.graph_module(x).values()) 82 | 83 | 84 | class GraphExtractNet(nn.Module): 85 | """ A standalone feature extraction wrapper that maps dict -> list or single tensor 86 | NOTE: 87 | * one can use feature_extractor directly if dictionary output is desired 88 | * unlike FeatureGraphNet, this is intended to be used standalone and not with model feature_info 89 | metadata for builtin feature extraction mode 90 | * create_feature_extractor can be used directly if dictionary output is acceptable 91 | 92 | Args: 93 | model: model to extract features from 94 | return_nodes: node names to return features from (dict or list) 95 | squeeze_out: if only one output, and output in list format, flatten to single tensor 96 | """ 97 | def __init__(self, model, return_nodes: Union[Dict[str, str], List[str]], squeeze_out: bool = True): 98 | super().__init__() 99 | self.squeeze_out = squeeze_out 100 | self.graph_module = create_feature_extractor(model, return_nodes) 101 | 102 | def forward(self, x) -> Union[List[torch.Tensor], torch.Tensor]: 103 | out = list(self.graph_module(x).values()) 104 | if self.squeeze_out and len(out) == 1: 105 | return out[0] 106 | return out 107 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/hub.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from functools import partial 5 | from pathlib import Path 6 | from typing import Union 7 | 8 | import torch 9 | from torch.hub import HASH_REGEX, download_url_to_file, urlparse 10 | try: 11 | from torch.hub import get_dir 12 | except ImportError: 13 | from torch.hub import _get_torch_home as get_dir 14 | 15 | from timm import __version__ 16 | 17 | try: 18 | from huggingface_hub import HfApi, HfFolder, Repository, hf_hub_download, hf_hub_url 19 | hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__) 20 | _has_hf_hub = True 21 | except ImportError: 22 | hf_hub_download = None 23 | _has_hf_hub = False 24 | 25 | _logger = logging.getLogger(__name__) 26 | 27 | 28 | def get_cache_dir(child_dir=''): 29 | """ 30 | Returns the location of the directory where models are cached (and creates it if necessary). 31 | """ 32 | # Issue warning to move data if old env is set 33 | if os.getenv('TORCH_MODEL_ZOO'): 34 | _logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') 35 | 36 | hub_dir = get_dir() 37 | child_dir = () if not child_dir else (child_dir,) 38 | model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir) 39 | os.makedirs(model_dir, exist_ok=True) 40 | return model_dir 41 | 42 | 43 | def download_cached_file(url, check_hash=True, progress=False): 44 | parts = urlparse(url) 45 | filename = os.path.basename(parts.path) 46 | cached_file = os.path.join(get_cache_dir(), filename) 47 | if not os.path.exists(cached_file): 48 | _logger.info('Downloading: "{}" to {}\n'.format(url, cached_file)) 49 | hash_prefix = None 50 | if check_hash: 51 | r = HASH_REGEX.search(filename) # r is Optional[Match[str]] 52 | hash_prefix = r.group(1) if r else None 53 | download_url_to_file(url, cached_file, hash_prefix, progress=progress) 54 | return cached_file 55 | 56 | 57 | def has_hf_hub(necessary=False): 58 | if not _has_hf_hub and necessary: 59 | # if no HF Hub module installed, and it is necessary to continue, raise error 60 | raise RuntimeError( 61 | 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') 62 | return _has_hf_hub 63 | 64 | 65 | def hf_split(hf_id): 66 | # FIXME I may change @ -> # and be parsed as fragment in a URI model name scheme 67 | rev_split = hf_id.split('@') 68 | assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one @ character to identify revision.' 69 | hf_model_id = rev_split[0] 70 | hf_revision = rev_split[-1] if len(rev_split) > 1 else None 71 | return hf_model_id, hf_revision 72 | 73 | 74 | def load_cfg_from_json(json_file: Union[str, os.PathLike]): 75 | with open(json_file, "r", encoding="utf-8") as reader: 76 | text = reader.read() 77 | return json.loads(text) 78 | 79 | 80 | def _download_from_hf(model_id: str, filename: str): 81 | hf_model_id, hf_revision = hf_split(model_id) 82 | return hf_hub_download(hf_model_id, filename, revision=hf_revision) 83 | 84 | 85 | def load_model_config_from_hf(model_id: str): 86 | assert has_hf_hub(True) 87 | cached_file = _download_from_hf(model_id, 'config.json') 88 | pretrained_cfg = load_cfg_from_json(cached_file) 89 | pretrained_cfg['hf_hub_id'] = model_id # insert hf_hub id for pretrained weight load during model creation 90 | pretrained_cfg['source'] = 'hf-hub' 91 | model_name = pretrained_cfg.get('architecture') 92 | return pretrained_cfg, model_name 93 | 94 | 95 | def load_state_dict_from_hf(model_id: str, filename: str = 'pytorch_model.bin'): 96 | assert has_hf_hub(True) 97 | cached_file = _download_from_hf(model_id, filename) 98 | state_dict = torch.load(cached_file, map_location='cpu') 99 | return state_dict 100 | 101 | 102 | def save_for_hf(model, save_directory, model_config=None): 103 | assert has_hf_hub(True) 104 | model_config = model_config or {} 105 | save_directory = Path(save_directory) 106 | save_directory.mkdir(exist_ok=True, parents=True) 107 | 108 | weights_path = save_directory / 'pytorch_model.bin' 109 | torch.save(model.state_dict(), weights_path) 110 | 111 | config_path = save_directory / 'config.json' 112 | hf_config = model.pretrained_cfg 113 | hf_config['num_classes'] = model_config.pop('num_classes', model.num_classes) 114 | hf_config['num_features'] = model_config.pop('num_features', model.num_features) 115 | hf_config['labels'] = model_config.pop('labels', [f"LABEL_{i}" for i in range(hf_config['num_classes'])]) 116 | hf_config.update(model_config) 117 | 118 | with config_path.open('w') as f: 119 | json.dump(hf_config, f, indent=2) 120 | 121 | 122 | def push_to_hf_hub( 123 | model, 124 | local_dir, 125 | repo_namespace_or_url=None, 126 | commit_message='Add model', 127 | use_auth_token=True, 128 | git_email=None, 129 | git_user=None, 130 | revision=None, 131 | model_config=None, 132 | ): 133 | if repo_namespace_or_url: 134 | repo_owner, repo_name = repo_namespace_or_url.rstrip('/').split('/')[-2:] 135 | else: 136 | if isinstance(use_auth_token, str): 137 | token = use_auth_token 138 | else: 139 | token = HfFolder.get_token() 140 | 141 | if token is None: 142 | raise ValueError( 143 | "You must login to the Hugging Face hub on this computer by typing `transformers-cli login` and " 144 | "entering your credentials to use `use_auth_token=True`. Alternatively, you can pass your own " 145 | "token as the `use_auth_token` argument." 146 | ) 147 | 148 | repo_owner = HfApi().whoami(token)['name'] 149 | repo_name = Path(local_dir).name 150 | 151 | repo_url = f'https://huggingface.co/{repo_owner}/{repo_name}' 152 | 153 | repo = Repository( 154 | local_dir, 155 | clone_from=repo_url, 156 | use_auth_token=use_auth_token, 157 | git_user=git_user, 158 | git_email=git_email, 159 | revision=revision, 160 | ) 161 | 162 | # Prepare a default model card that includes the necessary tags to enable inference. 163 | readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {repo_name}' 164 | with repo.commit(commit_message): 165 | # Save model weights and config. 166 | save_for_hf(model, repo.local_dir, model_config=model_config) 167 | 168 | # Save a model card if it doesn't exist. 169 | readme_path = Path(repo.local_dir) / 'README.md' 170 | if not readme_path.exists(): 171 | readme_path.write_text(readme_text) 172 | 173 | return repo.git_remote_url() 174 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .activations import * 2 | from .adaptive_avgmax_pool import \ 3 | adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d 4 | from .blur_pool import BlurPool2d 5 | from .classifier import ClassifierHead, create_classifier 6 | from .cond_conv2d import CondConv2d, get_condconv_initializer 7 | from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ 8 | set_layer_config 9 | from .conv2d_same import Conv2dSame, conv2d_same 10 | from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct 11 | from .create_act import create_act_layer, get_act_layer, get_act_fn 12 | from .create_attn import get_attn, create_attn 13 | from .create_conv2d import create_conv2d 14 | from .create_norm import get_norm_layer, create_norm_layer 15 | from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer 16 | from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path 17 | from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn 18 | from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\ 19 | EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a 20 | from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm 21 | from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d 22 | from .gather_excite import GatherExcite 23 | from .global_context import GlobalContext 24 | from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple 25 | from .inplace_abn import InplaceAbn 26 | from .linear import Linear 27 | from .mixed_conv2d import MixedConv2d 28 | from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp 29 | from .non_local_attn import NonLocalAttn, BatNonLocalAttn 30 | from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d 31 | from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm 32 | from .padding import get_padding, get_same_padding, pad_same 33 | from .patch_embed import PatchEmbed 34 | from .pool2d_same import AvgPool2dSame, create_pool2d 35 | from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite 36 | from .selective_kernel import SelectiveKernel 37 | from .separable_conv import SeparableConv2d, SeparableConvNormAct 38 | from .space_to_depth import SpaceToDepthModule 39 | from .split_attn import SplitAttn 40 | from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model 41 | from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame 42 | from .test_time_pool import TestTimePoolHead, apply_test_time_pool 43 | from .trace_utils import _assert, _float_to_int 44 | from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_ 45 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/activations.py: -------------------------------------------------------------------------------- 1 | """ Activations 2 | 3 | A collection of activations fn and modules with a common interface so that they can 4 | easily be swapped. All have an `inplace` arg even if not used. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | 9 | import torch 10 | from torch import nn as nn 11 | from torch.nn import functional as F 12 | 13 | 14 | def swish(x, inplace: bool = False): 15 | """Swish - Described in: https://arxiv.org/abs/1710.05941 16 | """ 17 | return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) 18 | 19 | 20 | class Swish(nn.Module): 21 | def __init__(self, inplace: bool = False): 22 | super(Swish, self).__init__() 23 | self.inplace = inplace 24 | 25 | def forward(self, x): 26 | return swish(x, self.inplace) 27 | 28 | 29 | def mish(x, inplace: bool = False): 30 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 31 | NOTE: I don't have a working inplace variant 32 | """ 33 | return x.mul(F.softplus(x).tanh()) 34 | 35 | 36 | class Mish(nn.Module): 37 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 38 | """ 39 | def __init__(self, inplace: bool = False): 40 | super(Mish, self).__init__() 41 | 42 | def forward(self, x): 43 | return mish(x) 44 | 45 | 46 | def sigmoid(x, inplace: bool = False): 47 | return x.sigmoid_() if inplace else x.sigmoid() 48 | 49 | 50 | # PyTorch has this, but not with a consistent inplace argmument interface 51 | class Sigmoid(nn.Module): 52 | def __init__(self, inplace: bool = False): 53 | super(Sigmoid, self).__init__() 54 | self.inplace = inplace 55 | 56 | def forward(self, x): 57 | return x.sigmoid_() if self.inplace else x.sigmoid() 58 | 59 | 60 | def tanh(x, inplace: bool = False): 61 | return x.tanh_() if inplace else x.tanh() 62 | 63 | 64 | # PyTorch has this, but not with a consistent inplace argmument interface 65 | class Tanh(nn.Module): 66 | def __init__(self, inplace: bool = False): 67 | super(Tanh, self).__init__() 68 | self.inplace = inplace 69 | 70 | def forward(self, x): 71 | return x.tanh_() if self.inplace else x.tanh() 72 | 73 | 74 | def hard_swish(x, inplace: bool = False): 75 | inner = F.relu6(x + 3.).div_(6.) 76 | return x.mul_(inner) if inplace else x.mul(inner) 77 | 78 | 79 | class HardSwish(nn.Module): 80 | def __init__(self, inplace: bool = False): 81 | super(HardSwish, self).__init__() 82 | self.inplace = inplace 83 | 84 | def forward(self, x): 85 | return hard_swish(x, self.inplace) 86 | 87 | 88 | def hard_sigmoid(x, inplace: bool = False): 89 | if inplace: 90 | return x.add_(3.).clamp_(0., 6.).div_(6.) 91 | else: 92 | return F.relu6(x + 3.) / 6. 93 | 94 | 95 | class HardSigmoid(nn.Module): 96 | def __init__(self, inplace: bool = False): 97 | super(HardSigmoid, self).__init__() 98 | self.inplace = inplace 99 | 100 | def forward(self, x): 101 | return hard_sigmoid(x, self.inplace) 102 | 103 | 104 | def hard_mish(x, inplace: bool = False): 105 | """ Hard Mish 106 | Experimental, based on notes by Mish author Diganta Misra at 107 | https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md 108 | """ 109 | if inplace: 110 | return x.mul_(0.5 * (x + 2).clamp(min=0, max=2)) 111 | else: 112 | return 0.5 * x * (x + 2).clamp(min=0, max=2) 113 | 114 | 115 | class HardMish(nn.Module): 116 | def __init__(self, inplace: bool = False): 117 | super(HardMish, self).__init__() 118 | self.inplace = inplace 119 | 120 | def forward(self, x): 121 | return hard_mish(x, self.inplace) 122 | 123 | 124 | class PReLU(nn.PReLU): 125 | """Applies PReLU (w/ dummy inplace arg) 126 | """ 127 | def __init__(self, num_parameters: int = 1, init: float = 0.25, inplace: bool = False) -> None: 128 | super(PReLU, self).__init__(num_parameters=num_parameters, init=init) 129 | 130 | def forward(self, input: torch.Tensor) -> torch.Tensor: 131 | return F.prelu(input, self.weight) 132 | 133 | 134 | def gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor: 135 | return F.gelu(x) 136 | 137 | 138 | class GELU(nn.Module): 139 | """Applies the Gaussian Error Linear Units function (w/ dummy inplace arg) 140 | """ 141 | def __init__(self, inplace: bool = False): 142 | super(GELU, self).__init__() 143 | 144 | def forward(self, input: torch.Tensor) -> torch.Tensor: 145 | return F.gelu(input) 146 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/activations_jit.py: -------------------------------------------------------------------------------- 1 | """ Activations 2 | 3 | A collection of jit-scripted activations fn and modules with a common interface so that they can 4 | easily be swapped. All have an `inplace` arg even if not used. 5 | 6 | All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not 7 | currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted 8 | versions if they contain in-place ops. 9 | 10 | Hacked together by / Copyright 2020 Ross Wightman 11 | """ 12 | 13 | import torch 14 | from torch import nn as nn 15 | from torch.nn import functional as F 16 | 17 | 18 | @torch.jit.script 19 | def swish_jit(x, inplace: bool = False): 20 | """Swish - Described in: https://arxiv.org/abs/1710.05941 21 | """ 22 | return x.mul(x.sigmoid()) 23 | 24 | 25 | @torch.jit.script 26 | def mish_jit(x, _inplace: bool = False): 27 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 28 | """ 29 | return x.mul(F.softplus(x).tanh()) 30 | 31 | 32 | class SwishJit(nn.Module): 33 | def __init__(self, inplace: bool = False): 34 | super(SwishJit, self).__init__() 35 | 36 | def forward(self, x): 37 | return swish_jit(x) 38 | 39 | 40 | class MishJit(nn.Module): 41 | def __init__(self, inplace: bool = False): 42 | super(MishJit, self).__init__() 43 | 44 | def forward(self, x): 45 | return mish_jit(x) 46 | 47 | 48 | @torch.jit.script 49 | def hard_sigmoid_jit(x, inplace: bool = False): 50 | # return F.relu6(x + 3.) / 6. 51 | return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? 52 | 53 | 54 | class HardSigmoidJit(nn.Module): 55 | def __init__(self, inplace: bool = False): 56 | super(HardSigmoidJit, self).__init__() 57 | 58 | def forward(self, x): 59 | return hard_sigmoid_jit(x) 60 | 61 | 62 | @torch.jit.script 63 | def hard_swish_jit(x, inplace: bool = False): 64 | # return x * (F.relu6(x + 3.) / 6) 65 | return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? 66 | 67 | 68 | class HardSwishJit(nn.Module): 69 | def __init__(self, inplace: bool = False): 70 | super(HardSwishJit, self).__init__() 71 | 72 | def forward(self, x): 73 | return hard_swish_jit(x) 74 | 75 | 76 | @torch.jit.script 77 | def hard_mish_jit(x, inplace: bool = False): 78 | """ Hard Mish 79 | Experimental, based on notes by Mish author Diganta Misra at 80 | https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md 81 | """ 82 | return 0.5 * x * (x + 2).clamp(min=0, max=2) 83 | 84 | 85 | class HardMishJit(nn.Module): 86 | def __init__(self, inplace: bool = False): 87 | super(HardMishJit, self).__init__() 88 | 89 | def forward(self, x): 90 | return hard_mish_jit(x) 91 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/activations_me.py: -------------------------------------------------------------------------------- 1 | """ Activations (memory-efficient w/ custom autograd) 2 | 3 | A collection of activations fn and modules with a common interface so that they can 4 | easily be swapped. All have an `inplace` arg even if not used. 5 | 6 | These activations are not compatible with jit scripting or ONNX export of the model, please use either 7 | the JIT or basic versions of the activations. 8 | 9 | Hacked together by / Copyright 2020 Ross Wightman 10 | """ 11 | 12 | import torch 13 | from torch import nn as nn 14 | from torch.nn import functional as F 15 | 16 | 17 | @torch.jit.script 18 | def swish_jit_fwd(x): 19 | return x.mul(torch.sigmoid(x)) 20 | 21 | 22 | @torch.jit.script 23 | def swish_jit_bwd(x, grad_output): 24 | x_sigmoid = torch.sigmoid(x) 25 | return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid))) 26 | 27 | 28 | class SwishJitAutoFn(torch.autograd.Function): 29 | """ torch.jit.script optimised Swish w/ memory-efficient checkpoint 30 | Inspired by conversation btw Jeremy Howard & Adam Pazske 31 | https://twitter.com/jeremyphoward/status/1188251041835315200 32 | """ 33 | @staticmethod 34 | def symbolic(g, x): 35 | return g.op("Mul", x, g.op("Sigmoid", x)) 36 | 37 | @staticmethod 38 | def forward(ctx, x): 39 | ctx.save_for_backward(x) 40 | return swish_jit_fwd(x) 41 | 42 | @staticmethod 43 | def backward(ctx, grad_output): 44 | x = ctx.saved_tensors[0] 45 | return swish_jit_bwd(x, grad_output) 46 | 47 | 48 | def swish_me(x, inplace=False): 49 | return SwishJitAutoFn.apply(x) 50 | 51 | 52 | class SwishMe(nn.Module): 53 | def __init__(self, inplace: bool = False): 54 | super(SwishMe, self).__init__() 55 | 56 | def forward(self, x): 57 | return SwishJitAutoFn.apply(x) 58 | 59 | 60 | @torch.jit.script 61 | def mish_jit_fwd(x): 62 | return x.mul(torch.tanh(F.softplus(x))) 63 | 64 | 65 | @torch.jit.script 66 | def mish_jit_bwd(x, grad_output): 67 | x_sigmoid = torch.sigmoid(x) 68 | x_tanh_sp = F.softplus(x).tanh() 69 | return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp)) 70 | 71 | 72 | class MishJitAutoFn(torch.autograd.Function): 73 | """ Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 74 | A memory efficient, jit scripted variant of Mish 75 | """ 76 | @staticmethod 77 | def forward(ctx, x): 78 | ctx.save_for_backward(x) 79 | return mish_jit_fwd(x) 80 | 81 | @staticmethod 82 | def backward(ctx, grad_output): 83 | x = ctx.saved_tensors[0] 84 | return mish_jit_bwd(x, grad_output) 85 | 86 | 87 | def mish_me(x, inplace=False): 88 | return MishJitAutoFn.apply(x) 89 | 90 | 91 | class MishMe(nn.Module): 92 | def __init__(self, inplace: bool = False): 93 | super(MishMe, self).__init__() 94 | 95 | def forward(self, x): 96 | return MishJitAutoFn.apply(x) 97 | 98 | 99 | @torch.jit.script 100 | def hard_sigmoid_jit_fwd(x, inplace: bool = False): 101 | return (x + 3).clamp(min=0, max=6).div(6.) 102 | 103 | 104 | @torch.jit.script 105 | def hard_sigmoid_jit_bwd(x, grad_output): 106 | m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6. 107 | return grad_output * m 108 | 109 | 110 | class HardSigmoidJitAutoFn(torch.autograd.Function): 111 | @staticmethod 112 | def forward(ctx, x): 113 | ctx.save_for_backward(x) 114 | return hard_sigmoid_jit_fwd(x) 115 | 116 | @staticmethod 117 | def backward(ctx, grad_output): 118 | x = ctx.saved_tensors[0] 119 | return hard_sigmoid_jit_bwd(x, grad_output) 120 | 121 | 122 | def hard_sigmoid_me(x, inplace: bool = False): 123 | return HardSigmoidJitAutoFn.apply(x) 124 | 125 | 126 | class HardSigmoidMe(nn.Module): 127 | def __init__(self, inplace: bool = False): 128 | super(HardSigmoidMe, self).__init__() 129 | 130 | def forward(self, x): 131 | return HardSigmoidJitAutoFn.apply(x) 132 | 133 | 134 | @torch.jit.script 135 | def hard_swish_jit_fwd(x): 136 | return x * (x + 3).clamp(min=0, max=6).div(6.) 137 | 138 | 139 | @torch.jit.script 140 | def hard_swish_jit_bwd(x, grad_output): 141 | m = torch.ones_like(x) * (x >= 3.) 142 | m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m) 143 | return grad_output * m 144 | 145 | 146 | class HardSwishJitAutoFn(torch.autograd.Function): 147 | """A memory efficient, jit-scripted HardSwish activation""" 148 | @staticmethod 149 | def forward(ctx, x): 150 | ctx.save_for_backward(x) 151 | return hard_swish_jit_fwd(x) 152 | 153 | @staticmethod 154 | def backward(ctx, grad_output): 155 | x = ctx.saved_tensors[0] 156 | return hard_swish_jit_bwd(x, grad_output) 157 | 158 | @staticmethod 159 | def symbolic(g, self): 160 | input = g.op("Add", self, g.op('Constant', value_t=torch.tensor(3, dtype=torch.float))) 161 | hardtanh_ = g.op("Clip", input, g.op('Constant', value_t=torch.tensor(0, dtype=torch.float)), g.op('Constant', value_t=torch.tensor(6, dtype=torch.float))) 162 | hardtanh_ = g.op("Div", hardtanh_, g.op('Constant', value_t=torch.tensor(6, dtype=torch.float))) 163 | return g.op("Mul", self, hardtanh_) 164 | 165 | 166 | def hard_swish_me(x, inplace=False): 167 | return HardSwishJitAutoFn.apply(x) 168 | 169 | 170 | class HardSwishMe(nn.Module): 171 | def __init__(self, inplace: bool = False): 172 | super(HardSwishMe, self).__init__() 173 | 174 | def forward(self, x): 175 | return HardSwishJitAutoFn.apply(x) 176 | 177 | 178 | @torch.jit.script 179 | def hard_mish_jit_fwd(x): 180 | return 0.5 * x * (x + 2).clamp(min=0, max=2) 181 | 182 | 183 | @torch.jit.script 184 | def hard_mish_jit_bwd(x, grad_output): 185 | m = torch.ones_like(x) * (x >= -2.) 186 | m = torch.where((x >= -2.) & (x <= 0.), x + 1., m) 187 | return grad_output * m 188 | 189 | 190 | class HardMishJitAutoFn(torch.autograd.Function): 191 | """ A memory efficient, jit scripted variant of Hard Mish 192 | Experimental, based on notes by Mish author Diganta Misra at 193 | https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md 194 | """ 195 | @staticmethod 196 | def forward(ctx, x): 197 | ctx.save_for_backward(x) 198 | return hard_mish_jit_fwd(x) 199 | 200 | @staticmethod 201 | def backward(ctx, grad_output): 202 | x = ctx.saved_tensors[0] 203 | return hard_mish_jit_bwd(x, grad_output) 204 | 205 | 206 | def hard_mish_me(x, inplace: bool = False): 207 | return HardMishJitAutoFn.apply(x) 208 | 209 | 210 | class HardMishMe(nn.Module): 211 | def __init__(self, inplace: bool = False): 212 | super(HardMishMe, self).__init__() 213 | 214 | def forward(self, x): 215 | return HardMishJitAutoFn.apply(x) 216 | 217 | 218 | 219 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/adaptive_avgmax_pool.py: -------------------------------------------------------------------------------- 1 | """ PyTorch selectable adaptive pooling 2 | Adaptive pooling with the ability to select the type of pooling from: 3 | * 'avg' - Average pooling 4 | * 'max' - Max pooling 5 | * 'avgmax' - Sum of average and max pooling re-scaled by 0.5 6 | * 'avgmaxc' - Concatenation of average and max pooling along feature dim, doubles feature dim 7 | 8 | Both a functional and a nn.Module version of the pooling is provided. 9 | 10 | Hacked together by / Copyright 2020 Ross Wightman 11 | """ 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | 17 | def adaptive_pool_feat_mult(pool_type='avg'): 18 | if pool_type == 'catavgmax': 19 | return 2 20 | else: 21 | return 1 22 | 23 | 24 | def adaptive_avgmax_pool2d(x, output_size=1): 25 | x_avg = F.adaptive_avg_pool2d(x, output_size) 26 | x_max = F.adaptive_max_pool2d(x, output_size) 27 | return 0.5 * (x_avg + x_max) 28 | 29 | 30 | def adaptive_catavgmax_pool2d(x, output_size=1): 31 | x_avg = F.adaptive_avg_pool2d(x, output_size) 32 | x_max = F.adaptive_max_pool2d(x, output_size) 33 | return torch.cat((x_avg, x_max), 1) 34 | 35 | 36 | def select_adaptive_pool2d(x, pool_type='avg', output_size=1): 37 | """Selectable global pooling function with dynamic input kernel size 38 | """ 39 | if pool_type == 'avg': 40 | x = F.adaptive_avg_pool2d(x, output_size) 41 | elif pool_type == 'avgmax': 42 | x = adaptive_avgmax_pool2d(x, output_size) 43 | elif pool_type == 'catavgmax': 44 | x = adaptive_catavgmax_pool2d(x, output_size) 45 | elif pool_type == 'max': 46 | x = F.adaptive_max_pool2d(x, output_size) 47 | else: 48 | assert False, 'Invalid pool type: %s' % pool_type 49 | return x 50 | 51 | 52 | class FastAdaptiveAvgPool2d(nn.Module): 53 | def __init__(self, flatten=False): 54 | super(FastAdaptiveAvgPool2d, self).__init__() 55 | self.flatten = flatten 56 | 57 | def forward(self, x): 58 | return x.mean((2, 3), keepdim=not self.flatten) 59 | 60 | 61 | class AdaptiveAvgMaxPool2d(nn.Module): 62 | def __init__(self, output_size=1): 63 | super(AdaptiveAvgMaxPool2d, self).__init__() 64 | self.output_size = output_size 65 | 66 | def forward(self, x): 67 | return adaptive_avgmax_pool2d(x, self.output_size) 68 | 69 | 70 | class AdaptiveCatAvgMaxPool2d(nn.Module): 71 | def __init__(self, output_size=1): 72 | super(AdaptiveCatAvgMaxPool2d, self).__init__() 73 | self.output_size = output_size 74 | 75 | def forward(self, x): 76 | return adaptive_catavgmax_pool2d(x, self.output_size) 77 | 78 | 79 | class SelectAdaptivePool2d(nn.Module): 80 | """Selectable global pooling layer with dynamic input kernel size 81 | """ 82 | def __init__(self, output_size=1, pool_type='fast', flatten=False): 83 | super(SelectAdaptivePool2d, self).__init__() 84 | self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing 85 | self.flatten = nn.Flatten(1) if flatten else nn.Identity() 86 | if pool_type == '': 87 | self.pool = nn.Identity() # pass through 88 | elif pool_type == 'fast': 89 | assert output_size == 1 90 | self.pool = FastAdaptiveAvgPool2d(flatten) 91 | self.flatten = nn.Identity() 92 | elif pool_type == 'avg': 93 | self.pool = nn.AdaptiveAvgPool2d(output_size) 94 | elif pool_type == 'avgmax': 95 | self.pool = AdaptiveAvgMaxPool2d(output_size) 96 | elif pool_type == 'catavgmax': 97 | self.pool = AdaptiveCatAvgMaxPool2d(output_size) 98 | elif pool_type == 'max': 99 | self.pool = nn.AdaptiveMaxPool2d(output_size) 100 | else: 101 | assert False, 'Invalid pool type: %s' % pool_type 102 | 103 | def is_identity(self): 104 | return not self.pool_type 105 | 106 | def forward(self, x): 107 | x = self.pool(x) 108 | x = self.flatten(x) 109 | return x 110 | 111 | def feat_mult(self): 112 | return adaptive_pool_feat_mult(self.pool_type) 113 | 114 | def __repr__(self): 115 | return self.__class__.__name__ + ' (' \ 116 | + 'pool_type=' + self.pool_type \ 117 | + ', flatten=' + str(self.flatten) + ')' 118 | 119 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/attention_pool2d.py: -------------------------------------------------------------------------------- 1 | """ Attention Pool 2D 2 | 3 | Implementations of 2D spatial feature pooling using multi-head attention instead of average pool. 4 | 5 | Based on idea in CLIP by OpenAI, licensed Apache 2.0 6 | https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py 7 | 8 | Hacked together by / Copyright 2021 Ross Wightman 9 | """ 10 | from typing import Union, Tuple 11 | 12 | import torch 13 | import torch.nn as nn 14 | 15 | from .helpers import to_2tuple 16 | from .pos_embed import apply_rot_embed, RotaryEmbedding 17 | from .weight_init import trunc_normal_ 18 | 19 | 20 | class RotAttentionPool2d(nn.Module): 21 | """ Attention based 2D feature pooling w/ rotary (relative) pos embedding. 22 | This is a multi-head attention based replacement for (spatial) average pooling in NN architectures. 23 | 24 | Adapted from the AttentionPool2d in CLIP w/ rotary embedding instead of learned embed. 25 | https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py 26 | 27 | NOTE: While this impl does not require a fixed feature size, performance at differeing resolutions from 28 | train varies widely and falls off dramatically. I'm not sure if there is a way around this... -RW 29 | """ 30 | def __init__( 31 | self, 32 | in_features: int, 33 | out_features: int = None, 34 | embed_dim: int = None, 35 | num_heads: int = 4, 36 | qkv_bias: bool = True, 37 | ): 38 | super().__init__() 39 | embed_dim = embed_dim or in_features 40 | out_features = out_features or in_features 41 | self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) 42 | self.proj = nn.Linear(embed_dim, out_features) 43 | self.num_heads = num_heads 44 | assert embed_dim % num_heads == 0 45 | self.head_dim = embed_dim // num_heads 46 | self.scale = self.head_dim ** -0.5 47 | self.pos_embed = RotaryEmbedding(self.head_dim) 48 | 49 | trunc_normal_(self.qkv.weight, std=in_features ** -0.5) 50 | nn.init.zeros_(self.qkv.bias) 51 | 52 | def forward(self, x): 53 | B, _, H, W = x.shape 54 | N = H * W 55 | x = x.reshape(B, -1, N).permute(0, 2, 1) 56 | 57 | x = torch.cat([x.mean(1, keepdim=True), x], dim=1) 58 | 59 | x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 60 | q, k, v = x[0], x[1], x[2] 61 | 62 | qc, q = q[:, :, :1], q[:, :, 1:] 63 | sin_emb, cos_emb = self.pos_embed.get_embed((H, W)) 64 | q = apply_rot_embed(q, sin_emb, cos_emb) 65 | q = torch.cat([qc, q], dim=2) 66 | 67 | kc, k = k[:, :, :1], k[:, :, 1:] 68 | k = apply_rot_embed(k, sin_emb, cos_emb) 69 | k = torch.cat([kc, k], dim=2) 70 | 71 | attn = (q @ k.transpose(-2, -1)) * self.scale 72 | attn = attn.softmax(dim=-1) 73 | 74 | x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1) 75 | x = self.proj(x) 76 | return x[:, 0] 77 | 78 | 79 | class AttentionPool2d(nn.Module): 80 | """ Attention based 2D feature pooling w/ learned (absolute) pos embedding. 81 | This is a multi-head attention based replacement for (spatial) average pooling in NN architectures. 82 | 83 | It was based on impl in CLIP by OpenAI 84 | https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py 85 | 86 | NOTE: This requires feature size upon construction and well prevent adaptive sizing of the network. 87 | """ 88 | def __init__( 89 | self, 90 | in_features: int, 91 | feat_size: Union[int, Tuple[int, int]], 92 | out_features: int = None, 93 | embed_dim: int = None, 94 | num_heads: int = 4, 95 | qkv_bias: bool = True, 96 | ): 97 | super().__init__() 98 | 99 | embed_dim = embed_dim or in_features 100 | out_features = out_features or in_features 101 | assert embed_dim % num_heads == 0 102 | self.feat_size = to_2tuple(feat_size) 103 | self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) 104 | self.proj = nn.Linear(embed_dim, out_features) 105 | self.num_heads = num_heads 106 | self.head_dim = embed_dim // num_heads 107 | self.scale = self.head_dim ** -0.5 108 | 109 | spatial_dim = self.feat_size[0] * self.feat_size[1] 110 | self.pos_embed = nn.Parameter(torch.zeros(spatial_dim + 1, in_features)) 111 | trunc_normal_(self.pos_embed, std=in_features ** -0.5) 112 | trunc_normal_(self.qkv.weight, std=in_features ** -0.5) 113 | nn.init.zeros_(self.qkv.bias) 114 | 115 | def forward(self, x): 116 | B, _, H, W = x.shape 117 | N = H * W 118 | assert self.feat_size[0] == H 119 | assert self.feat_size[1] == W 120 | x = x.reshape(B, -1, N).permute(0, 2, 1) 121 | x = torch.cat([x.mean(1, keepdim=True), x], dim=1) 122 | x = x + self.pos_embed.unsqueeze(0).to(x.dtype) 123 | 124 | x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 125 | q, k, v = x[0], x[1], x[2] 126 | attn = (q @ k.transpose(-2, -1)) * self.scale 127 | attn = attn.softmax(dim=-1) 128 | 129 | x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1) 130 | x = self.proj(x) 131 | return x[:, 0] 132 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/blur_pool.py: -------------------------------------------------------------------------------- 1 | """ 2 | BlurPool layer inspired by 3 | - Kornia's Max_BlurPool2d 4 | - Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar` 5 | 6 | Hacked together by Chris Ha and Ross Wightman 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import numpy as np 13 | from .padding import get_padding 14 | 15 | 16 | class BlurPool2d(nn.Module): 17 | r"""Creates a module that computes blurs and downsample a given feature map. 18 | See :cite:`zhang2019shiftinvar` for more details. 19 | Corresponds to the Downsample class, which does blurring and subsampling 20 | 21 | Args: 22 | channels = Number of input channels 23 | filt_size (int): binomial filter size for blurring. currently supports 3 (default) and 5. 24 | stride (int): downsampling filter stride 25 | 26 | Returns: 27 | torch.Tensor: the transformed tensor. 28 | """ 29 | def __init__(self, channels, filt_size=3, stride=2) -> None: 30 | super(BlurPool2d, self).__init__() 31 | assert filt_size > 1 32 | self.channels = channels 33 | self.filt_size = filt_size 34 | self.stride = stride 35 | self.padding = [get_padding(filt_size, stride, dilation=1)] * 4 36 | coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs.astype(np.float32)) 37 | blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :].repeat(self.channels, 1, 1, 1) 38 | self.register_buffer('filt', blur_filter, persistent=False) 39 | 40 | def forward(self, x: torch.Tensor) -> torch.Tensor: 41 | x = F.pad(x, self.padding, 'reflect') 42 | return F.conv2d(x, self.filt, stride=self.stride, groups=self.channels) 43 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/bottleneck_attn.py: -------------------------------------------------------------------------------- 1 | """ Bottleneck Self Attention (Bottleneck Transformers) 2 | 3 | Paper: `Bottleneck Transformers for Visual Recognition` - https://arxiv.org/abs/2101.11605 4 | 5 | @misc{2101.11605, 6 | Author = {Aravind Srinivas and Tsung-Yi Lin and Niki Parmar and Jonathon Shlens and Pieter Abbeel and Ashish Vaswani}, 7 | Title = {Bottleneck Transformers for Visual Recognition}, 8 | Year = {2021}, 9 | } 10 | 11 | Based on ref gist at: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 12 | 13 | This impl is a WIP but given that it is based on the ref gist likely not too far off. 14 | 15 | Hacked together by / Copyright 2021 Ross Wightman 16 | """ 17 | from typing import List 18 | 19 | import torch 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | 23 | from .helpers import to_2tuple, make_divisible 24 | from .weight_init import trunc_normal_ 25 | from .trace_utils import _assert 26 | 27 | 28 | def rel_logits_1d(q, rel_k, permute_mask: List[int]): 29 | """ Compute relative logits along one dimension 30 | 31 | As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 32 | Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925 33 | 34 | Args: 35 | q: (batch, heads, height, width, dim) 36 | rel_k: (2 * width - 1, dim) 37 | permute_mask: permute output dim according to this 38 | """ 39 | B, H, W, dim = q.shape 40 | x = (q @ rel_k.transpose(-1, -2)) 41 | x = x.reshape(-1, W, 2 * W -1) 42 | 43 | # pad to shift from relative to absolute indexing 44 | x_pad = F.pad(x, [0, 1]).flatten(1) 45 | x_pad = F.pad(x_pad, [0, W - 1]) 46 | 47 | # reshape and slice out the padded elements 48 | x_pad = x_pad.reshape(-1, W + 1, 2 * W - 1) 49 | x = x_pad[:, :W, W - 1:] 50 | 51 | # reshape and tile 52 | x = x.reshape(B, H, 1, W, W).expand(-1, -1, H, -1, -1) 53 | return x.permute(permute_mask) 54 | 55 | 56 | class PosEmbedRel(nn.Module): 57 | """ Relative Position Embedding 58 | As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 59 | Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925 60 | """ 61 | def __init__(self, feat_size, dim_head, scale): 62 | super().__init__() 63 | self.height, self.width = to_2tuple(feat_size) 64 | self.dim_head = dim_head 65 | self.height_rel = nn.Parameter(torch.randn(self.height * 2 - 1, dim_head) * scale) 66 | self.width_rel = nn.Parameter(torch.randn(self.width * 2 - 1, dim_head) * scale) 67 | 68 | def forward(self, q): 69 | B, HW, _ = q.shape 70 | 71 | # relative logits in width dimension. 72 | q = q.reshape(B, self.height, self.width, -1) 73 | rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4)) 74 | 75 | # relative logits in height dimension. 76 | q = q.transpose(1, 2) 77 | rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2)) 78 | 79 | rel_logits = rel_logits_h + rel_logits_w 80 | rel_logits = rel_logits.reshape(B, HW, HW) 81 | return rel_logits 82 | 83 | 84 | class BottleneckAttn(nn.Module): 85 | """ Bottleneck Attention 86 | Paper: `Bottleneck Transformers for Visual Recognition` - https://arxiv.org/abs/2101.11605 87 | 88 | The internal dimensions of the attention module are controlled by the interaction of several arguments. 89 | * the output dimension of the module is specified by dim_out, which falls back to input dim if not set 90 | * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim 91 | * the query and key (qk) dimensions are determined by 92 | * num_heads * dim_head if dim_head is not None 93 | * num_heads * (dim_out * attn_ratio // num_heads) if dim_head is None 94 | * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not used 95 | 96 | Args: 97 | dim (int): input dimension to the module 98 | dim_out (int): output dimension of the module, same as dim if not set 99 | stride (int): output stride of the module, avg pool used if stride == 2 (default: 1). 100 | num_heads (int): parallel attention heads (default: 4) 101 | dim_head (int): dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set 102 | qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0) 103 | qkv_bias (bool): add bias to q, k, and v projections 104 | scale_pos_embed (bool): scale the position embedding as well as Q @ K 105 | """ 106 | def __init__( 107 | self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=None, 108 | qk_ratio=1.0, qkv_bias=False, scale_pos_embed=False): 109 | super().__init__() 110 | assert feat_size is not None, 'A concrete feature size matching expected input (H, W) is required' 111 | dim_out = dim_out or dim 112 | assert dim_out % num_heads == 0 113 | self.num_heads = num_heads 114 | self.dim_head_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads 115 | self.dim_head_v = dim_out // self.num_heads 116 | self.dim_out_qk = num_heads * self.dim_head_qk 117 | self.dim_out_v = num_heads * self.dim_head_v 118 | self.scale = self.dim_head_qk ** -0.5 119 | self.scale_pos_embed = scale_pos_embed 120 | 121 | self.qkv = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias) 122 | 123 | # NOTE I'm only supporting relative pos embedding for now 124 | self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head_qk, scale=self.scale) 125 | 126 | self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() 127 | 128 | self.reset_parameters() 129 | 130 | def reset_parameters(self): 131 | trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) # fan-in 132 | trunc_normal_(self.pos_embed.height_rel, std=self.scale) 133 | trunc_normal_(self.pos_embed.width_rel, std=self.scale) 134 | 135 | def forward(self, x): 136 | B, C, H, W = x.shape 137 | _assert(H == self.pos_embed.height, '') 138 | _assert(W == self.pos_embed.width, '') 139 | 140 | x = self.qkv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W 141 | 142 | # NOTE head vs channel split ordering in qkv projection was decided before I allowed qk to differ from v 143 | # So, this is more verbose than if heads were before qkv splits, but throughput is not impacted. 144 | q, k, v = torch.split(x, [self.dim_out_qk, self.dim_out_qk, self.dim_out_v], dim=1) 145 | q = q.reshape(B * self.num_heads, self.dim_head_qk, -1).transpose(-1, -2) 146 | k = k.reshape(B * self.num_heads, self.dim_head_qk, -1) # no transpose, for q @ k 147 | v = v.reshape(B * self.num_heads, self.dim_head_v, -1).transpose(-1, -2) 148 | 149 | if self.scale_pos_embed: 150 | attn = (q @ k + self.pos_embed(q)) * self.scale # B * num_heads, H * W, H * W 151 | else: 152 | attn = (q @ k) * self.scale + self.pos_embed(q) 153 | attn = attn.softmax(dim=-1) 154 | 155 | out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) # B, dim_out, H, W 156 | out = self.pool(out) 157 | return out 158 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/cbam.py: -------------------------------------------------------------------------------- 1 | """ CBAM (sort-of) Attention 2 | 3 | Experimental impl of CBAM: Convolutional Block Attention Module: https://arxiv.org/abs/1807.06521 4 | 5 | WARNING: Results with these attention layers have been mixed. They can significantly reduce performance on 6 | some tasks, especially fine-grained it seems. I may end up removing this impl. 7 | 8 | Hacked together by / Copyright 2020 Ross Wightman 9 | """ 10 | import torch 11 | from torch import nn as nn 12 | import torch.nn.functional as F 13 | 14 | from .conv_bn_act import ConvNormAct 15 | from .create_act import create_act_layer, get_act_layer 16 | from .helpers import make_divisible 17 | 18 | 19 | class ChannelAttn(nn.Module): 20 | """ Original CBAM channel attention module, currently avg + max pool variant only. 21 | """ 22 | def __init__( 23 | self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, 24 | act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): 25 | super(ChannelAttn, self).__init__() 26 | if not rd_channels: 27 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) 28 | self.fc1 = nn.Conv2d(channels, rd_channels, 1, bias=mlp_bias) 29 | self.act = act_layer(inplace=True) 30 | self.fc2 = nn.Conv2d(rd_channels, channels, 1, bias=mlp_bias) 31 | self.gate = create_act_layer(gate_layer) 32 | 33 | def forward(self, x): 34 | x_avg = self.fc2(self.act(self.fc1(x.mean((2, 3), keepdim=True)))) 35 | x_max = self.fc2(self.act(self.fc1(x.amax((2, 3), keepdim=True)))) 36 | return x * self.gate(x_avg + x_max) 37 | 38 | 39 | class LightChannelAttn(ChannelAttn): 40 | """An experimental 'lightweight' that sums avg + max pool first 41 | """ 42 | def __init__( 43 | self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, 44 | act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): 45 | super(LightChannelAttn, self).__init__( 46 | channels, rd_ratio, rd_channels, rd_divisor, act_layer, gate_layer, mlp_bias) 47 | 48 | def forward(self, x): 49 | x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * x.amax((2, 3), keepdim=True) 50 | x_attn = self.fc2(self.act(self.fc1(x_pool))) 51 | return x * F.sigmoid(x_attn) 52 | 53 | 54 | class SpatialAttn(nn.Module): 55 | """ Original CBAM spatial attention module 56 | """ 57 | def __init__(self, kernel_size=7, gate_layer='sigmoid'): 58 | super(SpatialAttn, self).__init__() 59 | self.conv = ConvNormAct(2, 1, kernel_size, apply_act=False) 60 | self.gate = create_act_layer(gate_layer) 61 | 62 | def forward(self, x): 63 | x_attn = torch.cat([x.mean(dim=1, keepdim=True), x.amax(dim=1, keepdim=True)], dim=1) 64 | x_attn = self.conv(x_attn) 65 | return x * self.gate(x_attn) 66 | 67 | 68 | class LightSpatialAttn(nn.Module): 69 | """An experimental 'lightweight' variant that sums avg_pool and max_pool results. 70 | """ 71 | def __init__(self, kernel_size=7, gate_layer='sigmoid'): 72 | super(LightSpatialAttn, self).__init__() 73 | self.conv = ConvNormAct(1, 1, kernel_size, apply_act=False) 74 | self.gate = create_act_layer(gate_layer) 75 | 76 | def forward(self, x): 77 | x_attn = 0.5 * x.mean(dim=1, keepdim=True) + 0.5 * x.amax(dim=1, keepdim=True) 78 | x_attn = self.conv(x_attn) 79 | return x * self.gate(x_attn) 80 | 81 | 82 | class CbamModule(nn.Module): 83 | def __init__( 84 | self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, 85 | spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): 86 | super(CbamModule, self).__init__() 87 | self.channel = ChannelAttn( 88 | channels, rd_ratio=rd_ratio, rd_channels=rd_channels, 89 | rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias) 90 | self.spatial = SpatialAttn(spatial_kernel_size, gate_layer=gate_layer) 91 | 92 | def forward(self, x): 93 | x = self.channel(x) 94 | x = self.spatial(x) 95 | return x 96 | 97 | 98 | class LightCbamModule(nn.Module): 99 | def __init__( 100 | self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, 101 | spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): 102 | super(LightCbamModule, self).__init__() 103 | self.channel = LightChannelAttn( 104 | channels, rd_ratio=rd_ratio, rd_channels=rd_channels, 105 | rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias) 106 | self.spatial = LightSpatialAttn(spatial_kernel_size) 107 | 108 | def forward(self, x): 109 | x = self.channel(x) 110 | x = self.spatial(x) 111 | return x 112 | 113 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/classifier.py: -------------------------------------------------------------------------------- 1 | """ Classifier head and layer factory 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from torch import nn as nn 6 | from torch.nn import functional as F 7 | 8 | from .adaptive_avgmax_pool import SelectAdaptivePool2d 9 | 10 | 11 | def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False): 12 | flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling 13 | if not pool_type: 14 | assert num_classes == 0 or use_conv,\ 15 | 'Pooling can only be disabled if classifier is also removed or conv classifier is used' 16 | flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling) 17 | global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten_in_pool) 18 | num_pooled_features = num_features * global_pool.feat_mult() 19 | return global_pool, num_pooled_features 20 | 21 | 22 | def _create_fc(num_features, num_classes, use_conv=False): 23 | if num_classes <= 0: 24 | fc = nn.Identity() # pass-through (no classifier) 25 | elif use_conv: 26 | fc = nn.Conv2d(num_features, num_classes, 1, bias=True) 27 | else: 28 | fc = nn.Linear(num_features, num_classes, bias=True) 29 | return fc 30 | 31 | 32 | def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False): 33 | global_pool, num_pooled_features = _create_pool(num_features, num_classes, pool_type, use_conv=use_conv) 34 | fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) 35 | return global_pool, fc 36 | 37 | 38 | class ClassifierHead(nn.Module): 39 | """Classifier head w/ configurable global pooling and dropout.""" 40 | 41 | def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0., use_conv=False): 42 | super(ClassifierHead, self).__init__() 43 | self.drop_rate = drop_rate 44 | self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv) 45 | self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) 46 | self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity() 47 | 48 | def forward(self, x, pre_logits: bool = False): 49 | x = self.global_pool(x) 50 | if self.drop_rate: 51 | x = F.dropout(x, p=float(self.drop_rate), training=self.training) 52 | if pre_logits: 53 | return x.flatten(1) 54 | else: 55 | x = self.fc(x) 56 | return self.flatten(x) 57 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/cond_conv2d.py: -------------------------------------------------------------------------------- 1 | """ PyTorch Conditionally Parameterized Convolution (CondConv) 2 | 3 | Paper: CondConv: Conditionally Parameterized Convolutions for Efficient Inference 4 | (https://arxiv.org/abs/1904.04971) 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | 9 | import math 10 | from functools import partial 11 | import numpy as np 12 | import torch 13 | from torch import nn as nn 14 | from torch.nn import functional as F 15 | 16 | from .helpers import to_2tuple 17 | from .conv2d_same import conv2d_same 18 | from .padding import get_padding_value 19 | 20 | 21 | def get_condconv_initializer(initializer, num_experts, expert_shape): 22 | def condconv_initializer(weight): 23 | """CondConv initializer function.""" 24 | num_params = np.prod(expert_shape) 25 | if (len(weight.shape) != 2 or weight.shape[0] != num_experts or 26 | weight.shape[1] != num_params): 27 | raise (ValueError( 28 | 'CondConv variables must have shape [num_experts, num_params]')) 29 | for i in range(num_experts): 30 | initializer(weight[i].view(expert_shape)) 31 | return condconv_initializer 32 | 33 | 34 | class CondConv2d(nn.Module): 35 | """ Conditionally Parameterized Convolution 36 | Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py 37 | 38 | Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion: 39 | https://github.com/pytorch/pytorch/issues/17983 40 | """ 41 | __constants__ = ['in_channels', 'out_channels', 'dynamic_padding'] 42 | 43 | def __init__(self, in_channels, out_channels, kernel_size=3, 44 | stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4): 45 | super(CondConv2d, self).__init__() 46 | 47 | self.in_channels = in_channels 48 | self.out_channels = out_channels 49 | self.kernel_size = to_2tuple(kernel_size) 50 | self.stride = to_2tuple(stride) 51 | padding_val, is_padding_dynamic = get_padding_value( 52 | padding, kernel_size, stride=stride, dilation=dilation) 53 | self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript 54 | self.padding = to_2tuple(padding_val) 55 | self.dilation = to_2tuple(dilation) 56 | self.groups = groups 57 | self.num_experts = num_experts 58 | 59 | self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size 60 | weight_num_param = 1 61 | for wd in self.weight_shape: 62 | weight_num_param *= wd 63 | self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param)) 64 | 65 | if bias: 66 | self.bias_shape = (self.out_channels,) 67 | self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels)) 68 | else: 69 | self.register_parameter('bias', None) 70 | 71 | self.reset_parameters() 72 | 73 | def reset_parameters(self): 74 | init_weight = get_condconv_initializer( 75 | partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape) 76 | init_weight(self.weight) 77 | if self.bias is not None: 78 | fan_in = np.prod(self.weight_shape[1:]) 79 | bound = 1 / math.sqrt(fan_in) 80 | init_bias = get_condconv_initializer( 81 | partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape) 82 | init_bias(self.bias) 83 | 84 | def forward(self, x, routing_weights): 85 | B, C, H, W = x.shape 86 | weight = torch.matmul(routing_weights, self.weight) 87 | new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size 88 | weight = weight.view(new_weight_shape) 89 | bias = None 90 | if self.bias is not None: 91 | bias = torch.matmul(routing_weights, self.bias) 92 | bias = bias.view(B * self.out_channels) 93 | # move batch elements with channels so each batch element can be efficiently convolved with separate kernel 94 | # reshape instead of view to work with channels_last input 95 | x = x.reshape(1, B * C, H, W) 96 | if self.dynamic_padding: 97 | out = conv2d_same( 98 | x, weight, bias, stride=self.stride, padding=self.padding, 99 | dilation=self.dilation, groups=self.groups * B) 100 | else: 101 | out = F.conv2d( 102 | x, weight, bias, stride=self.stride, padding=self.padding, 103 | dilation=self.dilation, groups=self.groups * B) 104 | out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1]) 105 | 106 | # Literal port (from TF definition) 107 | # x = torch.split(x, 1, 0) 108 | # weight = torch.split(weight, 1, 0) 109 | # if self.bias is not None: 110 | # bias = torch.matmul(routing_weights, self.bias) 111 | # bias = torch.split(bias, 1, 0) 112 | # else: 113 | # bias = [None] * B 114 | # out = [] 115 | # for xi, wi, bi in zip(x, weight, bias): 116 | # wi = wi.view(*self.weight_shape) 117 | # if bi is not None: 118 | # bi = bi.view(*self.bias_shape) 119 | # out.append(self.conv_fn( 120 | # xi, wi, bi, stride=self.stride, padding=self.padding, 121 | # dilation=self.dilation, groups=self.groups)) 122 | # out = torch.cat(out, 0) 123 | return out 124 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/config.py: -------------------------------------------------------------------------------- 1 | """ Model / Layer Config singleton state 2 | """ 3 | from typing import Any, Optional 4 | 5 | __all__ = [ 6 | 'is_exportable', 'is_scriptable', 'is_no_jit', 7 | 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config' 8 | ] 9 | 10 | # Set to True if prefer to have layers with no jit optimization (includes activations) 11 | _NO_JIT = False 12 | 13 | # Set to True if prefer to have activation layers with no jit optimization 14 | # NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying 15 | # the jit flags so far are activations. This will change as more layers are updated and/or added. 16 | _NO_ACTIVATION_JIT = False 17 | 18 | # Set to True if exporting a model with Same padding via ONNX 19 | _EXPORTABLE = False 20 | 21 | # Set to True if wanting to use torch.jit.script on a model 22 | _SCRIPTABLE = False 23 | 24 | 25 | def is_no_jit(): 26 | return _NO_JIT 27 | 28 | 29 | class set_no_jit: 30 | def __init__(self, mode: bool) -> None: 31 | global _NO_JIT 32 | self.prev = _NO_JIT 33 | _NO_JIT = mode 34 | 35 | def __enter__(self) -> None: 36 | pass 37 | 38 | def __exit__(self, *args: Any) -> bool: 39 | global _NO_JIT 40 | _NO_JIT = self.prev 41 | return False 42 | 43 | 44 | def is_exportable(): 45 | return _EXPORTABLE 46 | 47 | 48 | class set_exportable: 49 | def __init__(self, mode: bool) -> None: 50 | global _EXPORTABLE 51 | self.prev = _EXPORTABLE 52 | _EXPORTABLE = mode 53 | 54 | def __enter__(self) -> None: 55 | pass 56 | 57 | def __exit__(self, *args: Any) -> bool: 58 | global _EXPORTABLE 59 | _EXPORTABLE = self.prev 60 | return False 61 | 62 | 63 | def is_scriptable(): 64 | return _SCRIPTABLE 65 | 66 | 67 | class set_scriptable: 68 | def __init__(self, mode: bool) -> None: 69 | global _SCRIPTABLE 70 | self.prev = _SCRIPTABLE 71 | _SCRIPTABLE = mode 72 | 73 | def __enter__(self) -> None: 74 | pass 75 | 76 | def __exit__(self, *args: Any) -> bool: 77 | global _SCRIPTABLE 78 | _SCRIPTABLE = self.prev 79 | return False 80 | 81 | 82 | class set_layer_config: 83 | """ Layer config context manager that allows setting all layer config flags at once. 84 | If a flag arg is None, it will not change the current value. 85 | """ 86 | def __init__( 87 | self, 88 | scriptable: Optional[bool] = None, 89 | exportable: Optional[bool] = None, 90 | no_jit: Optional[bool] = None, 91 | no_activation_jit: Optional[bool] = None): 92 | global _SCRIPTABLE 93 | global _EXPORTABLE 94 | global _NO_JIT 95 | global _NO_ACTIVATION_JIT 96 | self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT 97 | if scriptable is not None: 98 | _SCRIPTABLE = scriptable 99 | if exportable is not None: 100 | _EXPORTABLE = exportable 101 | if no_jit is not None: 102 | _NO_JIT = no_jit 103 | if no_activation_jit is not None: 104 | _NO_ACTIVATION_JIT = no_activation_jit 105 | 106 | def __enter__(self) -> None: 107 | pass 108 | 109 | def __exit__(self, *args: Any) -> bool: 110 | global _SCRIPTABLE 111 | global _EXPORTABLE 112 | global _NO_JIT 113 | global _NO_ACTIVATION_JIT 114 | _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev 115 | return False 116 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/conv2d_same.py: -------------------------------------------------------------------------------- 1 | """ Conv2d w/ Same Padding 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from typing import Tuple, Optional 9 | 10 | from .padding import pad_same, get_padding_value 11 | 12 | 13 | def conv2d_same( 14 | x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1), 15 | padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1): 16 | x = pad_same(x, weight.shape[-2:], stride, dilation) 17 | return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) 18 | 19 | 20 | class Conv2dSame(nn.Conv2d): 21 | """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions 22 | """ 23 | 24 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 25 | padding=0, dilation=1, groups=1, bias=True): 26 | super(Conv2dSame, self).__init__( 27 | in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) 28 | 29 | def forward(self, x): 30 | return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 31 | 32 | 33 | def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): 34 | padding = kwargs.pop('padding', '') 35 | kwargs.setdefault('bias', False) 36 | padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) 37 | if is_dynamic: 38 | return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) 39 | else: 40 | return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) 41 | 42 | 43 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/conv_bn_act.py: -------------------------------------------------------------------------------- 1 | """ Conv2d + BN + Act 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import functools 6 | from torch import nn as nn 7 | 8 | from .create_conv2d import create_conv2d 9 | from .create_norm_act import get_norm_act_layer 10 | 11 | 12 | class ConvNormAct(nn.Module): 13 | def __init__( 14 | self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1, 15 | bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, drop_layer=None): 16 | super(ConvNormAct, self).__init__() 17 | self.conv = create_conv2d( 18 | in_channels, out_channels, kernel_size, stride=stride, 19 | padding=padding, dilation=dilation, groups=groups, bias=bias) 20 | 21 | # NOTE for backwards compatibility with models that use separate norm and act layer definitions 22 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer) 23 | # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` 24 | norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {} 25 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs) 26 | 27 | @property 28 | def in_channels(self): 29 | return self.conv.in_channels 30 | 31 | @property 32 | def out_channels(self): 33 | return self.conv.out_channels 34 | 35 | def forward(self, x): 36 | x = self.conv(x) 37 | x = self.bn(x) 38 | return x 39 | 40 | 41 | ConvBnAct = ConvNormAct 42 | 43 | 44 | def create_aa(aa_layer, channels, stride=2, enable=True): 45 | if not aa_layer or not enable: 46 | return nn.Identity() 47 | if isinstance(aa_layer, functools.partial): 48 | if issubclass(aa_layer.func, nn.AvgPool2d): 49 | return aa_layer() 50 | else: 51 | return aa_layer(channels) 52 | elif issubclass(aa_layer, nn.AvgPool2d): 53 | return aa_layer(stride) 54 | else: 55 | return aa_layer(channels=channels, stride=stride) 56 | 57 | 58 | class ConvNormActAa(nn.Module): 59 | def __init__( 60 | self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1, 61 | bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None, drop_layer=None): 62 | super(ConvNormActAa, self).__init__() 63 | use_aa = aa_layer is not None and stride == 2 64 | 65 | self.conv = create_conv2d( 66 | in_channels, out_channels, kernel_size, stride=1 if use_aa else stride, 67 | padding=padding, dilation=dilation, groups=groups, bias=bias) 68 | 69 | # NOTE for backwards compatibility with models that use separate norm and act layer definitions 70 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer) 71 | # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` 72 | norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {} 73 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs) 74 | self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa) 75 | 76 | @property 77 | def in_channels(self): 78 | return self.conv.in_channels 79 | 80 | @property 81 | def out_channels(self): 82 | return self.conv.out_channels 83 | 84 | def forward(self, x): 85 | x = self.conv(x) 86 | x = self.bn(x) 87 | x = self.aa(x) 88 | return x 89 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/create_act.py: -------------------------------------------------------------------------------- 1 | """ Activation Factory 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | from typing import Union, Callable, Type 5 | 6 | from .activations import * 7 | from .activations_jit import * 8 | from .activations_me import * 9 | from .config import is_exportable, is_scriptable, is_no_jit 10 | 11 | # PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7. 12 | # Also hardsigmoid, hardswish, and soon mish. This code will use native version if present. 13 | # Eventually, the custom SiLU, Mish, Hard*, layers will be removed and only native variants will be used. 14 | _has_silu = 'silu' in dir(torch.nn.functional) 15 | _has_hardswish = 'hardswish' in dir(torch.nn.functional) 16 | _has_hardsigmoid = 'hardsigmoid' in dir(torch.nn.functional) 17 | _has_mish = 'mish' in dir(torch.nn.functional) 18 | 19 | 20 | _ACT_FN_DEFAULT = dict( 21 | silu=F.silu if _has_silu else swish, 22 | swish=F.silu if _has_silu else swish, 23 | mish=F.mish if _has_mish else mish, 24 | relu=F.relu, 25 | relu6=F.relu6, 26 | leaky_relu=F.leaky_relu, 27 | elu=F.elu, 28 | celu=F.celu, 29 | selu=F.selu, 30 | gelu=gelu, 31 | sigmoid=sigmoid, 32 | tanh=tanh, 33 | hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid, 34 | hard_swish=F.hardswish if _has_hardswish else hard_swish, 35 | hard_mish=hard_mish, 36 | ) 37 | 38 | _ACT_FN_JIT = dict( 39 | silu=F.silu if _has_silu else swish_jit, 40 | swish=F.silu if _has_silu else swish_jit, 41 | mish=F.mish if _has_mish else mish_jit, 42 | hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_jit, 43 | hard_swish=F.hardswish if _has_hardswish else hard_swish_jit, 44 | hard_mish=hard_mish_jit 45 | ) 46 | 47 | _ACT_FN_ME = dict( 48 | silu=F.silu if _has_silu else swish_me, 49 | swish=F.silu if _has_silu else swish_me, 50 | mish=F.mish if _has_mish else mish_me, 51 | hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_me, 52 | hard_swish=F.hardswish if _has_hardswish else hard_swish_me, 53 | hard_mish=hard_mish_me, 54 | ) 55 | 56 | _ACT_FNS = (_ACT_FN_ME, _ACT_FN_JIT, _ACT_FN_DEFAULT) 57 | for a in _ACT_FNS: 58 | a.setdefault('hardsigmoid', a.get('hard_sigmoid')) 59 | a.setdefault('hardswish', a.get('hard_swish')) 60 | 61 | 62 | _ACT_LAYER_DEFAULT = dict( 63 | silu=nn.SiLU if _has_silu else Swish, 64 | swish=nn.SiLU if _has_silu else Swish, 65 | mish=nn.Mish if _has_mish else Mish, 66 | relu=nn.ReLU, 67 | relu6=nn.ReLU6, 68 | leaky_relu=nn.LeakyReLU, 69 | elu=nn.ELU, 70 | prelu=PReLU, 71 | celu=nn.CELU, 72 | selu=nn.SELU, 73 | gelu=GELU, 74 | sigmoid=Sigmoid, 75 | tanh=Tanh, 76 | hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoid, 77 | hard_swish=nn.Hardswish if _has_hardswish else HardSwish, 78 | hard_mish=HardMish, 79 | ) 80 | 81 | _ACT_LAYER_JIT = dict( 82 | silu=nn.SiLU if _has_silu else SwishJit, 83 | swish=nn.SiLU if _has_silu else SwishJit, 84 | mish=nn.Mish if _has_mish else MishJit, 85 | hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidJit, 86 | hard_swish=nn.Hardswish if _has_hardswish else HardSwishJit, 87 | hard_mish=HardMishJit 88 | ) 89 | 90 | _ACT_LAYER_ME = dict( 91 | silu=nn.SiLU if _has_silu else SwishMe, 92 | swish=nn.SiLU if _has_silu else SwishMe, 93 | mish=nn.Mish if _has_mish else MishMe, 94 | hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidMe, 95 | hard_swish=nn.Hardswish if _has_hardswish else HardSwishMe, 96 | hard_mish=HardMishMe, 97 | ) 98 | 99 | _ACT_LAYERS = (_ACT_LAYER_ME, _ACT_LAYER_JIT, _ACT_LAYER_DEFAULT) 100 | for a in _ACT_LAYERS: 101 | a.setdefault('hardsigmoid', a.get('hard_sigmoid')) 102 | a.setdefault('hardswish', a.get('hard_swish')) 103 | 104 | 105 | def get_act_fn(name: Union[Callable, str] = 'relu'): 106 | """ Activation Function Factory 107 | Fetching activation fns by name with this function allows export or torch script friendly 108 | functions to be returned dynamically based on current config. 109 | """ 110 | if not name: 111 | return None 112 | if isinstance(name, Callable): 113 | return name 114 | if not (is_no_jit() or is_exportable() or is_scriptable()): 115 | # If not exporting or scripting the model, first look for a memory-efficient version with 116 | # custom autograd, then fallback 117 | if name in _ACT_FN_ME: 118 | return _ACT_FN_ME[name] 119 | if not (is_no_jit() or is_exportable()): 120 | if name in _ACT_FN_JIT: 121 | return _ACT_FN_JIT[name] 122 | return _ACT_FN_DEFAULT[name] 123 | 124 | 125 | def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'): 126 | """ Activation Layer Factory 127 | Fetching activation layers by name with this function allows export or torch script friendly 128 | functions to be returned dynamically based on current config. 129 | """ 130 | if not name: 131 | return None 132 | if not isinstance(name, str): 133 | # callable, module, etc 134 | return name 135 | if not (is_no_jit() or is_exportable() or is_scriptable()): 136 | if name in _ACT_LAYER_ME: 137 | return _ACT_LAYER_ME[name] 138 | if not (is_no_jit() or is_exportable()): 139 | if name in _ACT_LAYER_JIT: 140 | return _ACT_LAYER_JIT[name] 141 | return _ACT_LAYER_DEFAULT[name] 142 | 143 | 144 | def create_act_layer(name: Union[nn.Module, str], inplace=None, **kwargs): 145 | act_layer = get_act_layer(name) 146 | if act_layer is None: 147 | return None 148 | if inplace is None: 149 | return act_layer(**kwargs) 150 | try: 151 | return act_layer(inplace=inplace, **kwargs) 152 | except TypeError: 153 | # recover if act layer doesn't have inplace arg 154 | return act_layer(**kwargs) 155 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/create_attn.py: -------------------------------------------------------------------------------- 1 | """ Attention Factory 2 | 3 | Hacked together by / Copyright 2021 Ross Wightman 4 | """ 5 | import torch 6 | from functools import partial 7 | 8 | from .bottleneck_attn import BottleneckAttn 9 | from .cbam import CbamModule, LightCbamModule 10 | from .eca import EcaModule, CecaModule 11 | from .gather_excite import GatherExcite 12 | from .global_context import GlobalContext 13 | from .halo_attn import HaloAttn 14 | from .lambda_layer import LambdaLayer 15 | from .non_local_attn import NonLocalAttn, BatNonLocalAttn 16 | from .selective_kernel import SelectiveKernel 17 | from .split_attn import SplitAttn 18 | from .squeeze_excite import SEModule, EffectiveSEModule 19 | 20 | 21 | def get_attn(attn_type): 22 | if isinstance(attn_type, torch.nn.Module): 23 | return attn_type 24 | module_cls = None 25 | if attn_type: 26 | if isinstance(attn_type, str): 27 | attn_type = attn_type.lower() 28 | # Lightweight attention modules (channel and/or coarse spatial). 29 | # Typically added to existing network architecture blocks in addition to existing convolutions. 30 | if attn_type == 'se': 31 | module_cls = SEModule 32 | elif attn_type == 'ese': 33 | module_cls = EffectiveSEModule 34 | elif attn_type == 'eca': 35 | module_cls = EcaModule 36 | elif attn_type == 'ecam': 37 | module_cls = partial(EcaModule, use_mlp=True) 38 | elif attn_type == 'ceca': 39 | module_cls = CecaModule 40 | elif attn_type == 'ge': 41 | module_cls = GatherExcite 42 | elif attn_type == 'gc': 43 | module_cls = GlobalContext 44 | elif attn_type == 'gca': 45 | module_cls = partial(GlobalContext, fuse_add=True, fuse_scale=False) 46 | elif attn_type == 'cbam': 47 | module_cls = CbamModule 48 | elif attn_type == 'lcbam': 49 | module_cls = LightCbamModule 50 | 51 | # Attention / attention-like modules w/ significant params 52 | # Typically replace some of the existing workhorse convs in a network architecture. 53 | # All of these accept a stride argument and can spatially downsample the input. 54 | elif attn_type == 'sk': 55 | module_cls = SelectiveKernel 56 | elif attn_type == 'splat': 57 | module_cls = SplitAttn 58 | 59 | # Self-attention / attention-like modules w/ significant compute and/or params 60 | # Typically replace some of the existing workhorse convs in a network architecture. 61 | # All of these accept a stride argument and can spatially downsample the input. 62 | elif attn_type == 'lambda': 63 | return LambdaLayer 64 | elif attn_type == 'bottleneck': 65 | return BottleneckAttn 66 | elif attn_type == 'halo': 67 | return HaloAttn 68 | elif attn_type == 'nl': 69 | module_cls = NonLocalAttn 70 | elif attn_type == 'bat': 71 | module_cls = BatNonLocalAttn 72 | 73 | # Woops! 74 | else: 75 | assert False, "Invalid attn module (%s)" % attn_type 76 | elif isinstance(attn_type, bool): 77 | if attn_type: 78 | module_cls = SEModule 79 | else: 80 | module_cls = attn_type 81 | return module_cls 82 | 83 | 84 | def create_attn(attn_type, channels, **kwargs): 85 | module_cls = get_attn(attn_type) 86 | if module_cls is not None: 87 | # NOTE: it's expected the first (positional) argument of all attention layers is the # input channels 88 | return module_cls(channels, **kwargs) 89 | return None 90 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/create_conv2d.py: -------------------------------------------------------------------------------- 1 | """ Create Conv2d Factory Method 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | 6 | from .mixed_conv2d import MixedConv2d 7 | from .cond_conv2d import CondConv2d 8 | from .conv2d_same import create_conv2d_pad 9 | 10 | 11 | def create_conv2d(in_channels, out_channels, kernel_size, **kwargs): 12 | """ Select a 2d convolution implementation based on arguments 13 | Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d. 14 | 15 | Used extensively by EfficientNet, MobileNetv3 and related networks. 16 | """ 17 | if isinstance(kernel_size, list): 18 | assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently 19 | if 'groups' in kwargs: 20 | groups = kwargs.pop('groups') 21 | if groups == in_channels: 22 | kwargs['depthwise'] = True 23 | else: 24 | assert groups == 1 25 | # We're going to use only lists for defining the MixedConv2d kernel groups, 26 | # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. 27 | m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs) 28 | else: 29 | depthwise = kwargs.pop('depthwise', False) 30 | # for DW out_channels must be multiple of in_channels as must have out_channels % groups == 0 31 | groups = in_channels if depthwise else kwargs.pop('groups', 1) 32 | if 'num_experts' in kwargs and kwargs['num_experts'] > 0: 33 | m = CondConv2d(in_channels, out_channels, kernel_size, groups=groups, **kwargs) 34 | else: 35 | m = create_conv2d_pad(in_channels, out_channels, kernel_size, groups=groups, **kwargs) 36 | return m 37 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/create_norm.py: -------------------------------------------------------------------------------- 1 | """ Norm Layer Factory 2 | 3 | Create norm modules by string (to mirror create_act and creat_norm-act fns) 4 | 5 | Copyright 2022 Ross Wightman 6 | """ 7 | import types 8 | import functools 9 | 10 | import torch.nn as nn 11 | 12 | from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d 13 | 14 | _NORM_MAP = dict( 15 | batchnorm=nn.BatchNorm2d, 16 | batchnorm2d=nn.BatchNorm2d, 17 | batchnorm1d=nn.BatchNorm1d, 18 | groupnorm=GroupNorm, 19 | groupnorm1=GroupNorm1, 20 | layernorm=LayerNorm, 21 | layernorm2d=LayerNorm2d, 22 | ) 23 | _NORM_TYPES = {m for n, m in _NORM_MAP.items()} 24 | 25 | 26 | def create_norm_layer(layer_name, num_features, act_layer=None, apply_act=True, **kwargs): 27 | layer = get_norm_layer(layer_name, act_layer=act_layer) 28 | layer_instance = layer(num_features, apply_act=apply_act, **kwargs) 29 | return layer_instance 30 | 31 | 32 | def get_norm_layer(norm_layer): 33 | assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) 34 | norm_kwargs = {} 35 | 36 | # unbind partial fn, so args can be rebound later 37 | if isinstance(norm_layer, functools.partial): 38 | norm_kwargs.update(norm_layer.keywords) 39 | norm_layer = norm_layer.func 40 | 41 | if isinstance(norm_layer, str): 42 | layer_name = norm_layer.replace('_', '') 43 | norm_layer = _NORM_MAP.get(layer_name, None) 44 | elif norm_layer in _NORM_TYPES: 45 | norm_layer = norm_layer 46 | elif isinstance(norm_layer, types.FunctionType): 47 | # if function type, assume it is a lambda/fn that creates a norm layer 48 | norm_layer = norm_layer 49 | else: 50 | type_name = norm_layer.__name__.lower().replace('_', '') 51 | norm_layer = _NORM_MAP.get(type_name, None) 52 | assert norm_layer is not None, f"No equivalent norm layer for {type_name}" 53 | 54 | if norm_kwargs: 55 | norm_layer = functools.partial(norm_layer, **norm_kwargs) # bind/rebind args 56 | return norm_layer 57 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/create_norm_act.py: -------------------------------------------------------------------------------- 1 | """ NormAct (Normalizaiton + Activation Layer) Factory 2 | 3 | Create norm + act combo modules that attempt to be backwards compatible with separate norm + act 4 | isntances in models. Where these are used it will be possible to swap separate BN + act layers with 5 | combined modules like IABN or EvoNorms. 6 | 7 | Hacked together by / Copyright 2020 Ross Wightman 8 | """ 9 | import types 10 | import functools 11 | 12 | from .evo_norm import * 13 | from .filter_response_norm import FilterResponseNormAct2d, FilterResponseNormTlu2d 14 | from .norm_act import BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d 15 | from .inplace_abn import InplaceAbn 16 | 17 | _NORM_ACT_MAP = dict( 18 | batchnorm=BatchNormAct2d, 19 | batchnorm2d=BatchNormAct2d, 20 | groupnorm=GroupNormAct, 21 | groupnorm1=functools.partial(GroupNormAct, num_groups=1), 22 | layernorm=LayerNormAct, 23 | layernorm2d=LayerNormAct2d, 24 | evonormb0=EvoNorm2dB0, 25 | evonormb1=EvoNorm2dB1, 26 | evonormb2=EvoNorm2dB2, 27 | evonorms0=EvoNorm2dS0, 28 | evonorms0a=EvoNorm2dS0a, 29 | evonorms1=EvoNorm2dS1, 30 | evonorms1a=EvoNorm2dS1a, 31 | evonorms2=EvoNorm2dS2, 32 | evonorms2a=EvoNorm2dS2a, 33 | frn=FilterResponseNormAct2d, 34 | frntlu=FilterResponseNormTlu2d, 35 | inplaceabn=InplaceAbn, 36 | iabn=InplaceAbn, 37 | ) 38 | _NORM_ACT_TYPES = {m for n, m in _NORM_ACT_MAP.items()} 39 | # has act_layer arg to define act type 40 | _NORM_ACT_REQUIRES_ARG = { 41 | BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d, FilterResponseNormAct2d, InplaceAbn} 42 | 43 | 44 | def create_norm_act_layer(layer_name, num_features, act_layer=None, apply_act=True, jit=False, **kwargs): 45 | layer = get_norm_act_layer(layer_name, act_layer=act_layer) 46 | layer_instance = layer(num_features, apply_act=apply_act, **kwargs) 47 | if jit: 48 | layer_instance = torch.jit.script(layer_instance) 49 | return layer_instance 50 | 51 | 52 | def get_norm_act_layer(norm_layer, act_layer=None): 53 | assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) 54 | assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial)) 55 | norm_act_kwargs = {} 56 | 57 | # unbind partial fn, so args can be rebound later 58 | if isinstance(norm_layer, functools.partial): 59 | norm_act_kwargs.update(norm_layer.keywords) 60 | norm_layer = norm_layer.func 61 | 62 | if isinstance(norm_layer, str): 63 | layer_name = norm_layer.replace('_', '').lower().split('-')[0] 64 | norm_act_layer = _NORM_ACT_MAP.get(layer_name, None) 65 | elif norm_layer in _NORM_ACT_TYPES: 66 | norm_act_layer = norm_layer 67 | elif isinstance(norm_layer, types.FunctionType): 68 | # if function type, must be a lambda/fn that creates a norm_act layer 69 | norm_act_layer = norm_layer 70 | else: 71 | type_name = norm_layer.__name__.lower() 72 | if type_name.startswith('batchnorm'): 73 | norm_act_layer = BatchNormAct2d 74 | elif type_name.startswith('groupnorm'): 75 | norm_act_layer = GroupNormAct 76 | elif type_name.startswith('groupnorm1'): 77 | norm_act_layer = functools.partial(GroupNormAct, num_groups=1) 78 | elif type_name.startswith('layernorm2d'): 79 | norm_act_layer = LayerNormAct2d 80 | elif type_name.startswith('layernorm'): 81 | norm_act_layer = LayerNormAct 82 | else: 83 | assert False, f"No equivalent norm_act layer for {type_name}" 84 | 85 | if norm_act_layer in _NORM_ACT_REQUIRES_ARG: 86 | # pass `act_layer` through for backwards compat where `act_layer=None` implies no activation. 87 | # In the future, may force use of `apply_act` with `act_layer` arg bound to relevant NormAct types 88 | norm_act_kwargs.setdefault('act_layer', act_layer) 89 | if norm_act_kwargs: 90 | norm_act_layer = functools.partial(norm_act_layer, **norm_act_kwargs) # bind/rebind args 91 | return norm_act_layer 92 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/eca.py: -------------------------------------------------------------------------------- 1 | """ 2 | ECA module from ECAnet 3 | 4 | paper: ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks 5 | https://arxiv.org/abs/1910.03151 6 | 7 | Original ECA model borrowed from https://github.com/BangguWu/ECANet 8 | 9 | Modified circular ECA implementation and adaption for use in timm package 10 | by Chris Ha https://github.com/VRandme 11 | 12 | Original License: 13 | 14 | MIT License 15 | 16 | Copyright (c) 2019 BangguWu, Qilong Wang 17 | 18 | Permission is hereby granted, free of charge, to any person obtaining a copy 19 | of this software and associated documentation files (the "Software"), to deal 20 | in the Software without restriction, including without limitation the rights 21 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 22 | copies of the Software, and to permit persons to whom the Software is 23 | furnished to do so, subject to the following conditions: 24 | 25 | The above copyright notice and this permission notice shall be included in all 26 | copies or substantial portions of the Software. 27 | 28 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 29 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 30 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 31 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 32 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 33 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 34 | SOFTWARE. 35 | """ 36 | import math 37 | from torch import nn 38 | import torch.nn.functional as F 39 | 40 | 41 | from .create_act import create_act_layer 42 | from .helpers import make_divisible 43 | 44 | 45 | class EcaModule(nn.Module): 46 | """Constructs an ECA module. 47 | 48 | Args: 49 | channels: Number of channels of the input feature map for use in adaptive kernel sizes 50 | for actual calculations according to channel. 51 | gamma, beta: when channel is given parameters of mapping function 52 | refer to original paper https://arxiv.org/pdf/1910.03151.pdf 53 | (default=None. if channel size not given, use k_size given for kernel size.) 54 | kernel_size: Adaptive selection of kernel size (default=3) 55 | gamm: used in kernel_size calc, see above 56 | beta: used in kernel_size calc, see above 57 | act_layer: optional non-linearity after conv, enables conv bias, this is an experiment 58 | gate_layer: gating non-linearity to use 59 | """ 60 | def __init__( 61 | self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid', 62 | rd_ratio=1/8, rd_channels=None, rd_divisor=8, use_mlp=False): 63 | super(EcaModule, self).__init__() 64 | if channels is not None: 65 | t = int(abs(math.log(channels, 2) + beta) / gamma) 66 | kernel_size = max(t if t % 2 else t + 1, 3) 67 | assert kernel_size % 2 == 1 68 | padding = (kernel_size - 1) // 2 69 | if use_mlp: 70 | # NOTE 'mlp' mode is a timm experiment, not in paper 71 | assert channels is not None 72 | if rd_channels is None: 73 | rd_channels = make_divisible(channels * rd_ratio, divisor=rd_divisor) 74 | act_layer = act_layer or nn.ReLU 75 | self.conv = nn.Conv1d(1, rd_channels, kernel_size=1, padding=0, bias=True) 76 | self.act = create_act_layer(act_layer) 77 | self.conv2 = nn.Conv1d(rd_channels, 1, kernel_size=kernel_size, padding=padding, bias=True) 78 | else: 79 | self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=padding, bias=False) 80 | self.act = None 81 | self.conv2 = None 82 | self.gate = create_act_layer(gate_layer) 83 | 84 | def forward(self, x): 85 | y = x.mean((2, 3)).view(x.shape[0], 1, -1) # view for 1d conv 86 | y = self.conv(y) 87 | if self.conv2 is not None: 88 | y = self.act(y) 89 | y = self.conv2(y) 90 | y = self.gate(y).view(x.shape[0], -1, 1, 1) 91 | return x * y.expand_as(x) 92 | 93 | 94 | EfficientChannelAttn = EcaModule # alias 95 | 96 | 97 | class CecaModule(nn.Module): 98 | """Constructs a circular ECA module. 99 | 100 | ECA module where the conv uses circular padding rather than zero padding. 101 | Unlike the spatial dimension, the channels do not have inherent ordering nor 102 | locality. Although this module in essence, applies such an assumption, it is unnecessary 103 | to limit the channels on either "edge" from being circularly adapted to each other. 104 | This will fundamentally increase connectivity and possibly increase performance metrics 105 | (accuracy, robustness), without significantly impacting resource metrics 106 | (parameter size, throughput,latency, etc) 107 | 108 | Args: 109 | channels: Number of channels of the input feature map for use in adaptive kernel sizes 110 | for actual calculations according to channel. 111 | gamma, beta: when channel is given parameters of mapping function 112 | refer to original paper https://arxiv.org/pdf/1910.03151.pdf 113 | (default=None. if channel size not given, use k_size given for kernel size.) 114 | kernel_size: Adaptive selection of kernel size (default=3) 115 | gamm: used in kernel_size calc, see above 116 | beta: used in kernel_size calc, see above 117 | act_layer: optional non-linearity after conv, enables conv bias, this is an experiment 118 | gate_layer: gating non-linearity to use 119 | """ 120 | 121 | def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid'): 122 | super(CecaModule, self).__init__() 123 | if channels is not None: 124 | t = int(abs(math.log(channels, 2) + beta) / gamma) 125 | kernel_size = max(t if t % 2 else t + 1, 3) 126 | has_act = act_layer is not None 127 | assert kernel_size % 2 == 1 128 | 129 | # PyTorch circular padding mode is buggy as of pytorch 1.4 130 | # see https://github.com/pytorch/pytorch/pull/17240 131 | # implement manual circular padding 132 | self.padding = (kernel_size - 1) // 2 133 | self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=has_act) 134 | self.gate = create_act_layer(gate_layer) 135 | 136 | def forward(self, x): 137 | y = x.mean((2, 3)).view(x.shape[0], 1, -1) 138 | # Manually implement circular padding, F.pad does not seemed to be bugged 139 | y = F.pad(y, (self.padding, self.padding), mode='circular') 140 | y = self.conv(y) 141 | y = self.gate(y).view(x.shape[0], -1, 1, 1) 142 | return x * y.expand_as(x) 143 | 144 | 145 | CircularEfficientChannelAttn = CecaModule 146 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/fast_norm.py: -------------------------------------------------------------------------------- 1 | """ 'Fast' Normalization Functions 2 | 3 | For GroupNorm and LayerNorm these functions bypass typical AMP upcast to float32. 4 | 5 | Additionally, for LayerNorm, the APEX fused LN is used if available (which also does not upcast) 6 | 7 | Hacked together by / Copyright 2022 Ross Wightman 8 | """ 9 | from typing import List, Optional 10 | 11 | import torch 12 | from torch.nn import functional as F 13 | 14 | try: 15 | from apex.normalization.fused_layer_norm import fused_layer_norm_affine 16 | has_apex = True 17 | except ImportError: 18 | has_apex = False 19 | 20 | 21 | # fast (ie lower precision LN) can be disabled with this flag if issues crop up 22 | _USE_FAST_NORM = False # defaulting to False for now 23 | 24 | 25 | def is_fast_norm(): 26 | return _USE_FAST_NORM 27 | 28 | 29 | def set_fast_norm(enable=True): 30 | global _USE_FAST_NORM 31 | _USE_FAST_NORM = enable 32 | 33 | 34 | def fast_group_norm( 35 | x: torch.Tensor, 36 | num_groups: int, 37 | weight: Optional[torch.Tensor] = None, 38 | bias: Optional[torch.Tensor] = None, 39 | eps: float = 1e-5 40 | ) -> torch.Tensor: 41 | if torch.jit.is_scripting(): 42 | # currently cannot use is_autocast_enabled within torchscript 43 | return F.group_norm(x, num_groups, weight, bias, eps) 44 | 45 | if torch.is_autocast_enabled(): 46 | # normally native AMP casts GN inputs to float32 47 | # here we use the low precision autocast dtype 48 | # FIXME what to do re CPU autocast? 49 | dt = torch.get_autocast_gpu_dtype() 50 | x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) 51 | 52 | with torch.cuda.amp.autocast(enabled=False): 53 | return F.group_norm(x, num_groups, weight, bias, eps) 54 | 55 | 56 | def fast_layer_norm( 57 | x: torch.Tensor, 58 | normalized_shape: List[int], 59 | weight: Optional[torch.Tensor] = None, 60 | bias: Optional[torch.Tensor] = None, 61 | eps: float = 1e-5 62 | ) -> torch.Tensor: 63 | if torch.jit.is_scripting(): 64 | # currently cannot use is_autocast_enabled within torchscript 65 | return F.layer_norm(x, normalized_shape, weight, bias, eps) 66 | 67 | if has_apex: 68 | return fused_layer_norm_affine(x, weight, bias, normalized_shape, eps) 69 | 70 | if torch.is_autocast_enabled(): 71 | # normally native AMP casts LN inputs to float32 72 | # apex LN does not, this is behaving like Apex 73 | dt = torch.get_autocast_gpu_dtype() 74 | # FIXME what to do re CPU autocast? 75 | x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) 76 | 77 | with torch.cuda.amp.autocast(enabled=False): 78 | return F.layer_norm(x, normalized_shape, weight, bias, eps) 79 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/filter_response_norm.py: -------------------------------------------------------------------------------- 1 | """ Filter Response Norm in PyTorch 2 | 3 | Based on `Filter Response Normalization Layer` - https://arxiv.org/abs/1911.09737 4 | 5 | Hacked together by / Copyright 2021 Ross Wightman 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | 10 | from .create_act import create_act_layer 11 | from .trace_utils import _assert 12 | 13 | 14 | def inv_instance_rms(x, eps: float = 1e-5): 15 | rms = x.square().float().mean(dim=(2, 3), keepdim=True).add(eps).rsqrt().to(x.dtype) 16 | return rms.expand(x.shape) 17 | 18 | 19 | class FilterResponseNormTlu2d(nn.Module): 20 | def __init__(self, num_features, apply_act=True, eps=1e-5, rms=True, **_): 21 | super(FilterResponseNormTlu2d, self).__init__() 22 | self.apply_act = apply_act # apply activation (non-linearity) 23 | self.rms = rms 24 | self.eps = eps 25 | self.weight = nn.Parameter(torch.ones(num_features)) 26 | self.bias = nn.Parameter(torch.zeros(num_features)) 27 | self.tau = nn.Parameter(torch.zeros(num_features)) if apply_act else None 28 | self.reset_parameters() 29 | 30 | def reset_parameters(self): 31 | nn.init.ones_(self.weight) 32 | nn.init.zeros_(self.bias) 33 | if self.tau is not None: 34 | nn.init.zeros_(self.tau) 35 | 36 | def forward(self, x): 37 | _assert(x.dim() == 4, 'expected 4D input') 38 | x_dtype = x.dtype 39 | v_shape = (1, -1, 1, 1) 40 | x = x * inv_instance_rms(x, self.eps) 41 | x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) 42 | return torch.maximum(x, self.tau.reshape(v_shape).to(dtype=x_dtype)) if self.tau is not None else x 43 | 44 | 45 | class FilterResponseNormAct2d(nn.Module): 46 | def __init__(self, num_features, apply_act=True, act_layer=nn.ReLU, inplace=None, rms=True, eps=1e-5, **_): 47 | super(FilterResponseNormAct2d, self).__init__() 48 | if act_layer is not None and apply_act: 49 | self.act = create_act_layer(act_layer, inplace=inplace) 50 | else: 51 | self.act = nn.Identity() 52 | self.rms = rms 53 | self.eps = eps 54 | self.weight = nn.Parameter(torch.ones(num_features)) 55 | self.bias = nn.Parameter(torch.zeros(num_features)) 56 | self.reset_parameters() 57 | 58 | def reset_parameters(self): 59 | nn.init.ones_(self.weight) 60 | nn.init.zeros_(self.bias) 61 | 62 | def forward(self, x): 63 | _assert(x.dim() == 4, 'expected 4D input') 64 | x_dtype = x.dtype 65 | v_shape = (1, -1, 1, 1) 66 | x = x * inv_instance_rms(x, self.eps) 67 | x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) 68 | return self.act(x) 69 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/gather_excite.py: -------------------------------------------------------------------------------- 1 | """ Gather-Excite Attention Block 2 | 3 | Paper: `Gather-Excite: Exploiting Feature Context in CNNs` - https://arxiv.org/abs/1810.12348 4 | 5 | Official code here, but it's only partial impl in Caffe: https://github.com/hujie-frank/GENet 6 | 7 | I've tried to support all of the extent both w/ and w/o params. I don't believe I've seen another 8 | impl that covers all of the cases. 9 | 10 | NOTE: extent=0 + extra_params=False is equivalent to Squeeze-and-Excitation 11 | 12 | Hacked together by / Copyright 2021 Ross Wightman 13 | """ 14 | import math 15 | 16 | from torch import nn as nn 17 | import torch.nn.functional as F 18 | 19 | from .create_act import create_act_layer, get_act_layer 20 | from .create_conv2d import create_conv2d 21 | from .helpers import make_divisible 22 | from .mlp import ConvMlp 23 | 24 | 25 | class GatherExcite(nn.Module): 26 | """ Gather-Excite Attention Module 27 | """ 28 | def __init__( 29 | self, channels, feat_size=None, extra_params=False, extent=0, use_mlp=True, 30 | rd_ratio=1./16, rd_channels=None, rd_divisor=1, add_maxpool=False, 31 | act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, gate_layer='sigmoid'): 32 | super(GatherExcite, self).__init__() 33 | self.add_maxpool = add_maxpool 34 | act_layer = get_act_layer(act_layer) 35 | self.extent = extent 36 | if extra_params: 37 | self.gather = nn.Sequential() 38 | if extent == 0: 39 | assert feat_size is not None, 'spatial feature size must be specified for global extent w/ params' 40 | self.gather.add_module( 41 | 'conv1', create_conv2d(channels, channels, kernel_size=feat_size, stride=1, depthwise=True)) 42 | if norm_layer: 43 | self.gather.add_module(f'norm1', nn.BatchNorm2d(channels)) 44 | else: 45 | assert extent % 2 == 0 46 | num_conv = int(math.log2(extent)) 47 | for i in range(num_conv): 48 | self.gather.add_module( 49 | f'conv{i + 1}', 50 | create_conv2d(channels, channels, kernel_size=3, stride=2, depthwise=True)) 51 | if norm_layer: 52 | self.gather.add_module(f'norm{i + 1}', nn.BatchNorm2d(channels)) 53 | if i != num_conv - 1: 54 | self.gather.add_module(f'act{i + 1}', act_layer(inplace=True)) 55 | else: 56 | self.gather = None 57 | if self.extent == 0: 58 | self.gk = 0 59 | self.gs = 0 60 | else: 61 | assert extent % 2 == 0 62 | self.gk = self.extent * 2 - 1 63 | self.gs = self.extent 64 | 65 | if not rd_channels: 66 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) 67 | self.mlp = ConvMlp(channels, rd_channels, act_layer=act_layer) if use_mlp else nn.Identity() 68 | self.gate = create_act_layer(gate_layer) 69 | 70 | def forward(self, x): 71 | size = x.shape[-2:] 72 | if self.gather is not None: 73 | x_ge = self.gather(x) 74 | else: 75 | if self.extent == 0: 76 | # global extent 77 | x_ge = x.mean(dim=(2, 3), keepdims=True) 78 | if self.add_maxpool: 79 | # experimental codepath, may remove or change 80 | x_ge = 0.5 * x_ge + 0.5 * x.amax((2, 3), keepdim=True) 81 | else: 82 | x_ge = F.avg_pool2d( 83 | x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2, count_include_pad=False) 84 | if self.add_maxpool: 85 | # experimental codepath, may remove or change 86 | x_ge = 0.5 * x_ge + 0.5 * F.max_pool2d(x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2) 87 | x_ge = self.mlp(x_ge) 88 | if x_ge.shape[-1] != 1 or x_ge.shape[-2] != 1: 89 | x_ge = F.interpolate(x_ge, size=size) 90 | return x * self.gate(x_ge) 91 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/global_context.py: -------------------------------------------------------------------------------- 1 | """ Global Context Attention Block 2 | 3 | Paper: `GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond` 4 | - https://arxiv.org/abs/1904.11492 5 | 6 | Official code consulted as reference: https://github.com/xvjiarui/GCNet 7 | 8 | Hacked together by / Copyright 2021 Ross Wightman 9 | """ 10 | from torch import nn as nn 11 | import torch.nn.functional as F 12 | 13 | from .create_act import create_act_layer, get_act_layer 14 | from .helpers import make_divisible 15 | from .mlp import ConvMlp 16 | from .norm import LayerNorm2d 17 | 18 | 19 | class GlobalContext(nn.Module): 20 | 21 | def __init__(self, channels, use_attn=True, fuse_add=False, fuse_scale=True, init_last_zero=False, 22 | rd_ratio=1./8, rd_channels=None, rd_divisor=1, act_layer=nn.ReLU, gate_layer='sigmoid'): 23 | super(GlobalContext, self).__init__() 24 | act_layer = get_act_layer(act_layer) 25 | 26 | self.conv_attn = nn.Conv2d(channels, 1, kernel_size=1, bias=True) if use_attn else None 27 | 28 | if rd_channels is None: 29 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) 30 | if fuse_add: 31 | self.mlp_add = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d) 32 | else: 33 | self.mlp_add = None 34 | if fuse_scale: 35 | self.mlp_scale = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d) 36 | else: 37 | self.mlp_scale = None 38 | 39 | self.gate = create_act_layer(gate_layer) 40 | self.init_last_zero = init_last_zero 41 | self.reset_parameters() 42 | 43 | def reset_parameters(self): 44 | if self.conv_attn is not None: 45 | nn.init.kaiming_normal_(self.conv_attn.weight, mode='fan_in', nonlinearity='relu') 46 | if self.mlp_add is not None: 47 | nn.init.zeros_(self.mlp_add.fc2.weight) 48 | 49 | def forward(self, x): 50 | B, C, H, W = x.shape 51 | 52 | if self.conv_attn is not None: 53 | attn = self.conv_attn(x).reshape(B, 1, H * W) # (B, 1, H * W) 54 | attn = F.softmax(attn, dim=-1).unsqueeze(3) # (B, 1, H * W, 1) 55 | context = x.reshape(B, C, H * W).unsqueeze(1) @ attn 56 | context = context.view(B, C, 1, 1) 57 | else: 58 | context = x.mean(dim=(2, 3), keepdim=True) 59 | 60 | if self.mlp_scale is not None: 61 | mlp_x = self.mlp_scale(context) 62 | x = x * self.gate(mlp_x) 63 | if self.mlp_add is not None: 64 | mlp_x = self.mlp_add(context) 65 | x = x + mlp_x 66 | 67 | return x 68 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/helpers.py: -------------------------------------------------------------------------------- 1 | """ Layer/Module Helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from itertools import repeat 6 | import collections.abc 7 | 8 | 9 | # From PyTorch internals 10 | def _ntuple(n): 11 | def parse(x): 12 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): 13 | return x 14 | return tuple(repeat(x, n)) 15 | return parse 16 | 17 | 18 | to_1tuple = _ntuple(1) 19 | to_2tuple = _ntuple(2) 20 | to_3tuple = _ntuple(3) 21 | to_4tuple = _ntuple(4) 22 | to_ntuple = _ntuple 23 | 24 | 25 | def make_divisible(v, divisor=8, min_value=None, round_limit=.9): 26 | min_value = min_value or divisor 27 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 28 | # Make sure that round down does not go down by more than 10%. 29 | if new_v < round_limit * v: 30 | new_v += divisor 31 | return new_v 32 | 33 | 34 | def extend_tuple(x, n): 35 | # pdas a tuple to specified n by padding with last value 36 | if not isinstance(x, (tuple, list)): 37 | x = (x,) 38 | else: 39 | x = tuple(x) 40 | pad_n = n - len(x) 41 | if pad_n <= 0: 42 | return x[:n] 43 | return x + (x[-1],) * pad_n 44 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/inplace_abn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | 4 | try: 5 | from inplace_abn.functions import inplace_abn, inplace_abn_sync 6 | has_iabn = True 7 | except ImportError: 8 | has_iabn = False 9 | 10 | def inplace_abn(x, weight, bias, running_mean, running_var, 11 | training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01): 12 | raise ImportError( 13 | "Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.12'") 14 | 15 | def inplace_abn_sync(**kwargs): 16 | inplace_abn(**kwargs) 17 | 18 | 19 | class InplaceAbn(nn.Module): 20 | """Activated Batch Normalization 21 | 22 | This gathers a BatchNorm and an activation function in a single module 23 | 24 | Parameters 25 | ---------- 26 | num_features : int 27 | Number of feature channels in the input and output. 28 | eps : float 29 | Small constant to prevent numerical issues. 30 | momentum : float 31 | Momentum factor applied to compute running statistics. 32 | affine : bool 33 | If `True` apply learned scale and shift transformation after normalization. 34 | act_layer : str or nn.Module type 35 | Name or type of the activation functions, one of: `leaky_relu`, `elu` 36 | act_param : float 37 | Negative slope for the `leaky_relu` activation. 38 | """ 39 | 40 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, apply_act=True, 41 | act_layer="leaky_relu", act_param=0.01, drop_layer=None): 42 | super(InplaceAbn, self).__init__() 43 | self.num_features = num_features 44 | self.affine = affine 45 | self.eps = eps 46 | self.momentum = momentum 47 | if apply_act: 48 | if isinstance(act_layer, str): 49 | assert act_layer in ('leaky_relu', 'elu', 'identity', '') 50 | self.act_name = act_layer if act_layer else 'identity' 51 | else: 52 | # convert act layer passed as type to string 53 | if act_layer == nn.ELU: 54 | self.act_name = 'elu' 55 | elif act_layer == nn.LeakyReLU: 56 | self.act_name = 'leaky_relu' 57 | elif act_layer is None or act_layer == nn.Identity: 58 | self.act_name = 'identity' 59 | else: 60 | assert False, f'Invalid act layer {act_layer.__name__} for IABN' 61 | else: 62 | self.act_name = 'identity' 63 | self.act_param = act_param 64 | if self.affine: 65 | self.weight = nn.Parameter(torch.ones(num_features)) 66 | self.bias = nn.Parameter(torch.zeros(num_features)) 67 | else: 68 | self.register_parameter('weight', None) 69 | self.register_parameter('bias', None) 70 | self.register_buffer('running_mean', torch.zeros(num_features)) 71 | self.register_buffer('running_var', torch.ones(num_features)) 72 | self.reset_parameters() 73 | 74 | def reset_parameters(self): 75 | nn.init.constant_(self.running_mean, 0) 76 | nn.init.constant_(self.running_var, 1) 77 | if self.affine: 78 | nn.init.constant_(self.weight, 1) 79 | nn.init.constant_(self.bias, 0) 80 | 81 | def forward(self, x): 82 | output = inplace_abn( 83 | x, self.weight, self.bias, self.running_mean, self.running_var, 84 | self.training, self.momentum, self.eps, self.act_name, self.act_param) 85 | if isinstance(output, tuple): 86 | output = output[0] 87 | return output 88 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/lambda_layer.py: -------------------------------------------------------------------------------- 1 | """ Lambda Layer 2 | 3 | Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention` 4 | - https://arxiv.org/abs/2102.08602 5 | 6 | @misc{2102.08602, 7 | Author = {Irwan Bello}, 8 | Title = {LambdaNetworks: Modeling Long-Range Interactions Without Attention}, 9 | Year = {2021}, 10 | } 11 | 12 | Status: 13 | This impl is a WIP. Code snippets in the paper were used as reference but 14 | good chance some details are missing/wrong. 15 | 16 | I've only implemented local lambda conv based pos embeddings. 17 | 18 | For a PyTorch impl that includes other embedding options checkout 19 | https://github.com/lucidrains/lambda-networks 20 | 21 | Hacked together by / Copyright 2021 Ross Wightman 22 | """ 23 | import torch 24 | from torch import nn 25 | import torch.nn.functional as F 26 | 27 | from .helpers import to_2tuple, make_divisible 28 | from .weight_init import trunc_normal_ 29 | 30 | 31 | def rel_pos_indices(size): 32 | size = to_2tuple(size) 33 | pos = torch.stack(torch.meshgrid(torch.arange(size[0]), torch.arange(size[1]))).flatten(1) 34 | rel_pos = pos[:, None, :] - pos[:, :, None] 35 | rel_pos[0] += size[0] - 1 36 | rel_pos[1] += size[1] - 1 37 | return rel_pos # 2, H * W, H * W 38 | 39 | 40 | class LambdaLayer(nn.Module): 41 | """Lambda Layer 42 | 43 | Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention` 44 | - https://arxiv.org/abs/2102.08602 45 | 46 | NOTE: intra-depth parameter 'u' is fixed at 1. It did not appear worth the complexity to add. 47 | 48 | The internal dimensions of the lambda module are controlled via the interaction of several arguments. 49 | * the output dimension of the module is specified by dim_out, which falls back to input dim if not set 50 | * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim 51 | * the query (q) and key (k) dimension are determined by 52 | * dim_head = (dim_out * attn_ratio // num_heads) if dim_head is None 53 | * q = num_heads * dim_head, k = dim_head 54 | * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not set 55 | 56 | Args: 57 | dim (int): input dimension to the module 58 | dim_out (int): output dimension of the module, same as dim if not set 59 | feat_size (Tuple[int, int]): size of input feature_map for relative pos variant H, W 60 | stride (int): output stride of the module, avg pool used if stride == 2 61 | num_heads (int): parallel attention heads. 62 | dim_head (int): dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set 63 | r (int): local lambda convolution radius. Use lambda conv if set, else relative pos if not. (default: 9) 64 | qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0) 65 | qkv_bias (bool): add bias to q, k, and v projections 66 | """ 67 | def __init__( 68 | self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=16, r=9, 69 | qk_ratio=1.0, qkv_bias=False): 70 | super().__init__() 71 | dim_out = dim_out or dim 72 | assert dim_out % num_heads == 0, ' should be divided by num_heads' 73 | self.dim_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads 74 | self.num_heads = num_heads 75 | self.dim_v = dim_out // num_heads 76 | 77 | self.qkv = nn.Conv2d( 78 | dim, 79 | num_heads * self.dim_qk + self.dim_qk + self.dim_v, 80 | kernel_size=1, bias=qkv_bias) 81 | self.norm_q = nn.BatchNorm2d(num_heads * self.dim_qk) 82 | self.norm_v = nn.BatchNorm2d(self.dim_v) 83 | 84 | if r is not None: 85 | # local lambda convolution for pos 86 | self.conv_lambda = nn.Conv3d(1, self.dim_qk, (r, r, 1), padding=(r // 2, r // 2, 0)) 87 | self.pos_emb = None 88 | self.rel_pos_indices = None 89 | else: 90 | # relative pos embedding 91 | assert feat_size is not None 92 | feat_size = to_2tuple(feat_size) 93 | rel_size = [2 * s - 1 for s in feat_size] 94 | self.conv_lambda = None 95 | self.pos_emb = nn.Parameter(torch.zeros(rel_size[0], rel_size[1], self.dim_qk)) 96 | self.register_buffer('rel_pos_indices', rel_pos_indices(feat_size), persistent=False) 97 | 98 | self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() 99 | 100 | self.reset_parameters() 101 | 102 | def reset_parameters(self): 103 | trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) # fan-in 104 | if self.conv_lambda is not None: 105 | trunc_normal_(self.conv_lambda.weight, std=self.dim_qk ** -0.5) 106 | if self.pos_emb is not None: 107 | trunc_normal_(self.pos_emb, std=.02) 108 | 109 | def forward(self, x): 110 | B, C, H, W = x.shape 111 | M = H * W 112 | qkv = self.qkv(x) 113 | q, k, v = torch.split(qkv, [ 114 | self.num_heads * self.dim_qk, self.dim_qk, self.dim_v], dim=1) 115 | q = self.norm_q(q).reshape(B, self.num_heads, self.dim_qk, M).transpose(-1, -2) # B, num_heads, M, K 116 | v = self.norm_v(v).reshape(B, self.dim_v, M).transpose(-1, -2) # B, M, V 117 | k = F.softmax(k.reshape(B, self.dim_qk, M), dim=-1) # B, K, M 118 | 119 | content_lam = k @ v # B, K, V 120 | content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V 121 | 122 | if self.pos_emb is None: 123 | position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K 124 | position_lam = position_lam.reshape(B, 1, self.dim_qk, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V 125 | else: 126 | # FIXME relative pos embedding path not fully verified 127 | pos_emb = self.pos_emb[self.rel_pos_indices[0], self.rel_pos_indices[1]].expand(B, -1, -1, -1) 128 | position_lam = (pos_emb.transpose(-1, -2) @ v.unsqueeze(1)).unsqueeze(1) # B, 1, M, K, V 129 | position_out = (q.unsqueeze(-2) @ position_lam).squeeze(-2) # B, num_heads, M, V 130 | 131 | out = (content_out + position_out).transpose(-1, -2).reshape(B, C, H, W) # B, C (num_heads * V), H, W 132 | out = self.pool(out) 133 | return out 134 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/linear.py: -------------------------------------------------------------------------------- 1 | """ Linear layer (alternate definition) 2 | """ 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn as nn 6 | 7 | 8 | class Linear(nn.Linear): 9 | r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b` 10 | 11 | Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting 12 | weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case. 13 | """ 14 | def forward(self, input: torch.Tensor) -> torch.Tensor: 15 | if torch.jit.is_scripting(): 16 | bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None 17 | return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias) 18 | else: 19 | return F.linear(input, self.weight, self.bias) 20 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/median_pool.py: -------------------------------------------------------------------------------- 1 | """ Median Pool 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from .helpers import to_2tuple, to_4tuple 7 | 8 | 9 | class MedianPool2d(nn.Module): 10 | """ Median pool (usable as median filter when stride=1) module. 11 | 12 | Args: 13 | kernel_size: size of pooling kernel, int or 2-tuple 14 | stride: pool stride, int or 2-tuple 15 | padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad 16 | same: override padding and enforce same padding, boolean 17 | """ 18 | def __init__(self, kernel_size=3, stride=1, padding=0, same=False): 19 | super(MedianPool2d, self).__init__() 20 | self.k = to_2tuple(kernel_size) 21 | self.stride = to_2tuple(stride) 22 | self.padding = to_4tuple(padding) # convert to l, r, t, b 23 | self.same = same 24 | 25 | def _padding(self, x): 26 | if self.same: 27 | ih, iw = x.size()[2:] 28 | if ih % self.stride[0] == 0: 29 | ph = max(self.k[0] - self.stride[0], 0) 30 | else: 31 | ph = max(self.k[0] - (ih % self.stride[0]), 0) 32 | if iw % self.stride[1] == 0: 33 | pw = max(self.k[1] - self.stride[1], 0) 34 | else: 35 | pw = max(self.k[1] - (iw % self.stride[1]), 0) 36 | pl = pw // 2 37 | pr = pw - pl 38 | pt = ph // 2 39 | pb = ph - pt 40 | padding = (pl, pr, pt, pb) 41 | else: 42 | padding = self.padding 43 | return padding 44 | 45 | def forward(self, x): 46 | x = F.pad(x, self._padding(x), mode='reflect') 47 | x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1]) 48 | x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0] 49 | return x 50 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/mixed_conv2d.py: -------------------------------------------------------------------------------- 1 | """ PyTorch Mixed Convolution 2 | 3 | Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595) 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | 8 | import torch 9 | from torch import nn as nn 10 | 11 | from .conv2d_same import create_conv2d_pad 12 | 13 | 14 | def _split_channels(num_chan, num_groups): 15 | split = [num_chan // num_groups for _ in range(num_groups)] 16 | split[0] += num_chan - sum(split) 17 | return split 18 | 19 | 20 | class MixedConv2d(nn.ModuleDict): 21 | """ Mixed Grouped Convolution 22 | 23 | Based on MDConv and GroupedConv in MixNet impl: 24 | https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py 25 | """ 26 | def __init__(self, in_channels, out_channels, kernel_size=3, 27 | stride=1, padding='', dilation=1, depthwise=False, **kwargs): 28 | super(MixedConv2d, self).__init__() 29 | 30 | kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] 31 | num_groups = len(kernel_size) 32 | in_splits = _split_channels(in_channels, num_groups) 33 | out_splits = _split_channels(out_channels, num_groups) 34 | self.in_channels = sum(in_splits) 35 | self.out_channels = sum(out_splits) 36 | for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): 37 | conv_groups = in_ch if depthwise else 1 38 | # use add_module to keep key space clean 39 | self.add_module( 40 | str(idx), 41 | create_conv2d_pad( 42 | in_ch, out_ch, k, stride=stride, 43 | padding=padding, dilation=dilation, groups=conv_groups, **kwargs) 44 | ) 45 | self.splits = in_splits 46 | 47 | def forward(self, x): 48 | x_split = torch.split(x, self.splits, 1) 49 | x_out = [c(x_split[i]) for i, c in enumerate(self.values())] 50 | x = torch.cat(x_out, 1) 51 | return x 52 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/mlp.py: -------------------------------------------------------------------------------- 1 | """ MLP module w/ dropout and configurable activation layer 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from torch import nn as nn 6 | 7 | from .helpers import to_2tuple 8 | 9 | 10 | class Mlp(nn.Module): 11 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks 12 | """ 13 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.): 14 | super().__init__() 15 | out_features = out_features or in_features 16 | hidden_features = hidden_features or in_features 17 | bias = to_2tuple(bias) 18 | drop_probs = to_2tuple(drop) 19 | 20 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) 21 | self.act = act_layer() 22 | self.drop1 = nn.Dropout(drop_probs[0]) 23 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) 24 | self.drop2 = nn.Dropout(drop_probs[1]) 25 | 26 | def forward(self, x): 27 | x = self.fc1(x) 28 | x = self.act(x) 29 | x = self.drop1(x) 30 | x = self.fc2(x) 31 | x = self.drop2(x) 32 | return x 33 | 34 | 35 | class GluMlp(nn.Module): 36 | """ MLP w/ GLU style gating 37 | See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202 38 | """ 39 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, bias=True, drop=0.): 40 | super().__init__() 41 | out_features = out_features or in_features 42 | hidden_features = hidden_features or in_features 43 | assert hidden_features % 2 == 0 44 | bias = to_2tuple(bias) 45 | drop_probs = to_2tuple(drop) 46 | 47 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) 48 | self.act = act_layer() 49 | self.drop1 = nn.Dropout(drop_probs[0]) 50 | self.fc2 = nn.Linear(hidden_features // 2, out_features, bias=bias[1]) 51 | self.drop2 = nn.Dropout(drop_probs[1]) 52 | 53 | def init_weights(self): 54 | # override init of fc1 w/ gate portion set to weight near zero, bias=1 55 | fc1_mid = self.fc1.bias.shape[0] // 2 56 | nn.init.ones_(self.fc1.bias[fc1_mid:]) 57 | nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6) 58 | 59 | def forward(self, x): 60 | x = self.fc1(x) 61 | x, gates = x.chunk(2, dim=-1) 62 | x = x * self.act(gates) 63 | x = self.drop1(x) 64 | x = self.fc2(x) 65 | x = self.drop2(x) 66 | return x 67 | 68 | 69 | class GatedMlp(nn.Module): 70 | """ MLP as used in gMLP 71 | """ 72 | def __init__( 73 | self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, 74 | gate_layer=None, bias=True, drop=0.): 75 | super().__init__() 76 | out_features = out_features or in_features 77 | hidden_features = hidden_features or in_features 78 | bias = to_2tuple(bias) 79 | drop_probs = to_2tuple(drop) 80 | 81 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) 82 | self.act = act_layer() 83 | self.drop1 = nn.Dropout(drop_probs[0]) 84 | if gate_layer is not None: 85 | assert hidden_features % 2 == 0 86 | self.gate = gate_layer(hidden_features) 87 | hidden_features = hidden_features // 2 # FIXME base reduction on gate property? 88 | else: 89 | self.gate = nn.Identity() 90 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) 91 | self.drop2 = nn.Dropout(drop_probs[1]) 92 | 93 | def forward(self, x): 94 | x = self.fc1(x) 95 | x = self.act(x) 96 | x = self.drop1(x) 97 | x = self.gate(x) 98 | x = self.fc2(x) 99 | x = self.drop2(x) 100 | return x 101 | 102 | 103 | class ConvMlp(nn.Module): 104 | """ MLP using 1x1 convs that keeps spatial dims 105 | """ 106 | def __init__( 107 | self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, 108 | norm_layer=None, bias=True, drop=0.): 109 | super().__init__() 110 | out_features = out_features or in_features 111 | hidden_features = hidden_features or in_features 112 | bias = to_2tuple(bias) 113 | 114 | self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=bias[0]) 115 | self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity() 116 | self.act = act_layer() 117 | self.drop = nn.Dropout(drop) 118 | self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=bias[1]) 119 | 120 | def forward(self, x): 121 | x = self.fc1(x) 122 | x = self.norm(x) 123 | x = self.act(x) 124 | x = self.drop(x) 125 | x = self.fc2(x) 126 | return x 127 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/non_local_attn.py: -------------------------------------------------------------------------------- 1 | """ Bilinear-Attention-Transform and Non-Local Attention 2 | 3 | Paper: `Non-Local Neural Networks With Grouped Bilinear Attentional Transforms` 4 | - https://openaccess.thecvf.com/content_CVPR_2020/html/Chi_Non-Local_Neural_Networks_With_Grouped_Bilinear_Attentional_Transforms_CVPR_2020_paper.html 5 | Adapted from original code: https://github.com/BA-Transform/BAT-Image-Classification 6 | """ 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from .conv_bn_act import ConvNormAct 12 | from .helpers import make_divisible 13 | from .trace_utils import _assert 14 | 15 | 16 | class NonLocalAttn(nn.Module): 17 | """Spatial NL block for image classification. 18 | 19 | This was adapted from https://github.com/BA-Transform/BAT-Image-Classification 20 | Their NonLocal impl inspired by https://github.com/facebookresearch/video-nonlocal-net. 21 | """ 22 | 23 | def __init__(self, in_channels, use_scale=True, rd_ratio=1/8, rd_channels=None, rd_divisor=8, **kwargs): 24 | super(NonLocalAttn, self).__init__() 25 | if rd_channels is None: 26 | rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor) 27 | self.scale = in_channels ** -0.5 if use_scale else 1.0 28 | self.t = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True) 29 | self.p = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True) 30 | self.g = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True) 31 | self.z = nn.Conv2d(rd_channels, in_channels, kernel_size=1, stride=1, bias=True) 32 | self.norm = nn.BatchNorm2d(in_channels) 33 | self.reset_parameters() 34 | 35 | def forward(self, x): 36 | shortcut = x 37 | 38 | t = self.t(x) 39 | p = self.p(x) 40 | g = self.g(x) 41 | 42 | B, C, H, W = t.size() 43 | t = t.view(B, C, -1).permute(0, 2, 1) 44 | p = p.view(B, C, -1) 45 | g = g.view(B, C, -1).permute(0, 2, 1) 46 | 47 | att = torch.bmm(t, p) * self.scale 48 | att = F.softmax(att, dim=2) 49 | x = torch.bmm(att, g) 50 | 51 | x = x.permute(0, 2, 1).reshape(B, C, H, W) 52 | x = self.z(x) 53 | x = self.norm(x) + shortcut 54 | 55 | return x 56 | 57 | def reset_parameters(self): 58 | for name, m in self.named_modules(): 59 | if isinstance(m, nn.Conv2d): 60 | nn.init.kaiming_normal_( 61 | m.weight, mode='fan_out', nonlinearity='relu') 62 | if len(list(m.parameters())) > 1: 63 | nn.init.constant_(m.bias, 0.0) 64 | elif isinstance(m, nn.BatchNorm2d): 65 | nn.init.constant_(m.weight, 0) 66 | nn.init.constant_(m.bias, 0) 67 | elif isinstance(m, nn.GroupNorm): 68 | nn.init.constant_(m.weight, 0) 69 | nn.init.constant_(m.bias, 0) 70 | 71 | 72 | class BilinearAttnTransform(nn.Module): 73 | 74 | def __init__(self, in_channels, block_size, groups, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): 75 | super(BilinearAttnTransform, self).__init__() 76 | 77 | self.conv1 = ConvNormAct(in_channels, groups, 1, act_layer=act_layer, norm_layer=norm_layer) 78 | self.conv_p = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(block_size, 1)) 79 | self.conv_q = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(1, block_size)) 80 | self.conv2 = ConvNormAct(in_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer) 81 | self.block_size = block_size 82 | self.groups = groups 83 | self.in_channels = in_channels 84 | 85 | def resize_mat(self, x, t: int): 86 | B, C, block_size, block_size1 = x.shape 87 | _assert(block_size == block_size1, '') 88 | if t <= 1: 89 | return x 90 | x = x.view(B * C, -1, 1, 1) 91 | x = x * torch.eye(t, t, dtype=x.dtype, device=x.device) 92 | x = x.view(B * C, block_size, block_size, t, t) 93 | x = torch.cat(torch.split(x, 1, dim=1), dim=3) 94 | x = torch.cat(torch.split(x, 1, dim=2), dim=4) 95 | x = x.view(B, C, block_size * t, block_size * t) 96 | return x 97 | 98 | def forward(self, x): 99 | _assert(x.shape[-1] % self.block_size == 0, '') 100 | _assert(x.shape[-2] % self.block_size == 0, '') 101 | B, C, H, W = x.shape 102 | out = self.conv1(x) 103 | rp = F.adaptive_max_pool2d(out, (self.block_size, 1)) 104 | cp = F.adaptive_max_pool2d(out, (1, self.block_size)) 105 | p = self.conv_p(rp).view(B, self.groups, self.block_size, self.block_size).sigmoid() 106 | q = self.conv_q(cp).view(B, self.groups, self.block_size, self.block_size).sigmoid() 107 | p = p / p.sum(dim=3, keepdim=True) 108 | q = q / q.sum(dim=2, keepdim=True) 109 | p = p.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size( 110 | 0), self.groups, C // self.groups, self.block_size, self.block_size).contiguous() 111 | p = p.view(B, C, self.block_size, self.block_size) 112 | q = q.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size( 113 | 0), self.groups, C // self.groups, self.block_size, self.block_size).contiguous() 114 | q = q.view(B, C, self.block_size, self.block_size) 115 | p = self.resize_mat(p, H // self.block_size) 116 | q = self.resize_mat(q, W // self.block_size) 117 | y = p.matmul(x) 118 | y = y.matmul(q) 119 | 120 | y = self.conv2(y) 121 | return y 122 | 123 | 124 | class BatNonLocalAttn(nn.Module): 125 | """ BAT 126 | Adapted from: https://github.com/BA-Transform/BAT-Image-Classification 127 | """ 128 | 129 | def __init__( 130 | self, in_channels, block_size=7, groups=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8, 131 | drop_rate=0.2, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, **_): 132 | super().__init__() 133 | if rd_channels is None: 134 | rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor) 135 | self.conv1 = ConvNormAct(in_channels, rd_channels, 1, act_layer=act_layer, norm_layer=norm_layer) 136 | self.ba = BilinearAttnTransform(rd_channels, block_size, groups, act_layer=act_layer, norm_layer=norm_layer) 137 | self.conv2 = ConvNormAct(rd_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer) 138 | self.dropout = nn.Dropout2d(p=drop_rate) 139 | 140 | def forward(self, x): 141 | xl = self.conv1(x) 142 | y = self.ba(xl) 143 | y = self.conv2(y) 144 | y = self.dropout(y) 145 | return y + x 146 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/norm.py: -------------------------------------------------------------------------------- 1 | """ Normalization layers and wrappers 2 | 3 | Norm layer definitions that support fast norm and consistent channel arg order (always first arg). 4 | 5 | Hacked together by / Copyright 2022 Ross Wightman 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm 13 | 14 | 15 | class GroupNorm(nn.GroupNorm): 16 | def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True): 17 | # NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN 18 | super().__init__(num_groups, num_channels, eps=eps, affine=affine) 19 | self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) 20 | 21 | def forward(self, x): 22 | if self.fast_norm: 23 | return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps) 24 | else: 25 | return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) 26 | 27 | 28 | class GroupNorm1(nn.GroupNorm): 29 | """ Group Normalization with 1 group. 30 | Input: tensor in shape [B, C, *] 31 | """ 32 | 33 | def __init__(self, num_channels, **kwargs): 34 | super().__init__(1, num_channels, **kwargs) 35 | self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) 36 | 37 | def forward(self, x: torch.Tensor) -> torch.Tensor: 38 | if self.fast_norm: 39 | return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps) 40 | else: 41 | return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) 42 | 43 | 44 | class LayerNorm(nn.LayerNorm): 45 | """ LayerNorm w/ fast norm option 46 | """ 47 | def __init__(self, num_channels, eps=1e-6, affine=True): 48 | super().__init__(num_channels, eps=eps, elementwise_affine=affine) 49 | self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) 50 | 51 | def forward(self, x: torch.Tensor) -> torch.Tensor: 52 | if self._fast_norm: 53 | x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 54 | else: 55 | x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 56 | return x 57 | 58 | 59 | class LayerNorm2d(nn.LayerNorm): 60 | """ LayerNorm for channels of '2D' spatial NCHW tensors """ 61 | def __init__(self, num_channels, eps=1e-6, affine=True): 62 | super().__init__(num_channels, eps=eps, elementwise_affine=affine) 63 | self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) 64 | 65 | def forward(self, x: torch.Tensor) -> torch.Tensor: 66 | x = x.permute(0, 2, 3, 1) 67 | if self._fast_norm: 68 | x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 69 | else: 70 | x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 71 | x = x.permute(0, 3, 1, 2) 72 | return x 73 | 74 | 75 | def _is_contiguous(tensor: torch.Tensor) -> bool: 76 | # jit is oh so lovely :/ 77 | if torch.jit.is_scripting(): 78 | return tensor.is_contiguous() 79 | else: 80 | return tensor.is_contiguous(memory_format=torch.contiguous_format) 81 | 82 | 83 | @torch.jit.script 84 | def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float): 85 | s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True) 86 | x = (x - u) * torch.rsqrt(s + eps) 87 | x = x * weight[:, None, None] + bias[:, None, None] 88 | return x 89 | 90 | 91 | def _layer_norm_cf_sqm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float): 92 | u = x.mean(dim=1, keepdim=True) 93 | s = ((x * x).mean(dim=1, keepdim=True) - (u * u)).clamp(0) 94 | x = (x - u) * torch.rsqrt(s + eps) 95 | x = x * weight.view(1, -1, 1, 1) + bias.view(1, -1, 1, 1) 96 | return x 97 | 98 | 99 | class LayerNormExp2d(nn.LayerNorm): 100 | """ LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W). 101 | 102 | Experimental implementation w/ manual norm for tensors non-contiguous tensors. 103 | 104 | This improves throughput in some scenarios (tested on Ampere GPU), esp w/ channels_last 105 | layout. However, benefits are not always clear and can perform worse on other GPUs. 106 | """ 107 | 108 | def __init__(self, num_channels, eps=1e-6): 109 | super().__init__(num_channels, eps=eps) 110 | 111 | def forward(self, x) -> torch.Tensor: 112 | if _is_contiguous(x): 113 | x = F.layer_norm( 114 | x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) 115 | else: 116 | x = _layer_norm_cf(x, self.weight, self.bias, self.eps) 117 | return x 118 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/padding.py: -------------------------------------------------------------------------------- 1 | """ Padding Helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import math 6 | from typing import List, Tuple 7 | 8 | import torch.nn.functional as F 9 | 10 | 11 | # Calculate symmetric padding for a convolution 12 | def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int: 13 | padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 14 | return padding 15 | 16 | 17 | # Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution 18 | def get_same_padding(x: int, k: int, s: int, d: int): 19 | return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0) 20 | 21 | 22 | # Can SAME padding for given args be done statically? 23 | def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_): 24 | return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 25 | 26 | 27 | # Dynamically pad input x with 'SAME' padding for conv with specified args 28 | def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0): 29 | ih, iw = x.size()[-2:] 30 | pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1]) 31 | if pad_h > 0 or pad_w > 0: 32 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value) 33 | return x 34 | 35 | 36 | def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]: 37 | dynamic = False 38 | if isinstance(padding, str): 39 | # for any string padding, the padding will be calculated for you, one of three ways 40 | padding = padding.lower() 41 | if padding == 'same': 42 | # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact 43 | if is_static_pad(kernel_size, **kwargs): 44 | # static case, no extra overhead 45 | padding = get_padding(kernel_size, **kwargs) 46 | else: 47 | # dynamic 'SAME' padding, has runtime/GPU memory overhead 48 | padding = 0 49 | dynamic = True 50 | elif padding == 'valid': 51 | # 'VALID' padding, same as padding=0 52 | padding = 0 53 | else: 54 | # Default to PyTorch style 'same'-ish symmetric padding 55 | padding = get_padding(kernel_size, **kwargs) 56 | return padding, dynamic 57 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | """ Image to Patch Embedding using Conv2d 2 | 3 | A convolution based approach to patchifying a 2D image w/ embedding projection. 4 | 5 | Based on the impl in https://github.com/google-research/vision_transformer 6 | 7 | Hacked together by / Copyright 2020 Ross Wightman 8 | """ 9 | from torch import nn as nn 10 | 11 | from .helpers import to_2tuple 12 | from .trace_utils import _assert 13 | 14 | 15 | class PatchEmbed(nn.Module): 16 | """ 2D Image to Patch Embedding 17 | """ 18 | def __init__( 19 | self, 20 | img_size=224, 21 | patch_size=16, 22 | in_chans=3, 23 | embed_dim=768, 24 | norm_layer=None, 25 | flatten=True, 26 | bias=True, 27 | ): 28 | super().__init__() 29 | img_size = to_2tuple(img_size) 30 | patch_size = to_2tuple(patch_size) 31 | self.img_size = img_size 32 | self.patch_size = patch_size 33 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 34 | self.num_patches = self.grid_size[0] * self.grid_size[1] 35 | self.flatten = flatten 36 | 37 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) 38 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 39 | 40 | def forward(self, x): 41 | B, C, H, W = x.shape 42 | _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") 43 | _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") 44 | x = self.proj(x) 45 | if self.flatten: 46 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 47 | x = self.norm(x) 48 | return x 49 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/pool2d_same.py: -------------------------------------------------------------------------------- 1 | """ AvgPool2d w/ Same Padding 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from typing import List, Tuple, Optional 9 | 10 | from .helpers import to_2tuple 11 | from .padding import pad_same, get_padding_value 12 | 13 | 14 | def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), 15 | ceil_mode: bool = False, count_include_pad: bool = True): 16 | # FIXME how to deal with count_include_pad vs not for external padding? 17 | x = pad_same(x, kernel_size, stride) 18 | return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad) 19 | 20 | 21 | class AvgPool2dSame(nn.AvgPool2d): 22 | """ Tensorflow like 'SAME' wrapper for 2D average pooling 23 | """ 24 | def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True): 25 | kernel_size = to_2tuple(kernel_size) 26 | stride = to_2tuple(stride) 27 | super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad) 28 | 29 | def forward(self, x): 30 | x = pad_same(x, self.kernel_size, self.stride) 31 | return F.avg_pool2d( 32 | x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad) 33 | 34 | 35 | def max_pool2d_same( 36 | x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), 37 | dilation: List[int] = (1, 1), ceil_mode: bool = False): 38 | x = pad_same(x, kernel_size, stride, value=-float('inf')) 39 | return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode) 40 | 41 | 42 | class MaxPool2dSame(nn.MaxPool2d): 43 | """ Tensorflow like 'SAME' wrapper for 2D max pooling 44 | """ 45 | def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False): 46 | kernel_size = to_2tuple(kernel_size) 47 | stride = to_2tuple(stride) 48 | dilation = to_2tuple(dilation) 49 | super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode) 50 | 51 | def forward(self, x): 52 | x = pad_same(x, self.kernel_size, self.stride, value=-float('inf')) 53 | return F.max_pool2d(x, self.kernel_size, self.stride, (0, 0), self.dilation, self.ceil_mode) 54 | 55 | 56 | def create_pool2d(pool_type, kernel_size, stride=None, **kwargs): 57 | stride = stride or kernel_size 58 | padding = kwargs.pop('padding', '') 59 | padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs) 60 | if is_dynamic: 61 | if pool_type == 'avg': 62 | return AvgPool2dSame(kernel_size, stride=stride, **kwargs) 63 | elif pool_type == 'max': 64 | return MaxPool2dSame(kernel_size, stride=stride, **kwargs) 65 | else: 66 | assert False, f'Unsupported pool type {pool_type}' 67 | else: 68 | if pool_type == 'avg': 69 | return nn.AvgPool2d(kernel_size, stride=stride, padding=padding, **kwargs) 70 | elif pool_type == 'max': 71 | return nn.MaxPool2d(kernel_size, stride=stride, padding=padding, **kwargs) 72 | else: 73 | assert False, f'Unsupported pool type {pool_type}' 74 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/selective_kernel.py: -------------------------------------------------------------------------------- 1 | """ Selective Kernel Convolution/Attention 2 | 3 | Paper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586) 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import torch 8 | from torch import nn as nn 9 | 10 | from .conv_bn_act import ConvNormActAa 11 | from .helpers import make_divisible 12 | from .trace_utils import _assert 13 | 14 | 15 | def _kernel_valid(k): 16 | if isinstance(k, (list, tuple)): 17 | for ki in k: 18 | return _kernel_valid(ki) 19 | assert k >= 3 and k % 2 20 | 21 | 22 | class SelectiveKernelAttn(nn.Module): 23 | def __init__(self, channels, num_paths=2, attn_channels=32, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): 24 | """ Selective Kernel Attention Module 25 | 26 | Selective Kernel attention mechanism factored out into its own module. 27 | 28 | """ 29 | super(SelectiveKernelAttn, self).__init__() 30 | self.num_paths = num_paths 31 | self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False) 32 | self.bn = norm_layer(attn_channels) 33 | self.act = act_layer(inplace=True) 34 | self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False) 35 | 36 | def forward(self, x): 37 | _assert(x.shape[1] == self.num_paths, '') 38 | x = x.sum(1).mean((2, 3), keepdim=True) 39 | x = self.fc_reduce(x) 40 | x = self.bn(x) 41 | x = self.act(x) 42 | x = self.fc_select(x) 43 | B, C, H, W = x.shape 44 | x = x.view(B, self.num_paths, C // self.num_paths, H, W) 45 | x = torch.softmax(x, dim=1) 46 | return x 47 | 48 | 49 | class SelectiveKernel(nn.Module): 50 | 51 | def __init__(self, in_channels, out_channels=None, kernel_size=None, stride=1, dilation=1, groups=1, 52 | rd_ratio=1./16, rd_channels=None, rd_divisor=8, keep_3x3=True, split_input=True, 53 | act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_layer=None): 54 | """ Selective Kernel Convolution Module 55 | 56 | As described in Selective Kernel Networks (https://arxiv.org/abs/1903.06586) with some modifications. 57 | 58 | Largest change is the input split, which divides the input channels across each convolution path, this can 59 | be viewed as a grouping of sorts, but the output channel counts expand to the module level value. This keeps 60 | the parameter count from ballooning when the convolutions themselves don't have groups, but still provides 61 | a noteworthy increase in performance over similar param count models without this attention layer. -Ross W 62 | 63 | Args: 64 | in_channels (int): module input (feature) channel count 65 | out_channels (int): module output (feature) channel count 66 | kernel_size (int, list): kernel size for each convolution branch 67 | stride (int): stride for convolutions 68 | dilation (int): dilation for module as a whole, impacts dilation of each branch 69 | groups (int): number of groups for each branch 70 | rd_ratio (int, float): reduction factor for attention features 71 | keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations 72 | split_input (bool): split input channels evenly across each convolution branch, keeps param count lower, 73 | can be viewed as grouping by path, output expands to module out_channels count 74 | act_layer (nn.Module): activation layer to use 75 | norm_layer (nn.Module): batchnorm/norm layer to use 76 | aa_layer (nn.Module): anti-aliasing module 77 | drop_layer (nn.Module): spatial drop module in convs (drop block, etc) 78 | """ 79 | super(SelectiveKernel, self).__init__() 80 | out_channels = out_channels or in_channels 81 | kernel_size = kernel_size or [3, 5] # default to one 3x3 and one 5x5 branch. 5x5 -> 3x3 + dilation 82 | _kernel_valid(kernel_size) 83 | if not isinstance(kernel_size, list): 84 | kernel_size = [kernel_size] * 2 85 | if keep_3x3: 86 | dilation = [dilation * (k - 1) // 2 for k in kernel_size] 87 | kernel_size = [3] * len(kernel_size) 88 | else: 89 | dilation = [dilation] * len(kernel_size) 90 | self.num_paths = len(kernel_size) 91 | self.in_channels = in_channels 92 | self.out_channels = out_channels 93 | self.split_input = split_input 94 | if self.split_input: 95 | assert in_channels % self.num_paths == 0 96 | in_channels = in_channels // self.num_paths 97 | groups = min(out_channels, groups) 98 | 99 | conv_kwargs = dict( 100 | stride=stride, groups=groups, act_layer=act_layer, norm_layer=norm_layer, 101 | aa_layer=aa_layer, drop_layer=drop_layer) 102 | self.paths = nn.ModuleList([ 103 | ConvNormActAa(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs) 104 | for k, d in zip(kernel_size, dilation)]) 105 | 106 | attn_channels = rd_channels or make_divisible(out_channels * rd_ratio, divisor=rd_divisor) 107 | self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels) 108 | 109 | def forward(self, x): 110 | if self.split_input: 111 | x_split = torch.split(x, self.in_channels // self.num_paths, 1) 112 | x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)] 113 | else: 114 | x_paths = [op(x) for op in self.paths] 115 | x = torch.stack(x_paths, dim=1) 116 | x_attn = self.attn(x) 117 | x = x * x_attn 118 | x = torch.sum(x, dim=1) 119 | return x 120 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/separable_conv.py: -------------------------------------------------------------------------------- 1 | """ Depthwise Separable Conv Modules 2 | 3 | Basic DWS convs. Other variations of DWS exist with batch norm or activations between the 4 | DW and PW convs such as the Depthwise modules in MobileNetV2 / EfficientNet and Xception. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | from torch import nn as nn 9 | 10 | from .create_conv2d import create_conv2d 11 | from .create_norm_act import get_norm_act_layer 12 | 13 | 14 | class SeparableConvNormAct(nn.Module): 15 | """ Separable Conv w/ trailing Norm and Activation 16 | """ 17 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, 18 | channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, 19 | apply_act=True, drop_layer=None): 20 | super(SeparableConvNormAct, self).__init__() 21 | 22 | self.conv_dw = create_conv2d( 23 | in_channels, int(in_channels * channel_multiplier), kernel_size, 24 | stride=stride, dilation=dilation, padding=padding, depthwise=True) 25 | 26 | self.conv_pw = create_conv2d( 27 | int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) 28 | 29 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer) 30 | norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {} 31 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs) 32 | 33 | @property 34 | def in_channels(self): 35 | return self.conv_dw.in_channels 36 | 37 | @property 38 | def out_channels(self): 39 | return self.conv_pw.out_channels 40 | 41 | def forward(self, x): 42 | x = self.conv_dw(x) 43 | x = self.conv_pw(x) 44 | x = self.bn(x) 45 | return x 46 | 47 | 48 | SeparableConvBnAct = SeparableConvNormAct 49 | 50 | 51 | class SeparableConv2d(nn.Module): 52 | """ Separable Conv 53 | """ 54 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, 55 | channel_multiplier=1.0, pw_kernel_size=1): 56 | super(SeparableConv2d, self).__init__() 57 | 58 | self.conv_dw = create_conv2d( 59 | in_channels, int(in_channels * channel_multiplier), kernel_size, 60 | stride=stride, dilation=dilation, padding=padding, depthwise=True) 61 | 62 | self.conv_pw = create_conv2d( 63 | int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) 64 | 65 | @property 66 | def in_channels(self): 67 | return self.conv_dw.in_channels 68 | 69 | @property 70 | def out_channels(self): 71 | return self.conv_pw.out_channels 72 | 73 | def forward(self, x): 74 | x = self.conv_dw(x) 75 | x = self.conv_pw(x) 76 | return x 77 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/space_to_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SpaceToDepth(nn.Module): 6 | def __init__(self, block_size=4): 7 | super().__init__() 8 | assert block_size == 4 9 | self.bs = block_size 10 | 11 | def forward(self, x): 12 | N, C, H, W = x.size() 13 | x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs) 14 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) 15 | x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs) 16 | return x 17 | 18 | 19 | @torch.jit.script 20 | class SpaceToDepthJit(object): 21 | def __call__(self, x: torch.Tensor): 22 | # assuming hard-coded that block_size==4 for acceleration 23 | N, C, H, W = x.size() 24 | x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs) 25 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) 26 | x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs) 27 | return x 28 | 29 | 30 | class SpaceToDepthModule(nn.Module): 31 | def __init__(self, no_jit=False): 32 | super().__init__() 33 | if not no_jit: 34 | self.op = SpaceToDepthJit() 35 | else: 36 | self.op = SpaceToDepth() 37 | 38 | def forward(self, x): 39 | return self.op(x) 40 | 41 | 42 | class DepthToSpace(nn.Module): 43 | 44 | def __init__(self, block_size): 45 | super().__init__() 46 | self.bs = block_size 47 | 48 | def forward(self, x): 49 | N, C, H, W = x.size() 50 | x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W) 51 | x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs) 52 | x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs) 53 | return x 54 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/split_attn.py: -------------------------------------------------------------------------------- 1 | """ Split Attention Conv2d (for ResNeSt Models) 2 | 3 | Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955 4 | 5 | Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt 6 | 7 | Modified for torchscript compat, performance, and consistency with timm by Ross Wightman 8 | """ 9 | import torch 10 | import torch.nn.functional as F 11 | from torch import nn 12 | 13 | from .helpers import make_divisible 14 | 15 | 16 | class RadixSoftmax(nn.Module): 17 | def __init__(self, radix, cardinality): 18 | super(RadixSoftmax, self).__init__() 19 | self.radix = radix 20 | self.cardinality = cardinality 21 | 22 | def forward(self, x): 23 | batch = x.size(0) 24 | if self.radix > 1: 25 | x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) 26 | x = F.softmax(x, dim=1) 27 | x = x.reshape(batch, -1) 28 | else: 29 | x = torch.sigmoid(x) 30 | return x 31 | 32 | 33 | class SplitAttn(nn.Module): 34 | """Split-Attention (aka Splat) 35 | """ 36 | def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=None, 37 | dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8, 38 | act_layer=nn.ReLU, norm_layer=None, drop_layer=None, **kwargs): 39 | super(SplitAttn, self).__init__() 40 | out_channels = out_channels or in_channels 41 | self.radix = radix 42 | mid_chs = out_channels * radix 43 | if rd_channels is None: 44 | attn_chs = make_divisible(in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor) 45 | else: 46 | attn_chs = rd_channels * radix 47 | 48 | padding = kernel_size // 2 if padding is None else padding 49 | self.conv = nn.Conv2d( 50 | in_channels, mid_chs, kernel_size, stride, padding, dilation, 51 | groups=groups * radix, bias=bias, **kwargs) 52 | self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity() 53 | self.drop = drop_layer() if drop_layer is not None else nn.Identity() 54 | self.act0 = act_layer(inplace=True) 55 | self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups) 56 | self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity() 57 | self.act1 = act_layer(inplace=True) 58 | self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups) 59 | self.rsoftmax = RadixSoftmax(radix, groups) 60 | 61 | def forward(self, x): 62 | x = self.conv(x) 63 | x = self.bn0(x) 64 | x = self.drop(x) 65 | x = self.act0(x) 66 | 67 | B, RC, H, W = x.shape 68 | if self.radix > 1: 69 | x = x.reshape((B, self.radix, RC // self.radix, H, W)) 70 | x_gap = x.sum(dim=1) 71 | else: 72 | x_gap = x 73 | x_gap = x_gap.mean((2, 3), keepdim=True) 74 | x_gap = self.fc1(x_gap) 75 | x_gap = self.bn1(x_gap) 76 | x_gap = self.act1(x_gap) 77 | x_attn = self.fc2(x_gap) 78 | 79 | x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1) 80 | if self.radix > 1: 81 | out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1) 82 | else: 83 | out = x * x_attn 84 | return out.contiguous() 85 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/split_batchnorm.py: -------------------------------------------------------------------------------- 1 | """ Split BatchNorm 2 | 3 | A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through 4 | a separate BN layer. The first split is passed through the parent BN layers with weight/bias 5 | keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn' 6 | namespace. 7 | 8 | This allows easily removing the auxiliary BN layers after training to efficiently 9 | achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2, 10 | 'Disentangled Learning via An Auxiliary BN' 11 | 12 | Hacked together by / Copyright 2020 Ross Wightman 13 | """ 14 | import torch 15 | import torch.nn as nn 16 | 17 | 18 | class SplitBatchNorm2d(torch.nn.BatchNorm2d): 19 | 20 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 21 | track_running_stats=True, num_splits=2): 22 | super().__init__(num_features, eps, momentum, affine, track_running_stats) 23 | assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)' 24 | self.num_splits = num_splits 25 | self.aux_bn = nn.ModuleList([ 26 | nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)]) 27 | 28 | def forward(self, input: torch.Tensor): 29 | if self.training: # aux BN only relevant while training 30 | split_size = input.shape[0] // self.num_splits 31 | assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits" 32 | split_input = input.split(split_size) 33 | x = [super().forward(split_input[0])] 34 | for i, a in enumerate(self.aux_bn): 35 | x.append(a(split_input[i + 1])) 36 | return torch.cat(x, dim=0) 37 | else: 38 | return super().forward(input) 39 | 40 | 41 | def convert_splitbn_model(module, num_splits=2): 42 | """ 43 | Recursively traverse module and its children to replace all instances of 44 | ``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`. 45 | Args: 46 | module (torch.nn.Module): input module 47 | num_splits: number of separate batchnorm layers to split input across 48 | Example:: 49 | >>> # model is an instance of torch.nn.Module 50 | >>> model = timm.models.convert_splitbn_model(model, num_splits=2) 51 | """ 52 | mod = module 53 | if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm): 54 | return module 55 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): 56 | mod = SplitBatchNorm2d( 57 | module.num_features, module.eps, module.momentum, module.affine, 58 | module.track_running_stats, num_splits=num_splits) 59 | mod.running_mean = module.running_mean 60 | mod.running_var = module.running_var 61 | mod.num_batches_tracked = module.num_batches_tracked 62 | if module.affine: 63 | mod.weight.data = module.weight.data.clone().detach() 64 | mod.bias.data = module.bias.data.clone().detach() 65 | for aux in mod.aux_bn: 66 | aux.running_mean = module.running_mean.clone() 67 | aux.running_var = module.running_var.clone() 68 | aux.num_batches_tracked = module.num_batches_tracked.clone() 69 | if module.affine: 70 | aux.weight.data = module.weight.data.clone().detach() 71 | aux.bias.data = module.bias.data.clone().detach() 72 | for name, child in module.named_children(): 73 | mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits)) 74 | del module 75 | return mod 76 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/squeeze_excite.py: -------------------------------------------------------------------------------- 1 | """ Squeeze-and-Excitation Channel Attention 2 | 3 | An SE implementation originally based on PyTorch SE-Net impl. 4 | Has since evolved with additional functionality / configuration. 5 | 6 | Paper: `Squeeze-and-Excitation Networks` - https://arxiv.org/abs/1709.01507 7 | 8 | Also included is Effective Squeeze-Excitation (ESE). 9 | Paper: `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 10 | 11 | Hacked together by / Copyright 2021 Ross Wightman 12 | """ 13 | from torch import nn as nn 14 | 15 | from .create_act import create_act_layer 16 | from .helpers import make_divisible 17 | 18 | 19 | class SEModule(nn.Module): 20 | """ SE Module as defined in original SE-Nets with a few additions 21 | Additions include: 22 | * divisor can be specified to keep channels % div == 0 (default: 8) 23 | * reduction channels can be specified directly by arg (if rd_channels is set) 24 | * reduction channels can be specified by float rd_ratio (default: 1/16) 25 | * global max pooling can be added to the squeeze aggregation 26 | * customizable activation, normalization, and gate layer 27 | """ 28 | def __init__( 29 | self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False, 30 | bias=True, act_layer=nn.ReLU, norm_layer=None, gate_layer='sigmoid'): 31 | super(SEModule, self).__init__() 32 | self.add_maxpool = add_maxpool 33 | if not rd_channels: 34 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) 35 | self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=bias) 36 | self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity() 37 | self.act = create_act_layer(act_layer, inplace=True) 38 | self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=bias) 39 | self.gate = create_act_layer(gate_layer) 40 | 41 | def forward(self, x): 42 | x_se = x.mean((2, 3), keepdim=True) 43 | if self.add_maxpool: 44 | # experimental codepath, may remove or change 45 | x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True) 46 | x_se = self.fc1(x_se) 47 | x_se = self.act(self.bn(x_se)) 48 | x_se = self.fc2(x_se) 49 | return x * self.gate(x_se) 50 | 51 | 52 | SqueezeExcite = SEModule # alias 53 | 54 | 55 | class EffectiveSEModule(nn.Module): 56 | """ 'Effective Squeeze-Excitation 57 | From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 58 | """ 59 | def __init__(self, channels, add_maxpool=False, gate_layer='hard_sigmoid', **_): 60 | super(EffectiveSEModule, self).__init__() 61 | self.add_maxpool = add_maxpool 62 | self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0) 63 | self.gate = create_act_layer(gate_layer) 64 | 65 | def forward(self, x): 66 | x_se = x.mean((2, 3), keepdim=True) 67 | if self.add_maxpool: 68 | # experimental codepath, may remove or change 69 | x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True) 70 | x_se = self.fc(x_se) 71 | return x * self.gate(x_se) 72 | 73 | 74 | EffectiveSqueezeExcite = EffectiveSEModule # alias 75 | 76 | 77 | class SqueezeExciteCl(nn.Module): 78 | """ SE Module as defined in original SE-Nets with a few additions 79 | Additions include: 80 | * divisor can be specified to keep channels % div == 0 (default: 8) 81 | * reduction channels can be specified directly by arg (if rd_channels is set) 82 | * reduction channels can be specified by float rd_ratio (default: 1/16) 83 | * global max pooling can be added to the squeeze aggregation 84 | * customizable activation, normalization, and gate layer 85 | """ 86 | def __init__( 87 | self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, 88 | bias=True, act_layer=nn.ReLU, gate_layer='sigmoid'): 89 | super().__init__() 90 | if not rd_channels: 91 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) 92 | self.fc1 = nn.Linear(channels, rd_channels, bias=bias) 93 | self.act = create_act_layer(act_layer, inplace=True) 94 | self.fc2 = nn.Linear(rd_channels, channels, bias=bias) 95 | self.gate = create_act_layer(gate_layer) 96 | 97 | def forward(self, x): 98 | x_se = x.mean((1, 2), keepdims=True) # FIXME avg dim [1:n-1], don't assume 2D NHWC 99 | x_se = self.fc1(x_se) 100 | x_se = self.act(x_se) 101 | x_se = self.fc2(x_se) 102 | return x * self.gate(x_se) -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/std_conv.py: -------------------------------------------------------------------------------- 1 | """ Convolution with Weight Standardization (StdConv and ScaledStdConv) 2 | 3 | StdConv: 4 | @article{weightstandardization, 5 | author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Yuille}, 6 | title = {Weight Standardization}, 7 | journal = {arXiv preprint arXiv:1903.10520}, 8 | year = {2019}, 9 | } 10 | Code: https://github.com/joe-siyuan-qiao/WeightStandardization 11 | 12 | ScaledStdConv: 13 | Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` 14 | - https://arxiv.org/abs/2101.08692 15 | Official Deepmind JAX code: https://github.com/deepmind/deepmind-research/tree/master/nfnets 16 | 17 | Hacked together by / copyright Ross Wightman, 2021. 18 | """ 19 | import torch 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | 23 | from .padding import get_padding, get_padding_value, pad_same 24 | 25 | 26 | class StdConv2d(nn.Conv2d): 27 | """Conv2d with Weight Standardization. Used for BiT ResNet-V2 models. 28 | 29 | Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` - 30 | https://arxiv.org/abs/1903.10520v2 31 | """ 32 | def __init__( 33 | self, in_channel, out_channels, kernel_size, stride=1, padding=None, 34 | dilation=1, groups=1, bias=False, eps=1e-6): 35 | if padding is None: 36 | padding = get_padding(kernel_size, stride, dilation) 37 | super().__init__( 38 | in_channel, out_channels, kernel_size, stride=stride, 39 | padding=padding, dilation=dilation, groups=groups, bias=bias) 40 | self.eps = eps 41 | 42 | def forward(self, x): 43 | weight = F.batch_norm( 44 | self.weight.reshape(1, self.out_channels, -1), None, None, 45 | training=True, momentum=0., eps=self.eps).reshape_as(self.weight) 46 | x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 47 | return x 48 | 49 | 50 | class StdConv2dSame(nn.Conv2d): 51 | """Conv2d with Weight Standardization. TF compatible SAME padding. Used for ViT Hybrid model. 52 | 53 | Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` - 54 | https://arxiv.org/abs/1903.10520v2 55 | """ 56 | def __init__( 57 | self, in_channel, out_channels, kernel_size, stride=1, padding='SAME', 58 | dilation=1, groups=1, bias=False, eps=1e-6): 59 | padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation) 60 | super().__init__( 61 | in_channel, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, 62 | groups=groups, bias=bias) 63 | self.same_pad = is_dynamic 64 | self.eps = eps 65 | 66 | def forward(self, x): 67 | if self.same_pad: 68 | x = pad_same(x, self.kernel_size, self.stride, self.dilation) 69 | weight = F.batch_norm( 70 | self.weight.reshape(1, self.out_channels, -1), None, None, 71 | training=True, momentum=0., eps=self.eps).reshape_as(self.weight) 72 | x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 73 | return x 74 | 75 | 76 | class ScaledStdConv2d(nn.Conv2d): 77 | """Conv2d layer with Scaled Weight Standardization. 78 | 79 | Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - 80 | https://arxiv.org/abs/2101.08692 81 | 82 | NOTE: the operations used in this impl differ slightly from the DeepMind Haiku impl. The impact is minor. 83 | """ 84 | 85 | def __init__( 86 | self, in_channels, out_channels, kernel_size, stride=1, padding=None, 87 | dilation=1, groups=1, bias=True, gamma=1.0, eps=1e-6, gain_init=1.0): 88 | if padding is None: 89 | padding = get_padding(kernel_size, stride, dilation) 90 | super().__init__( 91 | in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, 92 | groups=groups, bias=bias) 93 | self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init)) 94 | self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in) 95 | self.eps = eps 96 | 97 | def forward(self, x): 98 | weight = F.batch_norm( 99 | self.weight.reshape(1, self.out_channels, -1), None, None, 100 | weight=(self.gain * self.scale).view(-1), 101 | training=True, momentum=0., eps=self.eps).reshape_as(self.weight) 102 | return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 103 | 104 | 105 | class ScaledStdConv2dSame(nn.Conv2d): 106 | """Conv2d layer with Scaled Weight Standardization and Tensorflow-like SAME padding support 107 | 108 | Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - 109 | https://arxiv.org/abs/2101.08692 110 | 111 | NOTE: the operations used in this impl differ slightly from the DeepMind Haiku impl. The impact is minor. 112 | """ 113 | 114 | def __init__( 115 | self, in_channels, out_channels, kernel_size, stride=1, padding='SAME', 116 | dilation=1, groups=1, bias=True, gamma=1.0, eps=1e-6, gain_init=1.0): 117 | padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation) 118 | super().__init__( 119 | in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, 120 | groups=groups, bias=bias) 121 | self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init)) 122 | self.scale = gamma * self.weight[0].numel() ** -0.5 123 | self.same_pad = is_dynamic 124 | self.eps = eps 125 | 126 | def forward(self, x): 127 | if self.same_pad: 128 | x = pad_same(x, self.kernel_size, self.stride, self.dilation) 129 | weight = F.batch_norm( 130 | self.weight.reshape(1, self.out_channels, -1), None, None, 131 | weight=(self.gain * self.scale).view(-1), 132 | training=True, momentum=0., eps=self.eps).reshape_as(self.weight) 133 | return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 134 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/test_time_pool.py: -------------------------------------------------------------------------------- 1 | """ Test Time Pooling (Average-Max Pool) 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | 6 | import logging 7 | from torch import nn 8 | import torch.nn.functional as F 9 | 10 | from .adaptive_avgmax_pool import adaptive_avgmax_pool2d 11 | 12 | 13 | _logger = logging.getLogger(__name__) 14 | 15 | 16 | class TestTimePoolHead(nn.Module): 17 | def __init__(self, base, original_pool=7): 18 | super(TestTimePoolHead, self).__init__() 19 | self.base = base 20 | self.original_pool = original_pool 21 | base_fc = self.base.get_classifier() 22 | if isinstance(base_fc, nn.Conv2d): 23 | self.fc = base_fc 24 | else: 25 | self.fc = nn.Conv2d( 26 | self.base.num_features, self.base.num_classes, kernel_size=1, bias=True) 27 | self.fc.weight.data.copy_(base_fc.weight.data.view(self.fc.weight.size())) 28 | self.fc.bias.data.copy_(base_fc.bias.data.view(self.fc.bias.size())) 29 | self.base.reset_classifier(0) # delete original fc layer 30 | 31 | def forward(self, x): 32 | x = self.base.forward_features(x) 33 | x = F.avg_pool2d(x, kernel_size=self.original_pool, stride=1) 34 | x = self.fc(x) 35 | x = adaptive_avgmax_pool2d(x, 1) 36 | return x.view(x.size(0), -1) 37 | 38 | 39 | def apply_test_time_pool(model, config, use_test_size=False): 40 | test_time_pool = False 41 | if not hasattr(model, 'default_cfg') or not model.default_cfg: 42 | return model, False 43 | if use_test_size and 'test_input_size' in model.default_cfg: 44 | df_input_size = model.default_cfg['test_input_size'] 45 | else: 46 | df_input_size = model.default_cfg['input_size'] 47 | if config['input_size'][-1] > df_input_size[-1] and config['input_size'][-2] > df_input_size[-2]: 48 | _logger.info('Target input size %s > pretrained default %s, using test time pooling' % 49 | (str(config['input_size'][-2:]), str(df_input_size[-2:]))) 50 | model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size']) 51 | test_time_pool = True 52 | return model, test_time_pool 53 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/trace_utils.py: -------------------------------------------------------------------------------- 1 | try: 2 | from torch import _assert 3 | except ImportError: 4 | def _assert(condition: bool, message: str): 5 | assert condition, message 6 | 7 | 8 | def _float_to_int(x: float) -> int: 9 | """ 10 | Symbolic tracing helper to substitute for inbuilt `int`. 11 | Hint: Inbuilt `int` can't accept an argument of type `Proxy` 12 | """ 13 | return int(x) 14 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/layers/weight_init.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import warnings 4 | 5 | from torch.nn.init import _calculate_fan_in_and_fan_out 6 | 7 | 8 | def _trunc_normal_(tensor, mean, std, a, b): 9 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 10 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 11 | def norm_cdf(x): 12 | # Computes standard normal cumulative distribution function 13 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 14 | 15 | if (mean < a - 2 * std) or (mean > b + 2 * std): 16 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 17 | "The distribution of values may be incorrect.", 18 | stacklevel=2) 19 | 20 | # Values are generated by using a truncated uniform distribution and 21 | # then using the inverse CDF for the normal distribution. 22 | # Get upper and lower cdf values 23 | l = norm_cdf((a - mean) / std) 24 | u = norm_cdf((b - mean) / std) 25 | 26 | # Uniformly fill tensor with values from [l, u], then translate to 27 | # [2l-1, 2u-1]. 28 | tensor.uniform_(2 * l - 1, 2 * u - 1) 29 | 30 | # Use inverse cdf transform for normal distribution to get truncated 31 | # standard normal 32 | tensor.erfinv_() 33 | 34 | # Transform to proper mean, std 35 | tensor.mul_(std * math.sqrt(2.)) 36 | tensor.add_(mean) 37 | 38 | # Clamp to ensure it's in the proper range 39 | tensor.clamp_(min=a, max=b) 40 | return tensor 41 | 42 | 43 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 44 | # type: (Tensor, float, float, float, float) -> Tensor 45 | r"""Fills the input Tensor with values drawn from a truncated 46 | normal distribution. The values are effectively drawn from the 47 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 48 | with values outside :math:`[a, b]` redrawn until they are within 49 | the bounds. The method used for generating the random values works 50 | best when :math:`a \leq \text{mean} \leq b`. 51 | 52 | NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are 53 | applied while sampling the normal with mean/std applied, therefore a, b args 54 | should be adjusted to match the range of mean, std args. 55 | 56 | Args: 57 | tensor: an n-dimensional `torch.Tensor` 58 | mean: the mean of the normal distribution 59 | std: the standard deviation of the normal distribution 60 | a: the minimum cutoff value 61 | b: the maximum cutoff value 62 | Examples: 63 | >>> w = torch.empty(3, 5) 64 | >>> nn.init.trunc_normal_(w) 65 | """ 66 | with torch.no_grad(): 67 | return _trunc_normal_(tensor, mean, std, a, b) 68 | 69 | 70 | def trunc_normal_tf_(tensor, mean=0., std=1., a=-2., b=2.): 71 | # type: (Tensor, float, float, float, float) -> Tensor 72 | r"""Fills the input Tensor with values drawn from a truncated 73 | normal distribution. The values are effectively drawn from the 74 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 75 | with values outside :math:`[a, b]` redrawn until they are within 76 | the bounds. The method used for generating the random values works 77 | best when :math:`a \leq \text{mean} \leq b`. 78 | 79 | NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the 80 | bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 81 | and the result is subsquently scaled and shifted by the mean and std args. 82 | 83 | Args: 84 | tensor: an n-dimensional `torch.Tensor` 85 | mean: the mean of the normal distribution 86 | std: the standard deviation of the normal distribution 87 | a: the minimum cutoff value 88 | b: the maximum cutoff value 89 | Examples: 90 | >>> w = torch.empty(3, 5) 91 | >>> nn.init.trunc_normal_(w) 92 | """ 93 | with torch.no_grad(): 94 | _trunc_normal_(tensor, 0, 1.0, a, b) 95 | tensor.mul_(std).add_(mean) 96 | return tensor 97 | 98 | 99 | def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'): 100 | fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) 101 | if mode == 'fan_in': 102 | denom = fan_in 103 | elif mode == 'fan_out': 104 | denom = fan_out 105 | elif mode == 'fan_avg': 106 | denom = (fan_in + fan_out) / 2 107 | 108 | variance = scale / denom 109 | 110 | if distribution == "truncated_normal": 111 | # constant is stddev of standard normal truncated to (-2, 2) 112 | trunc_normal_tf_(tensor, std=math.sqrt(variance) / .87962566103423978) 113 | elif distribution == "normal": 114 | with torch.no_grad(): 115 | tensor.normal_(std=math.sqrt(variance)) 116 | elif distribution == "uniform": 117 | bound = math.sqrt(3 * variance) 118 | with torch.no_grad(): 119 | tensor.uniform_(-bound, bound) 120 | else: 121 | raise ValueError(f"invalid distribution {distribution}") 122 | 123 | 124 | def lecun_normal_(tensor): 125 | variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal') 126 | -------------------------------------------------------------------------------- /networks/merit_lib/models_timm/registry.py: -------------------------------------------------------------------------------- 1 | """ Model Registry 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | 5 | import sys 6 | import re 7 | import fnmatch 8 | from collections import defaultdict 9 | from copy import deepcopy 10 | 11 | __all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules', 12 | 'is_pretrained_cfg_key', 'has_pretrained_cfg_key', 'get_pretrained_cfg_value', 'is_model_pretrained'] 13 | 14 | _module_to_models = defaultdict(set) # dict of sets to check membership of model in module 15 | _model_to_module = {} # mapping of model names to module names 16 | _model_entrypoints = {} # mapping of model names to entrypoint fns 17 | _model_has_pretrained = set() # set of model names that have pretrained weight url present 18 | _model_pretrained_cfgs = dict() # central repo for model default_cfgs 19 | 20 | 21 | def register_model(fn): 22 | # lookup containing module 23 | mod = sys.modules[fn.__module__] 24 | module_name_split = fn.__module__.split('.') 25 | module_name = module_name_split[-1] if len(module_name_split) else '' 26 | 27 | # add model to __all__ in module 28 | model_name = fn.__name__ 29 | if hasattr(mod, '__all__'): 30 | mod.__all__.append(model_name) 31 | else: 32 | mod.__all__ = [model_name] 33 | 34 | # add entries to registry dict/sets 35 | _model_entrypoints[model_name] = fn 36 | _model_to_module[model_name] = module_name 37 | _module_to_models[module_name].add(model_name) 38 | has_valid_pretrained = False # check if model has a pretrained url to allow filtering on this 39 | if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs: 40 | # this will catch all models that have entrypoint matching cfg key, but miss any aliasing 41 | # entrypoints or non-matching combos 42 | cfg = mod.default_cfgs[model_name] 43 | has_valid_pretrained = ( 44 | ('url' in cfg and 'http' in cfg['url']) or 45 | ('file' in cfg and cfg['file']) or 46 | ('hf_hub_id' in cfg and cfg['hf_hub_id']) 47 | ) 48 | _model_pretrained_cfgs[model_name] = mod.default_cfgs[model_name] 49 | if has_valid_pretrained: 50 | _model_has_pretrained.add(model_name) 51 | return fn 52 | 53 | 54 | def _natural_key(string_): 55 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] 56 | 57 | 58 | def list_models(filter='', module='', pretrained=False, exclude_filters='', name_matches_cfg=False): 59 | """ Return list of available model names, sorted alphabetically 60 | 61 | Args: 62 | filter (str) - Wildcard filter string that works with fnmatch 63 | module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet') 64 | pretrained (bool) - Include only models with pretrained weights if True 65 | exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter 66 | name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases) 67 | 68 | Example: 69 | model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet' 70 | model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module 71 | """ 72 | if module: 73 | all_models = list(_module_to_models[module]) 74 | else: 75 | all_models = _model_entrypoints.keys() 76 | if filter: 77 | models = [] 78 | include_filters = filter if isinstance(filter, (tuple, list)) else [filter] 79 | for f in include_filters: 80 | include_models = fnmatch.filter(all_models, f) # include these models 81 | if len(include_models): 82 | models = set(models).union(include_models) 83 | else: 84 | models = all_models 85 | if exclude_filters: 86 | if not isinstance(exclude_filters, (tuple, list)): 87 | exclude_filters = [exclude_filters] 88 | for xf in exclude_filters: 89 | exclude_models = fnmatch.filter(models, xf) # exclude these models 90 | if len(exclude_models): 91 | models = set(models).difference(exclude_models) 92 | if pretrained: 93 | models = _model_has_pretrained.intersection(models) 94 | if name_matches_cfg: 95 | models = set(_model_pretrained_cfgs).intersection(models) 96 | return list(sorted(models, key=_natural_key)) 97 | 98 | 99 | def is_model(model_name): 100 | """ Check if a model name exists 101 | """ 102 | return model_name in _model_entrypoints 103 | 104 | 105 | def model_entrypoint(model_name): 106 | """Fetch a model entrypoint for specified model name 107 | """ 108 | return _model_entrypoints[model_name] 109 | 110 | 111 | def list_modules(): 112 | """ Return list of module names that contain models / model entrypoints 113 | """ 114 | modules = _module_to_models.keys() 115 | return list(sorted(modules)) 116 | 117 | 118 | def is_model_in_modules(model_name, module_names): 119 | """Check if a model exists within a subset of modules 120 | Args: 121 | model_name (str) - name of model to check 122 | module_names (tuple, list, set) - names of modules to search in 123 | """ 124 | assert isinstance(module_names, (tuple, list, set)) 125 | return any(model_name in _module_to_models[n] for n in module_names) 126 | 127 | 128 | def is_model_pretrained(model_name): 129 | return model_name in _model_has_pretrained 130 | 131 | 132 | def get_pretrained_cfg(model_name): 133 | if model_name in _model_pretrained_cfgs: 134 | return deepcopy(_model_pretrained_cfgs[model_name]) 135 | return {} 136 | 137 | 138 | def has_pretrained_cfg_key(model_name, cfg_key): 139 | """ Query model default_cfgs for existence of a specific key. 140 | """ 141 | if model_name in _model_pretrained_cfgs and cfg_key in _model_pretrained_cfgs[model_name]: 142 | return True 143 | return False 144 | 145 | 146 | def is_pretrained_cfg_key(model_name, cfg_key): 147 | """ Return truthy value for specified model default_cfg key, False if does not exist. 148 | """ 149 | if model_name in _model_pretrained_cfgs and _model_pretrained_cfgs[model_name].get(cfg_key, False): 150 | return True 151 | return False 152 | 153 | 154 | def get_pretrained_cfg_value(model_name, cfg_key): 155 | """ Get a specific model default_cfg value by key. None if it doesn't exist. 156 | """ 157 | if model_name in _model_pretrained_cfgs: 158 | return _model_pretrained_cfgs[model_name].get(cfg_key, None) 159 | return None -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.4.0 2 | addict==2.4.0 3 | albumentations==1.3.0 4 | certifi==2023.5.7 5 | charset-normalizer==3.1.0 6 | contextlib2==21.6.0 7 | contourpy==1.0.7 8 | cycler==0.11.0 9 | einops==0.6.1 10 | fairscale==0.4.13 11 | filelock==3.12.2 12 | fonttools==4.39.3 13 | fsspec==2023.4.0 14 | fvcore==0.1.5.post20221221 15 | h5py==3.8.0 16 | huggingface-hub==0.14.1 17 | idna==3.4 18 | imageio==2.28.1 19 | imgaug==0.4.0 20 | importlib-resources==5.12.0 21 | iopath==0.1.10 22 | joblib==1.2.0 23 | kiwisolver==1.4.4 24 | lazy_loader==0.2 25 | loguru==0.7.0 26 | matplotlib==3.7.1 27 | MedPy==0.4.0 28 | ml-collections==0.1.1 29 | networkx==3.1 30 | nibabel==5.1.0 31 | numpy<1.25.0 32 | opencv-python==4.7.0.72 33 | opencv-python-headless==4.7.0.72 34 | packaging==23.1 35 | pandas<1.6 36 | Pillow==10.0.1 37 | pip==23.0.1 38 | portalocker==2.7.0 39 | pthflops==0.4.2 40 | pyparsing==3.0.9 41 | python-dateutil==2.8.2 42 | pytz==2023.3 43 | PyWavelets==1.4.1 44 | PyYAML==6.0 45 | qudida==0.0.4 46 | requests==2.30.0 47 | scikit-image==0.20.0 48 | scikit-learn==1.2.2 49 | scipy==1.11.3 50 | seaborn==0.12.2 51 | segmentation-mask-overlay==0.4.4 52 | setuptools==66.0.0 53 | shapely==2.0.1 54 | SimpleITK==2.2.1 55 | six==1.16.0 56 | tabulate==0.9.0 57 | tensorboardX 58 | termcolor==2.3.0 59 | thop==0.1.1.post2209072238 60 | threadpoolctl==3.1.0 61 | tifffile==2023.4.12 62 | timm==0.6.13 63 | tomli==2.0.1 64 | tqdm==4.65.0 65 | typing_extensions==4.7.1 66 | tzdata==2023.3 67 | warmup-scheduler==0.3 68 | wheel==0.38.4 69 | yacs==0.1.8 70 | yapf==0.33.0 71 | zipp==3.15.0 72 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import random 5 | import warnings 6 | from pydoc import locate 7 | 8 | import numpy as np 9 | import torch 10 | import torch.backends.cudnn as cudnn 11 | 12 | from networks.msa2net import Msa2Net 13 | from trainer import trainer_synapse 14 | 15 | from fvcore.nn import FlopCountAnalysis 16 | import sys 17 | 18 | warnings.filterwarnings("ignore") 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument( 22 | "--root_path", 23 | type=str, 24 | default="data/Synapse/train_npz", 25 | help="root dir for train data", 26 | ) 27 | parser.add_argument( 28 | "--test_path", 29 | type=str, 30 | default="data/Synapse/test_vol_h5", 31 | help="root dir for test data", 32 | ) 33 | parser.add_argument("--dataset", type=str, default="Synapse", help="experiment_name") 34 | parser.add_argument("--list_dir", type=str, default="./lists/lists_Synapse", help="list dir") 35 | parser.add_argument("--num_classes", type=int, default=9, help="output channel of network") 36 | parser.add_argument("--output_dir", type=str, default="./model_out/MSA2Net", help="output dir") 37 | parser.add_argument("--max_iterations", type=int, default=90000, help="maximum epoch number to train") 38 | parser.add_argument("--max_epochs", type=int, default=400, help="maximum epoch number to train") 39 | parser.add_argument("--batch_size", type=int, default=24, help="batch_size per gpu") 40 | parser.add_argument("--num_workers", type=int, default=8, help="num_workers") 41 | parser.add_argument("--eval_interval", type=int, default=20, help="eval_interval") 42 | parser.add_argument("--model_name", type=str, default="msa2net", help="model_name") 43 | parser.add_argument("--n_gpu", type=int, default=1, help="total gpu") 44 | parser.add_argument("--deterministic", type=int, default=1, help="whether to use deterministic training") 45 | parser.add_argument("--base_lr", type=float, default=0.05, help="segmentation network base learning rate") 46 | parser.add_argument("--img_size", type=int, default=224, help="input patch size of network input") 47 | parser.add_argument("--z_spacing", type=int, default=1, help="z_spacing") 48 | parser.add_argument("--seed", type=int, default=1234, help="random seed") 49 | parser.add_argument("--zip", action="store_true", help="use zipped dataset instead of folder dataset") 50 | parser.add_argument( 51 | "--cache-mode", 52 | type=str, 53 | default="part", 54 | choices=["no", "full", "part"], 55 | help="no: no cache, " 56 | "full: cache all data, " 57 | "part: sharding the dataset into nonoverlapping pieces and only cache one piece", 58 | ) 59 | parser.add_argument("--resume", help="resume from checkpoint") 60 | parser.add_argument("--accumulation-steps", type=int, help="gradient accumulation steps") 61 | parser.add_argument( 62 | "--use-checkpoint", action="store_true", help="whether to use gradient checkpointing to save memory" 63 | ) 64 | parser.add_argument( 65 | "--amp-opt-level", 66 | type=str, 67 | default="O1", 68 | choices=["O0", "O1", "O2"], 69 | help="mixed precision opt level, if O0, no amp is used", 70 | ) 71 | parser.add_argument("--tag", help="tag of experiment") 72 | parser.add_argument("--eval", action="store_true", help="Perform evaluation only") 73 | parser.add_argument("--throughput", action="store_true", help="Test throughput only") 74 | 75 | args = parser.parse_args() 76 | 77 | 78 | if __name__ == "__main__": 79 | # setting device on GPU if available, else CPU 80 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 81 | print("Using device:", device) 82 | print() 83 | 84 | # Additional Info when using cuda 85 | if device.type == "cuda": 86 | print(torch.cuda.get_device_name(0)) 87 | print("Memory Usage:") 88 | print("Allocated:", round(torch.cuda.memory_allocated(0) / 1024**3, 1), "GB") 89 | print("Cached: ", round(torch.cuda.memory_reserved(0) / 1024**3, 1), "GB") 90 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 91 | if not args.deterministic: 92 | cudnn.benchmark = True 93 | cudnn.deterministic = False 94 | else: 95 | cudnn.benchmark = False 96 | cudnn.deterministic = True 97 | 98 | random.seed(args.seed) 99 | np.random.seed(args.seed) 100 | torch.manual_seed(args.seed) 101 | torch.cuda.manual_seed(args.seed) 102 | 103 | dataset_name = args.dataset 104 | dataset_config = { 105 | "Synapse": { 106 | "root_path": args.root_path, 107 | "list_dir": args.list_dir, 108 | "num_classes": 9, 109 | }, 110 | } 111 | 112 | if args.batch_size != 24 and args.batch_size % 5 == 0: 113 | args.base_lr *= args.batch_size / 24 114 | args.num_classes = dataset_config[dataset_name]["num_classes"] 115 | args.root_path = dataset_config[dataset_name]["root_path"] 116 | args.list_dir = dataset_config[dataset_name]["list_dir"] 117 | 118 | if not os.path.exists(args.output_dir): 119 | os.makedirs(args.output_dir) 120 | 121 | net = Msa2Net().cuda(0) # Msa2net + masag 122 | 123 | input = torch.rand((1,3,224,224)).cuda(0) 124 | flops = FlopCountAnalysis(net, input) 125 | model_flops = flops.total() 126 | print(f"MAdds: {round(model_flops * 1e-9, 2)} G") 127 | 128 | #sys.exit() 129 | 130 | trainer = { 131 | "Synapse": trainer_synapse, 132 | } 133 | trainer[dataset_name](args, net, args.output_dir) 134 | --------------------------------------------------------------------------------