├── Test_Minist ├── test_imgs │ ├── img1.png │ ├── img2.png │ ├── img3.png │ ├── img4.png │ └── img5.png ├── utils │ ├── decode_heads │ │ ├── __init__.py │ │ ├── fpn_head.py │ │ ├── fcn_head.py │ │ ├── segformer_head.py │ │ ├── psp_head.py │ │ ├── aspp_head.py │ │ ├── uper_head.py │ │ └── decode_head.py │ ├── transforms_utils.py │ ├── config.py │ ├── color_seg.py │ ├── labels_dict.py │ └── segformer.py ├── configs │ ├── test_720_sm.yaml │ ├── test_720_ss.yaml │ ├── test_1080_sm.yaml │ └── test_1080_ss.yaml ├── run.sh └── tools │ └── test.py └── README.md /Test_Minist/test_imgs/img1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irfanICMLL/SSIW/HEAD/Test_Minist/test_imgs/img1.png -------------------------------------------------------------------------------- /Test_Minist/test_imgs/img2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irfanICMLL/SSIW/HEAD/Test_Minist/test_imgs/img2.png -------------------------------------------------------------------------------- /Test_Minist/test_imgs/img3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irfanICMLL/SSIW/HEAD/Test_Minist/test_imgs/img3.png -------------------------------------------------------------------------------- /Test_Minist/test_imgs/img4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irfanICMLL/SSIW/HEAD/Test_Minist/test_imgs/img4.png -------------------------------------------------------------------------------- /Test_Minist/test_imgs/img5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irfanICMLL/SSIW/HEAD/Test_Minist/test_imgs/img5.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repository contains the source code of our paper: 2 | [Wei Yin, Yifan Liu, Chunhua Shen, Anton van den Hengel, Baichuan Sun, The devil is in the labels: Semantic segmentation from sentences](https://arxiv.org/abs/2202.02002) 3 | 4 | 5 | Embedding: https://cloudstor.aarnet.edu.au/plus/s/gXaGsZyvoUwu97t 6 | CKPT: https://cloudstor.aarnet.edu.au/plus/s/AtYYaVSVVAlEwve 7 | -------------------------------------------------------------------------------- /Test_Minist/utils/decode_heads/__init__.py: -------------------------------------------------------------------------------- 1 | from .aspp_head import ASPPHead 2 | 3 | from .fcn_head import FCNHead 4 | from .fpn_head import FPNHead 5 | 6 | from .psp_head import PSPHead 7 | 8 | from .uper_head import UPerHead 9 | from .segformer_head import SegFormerHead 10 | 11 | __all__ = [ 12 | 'FCNHead', 'ASPPHead', 'FPNHead', 'PSPHead', 'SegFormerHead' 13 | ] 14 | -------------------------------------------------------------------------------- /Test_Minist/configs/test_720_sm.yaml: -------------------------------------------------------------------------------- 1 | TRAIN: 2 | ignore_label: 255 3 | 4 | TEST: 5 | base_size: 720 6 | test_h: 640 7 | test_w: 640 8 | scales: [1.0] 9 | 10 | test_gpu: [0, ] 11 | model_path: 'models/segformer_7data.path' 12 | emb_path: 'models/universal_cat2vec.npy' 13 | 14 | num_model_classes: 512 15 | num_train_classes: 194 16 | test_with_embeddings: False 17 | logit_softmax_weight: 500 18 | single_scale_single_crop: False 19 | single_scale_multi_crop: True 20 | multi_scale_multi_crop: False 21 | 22 | emd_method: 'embeddings' 23 | dataset_lib: dataset 24 | 25 | distributed: False 26 | 27 | root_dir: '' 28 | cam_id: 0 29 | img_file_type: 'jpeg' 30 | gpus_num: 1 31 | save_folder: '' 32 | -------------------------------------------------------------------------------- /Test_Minist/configs/test_720_ss.yaml: -------------------------------------------------------------------------------- 1 | TRAIN: 2 | ignore_label: 255 3 | 4 | TEST: 5 | base_size: 720 6 | test_h: 640 7 | test_w: 640 8 | scales: [1.0] 9 | 10 | test_gpu: [0, ] 11 | model_path: 'models/segformer_7data.pth' 12 | emb_path: 'models/universal_cat2vec.npy' 13 | 14 | num_model_classes: 512 15 | num_train_classes: 194 16 | test_with_embeddings: False 17 | logit_softmax_weight: 500 18 | single_scale_single_crop: True 19 | single_scale_multi_crop: False 20 | multi_scale_multi_crop: False 21 | 22 | emd_method: 'embeddings' 23 | dataset_lib: dataset 24 | 25 | distributed: False 26 | 27 | root_dir: '' 28 | cam_id: 0 29 | img_file_type: 'jpeg' 30 | gpus_num: 1 31 | save_folder: '' 32 | -------------------------------------------------------------------------------- /Test_Minist/configs/test_1080_sm.yaml: -------------------------------------------------------------------------------- 1 | TRAIN: 2 | ignore_label: 255 3 | 4 | TEST: 5 | base_size: 1080 6 | test_h: 640 7 | test_w: 640 8 | scales: [1.0] 9 | 10 | test_gpu: [0, ] 11 | model_path: 'models/segformer_7data.path' 12 | emb_path: 'models/universal_cat2vec.npy' 13 | 14 | num_model_classes: 512 15 | num_train_classes: 194 16 | test_with_embeddings: False 17 | logit_softmax_weight: 500 18 | single_scale_single_crop: False 19 | single_scale_multi_crop: True 20 | multi_scale_multi_crop: False 21 | 22 | emd_method: 'embeddings' 23 | dataset_lib: dataset 24 | 25 | distributed: False 26 | 27 | root_dir: '' 28 | cam_id: 0 29 | img_file_type: 'jpeg' 30 | gpus_num: 1 31 | save_folder: '' 32 | -------------------------------------------------------------------------------- /Test_Minist/configs/test_1080_ss.yaml: -------------------------------------------------------------------------------- 1 | TRAIN: 2 | ignore_label: 255 3 | 4 | TEST: 5 | base_size: 1080 6 | test_h: 640 7 | test_w: 640 8 | scales: [1.0] 9 | 10 | test_gpu: [0, ] 11 | model_path: 'models/segformer_7data.path' 12 | emb_path: 'models/universal_cat2vec.npy' 13 | 14 | num_model_classes: 512 15 | num_train_classes: 194 16 | test_with_embeddings: False 17 | logit_softmax_weight: 500 18 | single_scale_single_crop: True 19 | single_scale_multi_crop: False 20 | multi_scale_multi_crop: False 21 | 22 | emd_method: 'embeddings' 23 | dataset_lib: dataset 24 | 25 | distributed: False 26 | 27 | root_dir: '' 28 | cam_id: 0 29 | img_file_type: 'jpeg' 30 | gpus_num: 1 31 | save_folder: '' 32 | -------------------------------------------------------------------------------- /Test_Minist/run.sh: -------------------------------------------------------------------------------- 1 | 2 | # single scale and single crop, it only forwards once per image. The short edge of the image will be resized to 1080. 3 | python tools/test.py --config test_720_ss 4 | 5 | # single scale and single crop, it only forwards once per image. The short edge of the image will be resized to 1080. 6 | #python tools/test.py --config confgs/test_1080_ss 7 | 8 | 9 | # single scale and mutiple crops. Each image will be splits to 4 crops, which are fed to the model for multiple forwards. The performance and details 10 | # will be much better than previous methods. The short edge of the image will be resized to 720. 11 | #python tools/test.py --config confgs/test_720_sm 12 | 13 | # single scale and mutiple crops. Each image will be splits to 4 crops, which are fed to the model for multiple forwards. The performance and details 14 | # will be much better than previous methods. The short edge of the image will be resized to 720. 15 | #python tools/test.py --config confgs/test_1080_sm 16 | -------------------------------------------------------------------------------- /Test_Minist/utils/decode_heads/fpn_head.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | from mmcv.cnn import ConvModule 4 | 5 | from mmseg.ops import resize 6 | #from ..builder import HEADS 7 | from .decode_head import BaseDecodeHead 8 | 9 | 10 | #@HEADS.register_module() 11 | class FPNHead(BaseDecodeHead): 12 | """Panoptic Feature Pyramid Networks. 13 | This head is the implementation of `Semantic FPN 14 | `_. 15 | Args: 16 | feature_strides (tuple[int]): The strides for input feature maps. 17 | stack_lateral. All strides suppose to be power of 2. The first 18 | one is of largest resolution. 19 | """ 20 | 21 | def __init__(self, feature_strides, **kwargs): 22 | super(FPNHead, self).__init__( 23 | input_transform='multiple_select', **kwargs) 24 | assert len(feature_strides) == len(self.in_channels) 25 | assert min(feature_strides) == feature_strides[0] 26 | self.feature_strides = feature_strides 27 | 28 | self.scale_heads = nn.ModuleList() 29 | for i in range(len(feature_strides)): 30 | head_length = max( 31 | 1, 32 | int(np.log2(feature_strides[i]) - np.log2(feature_strides[0]))) 33 | scale_head = [] 34 | for k in range(head_length): 35 | scale_head.append( 36 | ConvModule( 37 | self.in_channels[i] if k == 0 else self.channels, 38 | self.channels, 39 | 3, 40 | padding=1, 41 | conv_cfg=self.conv_cfg, 42 | norm_cfg=self.norm_cfg, 43 | act_cfg=self.act_cfg)) 44 | if feature_strides[i] != feature_strides[0]: 45 | scale_head.append( 46 | nn.Upsample( 47 | scale_factor=2, 48 | mode='bilinear', 49 | align_corners=self.align_corners)) 50 | self.scale_heads.append(nn.Sequential(*scale_head)) 51 | 52 | def forward(self, inputs): 53 | 54 | x = self._transform_inputs(inputs) 55 | 56 | output = self.scale_heads[0](x[0]) 57 | for i in range(1, len(self.feature_strides)): 58 | # non inplace 59 | output = output + resize( 60 | self.scale_heads[i](x[i]), 61 | size=output.shape[2:], 62 | mode='bilinear', 63 | align_corners=self.align_corners) 64 | 65 | output = self.cls_seg(output) 66 | return output 67 | -------------------------------------------------------------------------------- /Test_Minist/utils/decode_heads/fcn_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from mmcv.cnn import ConvModule 4 | 5 | #from ..builder import HEADS 6 | from .decode_head import BaseDecodeHead 7 | 8 | 9 | #@HEADS.register_module() 10 | class FCNHead(BaseDecodeHead): 11 | """Fully Convolution Networks for Semantic Segmentation. 12 | This head is implemented of `FCNNet `_. 13 | Args: 14 | num_convs (int): Number of convs in the head. Default: 2. 15 | kernel_size (int): The kernel size for convs in the head. Default: 3. 16 | concat_input (bool): Whether concat the input and output of convs 17 | before classification layer. 18 | """ 19 | 20 | def __init__(self, 21 | num_convs=2, 22 | kernel_size=3, 23 | concat_input=True, 24 | **kwargs): 25 | assert num_convs >= 0 26 | self.num_convs = num_convs 27 | self.concat_input = concat_input 28 | self.kernel_size = kernel_size 29 | super(FCNHead, self).__init__(**kwargs) 30 | if num_convs == 0: 31 | assert self.in_channels == self.channels 32 | 33 | convs = [] 34 | convs.append( 35 | ConvModule( 36 | self.in_channels, 37 | self.channels, 38 | kernel_size=kernel_size, 39 | padding=kernel_size // 2, 40 | conv_cfg=self.conv_cfg, 41 | norm_cfg=self.norm_cfg, 42 | act_cfg=self.act_cfg)) 43 | for i in range(num_convs - 1): 44 | convs.append( 45 | ConvModule( 46 | self.channels, 47 | self.channels, 48 | kernel_size=kernel_size, 49 | padding=kernel_size // 2, 50 | conv_cfg=self.conv_cfg, 51 | norm_cfg=self.norm_cfg, 52 | act_cfg=self.act_cfg)) 53 | if num_convs == 0: 54 | self.convs = nn.Identity() 55 | else: 56 | self.convs = nn.Sequential(*convs) 57 | if self.concat_input: 58 | self.conv_cat = ConvModule( 59 | self.in_channels + self.channels, 60 | self.channels, 61 | kernel_size=kernel_size, 62 | padding=kernel_size // 2, 63 | conv_cfg=self.conv_cfg, 64 | norm_cfg=self.norm_cfg, 65 | act_cfg=self.act_cfg) 66 | 67 | def forward(self, inputs): 68 | """Forward function.""" 69 | x = self._transform_inputs(inputs) 70 | output = self.convs(x) 71 | if self.concat_input: 72 | output = self.conv_cat(torch.cat([x, output], dim=1)) 73 | output = self.cls_seg(output) 74 | return output 75 | -------------------------------------------------------------------------------- /Test_Minist/utils/transforms_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import random 4 | import torch 5 | from typing import Optional, Tuple 6 | 7 | def get_imagenet_mean_std() -> Tuple[Tuple[float,float,float], Tuple[float,float,float]]: 8 | """ See use here in Pytorch ImageNet script: 9 | https://github.com/pytorch/examples/blob/master/imagenet/main.py#L197 10 | Returns: 11 | - mean: Tuple[float,float,float], 12 | - std: Tuple[float,float,float] = None 13 | """ 14 | value_scale = 255 15 | mean = [0.485, 0.456, 0.406] 16 | mean = [item * value_scale for item in mean] 17 | std = [0.229, 0.224, 0.225] 18 | std = [item * value_scale for item in std] 19 | return mean, std 20 | 21 | 22 | def normalize_img( input: torch.Tensor, 23 | mean: Tuple[float,float,float], 24 | std: Optional[Tuple[float,float,float]] = None): 25 | """ Pass in by reference Torch tensor, and normalize its values. 26 | Args: 27 | - input: Torch tensor of shape (3,M,N), must be in this order, and 28 | of type float (necessary). 29 | - mean: mean values for each RGB channel 30 | - std: standard deviation values for each RGB channel 31 | Returns: 32 | - None 33 | """ 34 | if std is None: 35 | for t, m in zip(input, mean): 36 | t.sub_(m) 37 | else: 38 | for t, m, s in zip(input, mean, std): 39 | t.sub_(m).div_(s) 40 | 41 | 42 | def pad_to_crop_sz( 43 | image: np.ndarray, 44 | crop_h: int, 45 | crop_w: int, 46 | mean: Tuple[float,float,float] 47 | ) -> Tuple[np.ndarray,int,int]: 48 | ori_h, ori_w, _ = image.shape 49 | pad_h = max(crop_h - ori_h, 0) 50 | pad_w = max(crop_w - ori_w, 0) 51 | pad_h_half = int(pad_h / 2) 52 | pad_w_half = int(pad_w / 2) 53 | if pad_h > 0 or pad_w > 0: 54 | image = cv2.copyMakeBorder( 55 | src=image, 56 | top=pad_h_half, 57 | bottom=pad_h - pad_h_half, 58 | left=pad_w_half, 59 | right=pad_w - pad_w_half, 60 | borderType=cv2.BORDER_CONSTANT, 61 | value=mean) 62 | return image, pad_h_half, pad_w_half 63 | 64 | 65 | def resize_by_scaled_short_side( 66 | image: np.ndarray, 67 | base_size: int, 68 | scale: float) -> np.ndarray: 69 | """ Equivalent to ResizeShort(), but functional, instead of OOP paradigm, and w/ scale param. 70 | 71 | Args: 72 | image: Numpy array of shape () 73 | scale: scaling factor for image 74 | 75 | Returns: 76 | image_scaled: 77 | """ 78 | h, w, _ = image.shape 79 | short_size = round(scale * base_size) 80 | new_h = short_size 81 | new_w = short_size 82 | # Preserve the aspect ratio 83 | if h > w: 84 | new_h = round(short_size / float(w) * h) 85 | else: 86 | new_w = round(short_size / float(h) * w) 87 | image_scaled = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR) 88 | return image_scaled 89 | -------------------------------------------------------------------------------- /Test_Minist/utils/decode_heads/segformer_head.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | import torch 4 | from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule 5 | from collections import OrderedDict 6 | 7 | from mmseg.ops import resize 8 | from .decode_head import BaseDecodeHead 9 | 10 | 11 | 12 | class MLP(nn.Module): 13 | """ 14 | Linear Embedding 15 | """ 16 | def __init__(self, input_dim=2048, embed_dim=768): 17 | super().__init__() 18 | self.proj = nn.Linear(input_dim, embed_dim) 19 | 20 | def forward(self, x): 21 | x = x.flatten(2).transpose(1, 2) 22 | x = self.proj(x) 23 | return x 24 | 25 | 26 | 27 | class SegFormerHead(BaseDecodeHead): 28 | """ 29 | SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers 30 | """ 31 | def __init__(self, feature_strides, **kwargs): 32 | super(SegFormerHead, self).__init__(input_transform='multiple_select', **kwargs) 33 | assert len(feature_strides) == len(self.in_channels) 34 | assert min(feature_strides) == feature_strides[0] 35 | self.feature_strides = feature_strides 36 | 37 | c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels 38 | 39 | decoder_params = dict(embed_dim=768)#kwargs['decoder_params'] 40 | embedding_dim = decoder_params['embed_dim'] 41 | 42 | self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim) 43 | self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim) 44 | self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim) 45 | self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim) 46 | 47 | self.linear_fuse = ConvModule( 48 | in_channels=embedding_dim*4, 49 | out_channels=embedding_dim, 50 | kernel_size=1, 51 | norm_cfg=dict(type='SyncBN', requires_grad=True) 52 | ) 53 | 54 | self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1) 55 | 56 | def forward(self, inputs): 57 | x = inputs #self._transform_inputs(inputs) # len=4, 1/4,1/8,1/16,1/32 58 | c1, c2, c3, c4 = x 59 | 60 | ############## MLP decoder on C1-C4 ########### 61 | n, _, h, w = c4.shape 62 | 63 | _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3]) 64 | _c4 = resize(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False) 65 | 66 | _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3]) 67 | _c3 = resize(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False) 68 | 69 | _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3]) 70 | _c2 = resize(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False) 71 | 72 | _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3]) 73 | 74 | _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) 75 | 76 | x = self.dropout(_c) 77 | x = self.linear_pred(x) 78 | 79 | return x 80 | -------------------------------------------------------------------------------- /Test_Minist/utils/decode_heads/psp_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from mmcv.cnn import ConvModule 4 | 5 | from mmseg.ops import resize 6 | #from ..builder import HEADS 7 | from .decode_head import BaseDecodeHead 8 | 9 | 10 | class PPM(nn.ModuleList): 11 | """Pooling Pyramid Module used in PSPNet. 12 | Args: 13 | pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid 14 | Module. 15 | in_channels (int): Input channels. 16 | channels (int): Channels after modules, before conv_seg. 17 | conv_cfg (dict|None): Config of conv layers. 18 | norm_cfg (dict|None): Config of norm layers. 19 | act_cfg (dict): Config of activation layers. 20 | align_corners (bool): align_corners argument of F.interpolate. 21 | """ 22 | 23 | def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg, 24 | act_cfg, align_corners): 25 | super(PPM, self).__init__() 26 | self.pool_scales = pool_scales 27 | self.align_corners = align_corners 28 | self.in_channels = in_channels 29 | self.channels = channels 30 | self.conv_cfg = conv_cfg 31 | self.norm_cfg = norm_cfg 32 | self.act_cfg = act_cfg 33 | for pool_scale in pool_scales: 34 | self.append( 35 | nn.Sequential( 36 | nn.AdaptiveAvgPool2d(pool_scale), 37 | ConvModule( 38 | self.in_channels, 39 | self.channels, 40 | 1, 41 | conv_cfg=self.conv_cfg, 42 | norm_cfg=self.norm_cfg, 43 | act_cfg=self.act_cfg))) 44 | 45 | def forward(self, x): 46 | """Forward function.""" 47 | ppm_outs = [] 48 | for ppm in self: 49 | ppm_out = ppm(x) 50 | upsampled_ppm_out = resize( 51 | ppm_out, 52 | size=x.size()[2:], 53 | mode='bilinear', 54 | align_corners=self.align_corners) 55 | ppm_outs.append(upsampled_ppm_out) 56 | return ppm_outs 57 | 58 | 59 | #@HEADS.register_module() 60 | class PSPHead(BaseDecodeHead): 61 | """Pyramid Scene Parsing Network. 62 | This head is the implementation of 63 | `PSPNet `_. 64 | Args: 65 | pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid 66 | Module. Default: (1, 2, 3, 6). 67 | """ 68 | 69 | def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): 70 | super(PSPHead, self).__init__(**kwargs) 71 | assert isinstance(pool_scales, (list, tuple)) 72 | self.pool_scales = pool_scales 73 | self.psp_modules = PPM( 74 | self.pool_scales, 75 | self.in_channels, 76 | self.channels, 77 | conv_cfg=self.conv_cfg, 78 | norm_cfg=self.norm_cfg, 79 | act_cfg=self.act_cfg, 80 | align_corners=self.align_corners) 81 | self.bottleneck = ConvModule( 82 | self.in_channels + len(pool_scales) * self.channels, 83 | self.channels, 84 | 3, 85 | padding=1, 86 | conv_cfg=self.conv_cfg, 87 | norm_cfg=self.norm_cfg, 88 | act_cfg=self.act_cfg) 89 | 90 | def forward(self, inputs): 91 | """Forward function.""" 92 | x = self._transform_inputs(inputs) 93 | psp_outs = [x] 94 | psp_outs.extend(self.psp_modules(x)) 95 | psp_outs = torch.cat(psp_outs, dim=1) 96 | output = self.bottleneck(psp_outs) 97 | output = self.cls_seg(output) 98 | return output 99 | -------------------------------------------------------------------------------- /Test_Minist/utils/decode_heads/aspp_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from mmcv.cnn import ConvModule 4 | 5 | from mmseg.ops import resize 6 | #from ..builder import HEADS 7 | from .decode_head import BaseDecodeHead 8 | 9 | 10 | class ASPPModule(nn.ModuleList): 11 | """Atrous Spatial Pyramid Pooling (ASPP) Module. 12 | Args: 13 | dilations (tuple[int]): Dilation rate of each layer. 14 | in_channels (int): Input channels. 15 | channels (int): Channels after modules, before conv_seg. 16 | conv_cfg (dict|None): Config of conv layers. 17 | norm_cfg (dict|None): Config of norm layers. 18 | act_cfg (dict): Config of activation layers. 19 | """ 20 | 21 | def __init__(self, dilations, in_channels, channels, conv_cfg, norm_cfg, 22 | act_cfg): 23 | super(ASPPModule, self).__init__() 24 | self.dilations = dilations 25 | self.in_channels = in_channels 26 | self.channels = channels 27 | self.conv_cfg = conv_cfg 28 | self.norm_cfg = norm_cfg 29 | self.act_cfg = act_cfg 30 | for dilation in dilations: 31 | self.append( 32 | ConvModule( 33 | self.in_channels, 34 | self.channels, 35 | 1 if dilation == 1 else 3, 36 | dilation=dilation, 37 | padding=0 if dilation == 1 else dilation, 38 | conv_cfg=self.conv_cfg, 39 | norm_cfg=self.norm_cfg, 40 | act_cfg=self.act_cfg)) 41 | 42 | def forward(self, x): 43 | """Forward function.""" 44 | aspp_outs = [] 45 | for aspp_module in self: 46 | aspp_outs.append(aspp_module(x)) 47 | 48 | return aspp_outs 49 | 50 | 51 | #@HEADS.register_module() 52 | class ASPPHead(BaseDecodeHead): 53 | """Rethinking Atrous Convolution for Semantic Image Segmentation. 54 | This head is the implementation of `DeepLabV3 55 | `_. 56 | Args: 57 | dilations (tuple[int]): Dilation rates for ASPP module. 58 | Default: (1, 6, 12, 18). 59 | """ 60 | 61 | def __init__(self, dilations=(1, 6, 12, 18), **kwargs): 62 | super(ASPPHead, self).__init__(**kwargs) 63 | assert isinstance(dilations, (list, tuple)) 64 | self.dilations = dilations 65 | self.image_pool = nn.Sequential( 66 | nn.AdaptiveAvgPool2d(1), 67 | ConvModule( 68 | self.in_channels, 69 | self.channels, 70 | 1, 71 | conv_cfg=self.conv_cfg, 72 | norm_cfg=self.norm_cfg, 73 | act_cfg=self.act_cfg)) 74 | self.aspp_modules = ASPPModule( 75 | dilations, 76 | self.in_channels, 77 | self.channels, 78 | conv_cfg=self.conv_cfg, 79 | norm_cfg=self.norm_cfg, 80 | act_cfg=self.act_cfg) 81 | self.bottleneck = ConvModule( 82 | (len(dilations) + 1) * self.channels, 83 | self.channels, 84 | 3, 85 | padding=1, 86 | conv_cfg=self.conv_cfg, 87 | norm_cfg=self.norm_cfg, 88 | act_cfg=self.act_cfg) 89 | 90 | def forward(self, inputs): 91 | """Forward function.""" 92 | x = self._transform_inputs(inputs) 93 | aspp_outs = [ 94 | resize( 95 | self.image_pool(x), 96 | size=x.size()[2:], 97 | mode='bilinear', 98 | align_corners=self.align_corners) 99 | ] 100 | aspp_outs.extend(self.aspp_modules(x)) 101 | aspp_outs = torch.cat(aspp_outs, dim=1) 102 | output = self.bottleneck(aspp_outs) 103 | output = self.cls_seg(output) 104 | return output 105 | -------------------------------------------------------------------------------- /Test_Minist/utils/decode_heads/uper_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from mmcv.cnn import ConvModule 4 | 5 | from mmseg.ops import resize 6 | #from ..builder import HEADS 7 | from .decode_head import BaseDecodeHead 8 | from .psp_head import PPM 9 | 10 | 11 | #@HEADS.register_module() 12 | class UPerHead(BaseDecodeHead): 13 | """Unified Perceptual Parsing for Scene Understanding. 14 | This head is the implementation of `UPerNet 15 | `_. 16 | Args: 17 | pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid 18 | Module applied on the last feature. Default: (1, 2, 3, 6). 19 | """ 20 | 21 | def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): 22 | super(UPerHead, self).__init__( 23 | input_transform='multiple_select', **kwargs) 24 | # PSP Module 25 | self.psp_modules = PPM( 26 | pool_scales, 27 | self.in_channels[-1], 28 | self.channels, 29 | conv_cfg=self.conv_cfg, 30 | norm_cfg=self.norm_cfg, 31 | act_cfg=self.act_cfg, 32 | align_corners=self.align_corners) 33 | self.bottleneck = ConvModule( 34 | self.in_channels[-1] + len(pool_scales) * self.channels, 35 | self.channels, 36 | 3, 37 | padding=1, 38 | conv_cfg=self.conv_cfg, 39 | norm_cfg=self.norm_cfg, 40 | act_cfg=self.act_cfg) 41 | # FPN Module 42 | self.lateral_convs = nn.ModuleList() 43 | self.fpn_convs = nn.ModuleList() 44 | for in_channels in self.in_channels[:-1]: # skip the top layer 45 | l_conv = ConvModule( 46 | in_channels, 47 | self.channels, 48 | 1, 49 | conv_cfg=self.conv_cfg, 50 | norm_cfg=self.norm_cfg, 51 | act_cfg=self.act_cfg, 52 | inplace=False) 53 | fpn_conv = ConvModule( 54 | self.channels, 55 | self.channels, 56 | 3, 57 | padding=1, 58 | conv_cfg=self.conv_cfg, 59 | norm_cfg=self.norm_cfg, 60 | act_cfg=self.act_cfg, 61 | inplace=False) 62 | self.lateral_convs.append(l_conv) 63 | self.fpn_convs.append(fpn_conv) 64 | 65 | self.fpn_bottleneck = ConvModule( 66 | len(self.in_channels) * self.channels, 67 | self.channels, 68 | 3, 69 | padding=1, 70 | conv_cfg=self.conv_cfg, 71 | norm_cfg=self.norm_cfg, 72 | act_cfg=self.act_cfg) 73 | 74 | def psp_forward(self, inputs): 75 | """Forward function of PSP module.""" 76 | x = inputs[-1] 77 | psp_outs = [x] 78 | psp_outs.extend(self.psp_modules(x)) 79 | psp_outs = torch.cat(psp_outs, dim=1) 80 | output = self.bottleneck(psp_outs) 81 | 82 | return output 83 | 84 | def forward(self, inputs): 85 | """Forward function.""" 86 | 87 | inputs = self._transform_inputs(inputs) 88 | 89 | # build laterals 90 | laterals = [ 91 | lateral_conv(inputs[i]) 92 | for i, lateral_conv in enumerate(self.lateral_convs) 93 | ] 94 | 95 | laterals.append(self.psp_forward(inputs)) 96 | 97 | # build top-down path 98 | used_backbone_levels = len(laterals) 99 | for i in range(used_backbone_levels - 1, 0, -1): 100 | prev_shape = laterals[i - 1].shape[2:] 101 | laterals[i - 1] += resize( 102 | laterals[i], 103 | size=prev_shape, 104 | mode='bilinear', 105 | align_corners=self.align_corners) 106 | 107 | # build outputs 108 | fpn_outs = [ 109 | self.fpn_convs[i](laterals[i]) 110 | for i in range(used_backbone_levels - 1) 111 | ] 112 | # append psp feature 113 | fpn_outs.append(laterals[-1]) 114 | 115 | for i in range(used_backbone_levels - 1, 0, -1): 116 | fpn_outs[i] = resize( 117 | fpn_outs[i], 118 | size=fpn_outs[0].shape[2:], 119 | mode='bilinear', 120 | align_corners=self.align_corners) 121 | fpn_outs = torch.cat(fpn_outs, dim=1) 122 | output = self.fpn_bottleneck(fpn_outs) 123 | output = self.cls_seg(output) 124 | return output 125 | -------------------------------------------------------------------------------- /Test_Minist/utils/config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import os 3 | from ast import literal_eval 4 | import copy 5 | 6 | 7 | class CfgNode(dict): 8 | """ 9 | CfgNode represents an internal node in the configuration tree. It's a simple 10 | dict-like container that allows for attribute-based access to keys. 11 | """ 12 | 13 | def __init__(self, init_dict=None, key_list=None, new_allowed=False): 14 | # Recursively convert nested dictionaries in init_dict into CfgNodes 15 | init_dict = {} if init_dict is None else init_dict 16 | key_list = [] if key_list is None else key_list 17 | for k, v in init_dict.items(): 18 | if type(v) is dict: 19 | # Convert dict to CfgNode 20 | init_dict[k] = CfgNode(v, key_list=key_list + [k]) 21 | super(CfgNode, self).__init__(init_dict) 22 | 23 | def __getattr__(self, name): 24 | if name in self: 25 | return self[name] 26 | else: 27 | raise AttributeError(name) 28 | 29 | def __setattr__(self, name, value): 30 | self[name] = value 31 | 32 | def __str__(self): 33 | def _indent(s_, num_spaces): 34 | s = s_.split("\n") 35 | if len(s) == 1: 36 | return s_ 37 | first = s.pop(0) 38 | s = [(num_spaces * " ") + line for line in s] 39 | s = "\n".join(s) 40 | s = first + "\n" + s 41 | return s 42 | 43 | r = "" 44 | s = [] 45 | for k, v in sorted(self.items()): 46 | seperator = "\n" if isinstance(v, CfgNode) else " " 47 | attr_str = "{}:{}{}".format(str(k), seperator, str(v)) 48 | attr_str = _indent(attr_str, 2) 49 | s.append(attr_str) 50 | r += "\n".join(s) 51 | return r 52 | 53 | def __repr__(self): 54 | return "{}({})".format(self.__class__.__name__, super(CfgNode, self).__repr__()) 55 | 56 | 57 | def load_cfg_from_cfg_file(file): 58 | cfg = {} 59 | assert os.path.isfile(file) and file.endswith('.yaml'), \ 60 | '{} is not a yaml file'.format(file) 61 | 62 | with open(file, 'r') as f: 63 | cfg_from_file = yaml.safe_load(f) 64 | 65 | for key in cfg_from_file: 66 | for k, v in cfg_from_file[key].items(): 67 | cfg[k] = v 68 | 69 | cfg = CfgNode(cfg) 70 | return cfg 71 | 72 | 73 | def merge_cfg_from_list(cfg, cfg_list): 74 | new_cfg = copy.deepcopy(cfg) 75 | assert len(cfg_list) % 2 == 0 76 | for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]): 77 | subkey = full_key.split('.')[-1] 78 | assert subkey in cfg, 'Non-existent key: {}'.format(full_key) 79 | value = _decode_cfg_value(v) 80 | value = _check_and_coerce_cfg_value_type( 81 | value, cfg[subkey], subkey, full_key 82 | ) 83 | setattr(new_cfg, subkey, value) 84 | 85 | return new_cfg 86 | 87 | 88 | def _decode_cfg_value(v): 89 | """Decodes a raw config value (e.g., from a yaml config files or command 90 | line argument) into a Python object. 91 | """ 92 | # All remaining processing is only applied to strings 93 | if not isinstance(v, str): 94 | return v 95 | # Try to interpret `v` as a: 96 | # string, number, tuple, list, dict, boolean, or None 97 | try: 98 | v = literal_eval(v) 99 | # The following two excepts allow v to pass through when it represents a 100 | # string. 101 | # 102 | # Longer explanation: 103 | # The type of v is always a string (before calling literal_eval), but 104 | # sometimes it *represents* a string and other times a data structure, like 105 | # a list. In the case that v represents a string, what we got back from the 106 | # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is 107 | # ok with '"foo"', but will raise a ValueError if given 'foo'. In other 108 | # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval 109 | # will raise a SyntaxError. 110 | except ValueError: 111 | pass 112 | except SyntaxError: 113 | pass 114 | return v 115 | 116 | 117 | def _check_and_coerce_cfg_value_type(replacement, original, key, full_key): 118 | """Checks that `replacement`, which is intended to replace `original` is of 119 | the right type. The type is correct if it matches exactly or is one of a few 120 | cases in which the type can be easily coerced. 121 | """ 122 | original_type = type(original) 123 | replacement_type = type(replacement) 124 | 125 | # The types must match (with some exceptions) 126 | if replacement_type == original_type: 127 | return replacement 128 | 129 | # Cast replacement from from_type to to_type if the replacement and original 130 | # types match from_type and to_type 131 | def conditional_cast(from_type, to_type): 132 | if replacement_type == from_type and original_type == to_type: 133 | return True, to_type(replacement) 134 | else: 135 | return False, None 136 | 137 | # Conditionally casts 138 | # list <-> tuple 139 | casts = [(tuple, list), (list, tuple)] 140 | # For py2: allow converting from str (bytes) to a unicode string 141 | try: 142 | casts.append((str, unicode)) # noqa: F821 143 | except Exception: 144 | pass 145 | 146 | for (from_type, to_type) in casts: 147 | converted, converted_value = conditional_cast(from_type, to_type) 148 | if converted: 149 | return converted_value 150 | 151 | # Original is None, can directly replace it 152 | if original_type == type(None) and original == None: 153 | return replacement 154 | 155 | raise ValueError( 156 | "Type mismatch ({} vs. {}) with values ({} vs. {}) for config " 157 | "key: {}".format( 158 | original_type, replacement_type, original, replacement, full_key 159 | ) 160 | ) 161 | 162 | 163 | def _assert_with_logging(cond, msg): 164 | if not cond: 165 | logger.debug(msg) 166 | assert cond, msg 167 | -------------------------------------------------------------------------------- /Test_Minist/utils/color_seg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def make_palette(num_classes=256): 5 | """ 6 | Inputs: 7 | num_classes: the number of classes 8 | Outputs: 9 | palette: the colormap as a k x 3 array of RGB colors 10 | """ 11 | palette = np.zeros((num_classes, 3), dtype=np.uint8) 12 | for k in range(0, num_classes): 13 | label = k 14 | i = 0 15 | while label: 16 | palette[k, 0] |= (((label >> 0) & 1) << (7 - i)) 17 | palette[k, 1] |= (((label >> 1) & 1) << (7 - i)) 18 | palette[k, 2] |= (((label >> 2) & 1) << (7 - i)) 19 | label >>= 3 20 | i += 1 21 | idx1 = np.arange(0, num_classes, 2)[::-1] 22 | idx2 = np.arange(1, num_classes, 2) 23 | idx = np.concatenate([idx1[:, None], idx2[:, None]], axis=1).flatten() 24 | palette = palette[idx] 25 | palette[num_classes - 1, :] = [255, 255, 255] 26 | return palette 27 | 28 | PALETTE = make_palette(256) 29 | 30 | 31 | def color_seg(seg, palette=None): 32 | if palette == None: 33 | color_out = PALETTE[seg.reshape(-1)].reshape(seg.shape + (3,)) 34 | else: 35 | color_out = palette[seg.reshape(-1)].reshape(seg.shape + (3,)) 36 | return color_out 37 | 38 | def color_map_list(class_num): 39 | map1 = np.asarray([ 40 | [0, 0, 0], 41 | [120, 120, 120], 42 | [180, 120, 120], 43 | [6, 230, 230], 44 | [80, 50, 50], 45 | [4, 200, 3], 46 | [120, 120, 80], 47 | [204, 5, 255], 48 | [230, 230, 230], 49 | [4, 250, 7], 50 | [224, 5, 255], 51 | [235, 255, 7], 52 | [150, 5, 61], 53 | [120, 120, 70], 54 | [8, 255, 51], 55 | [255, 6, 82], 56 | [143, 255, 140], 57 | [204, 255, 4], 58 | [255, 51, 7], 59 | [204, 70, 3], 60 | [0, 102, 200], 61 | [61, 230, 250], 62 | [255, 6, 51], 63 | [11, 102, 255], 64 | [255, 7, 71], 65 | [255, 9, 224], 66 | [9, 7, 230], 67 | [255, 9, 92], 68 | [112, 9, 255], 69 | [8, 255, 214], 70 | [7, 255, 224], 71 | [255, 184, 6], 72 | [10, 255, 71], 73 | [255, 41, 10], 74 | [7, 255, 255], 75 | [224, 255, 8], 76 | [102, 8, 255], 77 | [255, 61, 6], 78 | [255, 194, 7], 79 | [255, 122, 8], 80 | [0, 255, 20], 81 | [255, 8, 41], 82 | [255, 5, 153], 83 | [6, 51, 255], 84 | [235, 12, 255], 85 | [160, 150, 20], 86 | [0, 163, 255], 87 | [140, 140, 140], 88 | [250, 10, 15], 89 | [20, 255, 0], 90 | [31, 255, 0], 91 | [255, 31, 0], 92 | [255, 224, 0], 93 | [153, 255, 0], 94 | [0, 0, 255], 95 | [255, 71, 0], 96 | [0, 235, 255], 97 | [0, 173, 255], 98 | [31, 0, 255], 99 | [11, 200, 200], 100 | [255, 82, 0], 101 | [0, 255, 245], 102 | [0, 61, 255], 103 | [0, 255, 112], 104 | [0, 255, 133], 105 | [255, 163, 0], 106 | [255, 102, 0], 107 | [194, 255, 0], 108 | [0, 143, 255], 109 | [51, 255, 0], 110 | [0, 82, 255], 111 | [0, 255, 41], 112 | [0, 255, 173], 113 | [10, 0, 255], 114 | [173, 255, 0], 115 | [0, 255, 153], 116 | [255, 92, 0], 117 | [255, 0, 255], 118 | [255, 0, 245], 119 | [255, 0, 102], 120 | [255, 173, 0], 121 | [255, 0, 20], 122 | [255, 184, 184], 123 | [0, 31, 255], 124 | [0, 255, 61], 125 | [0, 71, 255], 126 | [255, 0, 204], 127 | [0, 255, 194], 128 | [0, 255, 82], 129 | [0, 10, 255], 130 | [0, 112, 255], 131 | [51, 0, 255], 132 | [0, 194, 255], 133 | [0, 122, 255], 134 | [0, 255, 163], 135 | [255, 153, 0], 136 | [0, 255, 10], 137 | [255, 112, 0], 138 | [143, 255, 0], 139 | [82, 0, 255], 140 | [163, 255, 0], 141 | [255, 235, 0], 142 | [8, 184, 170], 143 | [133, 0, 255], 144 | [0, 255, 92], 145 | [184, 0, 255], 146 | [255, 0, 31], 147 | [0, 184, 255], 148 | [0, 214, 255], 149 | [255, 0, 112], 150 | [92, 255, 0], 151 | [0, 224, 255], 152 | [112, 224, 255], 153 | [70, 184, 160], 154 | [163, 0, 255], 155 | [153, 0, 255], 156 | [71, 255, 0], 157 | [255, 0, 163], 158 | [255, 204, 0], 159 | [255, 0, 143], 160 | [0, 255, 235], 161 | [133, 255, 0], 162 | [255, 0, 235], 163 | [245, 0, 255], 164 | [255, 0, 122], 165 | [255, 245, 0], 166 | [10, 190, 212], 167 | [214, 255, 0], 168 | [0, 204, 255], 169 | [20, 0, 255], 170 | [255, 255, 0], 171 | [0, 153, 255], 172 | [0, 41, 255], 173 | [0, 255, 204], 174 | [41, 0, 255], 175 | [41, 255, 0], 176 | [173, 0, 255], 177 | [0, 245, 255], 178 | [71, 0, 255], 179 | [122, 0, 255], 180 | [0, 255, 184], 181 | [0, 92, 255], 182 | [184, 255, 0], 183 | [0, 133, 255], 184 | [255, 214, 0], 185 | [25, 194, 194], 186 | [102, 255, 0], 187 | [92, 0, 255], 188 | [165, 42, 42], 189 | [0, 192, 0], 190 | [196, 196, 196], 191 | [190, 153, 153], 192 | [180, 165, 180], 193 | [102, 102, 156], 194 | [128, 64, 255], 195 | [140, 140, 200], 196 | [170, 170, 170], 197 | [250, 170, 160], 198 | [96, 96, 96], 199 | [230, 150, 140], 200 | [128, 64, 128], 201 | [110, 110, 110], 202 | [244, 35, 232], 203 | [150, 100, 100], 204 | [70, 70, 70], 205 | [150, 120, 90], 206 | [220, 20, 60], 207 | [255, 0, 0], 208 | [200, 128, 128], 209 | [64, 170, 64], 210 | [128, 64, 64], 211 | [70, 130, 180], 212 | [152, 251, 152], 213 | [107, 142, 35], 214 | [0, 170, 30], 215 | [255, 255, 128], 216 | [250, 0, 30], 217 | [220, 220, 220], 218 | [222, 40, 40], 219 | [100, 170, 30], 220 | [40, 40, 40], 221 | [33, 33, 33], 222 | [0, 0, 142], 223 | [210, 170, 100], 224 | [153, 153, 153], 225 | [128, 128, 128], 226 | [250, 170, 30], 227 | [192, 192, 192], 228 | [220, 220, 0], 229 | [119, 11, 32], 230 | [0, 80, 100], 231 | [149, 32, 32], 232 | [10, 59, 140], 233 | [160, 0, 142], 234 | [0, 60, 100], 235 | [240, 100, 100] 236 | ]) 237 | idx1 = np.arange(0, map1.shape[0], 2)[::-1] 238 | idx2 = np.arange(1, map1.shape[0], 2) 239 | idx = np.concatenate([idx1[:, None], idx2[:, None]], axis=1).flatten() 240 | map1 = map1[idx] 241 | 242 | pa = np.ones((class_num, 3), dtype=np.uint8) * 255 243 | pa[:map1.shape[0], :] = map1 244 | return pa 245 | -------------------------------------------------------------------------------- /Test_Minist/utils/decode_heads/decode_head.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | import torch 4 | import torch.nn as nn 5 | #from mmcv.cnn import normal_init 6 | #from mmcv.runner import auto_fp16, force_fp32 7 | 8 | #from mmseg.core import build_pixel_sampler 9 | #from mmseg.ops import resize 10 | #from ..builder import build_loss 11 | #from ..losses import accuracy 12 | 13 | 14 | class BaseDecodeHead(nn.Module, metaclass=ABCMeta): 15 | """Base class for BaseDecodeHead. 16 | Args: 17 | in_channels (int|Sequence[int]): Input channels. 18 | channels (int): Channels after modules, before conv_seg. 19 | num_classes (int): Number of classes. 20 | dropout_ratio (float): Ratio of dropout layer. Default: 0.1. 21 | conv_cfg (dict|None): Config of conv layers. Default: None. 22 | norm_cfg (dict|None): Config of norm layers. Default: None. 23 | act_cfg (dict): Config of activation layers. 24 | Default: dict(type='ReLU') 25 | in_index (int|Sequence[int]): Input feature index. Default: -1 26 | input_transform (str|None): Transformation type of input features. 27 | Options: 'resize_concat', 'multiple_select', None. 28 | 'resize_concat': Multiple feature maps will be resize to the 29 | same size as first one and than concat together. 30 | Usually used in FCN head of HRNet. 31 | 'multiple_select': Multiple feature maps will be bundle into 32 | a list and passed into decode head. 33 | None: Only one select feature map is allowed. 34 | Default: None. 35 | loss_decode (dict): Config of decode loss. 36 | Default: dict(type='CrossEntropyLoss'). 37 | ignore_index (int | None): The label index to be ignored. When using 38 | masked BCE loss, ignore_index should be set to None. Default: 255 39 | sampler (dict|None): The config of segmentation map sampler. 40 | Default: None. 41 | align_corners (bool): align_corners argument of F.interpolate. 42 | Default: False. 43 | """ 44 | 45 | def __init__(self, 46 | in_channels, 47 | channels, 48 | *, 49 | num_classes, 50 | dropout_ratio=0.1, 51 | conv_cfg=None, 52 | norm_cfg=None, 53 | act_cfg=dict(type='ReLU'), 54 | in_index=-1, 55 | input_transform=None, 56 | loss_decode=dict( 57 | type='CrossEntropyLoss', 58 | use_sigmoid=False, 59 | loss_weight=1.0), 60 | ignore_index=255, 61 | sampler=None, 62 | align_corners=False): 63 | super(BaseDecodeHead, self).__init__() 64 | self._init_inputs(in_channels, in_index, input_transform) 65 | self.channels = channels 66 | self.num_classes = num_classes 67 | self.dropout_ratio = dropout_ratio 68 | self.conv_cfg = conv_cfg 69 | self.norm_cfg = norm_cfg 70 | self.act_cfg = act_cfg 71 | self.in_index = in_index 72 | self.loss_decode = None # build_loss(loss_decode) 73 | self.ignore_index = ignore_index 74 | self.align_corners = align_corners 75 | # if sampler is not None: 76 | # self.sampler = build_pixel_sampler(sampler, context=self) 77 | # else: 78 | # self.sampler = None 79 | self.sampler = None 80 | 81 | self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1) 82 | if dropout_ratio > 0: 83 | self.dropout = nn.Dropout2d(dropout_ratio) 84 | else: 85 | self.dropout = None 86 | self.fp16_enabled = False 87 | 88 | def extra_repr(self): 89 | pass 90 | # """Extra repr.""" 91 | # s = f'input_transform={self.input_transform}, ' \ 92 | # f'ignore_index={self.ignore_index}, ' \ 93 | # f'align_corners={self.align_corners}' 94 | # return s 95 | 96 | def _init_inputs(self, in_channels, in_index, input_transform): 97 | # """Check and initialize input transforms. 98 | # 99 | # The in_channels, in_index and input_transform must match. 100 | # Specifically, when input_transform is None, only single feature map 101 | # will be selected. So in_channels and in_index must be of type int. 102 | # When input_transform 103 | # 104 | # Args: 105 | # in_channels (int|Sequence[int]): Input channels. 106 | # in_index (int|Sequence[int]): Input feature index. 107 | # input_transform (str|None): Transformation type of input features. 108 | # Options: 'resize_concat', 'multiple_select', None. 109 | # 'resize_concat': Multiple feature maps will be resize to the 110 | # same size as first one and than concat together. 111 | # Usually used in FCN head of HRNet. 112 | # 'multiple_select': Multiple feature maps will be bundle into 113 | # a list and passed into decode head. 114 | # None: Only one select feature map is allowed. 115 | # """ 116 | # 117 | if input_transform is not None: 118 | assert input_transform in ['resize_concat', 'multiple_select'] 119 | self.input_transform = input_transform 120 | self.in_index = in_index 121 | if input_transform is not None: 122 | assert isinstance(in_channels, (list, tuple)) 123 | assert isinstance(in_index, (list, tuple)) 124 | assert len(in_channels) == len(in_index) 125 | if input_transform == 'resize_concat': 126 | self.in_channels = sum(in_channels) 127 | else: 128 | self.in_channels = in_channels 129 | else: 130 | assert isinstance(in_channels, int) 131 | assert isinstance(in_index, int) 132 | self.in_channels = in_channels 133 | 134 | def init_weights(self): 135 | pass 136 | # """Initialize weights of classification layer.""" 137 | # normal_init(self.conv_seg, mean=0, std=0.01) 138 | 139 | def _transform_inputs(self, inputs): 140 | pass 141 | # """Transform inputs for decoder. 142 | # 143 | # Args: 144 | # inputs (list[Tensor]): List of multi-level img features. 145 | # 146 | # Returns: 147 | # Tensor: The transformed inputs 148 | # """ 149 | # 150 | # if self.input_transform == 'resize_concat': 151 | # inputs = [inputs[i] for i in self.in_index] 152 | # upsampled_inputs = [ 153 | # resize( 154 | # input=x, 155 | # size=inputs[0].shape[2:], 156 | # mode='bilinear', 157 | # align_corners=self.align_corners) for x in inputs 158 | # ] 159 | # inputs = torch.cat(upsampled_inputs, dim=1) 160 | # elif self.input_transform == 'multiple_select': 161 | # inputs = [inputs[i] for i in self.in_index] 162 | # else: 163 | # inputs = inputs[self.in_index] 164 | # 165 | # return inputs 166 | 167 | #@auto_fp16() 168 | #@abstractmethod 169 | def forward(self, inputs): 170 | """Placeholder of forward function.""" 171 | pass 172 | 173 | def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg): 174 | pass 175 | # """Forward function for training. 176 | # Args: 177 | # inputs (list[Tensor]): List of multi-level img features. 178 | # img_metas (list[dict]): List of image info dict where each dict 179 | # has: 'img_shape', 'scale_factor', 'flip', and may also contain 180 | # 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. 181 | # For details on the values of these keys see 182 | # `mmseg/datasets/pipelines/formatting.py:Collect`. 183 | # gt_semantic_seg (Tensor): Semantic segmentation masks 184 | # used if the architecture supports semantic segmentation task. 185 | # train_cfg (dict): The training config. 186 | # 187 | # Returns: 188 | # dict[str, Tensor]: a dictionary of loss components 189 | # """ 190 | # seg_logits = self.forward(inputs) 191 | # losses = self.losses(seg_logits, gt_semantic_seg) 192 | # return losses 193 | 194 | def forward_test(self, inputs, img_metas, test_cfg): 195 | pass 196 | # """Forward function for testing. 197 | # 198 | # Args: 199 | # inputs (list[Tensor]): List of multi-level img features. 200 | # img_metas (list[dict]): List of image info dict where each dict 201 | # has: 'img_shape', 'scale_factor', 'flip', and may also contain 202 | # 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. 203 | # For details on the values of these keys see 204 | # `mmseg/datasets/pipelines/formatting.py:Collect`. 205 | # test_cfg (dict): The testing config. 206 | # 207 | # Returns: 208 | # Tensor: Output segmentation map. 209 | # """ 210 | # output = self.forward(inputs) 211 | # output = self.emb2cls(output) 212 | # return output 213 | 214 | def cls_seg(self, feat): 215 | """Classify each pixel.""" 216 | if self.dropout is not None: 217 | feat = self.dropout(feat) 218 | output = self.conv_seg(feat) 219 | return output 220 | 221 | def emb2cls(self, cls_score): 222 | pass 223 | # if hasattr(self.loss_decode, 'vec'): 224 | # # normalize 225 | # vec = self.loss_decode.vec.to(device=cls_score.device) 226 | # if hasattr(self.loss_decode, 'norm'): 227 | # cls_score = cls_score / cls_score.norm(dim=1, keepdim=True) 228 | # vec = vec / vec.norm(dim=1, keepdim=True) 229 | # if hasattr(self.loss_decode, 'logit_scale'): 230 | # logit_scale = self.loss_decode.logit_scale 231 | # cls_score = logit_scale * cls_score.permute(0, 2, 3, 1) @ vec.t() # [N, H, W, num_cls] 232 | # else: 233 | # cls_score = cls_score.permute(0, 2, 3, 1) @ vec.t() # [N, H, W, num_cls] 234 | # cls_score = cls_score.permute(0, 3, 1, 2) # [N, num_cls, H, W] 235 | # return cls_score 236 | # else: 237 | # raise NameError("No vec in loss_decode") 238 | 239 | #@force_fp32(apply_to=('seg_logit', )) 240 | def losses(self, seg_logit, seg_label): 241 | pass 242 | # """Compute segmentation loss.""" 243 | # loss = dict() 244 | # seg_logit = resize( 245 | # input=seg_logit, 246 | # size=seg_label.shape[2:], 247 | # mode='bilinear', 248 | # align_corners=self.align_corners) 249 | # if self.sampler is not None: 250 | # seg_weight = self.sampler.sample(seg_logit, seg_label) 251 | # else: 252 | # seg_weight = None 253 | # seg_label = seg_label.squeeze(1) 254 | # loss['loss_seg'] = self.loss_decode( 255 | # seg_logit, 256 | # seg_label, 257 | # weight=seg_weight, 258 | # ignore_index=self.ignore_index) 259 | # print(seg_logit.shape, ) 260 | # seg_logit = self.emb2cls(seg_logit) 261 | # loss['acc_seg'] = accuracy(seg_logit, seg_label) 262 | # return loss 263 | -------------------------------------------------------------------------------- /Test_Minist/tools/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | import os, sys 4 | CODE_SPACE=os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 5 | sys.path.append(CODE_SPACE) 6 | os.chdir(CODE_SPACE) 7 | 8 | import argparse 9 | import cv2 10 | import logging 11 | import numpy as np 12 | import torch 13 | import torch.nn.functional as F 14 | import json 15 | 16 | import utils.config as config 17 | from utils.config import CfgNode 18 | from utils.transforms_utils import get_imagenet_mean_std, normalize_img, pad_to_crop_sz, resize_by_scaled_short_side 19 | import matplotlib.pyplot as plt 20 | from utils.color_seg import color_seg 21 | 22 | import glob 23 | from PIL import Image 24 | from utils.labels_dict import UNI_UID2UNAME, ALL_LABEL2ID, UNAME2EM_NAME 25 | import torch.multiprocessing as mp 26 | from utils.segformer import get_configured_segformer 27 | from tqdm import tqdm 28 | 29 | def get_logger(): 30 | """ 31 | """ 32 | logger_name = "main-logger" 33 | logger = logging.getLogger(logger_name) 34 | logger.setLevel(logging.INFO) 35 | if not logger.handlers: 36 | handler = logging.StreamHandler() 37 | fmt = "[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s" 38 | handler.setFormatter(logging.Formatter(fmt)) 39 | logger.addHandler(handler) 40 | return logger 41 | 42 | logger = get_logger() 43 | 44 | 45 | def get_parser() -> CfgNode: 46 | """ 47 | TODO: add to library to avoid replication. 48 | """ 49 | parser = argparse.ArgumentParser(description='Yvan Yin\'s Semantic Segmentation Model.') 50 | parser.add_argument('--root_dir', type=str, help='root dir for the data') 51 | parser.add_argument('--cam_id', type=str, help='camera ID') 52 | parser.add_argument('--img_folder', default='image_', type=str, help='the images folder name except the camera ID') 53 | parser.add_argument('--img_file_type', default='jpeg', type=str, help='the file type of images, such as jpeg, png, jpg...') 54 | 55 | parser.add_argument('--config', type=str, default='720_ss', help='config file') 56 | parser.add_argument('--gpus_num', type=int, default=1, help='number of gpus') 57 | parser.add_argument('--save_folder', type=str, default='ann/semantics', help='the folder for saving semantic masks') 58 | parser.add_argument('opts', help='see mseg_semantic/config/test/default_config_360.yaml for all options, model path should be passed in', 59 | default=None, nargs=argparse.REMAINDER) 60 | args = parser.parse_args() 61 | config_path = os.path.join('configs', f'{args.config}.yaml') 62 | args.config = config_path 63 | 64 | # test on samples 65 | if args.root_dir is None: 66 | args.root_dir = f'{CODE_SPACE}/test_imgs' 67 | args.cam_id='01' 68 | args.img_file_type = 'png' 69 | 70 | assert args.config is not None 71 | cfg = config.load_cfg_from_cfg_file(args.config) 72 | cfg.root_dir = args.root_dir 73 | cfg.cam_id = args.cam_id 74 | cfg.img_folder = args.img_folder 75 | cfg.img_file_type = args.img_file_type 76 | cfg.gpus_num = args.gpus_num 77 | cfg.save_folder = args.save_folder 78 | return cfg 79 | 80 | 81 | 82 | def get_prediction(embs, gt_embs_list): 83 | prediction = [] 84 | logits = [] 85 | B, _, _, _ = embs.shape 86 | for b in range(B): 87 | score = embs[b,...] 88 | score = score.unsqueeze(0) 89 | emb = gt_embs_list 90 | emb = emb / emb.norm(dim=1, keepdim=True) 91 | score = score / score.norm(dim=1, keepdim=True) 92 | score = score.permute(0, 2, 3, 1) @ emb.t() 93 | # [N, H, W, num_cls] You maybe need to remove the .t() based on the shape of your saved .npy 94 | score = score.permute(0, 3, 1, 2) # [N, num_cls, H, W] 95 | prediction.append(score.max(1)[1]) 96 | logits.append(score) 97 | if len(prediction) == 1: 98 | prediction = prediction[0] 99 | logit = logits[0] 100 | else: 101 | prediction = torch.cat(prediction, dim=0) 102 | logit = torch.cat(logits, dim=0) 103 | return logit 104 | 105 | 106 | def single_scale_single_crop_cuda(model, 107 | image: np.ndarray, 108 | h: int, w: int, gt_embs_list, 109 | args=None) -> np.ndarray: 110 | ori_h, ori_w, _ = image.shape 111 | mean, std = get_imagenet_mean_std() 112 | crop_h = (np.ceil((ori_h - 1) / 32) * 32).astype(np.int32) 113 | crop_w = (np.ceil((ori_w - 1) / 32) * 32).astype(np.int32) 114 | 115 | image, pad_h_half, pad_w_half = pad_to_crop_sz(image, crop_h, crop_w, mean) 116 | image_crop = torch.from_numpy(image.transpose((2, 0, 1))).float() 117 | normalize_img(image_crop, mean, std) 118 | image_crop = image_crop.unsqueeze(0).cuda() 119 | with torch.no_grad(): 120 | emb, _, _ = model(inputs=image_crop, label_space=['universal']) 121 | logit = get_prediction(emb, gt_embs_list) 122 | logit_universal = F.softmax(logit * 100, dim=1).squeeze() 123 | 124 | # disregard predictions from padded portion of image 125 | prediction_crop = logit_universal[:, pad_h_half:pad_h_half + ori_h, pad_w_half:pad_w_half + ori_w] 126 | 127 | # CHW -> HWC 128 | prediction_crop = prediction_crop.permute(1, 2, 0) 129 | prediction_crop = prediction_crop.data.cpu().numpy() 130 | 131 | # upsample or shrink predictions back down to scale=1.0 132 | prediction = cv2.resize(prediction_crop, (w, h), interpolation=cv2.INTER_LINEAR) 133 | return prediction 134 | 135 | 136 | def single_scale_cuda(model, 137 | image: np.ndarray, 138 | h: int, w: int, gt_embs_list, stride_rate: float = 2/3, 139 | args=None) -> np.ndarray: 140 | mean, std = get_imagenet_mean_std() 141 | crop_h = args.test_h 142 | crop_w = args.test_w 143 | ori_h, ori_w, _ = image.shape 144 | image, pad_h_half, pad_w_half = pad_to_crop_sz(image, crop_h, crop_w, mean) 145 | new_h, new_w, _ = image.shape 146 | stride_h = int(np.ceil(crop_h*stride_rate)) 147 | stride_w = int(np.ceil(crop_w*stride_rate)) 148 | grid_h = int(np.ceil(float(new_h-crop_h)/stride_h) + 1) 149 | grid_w = int(np.ceil(float(new_w-crop_w)/stride_w) + 1) 150 | 151 | prediction_crop = torch.zeros((gt_embs_list.shape[0], new_h, new_w)).cuda() 152 | count_crop = torch.zeros((new_h, new_w)).cuda() 153 | # loop w/ sliding window, obtain start/end indices 154 | for index_h in range(0, grid_h): 155 | for index_w in range(0, grid_w): 156 | s_h = index_h * stride_h 157 | e_h = min(s_h + crop_h, new_h) 158 | s_h = e_h - crop_h 159 | s_w = index_w * stride_w 160 | e_w = min(s_w + crop_w, new_w) 161 | s_w = e_w - crop_w 162 | image_crop = image[s_h:e_h, s_w:e_w].copy() 163 | count_crop[s_h:e_h, s_w:e_w] += 1 164 | 165 | image_crop = torch.from_numpy(image_crop.transpose((2, 0, 1))).float() 166 | normalize_img(image_crop, mean, std) 167 | image_crop = image_crop.unsqueeze(0) 168 | with torch.no_grad(): 169 | emb, _, _ = model(inputs=image_crop, label_space=['universal']) 170 | logit = get_prediction(emb, gt_embs_list) 171 | logit_universal = F.softmax(logit * 100, dim=1) 172 | prediction_crop[:, s_h:e_h, s_w:e_w] += logit_universal.squeeze() 173 | 174 | prediction_crop /= count_crop.unsqueeze(0) 175 | # disregard predictions from padded portion of image 176 | prediction_crop = prediction_crop[:, pad_h_half:pad_h_half+ori_h, pad_w_half:pad_w_half+ori_w] 177 | 178 | # CHW -> HWC 179 | prediction_crop = prediction_crop.permute(1,2,0) 180 | prediction_crop = prediction_crop.data.cpu().numpy() 181 | 182 | # upsample or shrink predictions back down to scale=1.0 183 | prediction = cv2.resize(prediction_crop, (w, h), interpolation=cv2.INTER_LINEAR) 184 | return prediction 185 | 186 | 187 | def do_test(args, local_rank): 188 | imgs_on_devices = organize_images(args, local_rank) 189 | model = get_configured_segformer(args.num_model_classes, 190 | criterion=None, 191 | load_imagenet_model=False) 192 | model.eval() 193 | 194 | if args.distributed: 195 | model = torch.nn.parallel.DistributedDataParallel(model.cuda(), 196 | device_ids=[local_rank,], 197 | output_device=local_rank, 198 | find_unused_parameters=True) 199 | else: 200 | model = torch.nn.DataParallel(model) 201 | 202 | ckpt_path = args.model_path 203 | checkpoint = torch.load(ckpt_path, map_location='cpu')['state_dict'] 204 | ckpt_filter = {k: v for k, v in checkpoint.items() if 'criterion.0.criterion.weight' not in k} 205 | model.load_state_dict(ckpt_filter, strict=False) 206 | 207 | gt_embs_list = torch.tensor(np.load(args.emb_path)).cuda().float() 208 | id_to_label = UNI_UID2UNAME 209 | 210 | test_single(args, imgs_on_devices, local_rank, model, gt_embs_list) 211 | 212 | def test_single(args, imgs_list, local_rank, model, gt_embs_list): 213 | for i, rgb_path in tqdm(enumerate(imgs_list)): 214 | save_path = os.path.join(args.root_dir, args.save_folder, os.path.basename(rgb_path)) 215 | save_path = os.path.splitext(save_path)[0] + '.png' 216 | 217 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 218 | 219 | rgb = cv2.imread(rgb_path, -1)[:, :, ::-1] 220 | image_resized = resize_by_scaled_short_side(rgb, args.base_size, 1) 221 | h, w, _ = rgb.shape 222 | 223 | if args.single_scale_single_crop: 224 | out_logit = single_scale_single_crop_cuda(model, image_resized, h, w, gt_embs_list=gt_embs_list, args=args) 225 | elif args.single_scale_multi_crop: 226 | out_logit = single_scale_cuda(model, image_resized, h, w, gt_embs_list=gt_embs_list, args=args) 227 | 228 | prediction = out_logit.argmax(axis=-1).squeeze() 229 | probs = out_logit.max(axis=-1).squeeze() 230 | high_prob_mask = probs > 0.5 231 | 232 | mask = high_prob_mask 233 | prediction[~mask] = 255 234 | 235 | pred_color = color_seg(prediction) 236 | vis_seg = visual_segments(pred_color, rgb) 237 | 238 | vis_seg.save(os.path.splitext(save_path)[0] + '_vis.png') 239 | cv2.imwrite(save_path, prediction.astype(np.uint8)) 240 | 241 | def visual_segments(segments, rgb): 242 | seg = Image.fromarray(segments) 243 | rgb = Image.fromarray(rgb) 244 | 245 | seg1 = seg.convert('RGBA') 246 | rgb1 = rgb.convert('RGBA') 247 | 248 | vis_seg = Image.blend(rgb1, seg1, 0.8) 249 | return vis_seg 250 | 251 | def organize_images(args, local_rank): 252 | imgs_dir = args.root_dir 253 | imgs_list = glob.glob(imgs_dir + f'/*.{args.img_file_type}') 254 | imgs_list.sort() 255 | num_devices = args.gpus_num 256 | 257 | imgs_on_device = imgs_list[local_rank::num_devices] 258 | return imgs_on_device 259 | 260 | def main_worker(local_rank: int, cfg: dict): 261 | if cfg.distributed: 262 | global_rank = loca_rank 263 | world_size = cfg.gpus_num 264 | 265 | torch.cuda.set_device(global_rank) 266 | dist.init_process_group(backend="nccl", 267 | init_method=cfg.dist_url, 268 | world_size=world_size, 269 | rank=global_rank,) 270 | do_test(cfg, local_rank) 271 | if __name__ == '__main__': 272 | args = get_parser() 273 | logger.info(args) 274 | 275 | dist_url = 'tcp://127.0.0.1:6769' 276 | dist_url = dist_url[:-2] + str(os.getpid() % 100).zfill(2) 277 | args.dist_url = dist_url 278 | 279 | num_gpus = torch.cuda.device_count() 280 | if num_gpus != args.gpus_num: 281 | raise RuntimeError('The set gpus number cannot match the detected gpus number. Please check or set CUDA_VISIBLE_DEVICES') 282 | 283 | if num_gpus > 1: 284 | args.distributed = True 285 | else: 286 | args.distributed = False 287 | 288 | save_path = os.path.join(args.root_dir, args.save_folder, 'id2labels.json') 289 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 290 | with open(save_path, 'w') as f: 291 | json.dump(UNI_UID2UNAME, f) 292 | 293 | if not args.distributed: 294 | main_worker(0, args) 295 | else: 296 | mp.spawn(main_worker, nprocs=args.gpus_num, args=(args, )) 297 | -------------------------------------------------------------------------------- /Test_Minist/utils/labels_dict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import gensim 4 | import gensim.downloader 5 | 6 | 7 | UNI_UID2UNAME = {0: 'backpack', 1: 'umbrella', 2: 'bag', 3: 'tie', 4: 'suitcase', 5: 'case', 6: 'bird', 8 | 7: 'cat', 8: 'dog', 9: 'horse', 10: 'sheep', 11: 'cow', 12: 'elephant', 13: 'bear', 9 | 14: 'zebra', 15: 'giraffe', 16: 'animal_other', 17: 'microwave', 18: 'radiator', 19: 'oven', 10 | 20: 'toaster', 21: 'storage_tank', 22: 'conveyor_belt', 23: 'sink', 24: 'refrigerator', 11 | 25: 'washer_dryer', 26: 'fan', 27: 'dishwasher', 28: 'toilet', 29: 'bathtub', 30: 'shower', 12 | 31: 'tunnel', 32: 'bridge', 33: 'pier_wharf', 34: 'tent', 35: 'building', 36: 'ceiling', 13 | 37: 'laptop', 38: 'keyboard', 39: 'mouse', 40: 'remote', 41: 'cell phone', 42: 'television', 14 | 43: 'floor', 44: 'stage', 45: 'banana', 46: 'apple', 47: 'sandwich', 48: 'orange', 15 | 49: 'broccoli', 50: 'carrot', 51: 'hot_dog', 52: 'pizza', 53: 'donut', 54: 'cake', 16 | 55: 'fruit_other', 56: 'food_other', 57: 'chair_other', 58: 'armchair', 59: 'swivel_chair', 17 | 60: 'stool', 61: 'seat', 62: 'couch', 63: 'trash_can', 64: 'potted_plant', 65: 'nightstand', 18 | 66: 'bed', 67: 'table', 68: 'pool_table', 69: 'barrel', 70: 'desk', 71: 'ottoman', 19 | 72: 'wardrobe', 73: 'crib', 74: 'basket', 75: 'chest_of_drawers', 76: 'bookshelf', 20 | 77: 'counter_other', 78: 'bathroom_counter', 79: 'kitchen_island', 80: 'door', 81: 'light_other', 21 | 82: 'lamp', 83: 'sconce', 84: 'chandelier', 85: 'mirror', 86: 'whiteboard', 87: 'shelf', 22 | 88: 'stairs', 89: 'escalator', 90: 'cabinet', 91: 'fireplace', 92: 'stove', 93: 'arcade_machine', 23 | 94: 'gravel', 95: 'platform', 96: 'playingfield', 97: 'railroad', 98: 'road', 99: 'snow', 24 | 100: 'sidewalk_pavement', 101: 'runway', 102: 'terrain', 103: 'book', 104: 'box', 105: 'clock', 25 | 106: 'vase', 107: 'scissors', 108: 'plaything_other', 109: 'teddy_bear', 110: 'hair_dryer', 26 | 111: 'toothbrush', 112: 'painting', 113: 'poster', 114: 'bulletin_board', 115: 'bottle', 116: 'cup', 27 | 117: 'wine_glass', 118: 'knife', 119: 'fork', 120: 'spoon', 121: 'bowl', 122: 'tray', 123: 'range_hood', 28 | 124: 'plate', 125: 'person', 126: 'rider_other', 127: 'bicyclist', 128: 'motorcyclist', 129: 'paper', 29 | 130: 'streetlight', 131: 'road_barrier', 132: 'mailbox', 133: 'cctv_camera', 134: 'junction_box', 30 | 135: 'traffic_sign', 136: 'traffic_light', 137: 'fire_hydrant', 138: 'parking_meter', 139: 'bench', 31 | 140: 'bike_rack', 141: 'billboard', 142: 'sky', 143: 'pole', 144: 'fence', 145: 'railing_banister', 32 | 146: 'guard_rail', 147: 'mountain_hill', 148: 'rock', 149: 'frisbee', 150: 'skis', 151: 'snowboard', 33 | 152: 'sports_ball', 153: 'kite', 154: 'baseball_bat', 155: 'baseball_glove', 156: 'skateboard', 34 | 157: 'surfboard', 158: 'tennis_racket', 159: 'net', 160: 'base', 161: 'sculpture', 162: 'column', 35 | 163: 'fountain', 164: 'awning', 165: 'apparel', 166: 'banner', 167: 'flag', 168: 'blanket', 36 | 169: 'curtain_other', 170: 'shower_curtain', 171: 'pillow', 172: 'towel', 173: 'rug_floormat', 37 | 174: 'vegetation', 175: 'bicycle', 176: 'car', 177: 'autorickshaw', 178: 'motorcycle', 179: 'airplane', 38 | 180: 'bus', 181: 'train', 182: 'truck', 183: 'trailer', 184: 'boat_ship', 185: 'slow_wheeled_object', 39 | 186: 'river_lake', 187: 'sea', 188: 'water_other', 189: 'swimming_pool', 190: 'waterfall', 40 | 191: 'wall', 192: 'window', 193: 'window_blind', 255: 'unlabeled'} 41 | 42 | UNAME2EM_NAME = {'backpack': 'backpack', 'umbrella': 'umbrella', 'bag': 'bag', 'tie': 'tie', 43 | 'suitcase': 'suitcase', 'case': 'case', 'bird': 'bird', 'cat': 'cat', 'dog': 'dog', 44 | 'horse': 'horse', 'sheep': 'sheep', 'cow': 'cow', 'elephant': 'elephant', 'bear': 'bear', 45 | 'zebra': 'zebra', 'giraffe': 'giraffe', 'animal_other': 'animal', 46 | 'microwave': 'microwave', 'radiator': 'radiator', 'oven': 'oven', 'toaster': 'toaster', 47 | 'storage_tank': 'storage_tank', 'conveyor_belt': 'conveyor belt', 'sink': 'sink', 48 | 'refrigerator': 'refrigerator', 'washer_dryer': 'washer dryer', 'fan': 'fan', 49 | 'dishwasher': 'dishwasher', 'toilet': 'toilet', 'bathtub': 'bathtub', 'shower': 'shower', 50 | 'tunnel': 'tunnel', 'bridge': 'bridge', 'pier_wharf': 'pier wharf', 'tent': 'tent', 51 | 'building': 'building', 'ceiling': 'ceiling', 'laptop': 'laptop', 'keyboard': 'keyboard', 52 | 'mouse': 'mouse', 'remote': 'remote', 'cell phone': 'cell phone', 'television': 'television', 53 | 'floor': 'floor', 'stage': 'stage', 'banana': 'banana', 'apple': 'apple', 'sandwich': 'sandwich', 54 | 'orange': 'orange', 'broccoli': 'broccoli', 'carrot': 'carrot', 'hot_dog': 'hotdog', 55 | 'pizza': 'pizza', 'donut': 'donut', 'cake': 'cake', 'fruit_other': 'fruit', 56 | 'food_other': 'food', 'chair_other': 'chair', 'armchair': 'arm chair', 57 | 'swivel_chair': 'swivel chair', 'stool': 'stool', 'seat': 'seat', 'couch': 'couch', 58 | 'trash_can': 'trash can', 'potted_plant': 'potted plant', 'nightstand': 'nightstand', 59 | 'bed': 'bed', 'table': 'table', 'pool_table': 'pool table', 'barrel': 'barrel', 'desk': 'desk', 60 | 'ottoman': 'ottoman', 'wardrobe': 'wardrobe', 'crib': 'crib', 'basket': 'basket', 61 | 'chest_of_drawers': 'chest of drawers', 'bookshelf': 'bookshelf', 'counter_other': 'counter', 62 | 'bathroom_counter': 'bathroom counter', 'kitchen_island': 'kitchen island', 'door': 'door', 63 | 'light_other': 'light', 'lamp': 'lamp', 'sconce': 'sconce', 'chandelier': 'chandelier', 64 | 'mirror': 'mirror', 'whiteboard': 'whiteboard', 'shelf': 'shelf', 'stairs': 'stairs', 65 | 'escalator': 'escalator', 'cabinet': 'cabinet', 'fireplace': 'fireplace', 'stove': 'stove', 66 | 'arcade_machine': 'arcade machine', 'gravel': 'gravel', 'platform': 'platform', 67 | 'playingfield': 'playing field', 'railroad': 'railroad', 'road': 'road', 'snow': 'snow', 68 | 'sidewalk_pavement': 'sidewalk pavement', 'runway': 'runway', 'terrain': 'terrain', 69 | 'book': 'book', 'box': 'box', 'clock': 'clock', 'vase': 'vase', 'scissors': 'scissors', 70 | 'plaything_other': 'plaything other', 'teddy_bear': 'teddy bear', 'hair_dryer': 'hair dryer', 71 | 'toothbrush': 'toothbrush', 'painting': 'painting', 'poster': 'poster', 'bulletin_board': 'bulletin board', 72 | 'bottle': 'bottle', 'cup': 'cup', 'wine_glass': 'wine glass', 'knife': 'knife', 'fork': 'fork', 73 | 'spoon': 'spoon', 'bowl': 'bowl', 'tray': 'tray', 'range_hood': 'range hood', 'plate': 'plate', 74 | 'person': 'person', 'rider_other': 'rider', 'bicyclist': 'bicyclist', 'motorcyclist': 'motorcyclist', 75 | 'paper': 'paper', 'streetlight': 'streetlight', 'road_barrier': 'road barrier', 'mailbox': 'mailbox', 76 | 'cctv_camera': 'cctv camera', 'junction_box': 'junction box', 'traffic_sign': 'traffic sign', 77 | 'traffic_light': 'traffic light', 'fire_hydrant': 'fire hydrant', 'parking_meter': 'parking meter', 78 | 'bench': 'bench', 'bike_rack': 'bike rack', 'billboard': 'billboard', 'sky': 'sky', 'pole': 'pole', 79 | 'fence': 'fence', 'railing_banister': 'railing banister', 'guard_rail': 'guard rail', 80 | 'mountain_hill': 'mountain hill', 'rock': 'rock', 'frisbee': 'frisbee', 'skis': 'skis', 'snowboard': 'snowboard', 81 | 'sports_ball': 'sports ball', 'kite': 'kite', 'baseball_bat': 'baseball bat', 'baseball_glove': 'baseball glove', 82 | 'skateboard': 'skateboard', 'surfboard': 'surfboard', 'tennis_racket': 'tennis_racket', 'net': 'net', 83 | 'base': 'base', 'sculpture': 'sculpture', 'column': 'column', 'fountain': 'fountain', 'awning': 'awning', 84 | 'apparel': 'apparel', 'banner': 'banner', 'flag': 'flag', 'blanket': 'blanket', 'curtain_other': 'other curtain', 85 | 'shower_curtain': 'shower curtain', 'pillow': 'pillow', 'towel': 'towel', 'rug_floormat': 'rug floormat', 86 | 'vegetation': 'vegetation', 'bicycle': 'bicycle', 'car': 'car', 'autorickshaw': 'autorickshaw', 87 | 'motorcycle': 'motorcycle', 'airplane': 'airplane', 'bus': 'bus', 'train': 'train', 'truck': 'truck', 88 | 'trailer': 'trailer', 'boat_ship': 'boat ship', 'slow_wheeled_object': 'slow wheeled object', 89 | 'river_lake': 'river lake', 'sea': 'sea', 'water_other': 'water', 'swimming_pool': 'swimming pool', 90 | 'waterfall': 'waterfall', 'wall': 'wall', 'window': 'window', 'window_blind': 'window blind', 91 | 'unlabeled': 'unlabeled'} 92 | 93 | 94 | ALL_LABEL2ID = {'backpack': 0, 'umbrella': 1, 'bag': 2, 'tie': 3, 'suitcase': 4, 'case': 5, 'bird': 6, 'cat': 7, 95 | 'dog': 8, 'horse': 9, 'sheep': 10, 'cow': 11, 'elephant': 12, 'bear': 13, 'zebra': 14, 96 | 'giraffe': 15, 'animal_other': 16, 'microwave': 17, 'radiator': 18, 'oven': 19, 'toaster': 20, 97 | 'storage_tank': 21, 'conveyor_belt': 22, 'sink': 23, 'refrigerator': 24, 'washer_dryer': 25, 98 | 'fan': 26, 'dishwasher': 27, 'toilet': 28, 'bathtub': 29, 'shower': 30, 'tunnel': 31, 99 | 'bridge': 32, 'pier_wharf': 33, 'tent': 34, 'building': 35, 'ceiling': 36, 'laptop': 37, 100 | 'keyboard': 38, 'mouse': 39, 'remote': 40, 'cell phone': 41, 'television': 42, 'floor': 43, 101 | 'stage': 44, 'banana': 45, 'apple': 46, 'sandwich': 47, 'orange': 48, 'broccoli': 49, 102 | 'carrot': 50, 'hot_dog': 51, 'pizza': 52, 'donut': 53, 'cake': 54, 'fruit_other': 55, 103 | 'food_other': 56, 'chair_other': 57, 'armchair': 58, 'swivel_chair': 59, 'stool': 60, 104 | 'seat': 61, 'couch': 62, 'trash_can': 63, 'potted_plant': 64, 'nightstand': 65, 'bed': 66, 105 | 'table': 67, 'pool_table': 68, 'barrel': 69, 'desk': 70, 'ottoman': 71, 'wardrobe': 72, 106 | 'crib': 73, 'basket': 74, 'chest_of_drawers': 75, 'bookshelf': 76, 'counter_other': 77, 107 | 'bathroom_counter': 78, 'kitchen_island': 79, 'door': 80, 'light_other': 81, 'lamp': 82, 108 | 'sconce': 83, 'chandelier': 84, 'mirror': 85, 'whiteboard': 86, 'shelf': 87, 'stairs': 88, 109 | 'escalator': 89, 'cabinet': 90, 'fireplace': 91, 'stove': 92, 'arcade_machine': 93, 'gravel': 94, 110 | 'platform': 95, 'playingfield': 96, 'railroad': 97, 'road': 98, 'snow': 99, 111 | 'sidewalk_pavement': 100, 'runway': 101, 'terrain': 102, 'book': 103, 'box': 104, 'clock': 105, 112 | 'vase': 106, 'scissors': 107, 'plaything_other': 108, 'teddy_bear': 109, 'hair_dryer': 110, 113 | 'toothbrush': 111, 'painting': 112, 'poster': 113, 'bulletin_board': 114, 'bottle': 115, 114 | 'cup': 116, 'wine_glass': 117, 'knife': 118, 'fork': 119, 'spoon': 120, 'bowl': 121, 'tray': 122, 115 | 'range_hood': 123, 'plate': 124, 'person': 125, 'rider_other': 126, 'bicyclist': 127, 116 | 'motorcyclist': 128, 'paper': 129, 'streetlight': 130, 'road_barrier': 131, 'mailbox': 132, 117 | 'cctv_camera': 133, 'junction_box': 134, 'traffic_sign': 135, 'traffic_light': 136, 118 | 'fire_hydrant': 137, 'parking_meter': 138, 'bench': 139, 'bike_rack': 140, 'billboard': 141, 119 | 'sky': 142, 'pole': 143, 'fence': 144, 'railing_banister': 145, 'guard_rail': 146, 120 | 'mountain_hill': 147, 'rock': 148, 'frisbee': 149, 'skis': 150, 'snowboard': 151, 121 | 'sports_ball': 152, 'kite': 153, 'baseball_bat': 154, 'baseball_glove': 155, 'skateboard': 156, 122 | 'surfboard': 157, 'tennis_racket': 158, 'net': 159, 'base': 160, 'sculpture': 161, 'column': 162, 123 | 'fountain': 163, 'awning': 164, 'apparel': 165, 'banner': 166, 'flag': 167, 'blanket': 168, 124 | 'curtain_other': 169, 'shower_curtain': 170, 'pillow': 171, 'towel': 172, 'rug_floormat': 173, 125 | 'vegetation': 174, 'bicycle': 175, 'car': 176, 'autorickshaw': 177, 'motorcycle': 178, 126 | 'airplane': 179, 'bus': 180, 'train': 181, 'truck': 182, 'trailer': 183, 'boat_ship': 184, 127 | 'slow_wheeled_object': 185, 'river_lake': 186, 'sea': 187, 'water_other': 188, 'swimming_pool': 189, 128 | 'waterfall': 190, 'wall': 191, 'window': 192, 'window_blind': 193, 'chopsticks': 194, 129 | 'musical instrument': 195, 'drink': 196, 'wrench': 197, 'mule(which is a type of animal)': 198, 130 | 'camel(which is a type of animal)': 199, 'tap': 200, 'snowmobile': 201, 'fish': 202, 131 | 'crocodile(which is a type of animal)': 203, 'screwdriver': 204, 132 | 'panda(which is a type of animal)': 205, 'pig(which is a type of animal)': 206, 133 | 'red panda(which is a type of animal)': 207, 'frying pan': 208, 134 | 'monkey(which is a type of animal)': 209, 'kangaroo(which is a type of animal)': 210, 135 | 'leopard(which is a type of animal)': 211, 'koala(which is a type of animal)': 212, 136 | 'lion(which is a type of animal)': 213, 'hammer': 214, 'tiger(which is a type of animal)': 215, 137 | 'camera': 216, 'starfish': 217, 'drill': 218, 'rhinoceros(which is a type of animal)': 219, 138 | 'hippopotamus(which is a type of animal)': 220, 'turtle': 221, 'flashlight': 222, 139 | 'rabbit(which is a type of animal)': 223, 'skull': 224, 'kettle': 225, 'fox': 226, 140 | 'lynx(which is a type of animal)': 227, 'hat': 228, 'harbor seal': 229, 141 | 'alpaca(which is a type of animal)': 230, 'teapot': 231, 'glove': 232, 142 | 'sea lion(which is a type of animal)': 233, 'printer': 234, 'balloon': 235, 'stapler': 236, 143 | 'calculator': 237, 'unlabeled': 255} 144 | -------------------------------------------------------------------------------- /Test_Minist/utils/segformer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from functools import partial 6 | 7 | #from mmseg.models.builder import BACKBONES 8 | from mmseg.utils import get_root_logger 9 | from mmcv.runner import load_checkpoint 10 | from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule 11 | 12 | import math 13 | #from .decode_heads.decode_head import BaseDecodeHead 14 | from utils.decode_heads.aspp_head import ASPPHead, ASPPModule 15 | from mmcv.cnn.bricks import build_norm_layer 16 | from mmseg.ops import resize 17 | 18 | # pip install timm==0.3.2 19 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 20 | import logging 21 | from utils.decode_heads.segformer_head import SegFormerHead 22 | logger = logging.getLogger(__name__) 23 | 24 | class FCNHead(nn.Module): 25 | """Fully Convolution Networks for Semantic Segmentation. 26 | This head is implemented of `FCNNet `_. 27 | Args: 28 | num_convs (int): Number of convs in the head. Default: 2. 29 | kernel_size (int): The kernel size for convs in the head. Default: 3. 30 | concat_input (bool): Whether concat the input and output of convs 31 | before classification layer. 32 | """ 33 | 34 | def __init__(self, 35 | num_convs=1, 36 | kernel_size=3, 37 | in_channels=320, 38 | num_classes=150, 39 | norm_cfg=None, 40 | act_cfg=dict(type='ReLU'), 41 | **kwargs): 42 | assert num_convs >= 0 43 | self.kernel_size = kernel_size 44 | super(FCNHead, self).__init__() 45 | self.in_channels = in_channels 46 | self.channels = 256 47 | self.num_classes = num_classes 48 | self.norm_cfg = norm_cfg 49 | self.act_cfg = act_cfg 50 | convs = [] 51 | convs.append( 52 | ConvModule( 53 | self.in_channels, 54 | self.channels, 55 | kernel_size=kernel_size, 56 | padding=kernel_size // 2, # 57 | conv_cfg=None, 58 | norm_cfg=self.norm_cfg, # sync bn 59 | act_cfg=self.act_cfg)) # relu 60 | 61 | self.convs = nn.Sequential(*convs) 62 | self.cls_seg = nn.Conv2d(in_channels=self.channels, out_channels=self.num_classes, 63 | kernel_size=1) 64 | 65 | def forward(self, inputs): 66 | """Forward function.""" 67 | x = inputs[-2] 68 | output = self.convs(x) 69 | output = self.cls_seg(output) 70 | return output 71 | 72 | class DepthwiseSeparableASPPModule(ASPPModule): 73 | """Atrous Spatial Pyramid Pooling (ASPP) Module with depthwise separable 74 | conv.""" 75 | 76 | def __init__(self, **kwargs): 77 | super(DepthwiseSeparableASPPModule, self).__init__(**kwargs) 78 | for i, dilation in enumerate(self.dilations): 79 | if dilation > 1: 80 | self[i] = DepthwiseSeparableConvModule( 81 | self.in_channels, 82 | self.channels, 83 | 3, 84 | dilation=dilation, 85 | padding=dilation, 86 | norm_cfg=self.norm_cfg, 87 | act_cfg=self.act_cfg) 88 | 89 | 90 | class DynHead(nn.Module): 91 | def __init__(self, 92 | in_channels, 93 | num_classes, 94 | norm_cfg, 95 | act_cfg, 96 | upsample_f, 97 | dyn_ch, 98 | mask_ch, 99 | use_low_level_info=False, 100 | channel_reduce_factor=2, 101 | zero_init=False, 102 | supress_std=True): 103 | super(DynHead, self).__init__() 104 | 105 | channels = dyn_ch 106 | num_bases = 0 107 | if use_low_level_info: 108 | num_bases = mask_ch 109 | num_out_channel = (2 + num_bases) * channels + \ 110 | channels + \ 111 | channels * channels + \ 112 | channels + \ 113 | channels * num_classes + \ 114 | num_classes 115 | 116 | self.classifier = nn.Sequential( 117 | ConvModule( 118 | in_channels, 119 | in_channels // channel_reduce_factor, 120 | 3, 121 | padding=1, 122 | norm_cfg=norm_cfg, 123 | act_cfg=act_cfg, ), 124 | nn.Conv2d(in_channels // channel_reduce_factor, num_out_channel, 1) 125 | ) 126 | 127 | 128 | if zero_init: 129 | nn.init.constant_(self.classifier[-1].weight, 0) 130 | else: 131 | nn.init.xavier_normal_(self.classifier[-1].weight) 132 | if supress_std: 133 | param = self.classifier[-1].weight / num_out_channel 134 | self.classifier[-1].weight = nn.Parameter(param) 135 | nn.init.constant_(self.classifier[-1].bias, 0) 136 | 137 | def forward(self, feature): 138 | return self.classifier(feature) 139 | 140 | 141 | #@HEADS.register_module() 142 | class BilinearPADHead_fast_xavier_init(ASPPHead): 143 | """Encoder-Decoder with Atrous Separable Convolution for Semantic Image 144 | Segmentation. 145 | This head is the implementation of `DeepLabV3+ 146 | `_. 147 | Args: 148 | c1_in_channels (int): The input channels of c1 decoder. If is 0, 149 | the no decoder will be used. 150 | c1_channels (int): The intermediate channels of c1 decoder. 151 | """ 152 | 153 | def __init__(self, c1_in_channels, c1_channels, 154 | upsample_factor, 155 | dyn_branch_ch, 156 | mask_head_ch, 157 | pad_out_channel_factor=4, 158 | channel_reduce_factor=2, 159 | zero_init=False, 160 | supress_std=True, 161 | feature_strides=None, 162 | **kwargs): 163 | super(BilinearPADHead_fast_xavier_init, self).__init__(**kwargs) 164 | assert c1_in_channels >= 0 165 | self.pad_out_channel = self.num_classes 166 | self.upsample_f = upsample_factor 167 | self.dyn_ch = dyn_branch_ch 168 | self.mask_ch = mask_head_ch 169 | self.use_low_level_info = True 170 | self.channel_reduce_factor = channel_reduce_factor 171 | 172 | self.aspp_modules = DepthwiseSeparableASPPModule( 173 | dilations=self.dilations, 174 | in_channels=self.in_channels, 175 | channels=self.channels, 176 | conv_cfg=self.conv_cfg, 177 | norm_cfg=self.norm_cfg, 178 | act_cfg=self.act_cfg) 179 | 180 | last_stage_ch = self.channels 181 | self.classifier = DynHead(last_stage_ch, 182 | self.pad_out_channel, 183 | self.norm_cfg, 184 | self.act_cfg, 185 | self.upsample_f, 186 | self.dyn_ch, 187 | self.mask_ch, 188 | self.use_low_level_info, 189 | self.channel_reduce_factor, 190 | zero_init, 191 | supress_std) 192 | 193 | if c1_in_channels > 0: 194 | self.c1_bottleneck = nn.Sequential( 195 | ConvModule( 196 | c1_in_channels, 197 | c1_channels, 198 | 3, 199 | padding=1, 200 | conv_cfg=self.conv_cfg, 201 | norm_cfg=self.norm_cfg, 202 | act_cfg=self.act_cfg), 203 | ConvModule( 204 | c1_channels, 205 | self.mask_ch, 206 | 1, 207 | conv_cfg=self.conv_cfg, 208 | act_cfg=None, 209 | ), 210 | ) 211 | else: 212 | self.c1_bottleneck = None 213 | 214 | _, norm = build_norm_layer(self.norm_cfg, 2 + self.mask_ch) 215 | self.add_module("cat_norm", norm) 216 | nn.init.constant_(self.cat_norm.weight, 1) 217 | nn.init.constant_(self.cat_norm.bias, 0) 218 | 219 | coord_tmp = self.computer_locations_per_level(640, 640) 220 | self.register_buffer("coord", coord_tmp.float(), persistent=False) 221 | 222 | def computer_locations_per_level(self, height, width, h=8, w=8): 223 | shifts_x = torch.arange(0, 1, step=1/w, dtype=torch.float32) 224 | shifts_y = torch.arange(0, 1, step=1/h, dtype=torch.float32) 225 | shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) 226 | locations = torch.stack((shift_x, shift_y), dim=0) 227 | stride_h = height // 32 228 | stride_w = width // 32 229 | coord = locations.repeat(stride_h*stride_w, 1, 1, 1) 230 | return coord 231 | 232 | 233 | def forward(self, inputs): 234 | """Forward function.""" 235 | # inputs: [1/32 stage, 1/4 stage] 236 | x = inputs[0] # 1/32 stage 237 | 238 | aspp_outs = [ 239 | resize( 240 | self.image_pool(x), 241 | size=x.size()[2:], 242 | mode='bilinear', 243 | align_corners=self.align_corners) 244 | ] 245 | aspp_outs.extend(self.aspp_modules(x)) 246 | aspp_outs = torch.cat(aspp_outs, dim=1) 247 | output = self.bottleneck(aspp_outs) 248 | 249 | plot = False 250 | 251 | if self.c1_bottleneck is not None: 252 | c1_output = self.c1_bottleneck(inputs[1]) 253 | if plot: 254 | output2 = output 255 | output3 = c1_output 256 | if self.upsample_f != 8: 257 | c1_output = resize( 258 | c1_output, 259 | scale_factor=self.upsample_f // 8, 260 | mode='bilinear', 261 | align_corners=self.align_corners) 262 | output = self.classifier(output) 263 | output = self.interpolate_fast(output, c1_output, self.cat_norm) 264 | if plot: 265 | outputs = [] 266 | outputs.append(output) 267 | outputs.append(output2) 268 | outputs.append(output3) 269 | return outputs 270 | 271 | return output 272 | 273 | def interpolate(self, x, x_cat=None, norm=None): 274 | dy_ch = self.dyn_ch 275 | B, conv_ch, H, W = x.size() 276 | x = x.view(B, conv_ch, H * W).permute(0, 2, 1) 277 | x = x.reshape(B * H * W, conv_ch) 278 | weights, biases = self.get_subnetworks_params(x, channels=dy_ch) 279 | f = self.upsample_f 280 | self.coord_generator(H, W) 281 | coord = self.coord.reshape(1, H, W, 2, f, f).permute(0, 3, 1, 4, 2, 5).reshape(1, 2, H * f, W * f) 282 | coord = coord.repeat(B, 1, 1, 1) 283 | if x_cat is not None: 284 | coord = torch.cat((coord, x_cat), 1) 285 | coord = norm(coord) 286 | 287 | B_coord, ch_coord, H_coord, W_coord = coord.size() 288 | coord = coord.reshape(B_coord, ch_coord, H, f, W, f).permute(0, 2, 4, 1, 3, 5).reshape(1, 289 | B_coord * H * W * ch_coord, 290 | f, f) 291 | output = self.subnetworks_forward(coord, weights, biases, B * H * W) 292 | output = output.reshape(B, H, W, self.pad_out_channel, f, f).permute(0, 3, 1, 4, 2, 5) 293 | output = output.reshape(B, self.pad_out_channel, H * f, W * f) 294 | return output 295 | 296 | def interpolate_fast(self, x, x_cat=None, norm=None): 297 | dy_ch = self.dyn_ch 298 | B, conv_ch, H, W = x.size() 299 | weights, biases = self.get_subnetworks_params_fast(x, channels=dy_ch) 300 | f = self.upsample_f 301 | #self.coord_generator(H, W) 302 | coord = self.coord.reshape(1, H, W, 2, f, f).permute(0, 3, 1, 4, 2, 5).reshape(1, 2, H * f, W * f) 303 | coord = coord.repeat(B, 1, 1, 1) 304 | if x_cat is not None: 305 | coord = torch.cat((coord, x_cat), 1) 306 | coord = norm(coord) 307 | 308 | output = self.subnetworks_forward_fast(coord, weights, biases, B * H * W) 309 | return output 310 | 311 | def get_subnetworks_params(self, attns, num_bases=0, channels=16): 312 | assert attns.dim() == 2 313 | n_inst = attns.size(0) 314 | if self.use_low_level_info: 315 | num_bases = self.mask_ch 316 | else: 317 | num_bases = 0 318 | 319 | w0, b0, w1, b1, w2, b2 = torch.split_with_sizes(attns, [ 320 | (2 + num_bases) * channels, channels, 321 | channels * channels, channels, 322 | channels * self.pad_out_channel, self.pad_out_channel 323 | ], dim=1) 324 | 325 | w0 = w0.reshape(n_inst * channels, 2 + num_bases, 1, 1) 326 | b0 = b0.reshape(n_inst * channels) 327 | w1 = w1.reshape(n_inst * channels, channels, 1, 1) 328 | b1 = b1.reshape(n_inst * channels) 329 | w2 = w2.reshape(n_inst * self.pad_out_channel, channels, 1, 1) 330 | b2 = b2.reshape(n_inst * self.pad_out_channel) 331 | 332 | return [w0, w1, w2], [b0, b1, b2] 333 | 334 | def get_subnetworks_params_fast(self, attns, num_bases=0, channels=16): 335 | assert attns.dim() == 4 336 | B, conv_ch, H, W = attns.size() 337 | if self.use_low_level_info: 338 | num_bases = self.mask_ch 339 | else: 340 | num_bases = 0 341 | 342 | w0, b0, w1, b1, w2, b2 = torch.split_with_sizes(attns, [ 343 | (2 + num_bases) * channels, channels, 344 | channels * channels, channels, 345 | channels * self.pad_out_channel, self.pad_out_channel 346 | ], dim=1) 347 | 348 | w0 = resize(w0, scale_factor=self.upsample_f, mode='nearest') 349 | b0 = resize(b0, scale_factor=self.upsample_f, mode='nearest') 350 | w1 = resize(w1, scale_factor=self.upsample_f, mode='nearest') 351 | b1 = resize(b1, scale_factor=self.upsample_f, mode='nearest') 352 | w2 = resize(w2, scale_factor=self.upsample_f, mode='nearest') 353 | b2 = resize(b2, scale_factor=self.upsample_f, mode='nearest') 354 | 355 | return [w0, w1, w2], [b0, b1, b2] 356 | 357 | def subnetworks_forward(self, inputs, weights, biases, n_subnets): 358 | assert inputs.dim() == 4 359 | n_layer = len(weights) 360 | x = inputs 361 | # NOTE: x has to be treated as min_batch size 1 362 | for i, (w, b) in enumerate(zip(weights, biases)): 363 | x = F.conv2d( 364 | x, w, bias=b, 365 | stride=1, padding=0, 366 | groups=n_subnets 367 | ) 368 | if i < n_layer - 1: 369 | x = F.relu(x) 370 | return x 371 | 372 | def subnetworks_forward_fast(self, inputs, weights, biases, n_subnets): 373 | assert inputs.dim() == 4 374 | n_layer = len(weights) 375 | x = inputs 376 | if self.use_low_level_info: 377 | num_bases = self.mask_ch 378 | else: 379 | num_bases = 0 380 | for i, (w, b) in enumerate(zip(weights, biases)): 381 | if i == 0: 382 | x = self.padconv(x, w, b, cin=2 + num_bases, cout=self.dyn_ch, relu=True) 383 | if i == 1: 384 | x = self.padconv(x, w, b, cin=self.dyn_ch, cout=self.dyn_ch, relu=True) 385 | if i == 2: 386 | x = self.padconv(x, w, b, cin=self.dyn_ch, cout=self.pad_out_channel, relu=False) 387 | return x 388 | 389 | def padconv(self, input, w, b, cin, cout, relu): 390 | input = input.repeat(1, cout, 1, 1) 391 | x = input * w 392 | conv_w = torch.ones((cout, cin, 1, 1), device=input.device) 393 | x = F.conv2d( 394 | x, conv_w, stride=1, padding=0, 395 | groups=cout 396 | ) 397 | x = x + b 398 | if relu: 399 | x = F.relu(x) 400 | return x 401 | 402 | def coord_generator(self, height, width): 403 | f = self.upsample_f 404 | coord = compute_locations_per_level(f, f) 405 | H = height 406 | W = width 407 | coord = coord.repeat(H * W, 1, 1, 1) 408 | self.coord = coord.to(device='cuda') 409 | 410 | 411 | def compute_locations_per_level(h, w): 412 | shifts_x = torch.arange( 413 | 0, 1, step=1 / w, 414 | dtype=torch.float32, device='cuda' 415 | ) 416 | shifts_y = torch.arange( 417 | 0, 1, step=1 / h, 418 | dtype=torch.float32, device='cuda' 419 | ) 420 | shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) 421 | locations = torch.stack((shift_x, shift_y), dim=0) 422 | return locations 423 | 424 | 425 | 426 | class Mlp(nn.Module): 427 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 428 | super().__init__() 429 | out_features = out_features or in_features 430 | hidden_features = hidden_features or in_features 431 | self.fc1 = nn.Linear(in_features, hidden_features) 432 | self.dwconv = DWConv(hidden_features) 433 | self.act = act_layer() 434 | self.fc2 = nn.Linear(hidden_features, out_features) 435 | self.drop = nn.Dropout(drop) 436 | 437 | self.apply(self._init_weights) 438 | 439 | def _init_weights(self, m): 440 | if isinstance(m, nn.Linear): 441 | trunc_normal_(m.weight, std=.02) 442 | if isinstance(m, nn.Linear) and m.bias is not None: 443 | nn.init.constant_(m.bias, 0) 444 | elif isinstance(m, nn.LayerNorm): 445 | nn.init.constant_(m.bias, 0) 446 | nn.init.constant_(m.weight, 1.0) 447 | elif isinstance(m, nn.Conv2d): 448 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 449 | fan_out //= m.groups 450 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 451 | if m.bias is not None: 452 | m.bias.data.zero_() 453 | 454 | def forward(self, x, H, W): 455 | x = self.fc1(x) 456 | x = self.dwconv(x, H, W) 457 | x = self.act(x) 458 | x = self.drop(x) 459 | x = self.fc2(x) 460 | x = self.drop(x) 461 | return x 462 | 463 | 464 | class Attention(nn.Module): 465 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): 466 | super().__init__() 467 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 468 | 469 | self.dim = dim 470 | self.num_heads = num_heads 471 | head_dim = dim // num_heads 472 | self.scale = qk_scale or head_dim ** -0.5 473 | 474 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 475 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 476 | self.attn_drop = nn.Dropout(attn_drop) 477 | self.proj = nn.Linear(dim, dim) 478 | self.proj_drop = nn.Dropout(proj_drop) 479 | 480 | self.sr_ratio = sr_ratio 481 | if sr_ratio > 1: 482 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 483 | self.norm = nn.LayerNorm(dim) 484 | 485 | self.apply(self._init_weights) 486 | 487 | def _init_weights(self, m): 488 | if isinstance(m, nn.Linear): 489 | trunc_normal_(m.weight, std=.02) 490 | if isinstance(m, nn.Linear) and m.bias is not None: 491 | nn.init.constant_(m.bias, 0) 492 | elif isinstance(m, nn.LayerNorm): 493 | nn.init.constant_(m.bias, 0) 494 | nn.init.constant_(m.weight, 1.0) 495 | elif isinstance(m, nn.Conv2d): 496 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 497 | fan_out //= m.groups 498 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 499 | if m.bias is not None: 500 | m.bias.data.zero_() 501 | 502 | def forward(self, x, H, W): 503 | B, N, C = x.shape 504 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 505 | 506 | if self.sr_ratio > 1: 507 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 508 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 509 | x_ = self.norm(x_) 510 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 511 | else: 512 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 513 | k, v = kv[0], kv[1] 514 | 515 | attn = (q @ k.transpose(-2, -1)) * self.scale 516 | attn = attn.softmax(dim=-1) 517 | attn = self.attn_drop(attn) 518 | 519 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 520 | x = self.proj(x) 521 | x = self.proj_drop(x) 522 | 523 | return x 524 | 525 | 526 | class Block(nn.Module): 527 | 528 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 529 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): 530 | super().__init__() 531 | self.norm1 = norm_layer(dim) 532 | self.attn = Attention( 533 | dim, 534 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 535 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) 536 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 537 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 538 | self.norm2 = norm_layer(dim) 539 | mlp_hidden_dim = int(dim * mlp_ratio) 540 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 541 | 542 | self.apply(self._init_weights) 543 | 544 | def _init_weights(self, m): 545 | if isinstance(m, nn.Linear): 546 | trunc_normal_(m.weight, std=.02) 547 | if isinstance(m, nn.Linear) and m.bias is not None: 548 | nn.init.constant_(m.bias, 0) 549 | elif isinstance(m, nn.LayerNorm): 550 | nn.init.constant_(m.bias, 0) 551 | nn.init.constant_(m.weight, 1.0) 552 | elif isinstance(m, nn.Conv2d): 553 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 554 | fan_out //= m.groups 555 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 556 | if m.bias is not None: 557 | m.bias.data.zero_() 558 | 559 | def forward(self, x, H, W): 560 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 561 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) 562 | 563 | return x 564 | 565 | 566 | class OverlapPatchEmbed(nn.Module): 567 | """ Image to Patch Embedding 568 | """ 569 | 570 | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): 571 | super().__init__() 572 | img_size = to_2tuple(img_size) 573 | patch_size = to_2tuple(patch_size) 574 | 575 | self.img_size = img_size 576 | self.patch_size = patch_size 577 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] 578 | self.num_patches = self.H * self.W 579 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, 580 | padding=(patch_size[0] // 2, patch_size[1] // 2)) 581 | self.norm = nn.LayerNorm(embed_dim) 582 | 583 | self.apply(self._init_weights) 584 | 585 | def _init_weights(self, m): 586 | if isinstance(m, nn.Linear): 587 | trunc_normal_(m.weight, std=.02) 588 | if isinstance(m, nn.Linear) and m.bias is not None: 589 | nn.init.constant_(m.bias, 0) 590 | elif isinstance(m, nn.LayerNorm): 591 | nn.init.constant_(m.bias, 0) 592 | nn.init.constant_(m.weight, 1.0) 593 | elif isinstance(m, nn.Conv2d): 594 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 595 | fan_out //= m.groups 596 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 597 | if m.bias is not None: 598 | m.bias.data.zero_() 599 | 600 | def forward(self, x): 601 | x = self.proj(x) 602 | _, _, H, W = x.shape 603 | x = x.flatten(2).transpose(1, 2) 604 | x = self.norm(x) 605 | 606 | return x, H, W 607 | 608 | 609 | class MixVisionTransformer(nn.Module): 610 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], 611 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., 612 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, 613 | depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]): 614 | super().__init__() 615 | self.num_classes = num_classes 616 | self.depths = depths 617 | 618 | # patch_embed 619 | self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, 620 | embed_dim=embed_dims[0]) # 1/4 621 | self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], 622 | embed_dim=embed_dims[1]) # 1/8 623 | self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], 624 | embed_dim=embed_dims[2]) # auxilary output 625 | self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2], 626 | embed_dim=embed_dims[3]) # 1/32 627 | 628 | # transformer encoder 629 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 630 | cur = 0 631 | self.block1 = nn.ModuleList([Block( 632 | dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, 633 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 634 | sr_ratio=sr_ratios[0]) 635 | for i in range(depths[0])]) 636 | self.norm1 = norm_layer(embed_dims[0]) 637 | 638 | cur += depths[0] 639 | self.block2 = nn.ModuleList([Block( 640 | dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, 641 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 642 | sr_ratio=sr_ratios[1]) 643 | for i in range(depths[1])]) 644 | self.norm2 = norm_layer(embed_dims[1]) 645 | 646 | cur += depths[1] 647 | self.block3 = nn.ModuleList([Block( 648 | dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, 649 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 650 | sr_ratio=sr_ratios[2]) 651 | for i in range(depths[2])]) 652 | self.norm3 = norm_layer(embed_dims[2]) 653 | 654 | cur += depths[2] 655 | self.block4 = nn.ModuleList([Block( 656 | dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, 657 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 658 | sr_ratio=sr_ratios[3]) 659 | for i in range(depths[3])]) 660 | self.norm4 = norm_layer(embed_dims[3]) 661 | 662 | # classification head 663 | # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() 664 | 665 | self.apply(self._init_weights) 666 | 667 | def _init_weights(self, m): 668 | if isinstance(m, nn.Linear): 669 | trunc_normal_(m.weight, std=.02) 670 | if isinstance(m, nn.Linear) and m.bias is not None: 671 | nn.init.constant_(m.bias, 0) 672 | elif isinstance(m, nn.LayerNorm): 673 | nn.init.constant_(m.bias, 0) 674 | nn.init.constant_(m.weight, 1.0) 675 | elif isinstance(m, nn.Conv2d): 676 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 677 | fan_out //= m.groups 678 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 679 | if m.bias is not None: 680 | m.bias.data.zero_() 681 | 682 | def init_weights(self, pretrained=None): 683 | if isinstance(pretrained, str): 684 | logger = get_root_logger() 685 | load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) 686 | 687 | def reset_drop_path(self, drop_path_rate): 688 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] 689 | cur = 0 690 | for i in range(self.depths[0]): 691 | self.block1[i].drop_path.drop_prob = dpr[cur + i] 692 | 693 | cur += self.depths[0] 694 | for i in range(self.depths[1]): 695 | self.block2[i].drop_path.drop_prob = dpr[cur + i] 696 | 697 | cur += self.depths[1] 698 | for i in range(self.depths[2]): 699 | self.block3[i].drop_path.drop_prob = dpr[cur + i] 700 | 701 | cur += self.depths[2] 702 | for i in range(self.depths[3]): 703 | self.block4[i].drop_path.drop_prob = dpr[cur + i] 704 | 705 | def freeze_patch_emb(self): 706 | self.patch_embed1.requires_grad = False 707 | 708 | @torch.jit.ignore 709 | def no_weight_decay(self): 710 | return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better 711 | 712 | def get_classifier(self): 713 | return self.head 714 | 715 | def reset_classifier(self, num_classes, global_pool=''): 716 | self.num_classes = num_classes 717 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 718 | 719 | def forward_features(self, x): 720 | B = x.shape[0] 721 | outs = [] 722 | 723 | # stage 1 724 | x, H, W = self.patch_embed1(x) 725 | for i, blk in enumerate(self.block1): 726 | x = blk(x, H, W) 727 | x = self.norm1(x) 728 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 729 | outs.append(x) 730 | 731 | # stage 2 732 | x, H, W = self.patch_embed2(x) 733 | for i, blk in enumerate(self.block2): 734 | x = blk(x, H, W) 735 | x = self.norm2(x) 736 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 737 | outs.append(x) 738 | 739 | # stage 3 740 | x, H, W = self.patch_embed3(x) 741 | for i, blk in enumerate(self.block3): 742 | x = blk(x, H, W) 743 | x = self.norm3(x) 744 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 745 | outs.append(x) 746 | 747 | # stage 4 748 | x, H, W = self.patch_embed4(x) 749 | for i, blk in enumerate(self.block4): 750 | x = blk(x, H, W) 751 | x = self.norm4(x) 752 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 753 | outs.append(x) 754 | 755 | return outs 756 | 757 | def forward(self, x): 758 | x = self.forward_features(x) 759 | # x = self.head(x) 760 | 761 | return x 762 | 763 | 764 | class DWConv(nn.Module): 765 | def __init__(self, dim=768): 766 | super(DWConv, self).__init__() 767 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 768 | 769 | def forward(self, x, H, W): 770 | B, N, C = x.shape 771 | x = x.transpose(1, 2).view(B, C, H, W).contiguous() 772 | x = self.dwconv(x) 773 | x = x.flatten(2).transpose(1, 2) 774 | 775 | return x 776 | 777 | 778 | 779 | #@BACKBONES.register_module() 780 | class mit_b0(MixVisionTransformer): 781 | def __init__(self, **kwargs): 782 | super(mit_b0, self).__init__( 783 | patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 784 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 785 | drop_rate=0.0, drop_path_rate=0.1) 786 | 787 | 788 | #@BACKBONES.register_module() 789 | class mit_b1(MixVisionTransformer): 790 | def __init__(self, **kwargs): 791 | super(mit_b1, self).__init__( 792 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 793 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 794 | drop_rate=0.0, drop_path_rate=0.1) 795 | 796 | 797 | #@BACKBONES.register_module() 798 | class mit_b2(MixVisionTransformer): 799 | def __init__(self, **kwargs): 800 | super(mit_b2, self).__init__( 801 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 802 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], 803 | drop_rate=0.0, drop_path_rate=0.1) 804 | 805 | 806 | #@BACKBONES.register_module() 807 | class mit_b3(MixVisionTransformer): 808 | def __init__(self, **kwargs): 809 | super(mit_b3, self).__init__( 810 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 811 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], 812 | drop_rate=0.0, drop_path_rate=0.1) 813 | 814 | 815 | #@BACKBONES.register_module() 816 | class mit_b4(MixVisionTransformer): 817 | def __init__(self, **kwargs): 818 | super(mit_b4, self).__init__( 819 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 820 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], 821 | drop_rate=0.0, drop_path_rate=0.1) 822 | 823 | 824 | #@BACKBONES.register_module() 825 | class mit_b5(MixVisionTransformer): 826 | def __init__(self, **kwargs): 827 | super(mit_b5, self).__init__( 828 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 829 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], 830 | drop_rate=0.0, drop_path_rate=0.1) 831 | 832 | class SegFormer(nn.Module): 833 | def __init__(self, num_classes, load_imagenet_model, imagenet_ckpt_fpath, **kwargs): 834 | super(SegFormer, self).__init__(**kwargs) 835 | 836 | self.encoder = mit_b5() 837 | # self.head = BilinearPADHead_fast_xavier_init(num_classes=num_classes, 838 | # c1_in_channels=64, 839 | # c1_channels=48, 840 | # upsample_factor=8, 841 | # dyn_branch_ch=16, 842 | # mask_head_ch=16, 843 | # pad_out_channel_factor=4, 844 | # channel_reduce_factor=2, 845 | # zero_init=False, 846 | # supress_std=True, 847 | # feature_strides=None, 848 | # in_channels=512, 849 | # channels=512, 850 | # in_index=3, 851 | # dilations=(1, 3, 6, 9), 852 | # dropout_ratio=0.1, 853 | # norm_cfg=dict(type='SyncBN', requires_grad=True), 854 | # align_corners=False,) 855 | 856 | self.head = SegFormerHead(num_classes=num_classes, 857 | in_channels=[64, 128, 320, 512], 858 | channels=128, 859 | in_index=[0,1,2,3], 860 | feature_strides=[4, 8, 16, 32], 861 | #decoder_params=dict(embed_dim=768), 862 | dropout_ratio=0.1, 863 | norm_cfg=dict(type='SyncBN', requires_grad=True), 864 | align_corners=False) 865 | self.auxi_net = FCNHead(num_convs=1, 866 | kernel_size=3, 867 | concat_input=True, 868 | in_channels=320, 869 | num_classes=num_classes, 870 | norm_cfg=dict(type='SyncBN', requires_grad=True)) 871 | self.init_weights(load_imagenet_model, imagenet_ckpt_fpath) 872 | 873 | def init_weights(self, load_imagenet_model: bool=False, imagenet_ckpt_fpath: str='') -> None: 874 | """ For training, we use a model pretrained on ImageNet. Irrelevant at inference. 875 | Args: 876 | - pretrained_fpath: str representing path to pretrained model 877 | Returns: 878 | - None 879 | """ 880 | logger.info('=> init weights from normal distribution') 881 | if not load_imagenet_model: 882 | return 883 | if os.path.isfile(imagenet_ckpt_fpath): 884 | print('===========> loading pretrained model {}'.format(imagenet_ckpt_fpath)) 885 | self.encoder.init_weights(pretrained=imagenet_ckpt_fpath) 886 | else: 887 | # logger.info(pretrained) 888 | print('cannot find ImageNet model path, use random initialization') 889 | raise RuntimeError('no pretrained model found at {}'.format(imagenet_ckpt_fpath)) 890 | 891 | def forward(self, inputs): 892 | h = inputs.size()[2] 893 | w = inputs.size()[3] 894 | x = self.encoder(inputs) 895 | #out = self.head([x[3], x[0]]) 896 | out = self.head(x) 897 | auxi_out = self.auxi_net(x) 898 | high_out = F.interpolate(out, size=(h,w), mode='bilinear', align_corners=True) 899 | return high_out, out, auxi_out 900 | 901 | 902 | class SegModel(nn.Module): 903 | def __init__(self, criterions, num_classes, load_imagenet_model, imagenet_ckpt_fpath, **kwargs): 904 | super(SegModel, self).__init__(**kwargs) 905 | self.segmodel = SegFormer(num_classes=num_classes, 906 | load_imagenet_model=load_imagenet_model, 907 | imagenet_ckpt_fpath=imagenet_ckpt_fpath) 908 | self.criterion = None 909 | def forward(self, inputs, gt=None, label_space=None, others=None): 910 | high_reso, low_reso, auxi_out = self.segmodel(inputs) 911 | return high_reso, None, None 912 | 913 | def get_seg_model( 914 | criterion: list, 915 | n_classes: int, 916 | load_imagenet_model: bool = False, 917 | imagenet_ckpt_fpath: str = '', 918 | **kwargs 919 | ) -> nn.Module: 920 | model = SegModel(criterions=criterion, 921 | num_classes=n_classes, 922 | load_imagenet_model=load_imagenet_model, 923 | imagenet_ckpt_fpath=imagenet_ckpt_fpath) 924 | assert isinstance(model, nn.Module) 925 | return model 926 | 927 | def get_configured_segformer( 928 | n_classes: int, 929 | criterion: list, 930 | load_imagenet_model: bool = False, 931 | imagenet_ckpt_fpath: str = '', 932 | ) -> nn.Module: 933 | """ 934 | Args: 935 | - n_classes: integer representing number of output classes 936 | - load_imagenet_model: whether to initialize from ImageNet-pretrained model 937 | - imagenet_ckpt_fpath: string representing path to file with weights to 938 | initialize model with 939 | Returns: 940 | - model: HRNet model w/ architecture configured according to model yaml, 941 | and with specified number of classes and weights initialized 942 | (at training, init using imagenet-pretrained model) 943 | """ 944 | 945 | model = get_seg_model(criterion, n_classes, load_imagenet_model, imagenet_ckpt_fpath) 946 | return model 947 | 948 | 949 | if __name__=='__main__': 950 | imagenet_ckpt_fpath = '' 951 | load_imagenet_model = False 952 | criterions=[] 953 | from mseg_semantic.model.criterion import Cross_sim_loss 954 | 955 | loss_method = Cross_sim_loss(data_index=['universal'], 956 | data_root='./data', 957 | ignore_label=255, 958 | emd_method='wiki_embeddings') 959 | criterions.append(loss_method) 960 | model = get_configured_segformer(180, criterions, load_imagenet_model, imagenet_ckpt_fpath) 961 | num_p = sum(p.numel() for p in model.parameters() if p.requires_grad) 962 | print(num_p) 963 | --------------------------------------------------------------------------------