├── CCTNet.png ├── LICENSE ├── README.md ├── __init__.py ├── autorun.sh ├── custom_transforms.py ├── data └── .gitignore ├── dataset.py ├── loss.py ├── models ├── __init__.py ├── banet.py ├── beit.py ├── bisenetv2.py ├── cctnet.py ├── checkpoint.py ├── cswin.py ├── danet.py ├── deeplabv3.py ├── edgenet.py ├── fcn.py ├── fpn.py ├── head │ ├── __init__.py │ ├── ann.py │ ├── apc.py │ ├── aspp.py │ ├── aspp_plus.py │ ├── base_decoder.py │ ├── cefpn.py │ ├── da.py │ ├── dnl.py │ ├── edge.py │ ├── fcfpn.py │ ├── fcn.py │ ├── gc.py │ ├── mlp.py │ ├── psa.py │ ├── psp.py │ ├── seg.py │ ├── unet.py │ └── uper.py ├── hrnet.py ├── model_store.py ├── pspnet.py ├── resT.py ├── resnet.py ├── segbase.py ├── swinT.py ├── transformer.py ├── unet.py ├── utils.py └── volo.py ├── mutil_scale_test.py ├── post_process.py ├── pre_process.py ├── pretrained_weights └── .gitignore ├── requirements.txt ├── seg_metric.py ├── test.py ├── tools ├── edge │ └── .gitignore ├── flops_params_fps_count.py ├── generate_edge.py ├── generate_heatmap.py ├── heat_map.py ├── heatmap │ └── outputs │ │ ├── ori_image.png │ │ └── ori_label.png ├── heatmap_fun.py └── utils.py ├── train.py └── work_dir └── .gitignore /CCTNet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyxu1996/CCTNet/5a5db40d2e38bd478b404583050049eedca90844/CCTNet.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CCTNet: Coupled CNN and Transformer Network for Crop Segmentation of Remote Sensing Images, [RemoteSensing](https://www.mdpi.com/2072-4292/14/9/1956/htm) 2 | ## Introduction 3 | We propose a Coupled CNN and Transformer Network to combine the local modeling advantage of the CNN and the global modeling advantage of Transformer to achieve SOTA performance on the [Barley Remote Sensing Dataset](https://tianchi.aliyun.com/dataset/dataDetail?spm=5176.12281978.0.0.76944054ZQD0l2&dataId=74952). By applying our code base, you can easily deal with ultra-high-resolution remote sensing images. If our work is helpful to you, please star us. 4 | 5 | CCTNet Framework
6 | ## Usage 7 | * Install packages 8 | 9 | This repository is based on `python 3.6.12` and `torch 1.6.0`. 10 | 11 | ``` 12 | git clone https://github.com/zyxu1996/CCTNet.git 13 | cd CCTNet 14 | ``` 15 | ``` 16 | pip install -r requirements.txt 17 | ``` 18 | * Prepare datasets and pretrained weights 19 | 20 | * The code base has supported three high-resolution datasets, are respective Barley, Potsdam and Vaihingen. 21 | * Download `Barley, Potsdam and Vaihingen` datasets form BaiduYun, and put them on `./data ` 22 | `BaiduYun`: [https://pan.baidu.com/s/1MyDw1qncPKYJFK_zjFxFBA](https://pan.baidu.com/s/1MyDw1qncPKYJFK_zjFxFBA) 23 | `Password`: s7f2 24 | 25 | Data file structure of the above three datasets is as followed. 26 | ``` 27 | ├── data ├── data ├── data 28 | ├──barley ├──potsdam ├──vaihingen 29 | ├──images ├──images ├──images 30 | ├──image_1_0_0.png ├──top_potsdam_2_10.tif ├──top_mosaic_09cm_area1.tif 31 | ├──image_1_0_1.png ├──top_potsdam_2_11.tif ├──top_mosaic_09cm_area2.tif 32 | ... ... ... 33 | ├──labels ├──labels ├──labels 34 | ├──image_1_0_0.png ├──top_potsdam_2_10.png ├──top_mosaic_09cm_area1.png 35 | ├──image_1_0_1.png ├──top_potsdam_2_11.png ├──top_mosaic_09cm_area2.png 36 | ... ... ... 37 | ├──annotations ├──annotations ├──annotations 38 | ├──train.txt ├──train.txt ├──train.txt 39 | ├──test.txt ├──test.txt ├──test.txt 40 | 41 | ``` 42 | 43 | * Download the pretained weights from [CSwin-Transformer](https://github.com/microsoft/CSWin-Transformer), and put them on `./pretrained_weights` 44 | CSwin: `CSwin Tiny, Small, Base and Large` pretrained on `ImageNet-1K` and `ImageNet-22K` are used. 45 | ResNet: `ResNet 18, 34, 50 and 101` pretrained models are used, the download link is contained in the our code. 46 | 47 | * Training 48 | 49 | * The training and testing settings are written in the script, including the selection of datasets and models. 50 | ``` 51 | sh autorun.sh 52 | ``` 53 | * If directly run train.py, please undo the following code. 54 | ``` 55 | if __name__ == '__main__': 56 | os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" 57 | os.environ.setdefault('RANK', '0') 58 | os.environ.setdefault('WORLD_SIZE', '1') 59 | os.environ.setdefault('MASTER_ADDR', '127.0.0.1') 60 | os.environ.setdefault('MASTER_PORT', '29556') 61 | ``` 62 | * Testing 63 | * Generating the final results and visulizing the prediction. 64 | ``` 65 | cd ./work_dir/your_work 66 | ``` 67 | * Do remember undo the test command in `sh autorun.sh`. And keep the `--information num1` in testing command is same as the information in training command. 68 | `CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --nproc_per_node=1 --master_port 29505 test.py --dataset barley --val_batchsize 8 --models cctnet --head seghead --crop_size 512 512 --trans_cnn cswin_tiny resnet50 --save_dir work_dir --base_dir ../../ --information num1 69 | ` 70 | * Then run the script autorun.sh. 71 | ``` 72 | sh autorun.sh 73 | ``` 74 | ## Acknowledgments 75 | Thanks Guangzhou Jingwei Information Technology Co., Ltd., and the Xingren City government for providing the Barley Remote Sensing Dataset. 76 | Thanks the ISPRS for providing the Potsdam and Vaihingen datasets. 77 | ## Citation 78 | ``` 79 | @article{wang2022cctnet, 80 | title={CCTNet: Coupled CNN and Transformer Network for Crop Segmentation of Remote Sensing Images}, 81 | author={Wang, Hong and Chen, Xianzhong and Zhang, Tianxiang and Xu, Zhiyong and Li, Jiangyun}, 82 | journal={Remote Sensing}, 83 | volume={14}, 84 | number={9}, 85 | pages={1956}, 86 | year={2022}, 87 | publisher={MDPI} 88 | } 89 | ``` 90 | ## Other Links 91 | * [HRCNet: High-Resolution Context Extraction Network for Semantic Segmentation of Remote Sensing Images](https://github.com/zyxu1996/HRCNet-High-Resolution-Context-Extraction-Network) 92 | * [Efficient Transformer for Remote Sensing Image Segmentation](https://github.com/zyxu1996/Efficient-Transformer) 93 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import * -------------------------------------------------------------------------------- /autorun.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | ################### Test ################# 3 | 4 | #CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --nproc_per_node=1 --master_port 29505 test.py --dataset barley --val_batchsize 8 --models cctnet --head seghead --crop_size 512 512 --trans_cnn cswin_tiny resnet50 --save_dir work_dir --base_dir ../../ --information num1 5 | 6 | 7 | ################### Train ################# 8 | 9 | # barley 10 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port 29506 train.py --dataset barley --end_epoch 50 --lr 0.0001 --train_batchsize 4 --models cctnet --head seghead --crop_size 512 512 --trans_cnn cswin_tiny resnet50 --use_mixup 0 --information num1 11 | 12 | # vaihingen 13 | #CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port 29507 train.py --dataset vaihingen --end_epoch 100 --lr 0.0003 --train_batchsize 4 --models cctnet --head seghead --crop_size 512 512 --trans_cnn cswin_tiny resnet50 --use_mixup 0 --information num2 14 | 15 | 16 | # potsdam 17 | #CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port 29508 train.py --dataset potsdam --end_epoch 50 --lr 0.0001 --train_batchsize 4 --models cctnet --head seghead --crop_size 512 512 --trans_cnn cswin_tiny resnet50 --use_mixup 0 --information num3 -------------------------------------------------------------------------------- /custom_transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | import cv2 5 | import os 6 | import torch.nn as nn 7 | from torchvision import transforms 8 | 9 | class RandomHorizontalFlip(object): 10 | def __call__(self, sample): 11 | image = sample['image'] 12 | label = sample['label'] 13 | if random.random() < 0.5: 14 | image = cv2.flip(image, 1) 15 | label = cv2.flip(label, 1) 16 | 17 | return {'image': image, 'label': label} 18 | 19 | 20 | class RandomVerticalFlip(object): 21 | def __call__(self, sample): 22 | image = sample['image'] 23 | label = sample['label'] 24 | if random.random() < 0.5: 25 | image = cv2.flip(image, 0) 26 | label = cv2.flip(label, 0) 27 | 28 | return {'image': image, 'label': label} 29 | 30 | 31 | class RandomScaleCrop(object): 32 | def __init__(self, base_size=None, crop_size=None, fill=0): 33 | """shape [H, W]""" 34 | if base_size is None: 35 | base_size = [512, 512] 36 | if crop_size is None: 37 | crop_size = [512, 512] 38 | self.base_size = np.array(base_size) 39 | self.crop_size = np.array(crop_size) 40 | self.fill = fill 41 | 42 | def __call__(self, sample): 43 | img = sample['image'] 44 | mask = sample['label'] 45 | # random scale (short edge) 46 | short_size = random.choice([self.base_size * 0.5, self.base_size * 0.75, self.base_size, 47 | self.base_size * 1.25, self.base_size * 1.5]) 48 | short_size = short_size.astype(np.int) 49 | h, w = img.shape[0:2] 50 | if h > w: 51 | ow = short_size[1] 52 | oh = int(1.0 * h * ow / w) 53 | else: 54 | oh = short_size[0] 55 | ow = int(1.0 * w * oh / h) 56 | #img = img.resize((ow, oh), Image.BILINEAR) 57 | #mask = mask.resize((ow, oh), Image.NEAREST) 58 | img = cv2.resize(img, (ow, oh), interpolation=cv2.INTER_LINEAR) 59 | mask = cv2.resize(mask, (ow, oh), interpolation=cv2.INTER_NEAREST) 60 | # pad crop 61 | if short_size[0] < self.crop_size[0] or short_size[1] < self.crop_size[1]: 62 | padh = self.crop_size[0] - oh if oh < self.crop_size[0] else 0 63 | padw = self.crop_size[1] - ow if ow < self.crop_size[1] else 0 64 | #img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) 65 | #mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.fill) 66 | img = cv2.copyMakeBorder(img, 0, padh, 0, padw, borderType=cv2.BORDER_DEFAULT) 67 | mask = cv2.copyMakeBorder(mask, 0, padh, 0, padw, borderType=cv2.BORDER_DEFAULT) 68 | # random crop crop_size 69 | h, w = img.shape[0:2] 70 | x1 = random.randint(0, w - self.crop_size[1]) 71 | y1 = random.randint(0, h - self.crop_size[0]) 72 | img = img[y1:y1+self.crop_size[0], x1:x1+self.crop_size[1], :] 73 | mask = mask[y1:y1+self.crop_size[0], x1:x1+self.crop_size[1]] 74 | return {'image': img, 'label': mask} 75 | 76 | 77 | class ImageSplit(nn.Module): 78 | def __init__(self, numbers=None): 79 | super(ImageSplit, self).__init__() 80 | """numbers [H, W] 81 | split from left to right, top to bottom""" 82 | if numbers is None: 83 | numbers = [2, 2] 84 | self.num = numbers 85 | 86 | def forward(self, x): 87 | flag = None 88 | if len(x.shape) == 3: 89 | x = x.unsqueeze(dim=1) 90 | flag = 1 91 | b, c, h, w = x.shape 92 | num_h, num_w = self.num[0], self.num[1] 93 | assert h % num_h == 0 and w % num_w == 0 94 | split_h, split_w = h // num_h, w // num_w 95 | 96 | outputs = [] 97 | outputss = [] 98 | for i in range(b): 99 | for h_i in range(num_h): 100 | for w_i in range(num_w): 101 | output = x[i][:, split_h * h_i: split_h * (h_i + 1), 102 | split_w * w_i: split_w * (w_i + 1)].unsqueeze(dim=0) 103 | outputs.append(output) 104 | outputs = torch.cat(outputs, dim=0).unsqueeze(dim=0) 105 | outputss.append(outputs) 106 | outputs = [] 107 | outputss = torch.cat(outputss, dim=0).contiguous() 108 | if flag is not None: 109 | outputss = outputss.squeeze(dim=2) 110 | return outputss 111 | 112 | 113 | class ToTensor(object): 114 | """Convert ndarrays in sample to Tensors.""" 115 | def __init__(self, add_edge=True): 116 | """imagenet normalize""" 117 | self.normalize = transforms.Normalize((.485, .456, .406), (.229, .224, .225)) 118 | self.add_edge = add_edge 119 | 120 | def get_edge(self, img, edge_width=3): 121 | gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 122 | gray = cv2.GaussianBlur(gray, (11, 11), 0) 123 | edge = cv2.Canny(gray, 50, 150) 124 | # cv2.imshow('edge', edge) 125 | # cv2.waitKey(0) 126 | kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (edge_width, edge_width)) 127 | edge = cv2.dilate(edge, kernel) 128 | edge = edge / 255 129 | edge = torch.from_numpy(edge).unsqueeze(dim=0).float() 130 | 131 | return edge 132 | 133 | def __call__(self, sample): 134 | # swap color axis because 135 | # numpy image: H x W x C 136 | # torch image: C X H X W 137 | img = sample['image'] 138 | mask = sample['label'] 139 | 140 | mask = np.expand_dims(mask, axis=2) 141 | img = np.array(img).astype(np.float32).transpose((2, 0, 1)) 142 | mask = np.array(mask).astype(np.int64).transpose((2, 0, 1)) 143 | 144 | img = torch.from_numpy(img).float().div(255) 145 | img = self.normalize(img) 146 | mask = torch.from_numpy(mask).float() 147 | 148 | if self.add_edge: 149 | edge = self.get_edge(sample['image']) 150 | img = img + edge 151 | 152 | return {'image': img, 'label': mask} 153 | 154 | 155 | class RGBGrayExchange(): 156 | def __init__(self, path=None, palette=None): 157 | self.palette = palette 158 | """RGB format""" 159 | if palette is None: 160 | self.palette = [[255, 255, 255], [0, 0, 255], [0, 255, 255], 161 | [0, 255, 0], [255, 255, 0], [255, 0, 0]] 162 | self.path = path 163 | 164 | def read_img(self): 165 | img = cv2.imread(self.path, cv2.IMREAD_UNCHANGED) 166 | if len(img.shape) == 3: 167 | img = img[:, :, ::-1] 168 | return img 169 | 170 | def RGB_to_Gray(self, image=None): 171 | if not self.path is None: 172 | image = self.read_img() 173 | Gray = np.zeros(shape=[image.shape[0], image.shape[1]], dtype=np.uint8) 174 | for i in range(len(self.palette)): 175 | index = image == np.array(self.palette[i]) 176 | index[..., 0][index[..., 1] == False] = False 177 | index[..., 0][index[..., 2] == False] = False 178 | Gray[index[..., 0]] = i 179 | print('unique pixels:{}'.format(np.unique(Gray))) 180 | return Gray 181 | 182 | def Gray_to_RGB(self, image=None): 183 | if not self.path is None: 184 | image = self.read_img() 185 | RGB = np.zeros(shape=[image.shape[0], image.shape[1], 3], dtype=np.uint8) 186 | for i in range(len(self.palette)): 187 | index = image == i 188 | RGB[index] = np.array(self.palette[i]) 189 | print('unique pixels:{}'.format(np.unique(RGB))) 190 | return RGB 191 | 192 | 193 | class Mixup(nn.Module): 194 | def __init__(self, alpha=1.0, use_edge=False): 195 | super(Mixup, self).__init__() 196 | self.alpha = alpha 197 | self.use_edge = use_edge 198 | 199 | def criterion(self, lam, outputs, targets_a, targets_b, criterion): 200 | return lam * criterion(outputs, targets_a) + (1 - lam) * criterion(outputs, targets_b) 201 | 202 | def forward(self, inputs, targets, criterion, model): 203 | if self.alpha > 0: 204 | lam = np.random.beta(self.alpha, self.alpha) 205 | else: 206 | lam = 1 207 | batch_size = inputs.size(0) 208 | index = torch.randperm(batch_size).cuda() 209 | mix_inputs = lam*inputs + (1-lam)*inputs[index, :] 210 | targets_a, targets_b = targets, targets[index] 211 | outputs = model(mix_inputs) 212 | 213 | losses = 0 214 | if isinstance(outputs, (list, tuple)): 215 | if self.use_edge: 216 | for i in range(len(outputs) - 1): 217 | loss = self.criterion(lam, outputs[i], targets_a, targets_b, criterion[0]) 218 | losses += loss 219 | edge_targets_a = edge_contour(targets).long() 220 | edge_targets_b = edge_targets_a[index] 221 | loss2 = self.criterion(lam, outputs[-1], edge_targets_a, edge_targets_b, criterion[1]) 222 | losses += loss2 223 | else: 224 | for i in range(len(outputs)): 225 | loss = self.criterion(lam, outputs[i], targets_a, targets_b, criterion) 226 | losses += loss 227 | else: 228 | losses = self.criterion(lam, outputs, targets_a, targets_b, criterion) 229 | return losses 230 | 231 | 232 | def edge_contour(label, edge_width=3): 233 | import cv2 234 | cuda_type = label.is_cuda 235 | label = label.cpu().numpy().astype(np.int) 236 | b, h, w = label.shape 237 | edge = np.zeros(label.shape) 238 | 239 | # right 240 | edge_right = edge[:, 1:h, :] 241 | edge_right[(label[:, 1:h, :] != label[:, :h - 1, :]) & (label[:, 1:h, :] != 255) 242 | & (label[:, :h - 1, :] != 255)] = 1 243 | 244 | # up 245 | edge_up = edge[:, :, :w - 1] 246 | edge_up[(label[:, :, :w - 1] != label[:, :, 1:w]) 247 | & (label[:, :, :w - 1] != 255) 248 | & (label[:, :, 1:w] != 255)] = 1 249 | 250 | # upright 251 | edge_upright = edge[:, :h - 1, :w - 1] 252 | edge_upright[(label[:, :h - 1, :w - 1] != label[:, 1:h, 1:w]) 253 | & (label[:, :h - 1, :w - 1] != 255) 254 | & (label[:, 1:h, 1:w] != 255)] = 1 255 | 256 | # bottomright 257 | edge_bottomright = edge[:, :h - 1, 1:w] 258 | edge_bottomright[(label[:, :h - 1, 1:w] != label[:, 1:h, :w - 1]) 259 | & (label[:, :h - 1, 1:w] != 255) 260 | & (label[:, 1:h, :w - 1] != 255)] = 1 261 | 262 | kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (edge_width, edge_width)) 263 | for i in range(edge.shape[0]): 264 | edge[i] = cv2.dilate(edge[i], kernel) 265 | 266 | # edge[edge == 1] = 255 # view edge 267 | # import random 268 | # cv2.imwrite(os.path.join('./edge', '{}.png'.format(random.random())), edge[0]) 269 | if cuda_type: 270 | edge = torch.from_numpy(edge).cuda() 271 | else: 272 | edge = torch.from_numpy(edge) 273 | 274 | return edge 275 | 276 | 277 | if __name__ == '__main__': 278 | path = './data/vaihingen/annotations/labels' 279 | filelist = os.listdir(path) 280 | for file in filelist: 281 | print(file) 282 | img = cv2.imread(os.path.join(path, file), cv2.IMREAD_UNCHANGED) 283 | img = torch.from_numpy(img).unsqueeze(dim=0).repeat(2, 1, 1) 284 | img = edge_contour(img) 285 | # cv2.imwrite(os.path.join(save_path, os.path.splitext(file)[0] + '.png'), gray) 286 | -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyxu1996/CCTNet/5a5db40d2e38bd478b404583050049eedca90844/data/.gitignore -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | import torch.utils.data as data 5 | from torchvision import transforms 6 | import custom_transforms as tr 7 | import tifffile as tiff 8 | import math 9 | 10 | 11 | class RemoteData(data.Dataset): 12 | def __init__(self, base_dir='./data/', train=True, dataset='vaihingen', crop_szie=None, val_full_img=False): 13 | super(RemoteData, self).__init__() 14 | self.dataset_dir = base_dir 15 | self.train = train 16 | self.dataset = dataset 17 | self.val_full_img = val_full_img 18 | self.images = [] 19 | self.labels = [] 20 | self.names = [] 21 | self.alphas = [] 22 | alpha = None 23 | if crop_szie is None: 24 | crop_szie = [512, 512] 25 | self.crop_size = crop_szie 26 | if train: 27 | self.image_dir = os.path.join(self.dataset_dir, self.dataset + '/images') 28 | self.label_dir = os.path.join(self.dataset_dir, self.dataset + '/labels') 29 | txt = os.path.join(self.dataset_dir, self.dataset + '/annotations' + '/train.txt') 30 | else: 31 | self.image_dir = os.path.join(self.dataset_dir, self.dataset + '/images') 32 | self.label_dir = os.path.join(self.dataset_dir, self.dataset + '/labels') 33 | txt = os.path.join(self.dataset_dir, self.dataset + '/annotations' + '/test.txt') 34 | 35 | with open(txt, "r") as f: 36 | self.filename_list = f.readlines() 37 | for filename in self.filename_list: 38 | if self.dataset in ['barley']: 39 | image = os.path.join(self.image_dir, filename.strip() + '.png') 40 | image = Image.open(image) 41 | image = np.array(image) 42 | if image.shape[2] == 4: 43 | alpha = image[..., 3] 44 | image = image[..., 0:3] 45 | else: 46 | image = os.path.join(self.image_dir, filename.strip() + '.tif') 47 | image = tiff.imread(image) 48 | label = os.path.join(self.label_dir, filename.strip() + '.png') 49 | label = Image.open(label) 50 | label = np.array(label) 51 | if self.val_full_img: 52 | self.images.append(image) 53 | self.labels.append(label) 54 | self.names.append(filename.strip()) 55 | if alpha is not None: 56 | self.alphas.append(alpha) 57 | else: 58 | if alpha is not None: 59 | slide_crop(image, label, self.crop_size, self.images, self.labels, self.dataset, 60 | alpha=alpha, alpha_patches=self.alphas, stride_rate=2/3) 61 | else: 62 | slide_crop(image, label, self.crop_size, self.images, self.labels, self.dataset, stride_rate=2/3) 63 | assert(len(self.images) == len(self.labels)) 64 | 65 | def __len__(self): 66 | return len(self.images) 67 | 68 | def __getitem__(self, index): 69 | sample = {'image': self.images[index], 'label': self.labels[index]} 70 | sample = self.transform(sample) 71 | if self.val_full_img: 72 | sample['name'] = self.names[index] 73 | if self.alphas != [] and self.train == False: 74 | sample['alpha'] = self.alphas[index] 75 | return sample 76 | 77 | def transform(self, sample): 78 | if self.train: 79 | if self.dataset in ['barley']: 80 | composed_transforms = transforms.Compose([ 81 | tr.RandomHorizontalFlip(), 82 | tr.RandomVerticalFlip(), 83 | tr.ToTensor(add_edge=False), 84 | ]) 85 | else: 86 | composed_transforms = transforms.Compose([ 87 | tr.RandomHorizontalFlip(), 88 | tr.RandomVerticalFlip(), 89 | tr.RandomScaleCrop(base_size=self.crop_size, crop_size=self.crop_size), 90 | tr.ToTensor(add_edge=False), 91 | ]) 92 | else: 93 | composed_transforms = transforms.Compose([ 94 | tr.ToTensor(add_edge=False), 95 | ]) 96 | return composed_transforms(sample) 97 | 98 | def __str__(self): 99 | return 'dataset:{} train:{}'.format(self.dataset, self.train) 100 | 101 | 102 | def slide_crop(image, label, crop_size, image_patches, label_patches, dataset, 103 | stride_rate=1.0/2.0, alpha=None, alpha_patches=None): 104 | """images shape [h, w, c]""" 105 | if len(image.shape) == 2: 106 | image = np.expand_dims(image, axis=2) 107 | if len(label.shape) == 2: 108 | label = np.expand_dims(label, axis=2) 109 | if alpha is not None: 110 | alpha = np.expand_dims(alpha, axis=2) 111 | stride_rate = stride_rate 112 | h, w, c = image.shape 113 | H, W = crop_size 114 | stride_h = int(H * stride_rate) 115 | stride_w = int(W * stride_rate) 116 | assert h >= crop_size[0] and w >= crop_size[1] 117 | h_grids = int(math.ceil(1.0 * (h - H) / stride_h)) + 1 118 | w_grids = int(math.ceil(1.0 * (w - W) / stride_w)) + 1 119 | for idh in range(h_grids): 120 | for idw in range(w_grids): 121 | h0 = idh * stride_h 122 | w0 = idw * stride_w 123 | h1 = min(h0 + H, h) 124 | w1 = min(w0 + W, w) 125 | if h1 == h and w1 != w: 126 | crop_img = image[h - H:h, w0:w0 + W, :] 127 | crop_label = label[h - H:h, w0:w0 + W, :] 128 | if alpha is not None: 129 | crop_alpha = alpha[h - H:h, w0:w0 + W, :] 130 | if w1 == w and h1 != h: 131 | crop_img = image[h0:h0 + H, w - W:w, :] 132 | crop_label = label[h0:h0 + H, w - W:w, :] 133 | if alpha is not None: 134 | crop_alpha = alpha[h0:h0 + H, w - W:w, :] 135 | if h1 == h and w1 == w: 136 | crop_img = image[h - H:h, w - W:w, :] 137 | crop_label = label[h - H:h, w - W:w, :] 138 | if alpha is not None: 139 | crop_alpha = alpha[h - H:h, w - W:w, :] 140 | if w1 != w and h1 != h: 141 | crop_img = image[h0:h0 + H, w0:w0 + W, :] 142 | crop_label = label[h0:h0 + H, w0:w0 + W, :] 143 | if alpha is not None: 144 | crop_alpha = alpha[h0:h0 + H, w0:w0 + W, :] 145 | crop_img = crop_img.squeeze() 146 | crop_label = crop_label.squeeze() 147 | if alpha is not None: 148 | crop_alpha = crop_alpha.squeeze() 149 | if (dataset in ['barley'] and np.any(crop_alpha > 0)) or dataset not in ['barley']: 150 | image_patches.append(crop_img) 151 | label_patches.append(crop_label) 152 | if alpha is not None: 153 | alpha_patches.append(crop_alpha) 154 | 155 | 156 | def label_to_RGB(image, classes=6): 157 | RGB = np.zeros(shape=[image.shape[0], image.shape[1], 3], dtype=np.uint8) 158 | if classes == 6: # potsdam and vaihingen 159 | palette = [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], [255, 255, 0], [255, 0, 0]] 160 | if classes == 4: # barley 161 | palette = [[255, 255, 255], [0, 255, 0], [255, 255, 0], [255, 0, 0]] 162 | for i in range(classes): 163 | index = image == i 164 | RGB[index] = np.array(palette[i]) 165 | return RGB 166 | 167 | 168 | def RGB_to_label(image=None, classes=6): 169 | if classes == 6: # potsdam and vaihingen 170 | palette = [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], [255, 255, 0], [255, 0, 0]] 171 | if classes == 4: # barley 172 | palette = [[255, 255, 255], [0, 255, 0], [255, 255, 0], [255, 0, 0]] 173 | label = np.zeros(shape=[image.shape[0], image.shape[1]], dtype=np.uint8) 174 | for i in range(len(palette)): 175 | index = image == np.array(palette[i]) 176 | index[..., 0][index[..., 1] == False] = False 177 | index[..., 0][index[..., 2] == False] = False 178 | label[index[..., 0]] = i 179 | return label 180 | 181 | 182 | if __name__ == '__main__': 183 | from torch.utils.data import DataLoader 184 | import matplotlib.pyplot as plt 185 | 186 | remotedata_train = RemoteData(train=True, dataset='vaihingen') 187 | dataloader = DataLoader(remotedata_train, batch_size=1, shuffle=False, num_workers=1) 188 | # print(dataloader) 189 | 190 | for ii, sample in enumerate(dataloader): 191 | im = sample['label'].numpy().astype(np.uint8) 192 | pic = sample['image'].numpy().astype(np.uint8) 193 | print(im.shape) 194 | im = np.squeeze(im, axis=0) 195 | pic = np.squeeze(pic, axis=0) 196 | print(im.shape) 197 | im = np.transpose(im, axes=[1, 2, 0])[:, :, 0:3] 198 | pic = np.transpose(pic, axes=[1, 2, 0])[:, :, 0:3] 199 | print(im.shape) 200 | im = np.squeeze(im, axis=2) 201 | # print(im) 202 | im = label_to_RGB(im) 203 | plt.imshow(pic) 204 | plt.show() 205 | plt.imshow(im) 206 | plt.show() 207 | if ii == 10: 208 | break 209 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyxu1996/CCTNet/5a5db40d2e38bd478b404583050049eedca90844/models/__init__.py -------------------------------------------------------------------------------- /models/danet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | try: 3 | from .resnet import resnet50_v1b 4 | except: 5 | from resnet import resnet50_v1b 6 | import torch.nn.functional as F 7 | import torch 8 | 9 | 10 | class SegBaseModel(nn.Module): 11 | r"""Base Model for Semantic Segmentation 12 | 13 | Parameters 14 | ---------- 15 | backbone : string 16 | Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50', 17 | 'resnet101' or 'resnet152'). 18 | """ 19 | 20 | def __init__(self, nclass, aux, backbone='resnet50', dilated=True, pretrained_base=False, **kwargs): 21 | super(SegBaseModel, self).__init__() 22 | self.aux = aux 23 | self.nclass = nclass 24 | if backbone == 'resnet50': 25 | self.pretrained = resnet50_v1b(pretrained=pretrained_base, dilated=dilated, **kwargs) 26 | 27 | def base_forward(self, x): 28 | """forwarding pre-trained network""" 29 | x = self.pretrained.conv1(x) 30 | x = self.pretrained.bn1(x) 31 | x = self.pretrained.relu(x) 32 | x = self.pretrained.maxpool(x) 33 | c1 = self.pretrained.layer1(x) 34 | c2 = self.pretrained.layer2(c1) 35 | c3 = self.pretrained.layer3(c2) 36 | c4 = self.pretrained.layer4(c3) 37 | 38 | return c1, c2, c3, c4 39 | 40 | 41 | class _FCNHead(nn.Module): 42 | def __init__(self, in_channels, channels, norm_layer=nn.BatchNorm2d, **kwargs): 43 | super(_FCNHead, self).__init__() 44 | inter_channels = in_channels // 4 45 | self.block = nn.Sequential( 46 | nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), 47 | norm_layer(inter_channels), 48 | nn.ReLU(inplace=True), 49 | nn.Dropout(0.1), 50 | nn.Conv2d(inter_channels, channels, 1) 51 | ) 52 | 53 | def forward(self, x): 54 | return self.block(x) 55 | 56 | 57 | class _PositionAttentionModule(nn.Module): 58 | """ Position attention module""" 59 | 60 | def __init__(self, in_channels, **kwargs): 61 | super(_PositionAttentionModule, self).__init__() 62 | self.conv_b = nn.Conv2d(in_channels, in_channels // 8, 1) 63 | self.conv_c = nn.Conv2d(in_channels, in_channels // 8, 1) 64 | self.conv_d = nn.Conv2d(in_channels, in_channels, 1) 65 | self.alpha = nn.Parameter(torch.zeros(1)) 66 | self.softmax = nn.Softmax(dim=-1) 67 | 68 | def forward(self, x): 69 | batch_size, _, height, width = x.size() 70 | feat_b = self.conv_b(x).view(batch_size, -1, height * width).permute(0, 2, 1) 71 | feat_c = self.conv_c(x).view(batch_size, -1, height * width) 72 | attention_s = self.softmax(torch.bmm(feat_b, feat_c)) 73 | feat_d = self.conv_d(x).view(batch_size, -1, height * width) 74 | feat_e = torch.bmm(feat_d, attention_s.permute(0, 2, 1)).view(batch_size, -1, height, width) 75 | out = self.alpha * feat_e + x 76 | 77 | return out 78 | 79 | 80 | class _ChannelAttentionModule(nn.Module): 81 | """Channel attention module""" 82 | 83 | def __init__(self, **kwargs): 84 | super(_ChannelAttentionModule, self).__init__() 85 | self.beta = nn.Parameter(torch.zeros(1)) 86 | self.softmax = nn.Softmax(dim=-1) 87 | 88 | def forward(self, x): 89 | batch_size, _, height, width = x.size() 90 | feat_a = x.view(batch_size, -1, height * width) 91 | feat_a_transpose = x.view(batch_size, -1, height * width).permute(0, 2, 1) 92 | attention = torch.bmm(feat_a, feat_a_transpose) 93 | attention_new = torch.max(attention, dim=-1, keepdim=True)[0].expand_as(attention) - attention 94 | attention = self.softmax(attention_new) 95 | 96 | feat_e = torch.bmm(attention, feat_a).view(batch_size, -1, height, width) 97 | out = self.beta * feat_e + x 98 | 99 | return out 100 | 101 | 102 | class _DAHead(nn.Module): 103 | def __init__(self, in_channels, nclass, aux=True, norm_layer=nn.BatchNorm2d, norm_kwargs=None, **kwargs): 104 | super(_DAHead, self).__init__() 105 | self.aux = aux 106 | inter_channels = in_channels // 4 107 | self.conv_p1 = nn.Sequential( 108 | nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), 109 | norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)), 110 | nn.ReLU(True) 111 | ) 112 | self.conv_c1 = nn.Sequential( 113 | nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), 114 | norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)), 115 | nn.ReLU(True) 116 | ) 117 | self.pam = _PositionAttentionModule(inter_channels, **kwargs) 118 | self.cam = _ChannelAttentionModule(**kwargs) 119 | self.conv_p2 = nn.Sequential( 120 | nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), 121 | norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)), 122 | nn.ReLU(True) 123 | ) 124 | self.conv_c2 = nn.Sequential( 125 | nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), 126 | norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)), 127 | nn.ReLU(True) 128 | ) 129 | self.out = nn.Sequential( 130 | nn.Dropout(0.1), 131 | nn.Conv2d(inter_channels, nclass, 1) 132 | ) 133 | if aux: 134 | self.conv_p3 = nn.Sequential( 135 | nn.Dropout(0.1), 136 | nn.Conv2d(inter_channels, nclass, 1) 137 | ) 138 | self.conv_c3 = nn.Sequential( 139 | nn.Dropout(0.1), 140 | nn.Conv2d(inter_channels, nclass, 1) 141 | ) 142 | 143 | def forward(self, x): 144 | feat_p = self.conv_p1(x) 145 | feat_p = self.pam(feat_p) 146 | feat_p = self.conv_p2(feat_p) 147 | 148 | feat_c = self.conv_c1(x) 149 | feat_c = self.cam(feat_c) 150 | feat_c = self.conv_c2(feat_c) 151 | 152 | feat_fusion = feat_p + feat_c 153 | 154 | outputs = [] 155 | fusion_out = self.out(feat_fusion) 156 | outputs.append(fusion_out) 157 | if self.aux: 158 | p_out = self.conv_p3(feat_p) 159 | c_out = self.conv_c3(feat_c) 160 | outputs.append(p_out) 161 | outputs.append(c_out) 162 | 163 | return tuple(outputs) 164 | 165 | 166 | class DANet(SegBaseModel): 167 | r"""Pyramid Scene Parsing Network 168 | 169 | Parameters 170 | ---------- 171 | nclass : int 172 | Number of categories for the training dataset. 173 | backbone : string 174 | Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50', 175 | 'resnet101' or 'resnet152'). 176 | norm_layer : object 177 | Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`; 178 | for Synchronized Cross-GPU BachNormalization). 179 | aux : bool 180 | Auxiliary loss. 181 | Reference: 182 | Jun Fu, Jing Liu, Haijie Tian, Yong Li, Yongjun Bao, Zhiwei Fang,and Hanqing Lu. 183 | "Dual Attention Network for Scene Segmentation." *CVPR*, 2019 184 | """ 185 | 186 | def __init__(self, nclass, backbone='resnet50', aux=False, pretrained_base=False, **kwargs): 187 | super(DANet, self).__init__(nclass, aux, backbone, pretrained_base=pretrained_base, **kwargs) 188 | self.head = _DAHead(2048, nclass, aux, **kwargs) 189 | 190 | def forward(self, x): 191 | size = x.size()[2:] 192 | _, _, c3, c4 = self.base_forward(x) 193 | outputs = [] 194 | x = self.head(c4) 195 | x0 = F.interpolate(x[0], size, mode='bilinear', align_corners=True) 196 | if self.aux: 197 | x1 = F.interpolate(x[1], size, mode='bilinear', align_corners=True) 198 | x2 = F.interpolate(x[2], size, mode='bilinear', align_corners=True) 199 | outputs.append(x0) 200 | outputs.append(x1) 201 | outputs.append(x2) 202 | return outputs 203 | return x0 204 | 205 | 206 | if __name__ == '__main__': 207 | from tools.flops_params_fps_count import flops_params_fps 208 | model = DANet(nclass=6) 209 | flops_params_fps(model) 210 | -------------------------------------------------------------------------------- /models/deeplabv3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | try: 4 | from .resnet import resnet50_v1b 5 | except: 6 | from resnet import resnet50_v1b 7 | import torch.nn.functional as F 8 | 9 | 10 | class SegBaseModel(nn.Module): 11 | r"""Base Model for Semantic Segmentation 12 | 13 | Parameters 14 | ---------- 15 | backbone : string 16 | Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50', 17 | 'resnet101' or 'resnet152'). 18 | """ 19 | 20 | def __init__(self, nclass, aux, backbone='resnet50', jpu=False, pretrained_base=True, **kwargs): 21 | super(SegBaseModel, self).__init__() 22 | dilated = False if jpu else True 23 | self.aux = aux 24 | self.nclass = nclass 25 | if backbone == 'resnet50': 26 | self.pretrained = resnet50_v1b(pretrained=pretrained_base, dilated=dilated, **kwargs) 27 | 28 | # self.jpu = JPU([512, 1024, 2048], width=512, **kwargs) if jpu else None 29 | 30 | def base_forward(self, x): 31 | """forwarding pre-trained network""" 32 | x = self.pretrained.conv1(x) 33 | x = self.pretrained.bn1(x) 34 | x = self.pretrained.relu(x) 35 | x = self.pretrained.maxpool(x) 36 | c1 = self.pretrained.layer1(x) 37 | c2 = self.pretrained.layer2(c1) 38 | c3 = self.pretrained.layer3(c2) 39 | c4 = self.pretrained.layer4(c3) 40 | 41 | return c1, c2, c3, c4 42 | 43 | def evaluate(self, x): 44 | """evaluating network with inputs and targets""" 45 | return self.forward(x)[0] 46 | 47 | def demo(self, x): 48 | pred = self.forward(x) 49 | if self.aux: 50 | pred = pred[0] 51 | return pred 52 | 53 | 54 | class _FCNHead(nn.Module): 55 | def __init__(self, in_channels, channels, norm_layer=nn.BatchNorm2d, **kwargs): 56 | super(_FCNHead, self).__init__() 57 | inter_channels = in_channels // 4 58 | self.block = nn.Sequential( 59 | nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), 60 | norm_layer(inter_channels), 61 | nn.ReLU(inplace=True), 62 | nn.Dropout(0.1), 63 | nn.Conv2d(inter_channels, channels, 1) 64 | ) 65 | 66 | def forward(self, x): 67 | return self.block(x) 68 | 69 | 70 | class DeepLabV3(SegBaseModel): 71 | r"""DeepLabV3 72 | 73 | Parameters 74 | ---------- 75 | nclass : int 76 | Number of categories for the training dataset. 77 | backbone : string 78 | Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50', 79 | 'resnet101' or 'resnet152'). 80 | norm_layer : object 81 | Normalization layer used in backbone network (default: :class:`nn.BatchNorm`; 82 | for Synchronized Cross-GPU BachNormalization). 83 | aux : bool 84 | Auxiliary loss. 85 | 86 | Reference: 87 | Chen, Liang-Chieh, et al. "Rethinking atrous convolution for semantic image segmentation." 88 | arXiv preprint arXiv:1706.05587 (2017). 89 | """ 90 | 91 | def __init__(self, nclass, backbone='resnet50', aux=False, pretrained_base=False, **kwargs): 92 | super(DeepLabV3, self).__init__(nclass, aux, backbone, pretrained_base=pretrained_base, **kwargs) 93 | self.head = _DeepLabHead(nclass, **kwargs) 94 | if self.aux: 95 | self.auxlayer = _FCNHead(1024, nclass, **kwargs) 96 | 97 | self.__setattr__('exclusive', ['head', 'auxlayer'] if aux else ['head']) 98 | 99 | def forward(self, x): 100 | size = x.size()[2:] 101 | _, _, c3, c4 = self.base_forward(x) 102 | outputs = [] 103 | x = self.head(c4) 104 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 105 | 106 | 107 | if self.aux: 108 | auxout = self.auxlayer(c3) 109 | auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True) 110 | outputs.append(auxout) 111 | return x 112 | 113 | 114 | class _DeepLabHead(nn.Module): 115 | def __init__(self, nclass, norm_layer=nn.BatchNorm2d, norm_kwargs=None, **kwargs): 116 | super(_DeepLabHead, self).__init__() 117 | self.aspp = _ASPP(2048, [12, 24, 36], norm_layer=norm_layer, norm_kwargs=norm_kwargs, **kwargs) 118 | self.block = nn.Sequential( 119 | nn.Conv2d(256, 256, 3, padding=1, bias=False), 120 | norm_layer(256, **({} if norm_kwargs is None else norm_kwargs)), 121 | nn.ReLU(True), 122 | nn.Dropout(0.1), 123 | nn.Conv2d(256, nclass, 1) 124 | ) 125 | 126 | def forward(self, x): 127 | x = self.aspp(x) 128 | return self.block(x) 129 | 130 | 131 | class _ASPPConv(nn.Module): 132 | def __init__(self, in_channels, out_channels, atrous_rate, norm_layer, norm_kwargs): 133 | super(_ASPPConv, self).__init__() 134 | self.block = nn.Sequential( 135 | nn.Conv2d(in_channels, out_channels, 3, padding=atrous_rate, dilation=atrous_rate, bias=False), 136 | norm_layer(out_channels, **({} if norm_kwargs is None else norm_kwargs)), 137 | nn.ReLU(True) 138 | ) 139 | 140 | def forward(self, x): 141 | return self.block(x) 142 | 143 | 144 | class _AsppPooling(nn.Module): 145 | def __init__(self, in_channels, out_channels, norm_layer, norm_kwargs, **kwargs): 146 | super(_AsppPooling, self).__init__() 147 | self.gap = nn.Sequential( 148 | nn.AdaptiveAvgPool2d(1), 149 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 150 | norm_layer(out_channels, **({} if norm_kwargs is None else norm_kwargs)), 151 | nn.ReLU(True) 152 | ) 153 | 154 | def forward(self, x): 155 | size = x.size()[2:] 156 | pool = self.gap(x) 157 | out = F.interpolate(pool, size, mode='bilinear', align_corners=True) 158 | return out 159 | 160 | 161 | class _ASPP(nn.Module): 162 | def __init__(self, in_channels, atrous_rates, norm_layer, norm_kwargs, **kwargs): 163 | super(_ASPP, self).__init__() 164 | out_channels = 256 165 | self.b0 = nn.Sequential( 166 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 167 | norm_layer(out_channels, **({} if norm_kwargs is None else norm_kwargs)), 168 | nn.ReLU(True) 169 | ) 170 | 171 | rate1, rate2, rate3 = tuple(atrous_rates) 172 | self.b1 = _ASPPConv(in_channels, out_channels, rate1, norm_layer, norm_kwargs) 173 | self.b2 = _ASPPConv(in_channels, out_channels, rate2, norm_layer, norm_kwargs) 174 | self.b3 = _ASPPConv(in_channels, out_channels, rate3, norm_layer, norm_kwargs) 175 | self.b4 = _AsppPooling(in_channels, out_channels, norm_layer=norm_layer, norm_kwargs=norm_kwargs) 176 | 177 | self.project = nn.Sequential( 178 | nn.Conv2d(5 * out_channels, out_channels, 1, bias=False), 179 | norm_layer(out_channels, **({} if norm_kwargs is None else norm_kwargs)), 180 | nn.ReLU(True), 181 | nn.Dropout(0.5) 182 | ) 183 | 184 | def forward(self, x): 185 | feat1 = self.b0(x) 186 | feat2 = self.b1(x) 187 | feat3 = self.b2(x) 188 | feat4 = self.b3(x) 189 | feat5 = self.b4(x) 190 | x = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1) 191 | x = self.project(x) 192 | return x 193 | 194 | 195 | if __name__ == '__main__': 196 | from tools.flops_params_fps_count import flops_params_fps 197 | model = DeepLabV3(nclass=6) 198 | flops_params_fps(model) 199 | 200 | 201 | 202 | 203 | 204 | -------------------------------------------------------------------------------- /models/edgenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from models.resT import rest_tiny 4 | 5 | 6 | def EdgeNet(nclass=6): 7 | model = rest_tiny(nclass=nclass, pretrained=True, aux=True, edge_aux=False, head='mlphead') 8 | return model 9 | 10 | 11 | def edgenet_init(weight_dir): 12 | with torch.no_grad(): 13 | model = rest_tiny(nclass=6, pretrained=False, aux=True, edge_aux=False, head='mlphead').eval() 14 | if os.path.isfile(weight_dir): 15 | print('loaded edge model successfully') 16 | checkpoint = torch.load(weight_dir, map_location=lambda storage, loc: storage) 17 | checkpoint = {k: v for k, v in checkpoint.items() if not 'loss' in k} 18 | checkpoint = {k.replace('module.model.', ''): v for k, v in checkpoint.items()} 19 | model.load_state_dict(checkpoint) 20 | return model 21 | 22 | 23 | if __name__ == '__main__': 24 | from tools.flops_params_fps_count import flops_params_fps 25 | model = EdgeNet(nclass=6) 26 | flops_params_fps(model) 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /models/fcn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.model_zoo as model_zoo 6 | 7 | 8 | class VGG(nn.Module): 9 | def __init__(self, features, num_classes=1000, init_weights=True): 10 | super(VGG, self).__init__() 11 | self.features = features 12 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 13 | self.classifier = nn.Sequential( 14 | nn.Linear(512 * 7 * 7, 4096), 15 | nn.ReLU(True), 16 | nn.Dropout(), 17 | nn.Linear(4096, 4096), 18 | nn.ReLU(True), 19 | nn.Dropout(), 20 | nn.Linear(4096, num_classes) 21 | ) 22 | if init_weights: 23 | self._initialize_weights() 24 | 25 | def forward(self, x): 26 | x = self.features(x) 27 | x = self.avgpool(x) 28 | x = x.view(x.size(0), -1) 29 | x = self.classifier(x) 30 | return x 31 | 32 | def _initialize_weights(self): 33 | for m in self.modules(): 34 | if isinstance(m, nn.Conv2d): 35 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 36 | if m.bias is not None: 37 | nn.init.constant_(m.bias, 0) 38 | elif isinstance(m, nn.BatchNorm2d): 39 | nn.init.constant_(m.weight, 1) 40 | nn.init.constant_(m.bias, 0) 41 | elif isinstance(m, nn.Linear): 42 | nn.init.normal_(m.weight, 0, 0.01) 43 | nn.init.constant_(m.bias, 0) 44 | 45 | 46 | def make_layers(cfg, batch_norm=False): 47 | layers = [] 48 | in_channels = 3 49 | for v in cfg: 50 | if v == 'M': 51 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 52 | else: 53 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 54 | if batch_norm: 55 | layers += (conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)) 56 | else: 57 | layers += [conv2d, nn.ReLU(inplace=True)] 58 | in_channels = v 59 | return nn.Sequential(*layers) 60 | 61 | 62 | cfg = { 63 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 64 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 65 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 66 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 67 | } 68 | 69 | 70 | def vgg16(**kwargs): 71 | 72 | model = VGG(make_layers(cfg['D']), **kwargs) 73 | 74 | return model 75 | 76 | 77 | class FCN16s(nn.Module): 78 | def __init__(self, nclass, backbone='vgg16', aux=False, norm_layer=nn.BatchNorm2d, **kwargs): 79 | super(FCN16s, self).__init__() 80 | self.aux = aux 81 | if backbone == 'vgg16': 82 | self.pretrained = vgg16().features 83 | else: 84 | raise RuntimeError('unknown backbone: {}'.format(backbone)) 85 | self.pool4 = nn.Sequential(*self.pretrained[:24]) 86 | self.pool5 = nn.Sequential(*self.pretrained[24:]) 87 | self.head = _FCNHead(512, nclass, norm_layer) 88 | self.score_pool4 = nn.Conv2d(512, nclass, 1) 89 | if aux: 90 | self.auxlayer = _FCNHead(512, nclass, norm_layer) 91 | 92 | self.__setattr__('exclusive', ['head', 'score_pool4', 'auxlayer'] if aux else ['head', 'score_pool4']) 93 | 94 | def forward(self, x): 95 | pool4 = self.pool4(x) 96 | pool5 = self.pool5(pool4) 97 | 98 | outputs = [] 99 | score_fr = self.head(pool5) 100 | 101 | score_pool4 = self.score_pool4(pool4) 102 | 103 | upscore2 = F.interpolate(score_fr, score_pool4.size()[2:], mode='bilinear', align_corners=True) 104 | fuse_pool4 = upscore2 + score_pool4 105 | 106 | out = F.interpolate(fuse_pool4, x.size()[2:], mode='bilinear', align_corners=True) 107 | outputs = out 108 | 109 | if self.aux: 110 | auxout = self.auxlayer(pool5) 111 | auxout = F.interpolate(auxout, x.size()[2:], mode='bilinear', align_corners=True) 112 | outputs.append(auxout) 113 | 114 | return outputs 115 | 116 | 117 | class _FCNHead(nn.Module): 118 | def __init__(self, in_channels, channels, norm_layer=nn.BatchNorm2d, **kwargs): 119 | super(_FCNHead, self).__init__() 120 | inter_channels = in_channels // 4 121 | self.block = nn.Sequential( 122 | nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), 123 | norm_layer(inter_channels), 124 | nn.ReLU(inplace=True), 125 | nn.Dropout(0.1), 126 | nn.Conv2d(inter_channels, channels, 1) 127 | ) 128 | 129 | def forward(self, x): 130 | return self.block(x) 131 | 132 | 133 | if __name__ == '__main__': 134 | from tools.flops_params_fps_count import flops_params_fps 135 | model = FCN16s(nclass=6) 136 | flops_params_fps(model) 137 | -------------------------------------------------------------------------------- /models/fpn.py: -------------------------------------------------------------------------------- 1 | '''FPN in PyTorch. 2 | 3 | See the paper "Feature Pyramid Networks for Object Detection" for more details. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from torch.autograd import Variable 10 | 11 | 12 | def conv3x3(in_planes, out_planes, stride=1): 13 | """3x3 convolution with padding""" 14 | conv3x3 = nn.Sequential( 15 | nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 16 | padding=1, bias=False), 17 | nn.BatchNorm2d(out_planes), 18 | nn.ReLU(inplace=True), 19 | ) 20 | return conv3x3 21 | 22 | 23 | class Seg_head(nn.Module): 24 | def __init__(self, in_planes=256, out_planes=128, n_class=6): 25 | super(Seg_head, self).__init__() 26 | self.conv1 = conv3x3(in_planes, out_planes) 27 | self.conv2 = conv3x3(in_planes, out_planes) 28 | self.conv3 = conv3x3(in_planes, out_planes) 29 | self.conv4 = conv3x3(in_planes, out_planes) 30 | self.final_layer = nn.Sequential( 31 | nn.Conv2d( 32 | in_channels=out_planes * 4, 33 | out_channels=out_planes * 4, 34 | kernel_size=1, 35 | stride=1, 36 | padding=0), 37 | nn.BatchNorm2d(out_planes * 4), 38 | nn.ReLU(inplace=True), 39 | nn.Conv2d( 40 | in_channels=out_planes * 4, 41 | out_channels=n_class, 42 | kernel_size=1, 43 | stride=1, 44 | padding=0) 45 | ) 46 | 47 | def forward(self, p2, p3, p4, p5): 48 | x2 = self.conv1(p2) 49 | x3 = F.interpolate(self.conv2(p3), scale_factor=2, mode='bilinear') 50 | x4 = F.interpolate(self.conv2(p4), scale_factor=4, mode='bilinear') 51 | x5 = F.interpolate(self.conv2(p5), scale_factor=8, mode='bilinear') 52 | x = torch.cat((x2, x3, x4, x5), dim=1) 53 | x = self.final_layer(x) 54 | output = F.interpolate(x, scale_factor=4, mode='bilinear') 55 | 56 | return output 57 | 58 | 59 | class Bottleneck(nn.Module): 60 | expansion = 4 61 | 62 | def __init__(self, in_planes, planes, stride=1): 63 | super(Bottleneck, self).__init__() 64 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 65 | self.bn1 = nn.BatchNorm2d(planes) 66 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 67 | self.bn2 = nn.BatchNorm2d(planes) 68 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 69 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 70 | 71 | self.shortcut = nn.Sequential() 72 | if stride != 1 or in_planes != self.expansion*planes: 73 | self.shortcut = nn.Sequential( 74 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 75 | nn.BatchNorm2d(self.expansion*planes) 76 | ) 77 | 78 | def forward(self, x): 79 | out = F.relu(self.bn1(self.conv1(x))) 80 | out = F.relu(self.bn2(self.conv2(out))) 81 | out = self.bn3(self.conv3(out)) 82 | out += self.shortcut(x) 83 | out = F.relu(out) 84 | return out 85 | 86 | 87 | class FPN(nn.Module): 88 | def __init__(self, block=Bottleneck, num_blocks=[3, 4, 6, 3], nclass=6): 89 | super(FPN, self).__init__() 90 | self.in_planes = 64 91 | 92 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 93 | self.bn1 = nn.BatchNorm2d(64) 94 | 95 | # Bottom-up layers 96 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 97 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 98 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 99 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 100 | 101 | # Top layer 102 | self.toplayer = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=0) # Reduce channels 103 | 104 | # Smooth layers 105 | self.smooth1 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 106 | self.smooth2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 107 | self.smooth3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 108 | 109 | # Lateral layers 110 | self.latlayer1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0) 111 | self.latlayer2 = nn.Conv2d( 512, 256, kernel_size=1, stride=1, padding=0) 112 | self.latlayer3 = nn.Conv2d( 256, 256, kernel_size=1, stride=1, padding=0) 113 | 114 | self.seg_head = Seg_head(in_planes=256, out_planes=128, n_class=nclass) 115 | 116 | def _make_layer(self, block, planes, num_blocks, stride): 117 | strides = [stride] + [1]*(num_blocks-1) 118 | layers = [] 119 | for stride in strides: 120 | layers.append(block(self.in_planes, planes, stride)) 121 | self.in_planes = planes * block.expansion 122 | return nn.Sequential(*layers) 123 | 124 | def _upsample_add(self, x, y): 125 | '''Upsample and add two feature maps. 126 | 127 | Args: 128 | x: (Variable) top feature map to be upsampled. 129 | y: (Variable) lateral feature map. 130 | 131 | Returns: 132 | (Variable) added feature map. 133 | 134 | Note in PyTorch, when input size is odd, the upsampled feature map 135 | with `F.upsample(..., scale_factor=2, mode='nearest')` 136 | maybe not equal to the lateral feature map size. 137 | 138 | e.g. 139 | original input size: [N,_,15,15] -> 140 | conv2d feature map size: [N,_,8,8] -> 141 | upsampled feature map size: [N,_,16,16] 142 | 143 | So we choose bilinear upsample which supports arbitrary output sizes. 144 | ''' 145 | _,_,H,W = y.size() 146 | return F.upsample(x, size=(H,W), mode='bilinear') + y 147 | 148 | def forward(self, x): 149 | # Bottom-up 150 | c1 = F.relu(self.bn1(self.conv1(x))) 151 | c1 = F.max_pool2d(c1, kernel_size=3, stride=2, padding=1) 152 | c2 = self.layer1(c1) 153 | c3 = self.layer2(c2) 154 | c4 = self.layer3(c3) 155 | c5 = self.layer4(c4) 156 | # Top-down 157 | p5 = self.toplayer(c5) 158 | p4 = self._upsample_add(p5, self.latlayer1(c4)) 159 | p3 = self._upsample_add(p4, self.latlayer2(c3)) 160 | p2 = self._upsample_add(p3, self.latlayer3(c2)) 161 | # Smooth 162 | p4 = self.smooth1(p4) 163 | p3 = self.smooth2(p3) 164 | p2 = self.smooth3(p2) 165 | 166 | output = self.seg_head(p2, p3, p4, p5) 167 | return output 168 | 169 | 170 | if __name__ == '__main__': 171 | from tools.flops_params_fps_count import flops_params_fps 172 | model = FPN(nclass=6) 173 | flops_params_fps(model) 174 | 175 | 176 | -------------------------------------------------------------------------------- /models/head/__init__.py: -------------------------------------------------------------------------------- 1 | from .ann import ANNHead 2 | from .apc import APCHead 3 | from .aspp import ASPPHead 4 | from .aspp_plus import ASPPPlusHead 5 | from .da import DAHead 6 | from .dnl import DNLHead 7 | from .fcfpn import FCFPNHead 8 | from .fcn import FCNHead 9 | from .gc import GCHead 10 | from .psa import PSAHead 11 | from .psp import PSPHead 12 | from .unet import UNetHead 13 | from .uper import UPerHead 14 | from .seg import SegHead 15 | from .cefpn import CEFPNHead 16 | from .mlp import MLPHead 17 | from .edge import EdgeHead 18 | 19 | __all__ = [ 20 | 'ANNHead', 'APCHead', 'ASPPHead', 'ASPPPlusHead', 'DAHead', 'DNLHead', 'FCFPNHead', 'FCNHead', 21 | 'GCHead', 'PSAHead', 'PSPHead', 'UNetHead', 'UPerHead', 'SegHead', 'CEFPNHead', 'MLPHead', 'EdgeHead' 22 | ] -------------------------------------------------------------------------------- /models/head/apc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from mmcv.cnn import ConvModule 5 | from .base_decoder import BaseDecodeHead, resize 6 | 7 | norm_cfg = dict(type='BN', requires_grad=True) 8 | 9 | 10 | class ACM(nn.Module): 11 | """Adaptive Context Module used in APCNet. 12 | Args: 13 | pool_scale (int): Pooling scale used in Adaptive Context 14 | Module to extract region features. 15 | fusion (bool): Add one conv to fuse residual feature. 16 | in_channels (int): Input channels. 17 | channels (int): Channels after modules, before conv_seg. 18 | conv_cfg (dict | None): Config of conv layers. 19 | norm_cfg (dict | None): Config of norm layers. 20 | act_cfg (dict): Config of activation layers. 21 | """ 22 | 23 | def __init__(self, pool_scale, fusion, in_channels, channels, conv_cfg, 24 | norm_cfg, act_cfg): 25 | super(ACM, self).__init__() 26 | self.pool_scale = pool_scale 27 | self.fusion = fusion 28 | self.in_channels = in_channels 29 | self.channels = channels 30 | self.conv_cfg = conv_cfg 31 | self.norm_cfg = norm_cfg 32 | self.act_cfg = act_cfg 33 | self.pooled_redu_conv = ConvModule( 34 | self.in_channels, 35 | self.channels, 36 | 1, 37 | conv_cfg=self.conv_cfg, 38 | norm_cfg=self.norm_cfg, 39 | act_cfg=self.act_cfg) 40 | 41 | self.input_redu_conv = ConvModule( 42 | self.in_channels, 43 | self.channels, 44 | 1, 45 | conv_cfg=self.conv_cfg, 46 | norm_cfg=self.norm_cfg, 47 | act_cfg=self.act_cfg) 48 | 49 | self.global_info = ConvModule( 50 | self.channels, 51 | self.channels, 52 | 1, 53 | conv_cfg=self.conv_cfg, 54 | norm_cfg=self.norm_cfg, 55 | act_cfg=self.act_cfg) 56 | 57 | self.gla = nn.Conv2d(self.channels, self.pool_scale**2, 1, 1, 0) 58 | 59 | self.residual_conv = ConvModule( 60 | self.channels, 61 | self.channels, 62 | 1, 63 | conv_cfg=self.conv_cfg, 64 | norm_cfg=self.norm_cfg, 65 | act_cfg=self.act_cfg) 66 | 67 | if self.fusion: 68 | self.fusion_conv = ConvModule( 69 | self.channels, 70 | self.channels, 71 | 1, 72 | conv_cfg=self.conv_cfg, 73 | norm_cfg=self.norm_cfg, 74 | act_cfg=self.act_cfg) 75 | 76 | def forward(self, x): 77 | """Forward function.""" 78 | pooled_x = F.adaptive_avg_pool2d(x, self.pool_scale) 79 | # [batch_size, channels, h, w] 80 | x = self.input_redu_conv(x) 81 | # [batch_size, channels, pool_scale, pool_scale] 82 | pooled_x = self.pooled_redu_conv(pooled_x) 83 | batch_size = x.size(0) 84 | # [batch_size, pool_scale * pool_scale, channels] 85 | pooled_x = pooled_x.view(batch_size, self.channels, 86 | -1).permute(0, 2, 1).contiguous() 87 | # [batch_size, h * w, pool_scale * pool_scale] 88 | affinity_matrix = self.gla(x + resize( 89 | self.global_info(F.adaptive_avg_pool2d(x, 1)), size=x.shape[2:]) 90 | ).permute(0, 2, 3, 1).reshape( 91 | batch_size, -1, self.pool_scale**2) 92 | affinity_matrix = F.sigmoid(affinity_matrix) 93 | # [batch_size, h * w, channels] 94 | z_out = torch.matmul(affinity_matrix, pooled_x) 95 | # [batch_size, channels, h * w] 96 | z_out = z_out.permute(0, 2, 1).contiguous() 97 | # [batch_size, channels, h, w] 98 | z_out = z_out.view(batch_size, self.channels, x.size(2), x.size(3)) 99 | z_out = self.residual_conv(z_out) 100 | z_out = F.relu(z_out + x) 101 | if self.fusion: 102 | z_out = self.fusion_conv(z_out) 103 | 104 | return z_out 105 | 106 | 107 | class APCHead(BaseDecodeHead): 108 | """Adaptive Pyramid Context Network for Semantic Segmentation. 109 | This head is the implementation of 110 | `APCNet `_. 113 | Args: 114 | pool_scales (tuple[int]): Pooling scales used in Adaptive Context 115 | Module. Default: (1, 2, 3, 6). 116 | fusion (bool): Add one conv to fuse residual feature. 117 | """ 118 | 119 | def __init__(self, pool_scales=(1, 2, 3, 6), fusion=True, in_channels=768, num_classes=6, channels=512, in_index=3): 120 | super(APCHead, self).__init__(in_index=in_index, in_channels=in_channels, 121 | num_classes=num_classes, channels=channels, dropout_ratio=0.1, norm_cfg=norm_cfg, align_corners=False) 122 | assert isinstance(pool_scales, (list, tuple)) 123 | self.pool_scales = pool_scales 124 | self.fusion = fusion 125 | acm_modules = [] 126 | for pool_scale in self.pool_scales: 127 | acm_modules.append( 128 | ACM(pool_scale, 129 | self.fusion, 130 | self.in_channels, 131 | self.channels, 132 | conv_cfg=self.conv_cfg, 133 | norm_cfg=self.norm_cfg, 134 | act_cfg=self.act_cfg)) 135 | self.acm_modules = nn.ModuleList(acm_modules) 136 | self.bottleneck = ConvModule( 137 | self.in_channels + len(pool_scales) * self.channels, 138 | self.channels, 139 | 3, 140 | padding=1, 141 | conv_cfg=self.conv_cfg, 142 | norm_cfg=self.norm_cfg, 143 | act_cfg=self.act_cfg) 144 | 145 | def forward(self, inputs): 146 | """Forward function.""" 147 | x = self._transform_inputs(inputs) 148 | acm_outs = [x] 149 | for acm_module in self.acm_modules: 150 | acm_outs.append(acm_module(x)) 151 | acm_outs = torch.cat(acm_outs, dim=1) 152 | output = self.bottleneck(acm_outs) 153 | output = self.cls_seg(output) 154 | return output -------------------------------------------------------------------------------- /models/head/aspp.py: -------------------------------------------------------------------------------- 1 | ########################################################################### 2 | # Created by: Hang Zhang 3 | # Email: zhang.hang@rutgers.edu 4 | # Copyright (c) 2017 5 | ########################################################################### 6 | from __future__ import division 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn.functional import interpolate 10 | 11 | 12 | up_kwargs = {'mode': 'bilinear', 'align_corners': False} 13 | norm_layer = nn.BatchNorm2d 14 | 15 | 16 | def ASPPConv(in_channels, out_channels, atrous_rate, norm_layer): 17 | block = nn.Sequential( 18 | nn.Conv2d(in_channels, out_channels, 3, padding=atrous_rate, 19 | dilation=atrous_rate, bias=False), 20 | norm_layer(out_channels), 21 | nn.ReLU(True)) 22 | return block 23 | 24 | 25 | class AsppPooling(nn.Module): 26 | def __init__(self, in_channels, out_channels, norm_layer, up_kwargs): 27 | super(AsppPooling, self).__init__() 28 | self._up_kwargs = up_kwargs 29 | self.gap = nn.Sequential(nn.AdaptiveAvgPool2d(1), 30 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 31 | norm_layer(out_channels), 32 | nn.ReLU(True)) 33 | 34 | def forward(self, x): 35 | _, _, h, w = x.size() 36 | pool = self.gap(x) 37 | return interpolate(pool, (h,w), **self._up_kwargs) 38 | 39 | 40 | class ASPP_Module(nn.Module): 41 | def __init__(self, in_channels, atrous_rates, norm_layer, up_kwargs): 42 | super(ASPP_Module, self).__init__() 43 | out_channels = in_channels // 8 44 | rate1, rate2, rate3 = tuple(atrous_rates) 45 | self.b0 = nn.Sequential( 46 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 47 | norm_layer(out_channels), 48 | nn.ReLU(True)) 49 | self.b1 = ASPPConv(in_channels, out_channels, rate1, norm_layer) 50 | self.b2 = ASPPConv(in_channels, out_channels, rate2, norm_layer) 51 | self.b3 = ASPPConv(in_channels, out_channels, rate3, norm_layer) 52 | self.b4 = AsppPooling(in_channels, out_channels, norm_layer, up_kwargs) 53 | 54 | self.project = nn.Sequential( 55 | nn.Conv2d(5*out_channels, out_channels, 1, bias=False), 56 | norm_layer(out_channels), 57 | nn.ReLU(True), 58 | nn.Dropout2d(0.5, False)) 59 | 60 | def forward(self, x): 61 | feat0 = self.b0(x) 62 | feat1 = self.b1(x) 63 | feat2 = self.b2(x) 64 | feat3 = self.b3(x) 65 | feat4 = self.b4(x) 66 | y = torch.cat((feat0, feat1, feat2, feat3, feat4), 1) 67 | return self.project(y) 68 | 69 | 70 | class ASPPHead(nn.Module): 71 | def __init__(self, in_channels, num_classes, norm_layer=norm_layer, up_kwargs=up_kwargs, atrous_rates=[12, 24, 36], in_index=3): 72 | super(ASPPHead, self).__init__() 73 | inter_channels = in_channels // 8 74 | self.in_index = in_index 75 | self.aspp = ASPP_Module(in_channels, atrous_rates, norm_layer, up_kwargs) 76 | self.block = nn.Sequential( 77 | nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), 78 | norm_layer(inter_channels), 79 | nn.ReLU(True), 80 | nn.Dropout(0.1, False), 81 | nn.Conv2d(inter_channels, num_classes, 1)) 82 | 83 | def _transform_inputs(self, inputs): 84 | if isinstance(self.in_index, (list, tuple)): 85 | inputs = [inputs[i] for i in self.in_index] 86 | elif isinstance(self.in_index, int): 87 | inputs = inputs[self.in_index] 88 | return inputs 89 | 90 | def forward(self, inputs): 91 | x = self._transform_inputs(inputs) 92 | x = self.aspp(x) 93 | x = self.block(x) 94 | return x 95 | 96 | -------------------------------------------------------------------------------- /models/head/aspp_plus.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .aspp import ASPP_Module 6 | 7 | 8 | up_kwargs = {'mode': 'bilinear', 'align_corners': False} 9 | norm_layer = nn.BatchNorm2d 10 | 11 | 12 | class _ConvBNReLU(nn.Module): 13 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, 14 | dilation=1, groups=1, relu6=False, norm_layer=norm_layer): 15 | super(_ConvBNReLU, self).__init__() 16 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False) 17 | self.bn = norm_layer(out_channels) 18 | self.relu = nn.ReLU6(True) if relu6 else nn.ReLU(True) 19 | 20 | def forward(self, x): 21 | x = self.conv(x) 22 | x = self.bn(x) 23 | x = self.relu(x) 24 | return x 25 | 26 | 27 | class ASPPPlusHead(nn.Module): 28 | def __init__(self, num_classes, in_channels, norm_layer=norm_layer, up_kwargs=up_kwargs, in_index=[0, 3]): 29 | super(ASPPPlusHead, self).__init__() 30 | self._up_kwargs = up_kwargs 31 | self.in_index = in_index 32 | self.channels = in_channels // 2 ** in_index[1] 33 | self.aspp = ASPP_Module(in_channels, [12, 24, 36], norm_layer=norm_layer, up_kwargs=up_kwargs) 34 | self.c1_block = _ConvBNReLU(self.channels, self.channels, 3, padding=1, norm_layer=norm_layer) 35 | self.block = nn.Sequential( 36 | _ConvBNReLU(self.channels + in_channels // 8, self.channels + in_channels // 8, 3, padding=1, norm_layer=norm_layer), 37 | nn.Dropout(0.5), 38 | _ConvBNReLU(self.channels + in_channels // 8, self.channels + in_channels // 8, 3, padding=1, norm_layer=norm_layer), 39 | nn.Dropout(0.1), 40 | nn.Conv2d(self.channels + in_channels // 8, num_classes, 1)) 41 | 42 | def _transform_inputs(self, inputs): 43 | if isinstance(self.in_index, (list, tuple)): 44 | inputs = [inputs[i] for i in self.in_index] 45 | elif isinstance(self.in_index, int): 46 | inputs = inputs[self.in_index] 47 | return inputs 48 | 49 | def forward(self, inputs): 50 | inputs = self._transform_inputs(inputs) 51 | c1, x = inputs 52 | size = c1.size()[2:] 53 | c1 = self.c1_block(c1) 54 | x = self.aspp(x) 55 | x = F.interpolate(x, size, **self._up_kwargs) 56 | return self.block(torch.cat([x, c1], dim=1)) 57 | -------------------------------------------------------------------------------- /models/head/base_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from mmcv.cnn import normal_init 4 | import warnings 5 | import torch.nn.functional as F 6 | 7 | norm_cfg = dict(type='BN', requires_grad=True) 8 | 9 | 10 | def resize(input, 11 | size=None, 12 | scale_factor=None, 13 | mode='nearest', 14 | align_corners=None, 15 | warning=True): 16 | if warning: 17 | if size is not None and align_corners: 18 | input_h, input_w = tuple(int(x) for x in input.shape[2:]) 19 | output_h, output_w = tuple(int(x) for x in size) 20 | if output_h > input_h or output_w > output_h: 21 | if ((output_h > 1 and output_w > 1 and input_h > 1 22 | and input_w > 1) and (output_h - 1) % (input_h - 1) 23 | and (output_w - 1) % (input_w - 1)): 24 | warnings.warn( 25 | f'When align_corners={align_corners}, ' 26 | 'the output would more aligned if ' 27 | f'input size {(input_h, input_w)} is `x+1` and ' 28 | f'out size {(output_h, output_w)} is `nx+1`') 29 | if isinstance(size, torch.Size): 30 | size = tuple(int(x) for x in size) 31 | return F.interpolate(input, size, scale_factor, mode, align_corners) 32 | 33 | 34 | class BaseDecodeHead(nn.Module): 35 | """Base class for BaseDecodeHead. 36 | 37 | Args: 38 | in_channels (int|Sequence[int]): Input channels. 39 | channels (int): Channels after modules, before conv_seg. 40 | num_classes (int): Number of classes. 41 | dropout_ratio (float): Ratio of dropout layer. Default: 0.1. 42 | conv_cfg (dict|None): Config of conv layers. Default: None. 43 | norm_cfg (dict|None): Config of norm layers. Default: None. 44 | act_cfg (dict): Config of activation layers. 45 | Default: dict(type='ReLU') 46 | in_index (int|Sequence[int]): Input feature index. Default: -1 47 | input_transform (str|None): Transformation type of input features. 48 | Options: 'resize_concat', 'multiple_select', None. 49 | 'resize_concat': Multiple feature maps will be resize to the 50 | same size as first one and than concat together. 51 | Usually used in FCN head of HRNet. 52 | 'multiple_select': Multiple feature maps will be bundle into 53 | a list and passed into decode head. 54 | None: Only one select feature map is allowed. 55 | Default: None. 56 | loss_decode (dict): Config of decode loss. 57 | Default: dict(type='CrossEntropyLoss'). 58 | ignore_index (int | None): The label index to be ignored. When using 59 | masked BCE loss, ignore_index should be set to None. Default: 255 60 | sampler (dict|None): The config of segmentation map sampler. 61 | Default: None. 62 | align_corners (bool): align_corners argument of F.interpolate. 63 | Default: False. 64 | """ 65 | 66 | def __init__(self, 67 | in_channels, 68 | channels, 69 | *, 70 | num_classes, 71 | dropout_ratio=0.1, 72 | conv_cfg=None, 73 | norm_cfg=None, 74 | act_cfg=dict(type='ReLU'), 75 | in_index=-1, 76 | input_transform=None, 77 | ignore_index=255, 78 | sampler=None, 79 | align_corners=False): 80 | super(BaseDecodeHead, self).__init__() 81 | self._init_inputs(in_channels, in_index, input_transform) 82 | self.channels = channels 83 | self.num_classes = num_classes 84 | self.dropout_ratio = dropout_ratio 85 | self.conv_cfg = conv_cfg 86 | self.norm_cfg = norm_cfg 87 | self.act_cfg = act_cfg 88 | self.in_index = in_index 89 | self.ignore_index = ignore_index 90 | self.align_corners = align_corners 91 | self.sampler = None 92 | 93 | self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1) 94 | if dropout_ratio > 0: 95 | self.dropout = nn.Dropout2d(dropout_ratio) 96 | else: 97 | self.dropout = None 98 | self.fp16_enabled = False 99 | 100 | def extra_repr(self): 101 | """Extra repr.""" 102 | s = f'input_transform={self.input_transform}, ' \ 103 | f'ignore_index={self.ignore_index}, ' \ 104 | f'align_corners={self.align_corners}' 105 | return s 106 | 107 | def _init_inputs(self, in_channels, in_index, input_transform): 108 | """Check and initialize input transforms. 109 | 110 | The in_channels, in_index and input_transform must match. 111 | Specifically, when input_transform is None, only single feature map 112 | will be selected. So in_channels and in_index must be of type int. 113 | When input_transform 114 | 115 | Args: 116 | in_channels (int|Sequence[int]): Input channels. 117 | in_index (int|Sequence[int]): Input feature index. 118 | input_transform (str|None): Transformation type of input features. 119 | Options: 'resize_concat', 'multiple_select', None. 120 | 'resize_concat': Multiple feature maps will be resize to the 121 | same size as first one and than concat together. 122 | Usually used in FCN head of HRNet. 123 | 'multiple_select': Multiple feature maps will be bundle into 124 | a list and passed into decode head. 125 | None: Only one select feature map is allowed. 126 | """ 127 | 128 | if input_transform is not None: 129 | assert input_transform in ['resize_concat', 'multiple_select'] 130 | self.input_transform = input_transform 131 | self.in_index = in_index 132 | if input_transform is not None: 133 | assert isinstance(in_channels, (list, tuple)) 134 | assert isinstance(in_index, (list, tuple)) 135 | assert len(in_channels) == len(in_index) 136 | if input_transform == 'resize_concat': 137 | self.in_channels = sum(in_channels) 138 | else: 139 | self.in_channels = in_channels 140 | else: 141 | assert isinstance(in_channels, int) 142 | assert isinstance(in_index, int) 143 | self.in_channels = in_channels 144 | 145 | def init_weights(self): 146 | """Initialize weights of classification layer.""" 147 | normal_init(self.conv_seg, mean=0, std=0.01) 148 | 149 | def _transform_inputs(self, inputs): 150 | """Transform inputs for decoder. 151 | 152 | Args: 153 | inputs (list[Tensor]): List of multi-level img features. 154 | 155 | Returns: 156 | Tensor: The transformed inputs 157 | """ 158 | 159 | if self.input_transform == 'resize_concat': 160 | inputs = [inputs[i] for i in self.in_index] 161 | upsampled_inputs = [ 162 | resize( 163 | input=x, 164 | size=inputs[0].shape[2:], 165 | mode='bilinear', 166 | align_corners=self.align_corners) for x in inputs 167 | ] 168 | inputs = torch.cat(upsampled_inputs, dim=1) 169 | elif self.input_transform == 'multiple_select': 170 | inputs = [inputs[i] for i in self.in_index] 171 | else: 172 | inputs = inputs[self.in_index] 173 | 174 | return inputs 175 | 176 | def forward(self, inputs): 177 | """Placeholder of forward function.""" 178 | pass 179 | 180 | def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg): 181 | """Forward function for training. 182 | Args: 183 | inputs (list[Tensor]): List of multi-level img features. 184 | img_metas (list[dict]): List of image info dict where each dict 185 | has: 'img_shape', 'scale_factor', 'flip', and may also contain 186 | 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. 187 | For details on the values of these keys see 188 | `mmseg/datasets/pipelines/formatting.py:Collect`. 189 | gt_semantic_seg (Tensor): Semantic segmentation masks 190 | used if the architecture supports semantic segmentation task. 191 | train_cfg (dict): The training config. 192 | 193 | Returns: 194 | dict[str, Tensor]: a dictionary of loss components 195 | """ 196 | seg_logits = self.forward(inputs) 197 | losses = self.losses(seg_logits, gt_semantic_seg) 198 | return losses 199 | 200 | def forward_test(self, inputs, img_metas, test_cfg): 201 | """Forward function for testing. 202 | 203 | Args: 204 | inputs (list[Tensor]): List of multi-level img features. 205 | img_metas (list[dict]): List of image info dict where each dict 206 | has: 'img_shape', 'scale_factor', 'flip', and may also contain 207 | 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. 208 | For details on the values of these keys see 209 | `mmseg/datasets/pipelines/formatting.py:Collect`. 210 | test_cfg (dict): The testing config. 211 | 212 | Returns: 213 | Tensor: Output segmentation map. 214 | """ 215 | return self.forward(inputs) 216 | 217 | def cls_seg(self, feat): 218 | """Classify each pixel.""" 219 | if self.dropout is not None: 220 | feat = self.dropout(feat) 221 | output = self.conv_seg(feat) 222 | return output 223 | -------------------------------------------------------------------------------- /models/head/cefpn.py: -------------------------------------------------------------------------------- 1 | ########################################################################### 2 | # Created by: Hang Zhang 3 | # Email: zhang.hang@rutgers.edu 4 | # Copyright (c) 2017 5 | ########################################################################### 6 | from __future__ import division 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn.functional import upsample 10 | 11 | up_kwargs = {'mode': 'bilinear', 'align_corners': True} 12 | norm_layer = nn.BatchNorm2d 13 | 14 | 15 | class CEFPNHead(nn.Module): 16 | def __init__(self, in_channels=[256, 512, 1024, 2048], num_classes=6, channels=256, 17 | norm_layer=norm_layer, up_kwargs=up_kwargs, in_index=[0, 1, 2, 3]): 18 | super(CEFPNHead, self).__init__() 19 | assert up_kwargs is not None 20 | self._up_kwargs = up_kwargs 21 | self.in_index = in_index 22 | self.C5_2_F4 = nn.Sequential( 23 | nn.Conv2d(in_channels[3], in_channels[2], kernel_size=1, bias=False), 24 | norm_layer(in_channels[2]), 25 | nn.ReLU(inplace=True)) 26 | self.C4_2_F4 = nn.Sequential( 27 | nn.Conv2d(in_channels[2], channels, kernel_size=1, bias=False), 28 | norm_layer(channels), 29 | nn.ReLU(inplace=True)) 30 | self.C3_2_F3 = nn.Sequential( 31 | nn.Conv2d(in_channels[1], channels, kernel_size=1, bias=False), 32 | norm_layer(channels), 33 | nn.ReLU(inplace=True)) 34 | self.C2_2_F2 = nn.Sequential( 35 | nn.Conv2d(in_channels[0], channels, kernel_size=1, bias=False), 36 | norm_layer(channels), 37 | nn.ReLU(inplace=True)) 38 | 39 | fpn_out = [] 40 | for _ in range(len(in_channels)): 41 | fpn_out.append(nn.Sequential( 42 | nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False), 43 | norm_layer(channels), 44 | nn.ReLU(inplace=True), 45 | )) 46 | self.fpn_out = nn.ModuleList(fpn_out) 47 | inter_channels = len(in_channels) * channels 48 | self.conv5 = nn.Sequential(nn.Conv2d(inter_channels, 512, 3, padding=1, bias=False), 49 | norm_layer(512), 50 | nn.ReLU(), 51 | nn.Dropout(0.1, False), 52 | nn.Conv2d(512, num_classes, 1)) 53 | # channel_attention_guide 54 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 55 | self.max_pool = nn.AdaptiveMaxPool2d(1) 56 | self.shared_MLP = nn.Sequential( 57 | nn.Linear(in_features=channels, out_features=channels // 16), 58 | nn.ReLU(inplace=True), 59 | nn.Linear(in_features=channels // 16, out_features=channels)) 60 | self.sigmoid = nn.Sigmoid() 61 | 62 | # sub_pixel_context_enhancement 63 | self.conv1 = nn.Sequential(nn.Conv2d(in_channels[-1], in_channels[-1] // 2, kernel_size=3, padding=1, bias=False), 64 | norm_layer(in_channels[-1] // 2), 65 | nn.ReLU()) 66 | self.max_pool2 = nn.MaxPool2d(3, stride=2, padding=1) 67 | self.conv2 = nn.Sequential(nn.Conv2d(in_channels[-1], in_channels[-1] * 2, 1, bias=False), 68 | norm_layer(in_channels[-1] * 2), 69 | nn.ReLU()) 70 | self.global_pool = nn.AdaptiveAvgPool2d(1) 71 | self.conv3 = nn.Sequential(nn.Conv2d(in_channels[-1], in_channels[-1] // 8, 1, bias=False), 72 | norm_layer(in_channels[-1] // 8), 73 | nn.ReLU()) 74 | # inchannels to channels 75 | self.smooth1 = nn.Sequential( 76 | nn.Conv2d(in_channels[0], channels, kernel_size=1, bias=False), 77 | norm_layer(channels), 78 | nn.ReLU(inplace=True)) 79 | self.smooth2 = nn.Sequential( 80 | nn.Conv2d(in_channels[0], channels, kernel_size=1, bias=False), 81 | norm_layer(channels), 82 | nn.ReLU(inplace=True)) 83 | self.smooth3 = nn.Sequential( 84 | nn.Conv2d(in_channels[0], channels, kernel_size=1, bias=False), 85 | norm_layer(channels), 86 | nn.ReLU(inplace=True)) 87 | 88 | def sub_pixel_conv(self, inputs, up_factor=2): 89 | b, c, h, w = inputs.shape 90 | assert c % (up_factor * up_factor) == 0 91 | inputs = inputs.permute(0, 2, 3, 1) # b h w c 92 | inputs = inputs.view(b, h, w, c // (up_factor * up_factor), up_factor, up_factor) 93 | inputs = inputs.permute(0, 1, 4, 2, 5, 3).contiguous() 94 | inputs = inputs.view(b, h * up_factor, w * up_factor, c // (up_factor * up_factor)).permute(0, 3, 1, 2) 95 | inputs = inputs.contiguous() 96 | return inputs 97 | 98 | def channel_attention_guide(self, inputs): 99 | avgout = self.shared_MLP(self.avg_pool(inputs).view(inputs.size(0), -1)).unsqueeze(2).unsqueeze(3) 100 | maxout = self.shared_MLP(self.max_pool(inputs).view(inputs.size(0), -1)).unsqueeze(2).unsqueeze(3) 101 | weights = self.sigmoid(avgout + maxout) 102 | output = weights * inputs 103 | return output 104 | 105 | def sub_pixel_context_enhancement(self, inputs): 106 | h, w = inputs.size()[2:] 107 | input1 = self.sub_pixel_conv(self.conv1(inputs)) 108 | input2 = self.sub_pixel_conv(self.conv2(self.max_pool2(inputs)), up_factor=4) 109 | input3 = upsample(self.conv3(inputs), (h * 2, w * 2), **self._up_kwargs) 110 | output = input1 + input2 + input3 111 | output = self.smooth3(output) 112 | return output 113 | 114 | def _transform_inputs(self, inputs): 115 | if isinstance(self.in_index, (list, tuple)): 116 | inputs = [inputs[i] for i in self.in_index] 117 | elif isinstance(self.in_index, int): 118 | inputs = inputs[self.in_index] 119 | return inputs 120 | 121 | def forward(self, inputs): 122 | inputs = self._transform_inputs(inputs) 123 | c5 = inputs[-1] 124 | c1_size = inputs[0].size()[2:] 125 | if hasattr(self, 'extramodule'): 126 | c5 = self.extramodule(c5) 127 | 128 | feat = self.sub_pixel_context_enhancement(c5) 129 | feat_up = upsample(self.channel_attention_guide(self.fpn_out[3](feat)), c1_size, **self._up_kwargs) 130 | fpn_features = [feat_up] 131 | 132 | feat = self.smooth1(self.sub_pixel_conv(self.C5_2_F4(c5))) + self.C4_2_F4(inputs[2]) 133 | feat_up = upsample(self.channel_attention_guide(self.fpn_out[2](feat)), c1_size, **self._up_kwargs) 134 | fpn_features.append(feat_up) 135 | 136 | feats = [] 137 | feats.append(self.C2_2_F2(inputs[0])) 138 | feats.append(self.smooth2(self.sub_pixel_conv(inputs[2])) + self.C3_2_F3(inputs[1])) 139 | 140 | for i in reversed(range(len(inputs) - 2)): 141 | feat_i = feats[i] 142 | feat = upsample(feat, feat_i.size()[2:], **self._up_kwargs) 143 | feat = feat + feat_i 144 | feat_up = upsample(self.channel_attention_guide(self.fpn_out[i](feat)), c1_size, **self._up_kwargs) 145 | fpn_features.append(feat_up) 146 | fpn_features = torch.cat(fpn_features, 1) 147 | 148 | return self.conv5(fpn_features) 149 | -------------------------------------------------------------------------------- /models/head/da.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | up_kwargs = {'mode': 'bilinear', 'align_corners': False} 5 | norm_layer = nn.BatchNorm2d 6 | 7 | 8 | class _PositionAttentionModule(nn.Module): 9 | """ Position attention module""" 10 | 11 | def __init__(self, in_channels): 12 | super(_PositionAttentionModule, self).__init__() 13 | self.conv_b = nn.Conv2d(in_channels, in_channels // 8, 1) 14 | self.conv_c = nn.Conv2d(in_channels, in_channels // 8, 1) 15 | self.conv_d = nn.Conv2d(in_channels, in_channels, 1) 16 | self.alpha = nn.Parameter(torch.zeros(1)) 17 | self.softmax = nn.Softmax(dim=-1) 18 | 19 | def forward(self, x): 20 | batch_size, _, height, width = x.size() 21 | feat_b = self.conv_b(x).view(batch_size, -1, height * width).permute(0, 2, 1) 22 | feat_c = self.conv_c(x).view(batch_size, -1, height * width) 23 | attention_s = self.softmax(torch.bmm(feat_b, feat_c)) 24 | feat_d = self.conv_d(x).view(batch_size, -1, height * width) 25 | feat_e = torch.bmm(feat_d, attention_s.permute(0, 2, 1)).view(batch_size, -1, height, width) 26 | out = self.alpha * feat_e + x 27 | 28 | return out 29 | 30 | 31 | class _ChannelAttentionModule(nn.Module): 32 | """Channel attention module""" 33 | 34 | def __init__(self): 35 | super(_ChannelAttentionModule, self).__init__() 36 | self.beta = nn.Parameter(torch.zeros(1)) 37 | self.softmax = nn.Softmax(dim=-1) 38 | 39 | def forward(self, x): 40 | batch_size, _, height, width = x.size() 41 | feat_a = x.view(batch_size, -1, height * width) 42 | feat_a_transpose = x.view(batch_size, -1, height * width).permute(0, 2, 1) 43 | attention = torch.bmm(feat_a, feat_a_transpose) 44 | attention_new = torch.max(attention, dim=-1, keepdim=True)[0].expand_as(attention) - attention 45 | attention = self.softmax(attention_new) 46 | 47 | feat_e = torch.bmm(attention, feat_a).view(batch_size, -1, height, width) 48 | out = self.beta * feat_e + x 49 | 50 | return out 51 | 52 | 53 | class DAHead(nn.Module): 54 | def __init__(self, in_channels, num_classes, aux=False, norm_layer=norm_layer, norm_kwargs=None, in_index=3): 55 | super(DAHead, self).__init__() 56 | self.aux = aux 57 | self.in_index = in_index 58 | inter_channels = in_channels // 4 59 | self.conv_p1 = nn.Sequential( 60 | nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), 61 | norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)), 62 | nn.ReLU(True) 63 | ) 64 | self.conv_c1 = nn.Sequential( 65 | nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), 66 | norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)), 67 | nn.ReLU(True) 68 | ) 69 | self.pam = _PositionAttentionModule(inter_channels) 70 | self.cam = _ChannelAttentionModule() 71 | self.conv_p2 = nn.Sequential( 72 | nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), 73 | norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)), 74 | nn.ReLU(True) 75 | ) 76 | self.conv_c2 = nn.Sequential( 77 | nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), 78 | norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)), 79 | nn.ReLU(True) 80 | ) 81 | self.out = nn.Sequential( 82 | nn.Dropout(0.1), 83 | nn.Conv2d(inter_channels, num_classes, 1) 84 | ) 85 | if aux: 86 | self.conv_p3 = nn.Sequential( 87 | nn.Dropout(0.1), 88 | nn.Conv2d(inter_channels, num_classes, 1) 89 | ) 90 | self.conv_c3 = nn.Sequential( 91 | nn.Dropout(0.1), 92 | nn.Conv2d(inter_channels, num_classes, 1) 93 | ) 94 | 95 | def _transform_inputs(self, inputs): 96 | if isinstance(self.in_index, (list, tuple)): 97 | inputs = [inputs[i] for i in self.in_index] 98 | elif isinstance(self.in_index, int): 99 | inputs = inputs[self.in_index] 100 | return inputs 101 | 102 | def forward(self, inputs): 103 | x = self._transform_inputs(inputs) 104 | feat_p = self.conv_p1(x) 105 | feat_p = self.pam(feat_p) 106 | feat_p = self.conv_p2(feat_p) 107 | 108 | feat_c = self.conv_c1(x) 109 | feat_c = self.cam(feat_c) 110 | feat_c = self.conv_c2(feat_c) 111 | 112 | feat_fusion = feat_p + feat_c 113 | 114 | outputs = [] 115 | fusion_out = self.out(feat_fusion) 116 | outputs.append(fusion_out) 117 | if self.aux: 118 | p_out = self.conv_p3(feat_p) 119 | c_out = self.conv_c3(feat_c) 120 | outputs.append(p_out) 121 | outputs.append(c_out) 122 | 123 | return outputs 124 | -------------------------------------------------------------------------------- /models/head/dnl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from mmcv.cnn import NonLocal2d 3 | from torch import nn 4 | from .fcn import FCNHead 5 | 6 | norm_cfg = dict(type='BN', requires_grad=True) 7 | 8 | 9 | class DisentangledNonLocal2d(NonLocal2d): 10 | """Disentangled Non-Local Blocks. 11 | Args: 12 | temperature (float): Temperature to adjust attention. Default: 0.05 13 | """ 14 | 15 | def __init__(self, *arg, temperature, **kwargs): 16 | super().__init__(*arg, **kwargs) 17 | self.temperature = temperature 18 | self.conv_mask = nn.Conv2d(self.in_channels, 1, kernel_size=1) 19 | 20 | def embedded_gaussian(self, theta_x, phi_x): 21 | """Embedded gaussian with temperature.""" 22 | 23 | # NonLocal2d pairwise_weight: [N, HxW, HxW] 24 | pairwise_weight = torch.matmul(theta_x, phi_x) 25 | if self.use_scale: 26 | # theta_x.shape[-1] is `self.inter_channels` 27 | pairwise_weight /= theta_x.shape[-1]**0.5 28 | pairwise_weight /= self.temperature 29 | pairwise_weight = pairwise_weight.softmax(dim=-1) 30 | return pairwise_weight 31 | 32 | def forward(self, x): 33 | # x: [N, C, H, W] 34 | n = x.size(0) 35 | 36 | # g_x: [N, HxW, C] 37 | g_x = self.g(x).view(n, self.inter_channels, -1) 38 | g_x = g_x.permute(0, 2, 1) 39 | 40 | # theta_x: [N, HxW, C], phi_x: [N, C, HxW] 41 | if self.mode == 'gaussian': 42 | theta_x = x.view(n, self.in_channels, -1) 43 | theta_x = theta_x.permute(0, 2, 1) 44 | if self.sub_sample: 45 | phi_x = self.phi(x).view(n, self.in_channels, -1) 46 | else: 47 | phi_x = x.view(n, self.in_channels, -1) 48 | elif self.mode == 'concatenation': 49 | theta_x = self.theta(x).view(n, self.inter_channels, -1, 1) 50 | phi_x = self.phi(x).view(n, self.inter_channels, 1, -1) 51 | else: 52 | theta_x = self.theta(x).view(n, self.inter_channels, -1) 53 | theta_x = theta_x.permute(0, 2, 1) 54 | phi_x = self.phi(x).view(n, self.inter_channels, -1) 55 | 56 | # subtract mean 57 | theta_x -= theta_x.mean(dim=-2, keepdim=True) 58 | phi_x -= phi_x.mean(dim=-1, keepdim=True) 59 | 60 | pairwise_func = getattr(self, self.mode) 61 | # pairwise_weight: [N, HxW, HxW] 62 | pairwise_weight = pairwise_func(theta_x, phi_x) 63 | 64 | # y: [N, HxW, C] 65 | y = torch.matmul(pairwise_weight, g_x) 66 | # y: [N, C, H, W] 67 | y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels, 68 | *x.size()[2:]) 69 | 70 | # unary_mask: [N, 1, HxW] 71 | unary_mask = self.conv_mask(x) 72 | unary_mask = unary_mask.view(n, 1, -1) 73 | unary_mask = unary_mask.softmax(dim=-1) 74 | # unary_x: [N, 1, C] 75 | unary_x = torch.matmul(unary_mask, g_x) 76 | # unary_x: [N, C, 1, 1] 77 | unary_x = unary_x.permute(0, 2, 1).contiguous().reshape( 78 | n, self.inter_channels, 1, 1) 79 | 80 | output = x + self.conv_out(y + unary_x) 81 | 82 | return output 83 | 84 | 85 | class DNLHead(FCNHead): 86 | """Disentangled Non-Local Neural Networks. 87 | This head is the implementation of `DNLNet 88 | `_. 89 | Args: 90 | reduction (int): Reduction factor of projection transform. Default: 2. 91 | use_scale (bool): Whether to scale pairwise_weight by 92 | sqrt(1/inter_channels). Default: False. 93 | mode (str): The nonlocal mode. Options are 'embedded_gaussian', 94 | 'dot_product'. Default: 'embedded_gaussian.'. 95 | temperature (float): Temperature to adjust attention. Default: 0.05 96 | """ 97 | 98 | def __init__(self, 99 | reduction=2, 100 | use_scale=True, 101 | mode='embedded_gaussian', 102 | temperature=0.05, 103 | in_channels=768, 104 | num_classes=6, 105 | in_index=3, 106 | channels=512, 107 | ): 108 | super(DNLHead, self).__init__(num_convs=2, in_channels=in_channels, num_classes=num_classes, in_index=in_index, channels=channels) 109 | self.reduction = reduction 110 | self.use_scale = use_scale 111 | self.mode = mode 112 | self.temperature = temperature 113 | self.dnl_block = DisentangledNonLocal2d( 114 | in_channels=self.channels, 115 | reduction=self.reduction, 116 | use_scale=self.use_scale, 117 | conv_cfg=self.conv_cfg, 118 | norm_cfg=self.norm_cfg, 119 | mode=self.mode, 120 | temperature=self.temperature) 121 | 122 | def forward(self, inputs): 123 | """Forward function.""" 124 | x = self._transform_inputs(inputs) 125 | output = self.convs[0](x) 126 | output = self.dnl_block(output) 127 | output = self.convs[1](output) 128 | if self.concat_input: 129 | output = self.conv_cat(torch.cat([x, output], dim=1)) 130 | output = self.cls_seg(output) 131 | return output -------------------------------------------------------------------------------- /models/head/edge.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | up_kwargs = {'mode': 'bilinear', 'align_corners': False} 7 | 8 | 9 | class EdgeHead(nn.Module): 10 | """Edge awareness module""" 11 | 12 | def __init__(self, in_channels=[96, 192], channels=96, out_fea=2, in_index=[0, 1]): 13 | super(EdgeHead, self).__init__() 14 | self.in_index = in_index 15 | self.conv1 = nn.Sequential( 16 | nn.Conv2d(in_channels[0], in_channels[0], 1, 1, 0), 17 | nn.BatchNorm2d(in_channels[0]), 18 | nn.ReLU(True), 19 | nn.Conv2d(in_channels[0], channels, 1, 1, 0), 20 | nn.BatchNorm2d(channels), 21 | nn.ReLU(True), 22 | ) 23 | # self.conv2 = nn.Sequential( 24 | # nn.Conv2d(in_channels[1], in_channels[1], 1, 1, 0), 25 | # nn.BatchNorm2d(in_channels[1]), 26 | # nn.ReLU(True), 27 | # nn.Conv2d(in_channels[1], channels, 1, 1, 0), 28 | # nn.BatchNorm2d(channels), 29 | # nn.ReLU(True), 30 | # ) 31 | self.conv3 = nn.Conv2d(channels, out_fea, 1, 1, 0) 32 | 33 | def _transform_inputs(self, inputs): 34 | if isinstance(self.in_index, (list, tuple)): 35 | inputs = [inputs[i] for i in self.in_index] 36 | elif isinstance(self.in_index, int): 37 | inputs = inputs[self.in_index] 38 | return inputs 39 | 40 | def forward(self, inputs): 41 | inputs = self._transform_inputs(inputs) 42 | x1, x2 = inputs 43 | _, _, h, w = x1.size() 44 | 45 | edge1_fea = self.conv1(x1) 46 | # edge2_fea = self.conv2(x2) 47 | 48 | edge1_fea = F.interpolate(edge1_fea, size=(h, w), **up_kwargs) 49 | # edge2_fea = F.interpolate(edge2_fea, size=(h, w), **up_kwargs) 50 | 51 | # edge_fea = torch.cat([edge1_fea, edge2_fea], dim=1) 52 | 53 | edge = self.conv3(edge1_fea) 54 | 55 | return edge 56 | -------------------------------------------------------------------------------- /models/head/fcfpn.py: -------------------------------------------------------------------------------- 1 | ########################################################################### 2 | # Created by: Hang Zhang 3 | # Email: zhang.hang@rutgers.edu 4 | # Copyright (c) 2017 5 | ########################################################################### 6 | from __future__ import division 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn.functional import upsample 10 | 11 | up_kwargs = {'mode': 'bilinear', 'align_corners': True} 12 | norm_layer = nn.BatchNorm2d 13 | 14 | 15 | class FCFPNHead(nn.Module): 16 | def __init__(self, in_channels=[256, 512, 1024, 2048], num_classes=6, channels=256, 17 | norm_layer=norm_layer, up_kwargs=up_kwargs, in_index=[0, 1, 2, 3]): 18 | super(FCFPNHead, self).__init__() 19 | assert up_kwargs is not None 20 | self._up_kwargs = up_kwargs 21 | self.in_index = in_index 22 | fpn_lateral = [] 23 | for inchannel in in_channels[:-1]: 24 | fpn_lateral.append(nn.Sequential( 25 | nn.Conv2d(inchannel, channels, kernel_size=1, bias=False), 26 | norm_layer(channels), 27 | nn.ReLU(inplace=True), 28 | )) 29 | self.fpn_lateral = nn.ModuleList(fpn_lateral) 30 | fpn_out = [] 31 | for _ in range(len(in_channels) - 1): 32 | fpn_out.append(nn.Sequential( 33 | nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False), 34 | norm_layer(channels), 35 | nn.ReLU(inplace=True), 36 | )) 37 | self.fpn_out = nn.ModuleList(fpn_out) 38 | self.c4conv = nn.Sequential(nn.Conv2d(in_channels[-1], channels, 3, padding=1, bias=False), 39 | norm_layer(channels), 40 | nn.ReLU()) 41 | inter_channels = len(in_channels) * channels 42 | self.conv5 = nn.Sequential(nn.Conv2d(inter_channels, 512, 3, padding=1, bias=False), 43 | norm_layer(512), 44 | nn.ReLU(), 45 | nn.Dropout(0.1, False), 46 | nn.Conv2d(512, num_classes, 1)) 47 | 48 | def _transform_inputs(self, inputs): 49 | if isinstance(self.in_index, (list, tuple)): 50 | inputs = [inputs[i] for i in self.in_index] 51 | elif isinstance(self.in_index, int): 52 | inputs = inputs[self.in_index] 53 | return inputs 54 | 55 | def forward(self, inputs): 56 | inputs = self._transform_inputs(inputs) 57 | c4 = inputs[-1] 58 | if hasattr(self, 'extramodule'): 59 | c4 = self.extramodule(c4) 60 | feat = self.c4conv(c4) 61 | c1_size = inputs[0].size()[2:] 62 | feat_up = upsample(feat, c1_size, **self._up_kwargs) 63 | fpn_features = [feat_up] 64 | 65 | for i in reversed(range(len(inputs) - 1)): 66 | feat_i = self.fpn_lateral[i](inputs[i]) 67 | feat = upsample(feat, feat_i.size()[2:], **self._up_kwargs) 68 | feat = feat + feat_i 69 | feat_up = upsample(self.fpn_out[i](feat), c1_size, **self._up_kwargs) 70 | fpn_features.append(feat_up) 71 | fpn_features = torch.cat(fpn_features, 1) 72 | 73 | return self.conv5(fpn_features) 74 | -------------------------------------------------------------------------------- /models/head/fcn.py: -------------------------------------------------------------------------------- 1 | ########################################################################### 2 | # Created by: Hang Zhang 3 | # Email: zhang.hang@rutgers.edu 4 | # Copyright (c) 2017 5 | ########################################################################### 6 | from __future__ import division 7 | import torch.nn as nn 8 | import torch 9 | from .base_decoder import BaseDecodeHead 10 | from mmcv.cnn import ConvModule 11 | 12 | norm_cfg = dict(type='BN', requires_grad=True) 13 | 14 | 15 | class FCNHead(BaseDecodeHead): 16 | """Fully Convolution Networks for Semantic Segmentation. 17 | This head is implemented of `FCNNet `_. 18 | Args: 19 | num_convs (int): Number of convs in the head. Default: 2. 20 | kernel_size (int): The kernel size for convs in the head. Default: 3. 21 | concat_input (bool): Whether concat the input and output of convs 22 | before classification layer. 23 | """ 24 | 25 | def __init__(self, 26 | num_convs=2, 27 | kernel_size=3, 28 | concat_input=False, 29 | in_channels=768, 30 | num_classes=6, 31 | in_index=3, 32 | channels=512 33 | ): 34 | assert num_convs >= 0 35 | self.num_convs = num_convs 36 | self.concat_input = concat_input 37 | self.kernel_size = kernel_size 38 | super(FCNHead, self).__init__(in_channels=in_channels, in_index=in_index, channels=channels, dropout_ratio=0.1, 39 | num_classes=num_classes, norm_cfg=norm_cfg, align_corners=False) 40 | if num_convs == 0: 41 | assert self.in_channels == self.channels 42 | convs = [] 43 | convs.append( 44 | ConvModule( 45 | self.in_channels, 46 | self.channels, 47 | kernel_size=kernel_size, 48 | padding=kernel_size // 2, 49 | conv_cfg=self.conv_cfg, 50 | norm_cfg=self.norm_cfg, 51 | act_cfg=self.act_cfg)) 52 | for i in range(num_convs - 1): 53 | convs.append( 54 | ConvModule( 55 | self.channels, 56 | self.channels, 57 | kernel_size=kernel_size, 58 | padding=kernel_size // 2, 59 | conv_cfg=self.conv_cfg, 60 | norm_cfg=self.norm_cfg, 61 | act_cfg=self.act_cfg)) 62 | if num_convs == 0: 63 | self.convs = nn.Identity() 64 | else: 65 | self.convs = nn.Sequential(*convs) 66 | if self.concat_input: 67 | self.conv_cat = ConvModule( 68 | self.in_channels + self.channels, 69 | self.channels, 70 | kernel_size=kernel_size, 71 | padding=kernel_size // 2, 72 | conv_cfg=self.conv_cfg, 73 | norm_cfg=self.norm_cfg, 74 | act_cfg=self.act_cfg) 75 | 76 | def forward(self, inputs): 77 | """Forward function.""" 78 | x = self._transform_inputs(inputs) 79 | output = self.convs(x) 80 | if self.concat_input: 81 | output = self.conv_cat(torch.cat([x, output], dim=1)) 82 | output = self.cls_seg(output) 83 | return output 84 | -------------------------------------------------------------------------------- /models/head/gc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from mmcv.cnn import ContextBlock 3 | from .fcn import FCNHead 4 | 5 | 6 | class GCHead(FCNHead): 7 | """GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond. 8 | This head is the implementation of `GCNet 9 | `_. 10 | Args: 11 | ratio (float): Multiplier of channels ratio. Default: 1/4. 12 | pooling_type (str): The pooling type of context aggregation. 13 | Options are 'att', 'avg'. Default: 'avg'. 14 | fusion_types (tuple[str]): The fusion type for feature fusion. 15 | Options are 'channel_add', 'channel_mul'. Default: ('channel_add',) 16 | """ 17 | 18 | def __init__(self, 19 | ratio=1 / 4., 20 | pooling_type='att', 21 | fusion_types=('channel_add', ), 22 | in_channels=768, 23 | num_classes=6, 24 | in_index=3, 25 | channels=512, 26 | ): 27 | super(GCHead, self).__init__(num_convs=2, in_channels=in_channels, num_classes=num_classes, in_index=in_index, channels=channels) 28 | self.ratio = ratio 29 | self.pooling_type = pooling_type 30 | self.fusion_types = fusion_types 31 | self.gc_block = ContextBlock( 32 | in_channels=self.channels, 33 | ratio=self.ratio, 34 | pooling_type=self.pooling_type, 35 | fusion_types=self.fusion_types) 36 | 37 | def forward(self, inputs): 38 | """Forward function.""" 39 | x = self._transform_inputs(inputs) 40 | output = self.convs[0](x) 41 | output = self.gc_block(output) 42 | output = self.convs[1](output) 43 | if self.concat_input: 44 | output = self.conv_cat(torch.cat([x, output], dim=1)) 45 | output = self.cls_seg(output) 46 | return output -------------------------------------------------------------------------------- /models/head/mlp.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from .base_decoder import BaseDecodeHead, resize 4 | 5 | up_kwargs = {'mode': 'bilinear', 'align_corners': False} 6 | 7 | 8 | class MLP(nn.Module): 9 | """ 10 | Linear Embedding 11 | """ 12 | 13 | def __init__(self, input_dim=2048, embed_dim=768, norm_act=True): 14 | super().__init__() 15 | self.proj = nn.Linear(input_dim, embed_dim) 16 | self.norm_act = norm_act 17 | if self.norm_act: 18 | self.norm = nn.LayerNorm(input_dim) 19 | self.act = nn.GELU() 20 | 21 | def forward(self, x): 22 | x = x.flatten(2).transpose(1, 2) 23 | if self.norm_act: 24 | x = self.norm(x) 25 | x = self.proj(x) 26 | if self.norm_act: 27 | x = self.act(x) 28 | return x 29 | 30 | 31 | class MLPHead(BaseDecodeHead): 32 | """ 33 | SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers 34 | """ 35 | 36 | def __init__(self, in_channels=[96, 192, 384, 768], channels=512, num_classes=6, in_index=[0, 1, 2, 3]): 37 | super(MLPHead, self).__init__(input_transform='multiple_select', in_index=in_index, 38 | in_channels=in_channels, num_classes=num_classes, channels=channels) 39 | c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels 40 | 41 | self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=channels) 42 | self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=channels) 43 | self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=channels) 44 | self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=channels) 45 | 46 | self.linear_c3_out = MLP(input_dim=channels, embed_dim=channels) 47 | self.linear_c2_out = MLP(input_dim=channels, embed_dim=channels) 48 | self.linear_c1_out = MLP(input_dim=channels, embed_dim=channels) 49 | 50 | self.linear_fuse = MLP(input_dim=channels * 4, embed_dim=channels) 51 | self.linear_pred = MLP(input_dim=channels, embed_dim=num_classes, norm_act=False) 52 | 53 | def forward(self, inputs): 54 | x = self._transform_inputs(inputs) # len=4, 1/4,1/8,1/16,1/32 55 | c1, c2, c3, c4 = x 56 | out = [] 57 | ############## MLP decoder on C1-C4 ########### 58 | n, _, h, w = c4.shape 59 | 60 | _c4 = self.linear_c4(c4).permute(0, 2, 1).contiguous().reshape(n, -1, c4.shape[2], c4.shape[3]) 61 | _c4 = resize(_c4, size=c3.size()[2:], **up_kwargs) 62 | 63 | out.append(resize(_c4, size=c1.size()[2:], **up_kwargs)) 64 | 65 | _c3 = self.linear_c3(c3).permute(0, 2, 1).contiguous().reshape(n, -1, c3.shape[2], c3.shape[3]) 66 | _c3 = _c4 + _c3 67 | 68 | _c3_out = self.linear_c3_out(_c3).permute(0, 2, 1).contiguous().reshape(n, -1, c3.shape[2], c3.shape[3]) 69 | out.append(resize(_c3_out, size=c1.size()[2:], **up_kwargs)) 70 | 71 | _c2 = self.linear_c2(c2).permute(0, 2, 1).contiguous().reshape(n, -1, c2.shape[2], c2.shape[3]) 72 | _c3 = resize(_c3, size=c2.size()[2:], **up_kwargs) 73 | _c2 = _c3 + _c2 74 | 75 | _c2_out = self.linear_c2_out(_c2).permute(0, 2, 1).contiguous().reshape(n, -1, c2.shape[2], c2.shape[3]) 76 | out.append(resize(_c2_out, size=c1.size()[2:], **up_kwargs)) 77 | 78 | _c1 = self.linear_c1(c1).permute(0, 2, 1).contiguous().reshape(n, -1, c1.shape[2], c1.shape[3]) 79 | _c2 = resize(_c2, size=c1.size()[2:], **up_kwargs) 80 | _c1 = _c2 + _c1 81 | 82 | _c1_out = self.linear_c1_out(_c1).permute(0, 2, 1).contiguous().reshape(n, -1, c1.shape[2], c1.shape[3]) 83 | out.append(_c1_out) 84 | 85 | _c = self.linear_fuse(torch.cat(out, dim=1)).permute(0, 2, 1).contiguous().reshape(n, -1, c1.shape[2], c1.shape[3]) 86 | _c = self.dropout(_c) 87 | x = self.linear_pred(_c).permute(0, 2, 1).contiguous().reshape(n, -1, c1.shape[2], c1.shape[3]) 88 | 89 | return x 90 | -------------------------------------------------------------------------------- /models/head/psa.py: -------------------------------------------------------------------------------- 1 | """Point-wise Spatial Attention Network""" 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | up_kwargs = {'mode': 'bilinear', 'align_corners': True} 7 | norm_layer = nn.BatchNorm2d 8 | 9 | 10 | class _ConvBNReLU(nn.Module): 11 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, 12 | dilation=1, groups=1, relu6=False, norm_layer=norm_layer): 13 | super(_ConvBNReLU, self).__init__() 14 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False) 15 | self.bn = norm_layer(out_channels) 16 | self.relu = nn.ReLU6(True) if relu6 else nn.ReLU(True) 17 | 18 | def forward(self, x): 19 | x = self.conv(x) 20 | x = self.bn(x) 21 | x = self.relu(x) 22 | return x 23 | 24 | 25 | class PSAHead(nn.Module): 26 | def __init__(self, in_channels=768, num_classes=6, norm_layer=norm_layer, in_index=3): 27 | super(PSAHead, self).__init__() 28 | self.in_index = in_index 29 | # psa_out_channels = crop_size // stride_rate ** 2 30 | psa_out_channels = (512 // 32) ** 2 31 | self.psa = _PointwiseSpatialAttention(in_channels, psa_out_channels, norm_layer) 32 | 33 | self.conv_post = _ConvBNReLU(psa_out_channels, in_channels, 1, norm_layer=norm_layer) 34 | self.project = nn.Sequential( 35 | _ConvBNReLU(in_channels * 2, in_channels // 2, 3, padding=1, norm_layer=norm_layer), 36 | nn.Dropout2d(0.1, False), 37 | nn.Conv2d(in_channels // 2, num_classes, 1)) 38 | 39 | def _transform_inputs(self, inputs): 40 | if isinstance(self.in_index, (list, tuple)): 41 | inputs = [inputs[i] for i in self.in_index] 42 | elif isinstance(self.in_index, int): 43 | inputs = inputs[self.in_index] 44 | return inputs 45 | 46 | def forward(self, inputs): 47 | x = self._transform_inputs(inputs) 48 | global_feature = self.psa(x) 49 | out = self.conv_post(global_feature) 50 | out = torch.cat([x, out], dim=1) 51 | out = self.project(out) 52 | 53 | return out 54 | 55 | 56 | class _PointwiseSpatialAttention(nn.Module): 57 | def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d): 58 | super(_PointwiseSpatialAttention, self).__init__() 59 | reduced_channels = out_channels // 2 60 | self.collect_attention = _AttentionGeneration(in_channels, reduced_channels, out_channels, norm_layer) 61 | self.distribute_attention = _AttentionGeneration(in_channels, reduced_channels, out_channels, norm_layer) 62 | 63 | def forward(self, x): 64 | collect_fm = self.collect_attention(x) 65 | distribute_fm = self.distribute_attention(x) 66 | psa_fm = torch.cat([collect_fm, distribute_fm], dim=1) 67 | return psa_fm 68 | 69 | 70 | class _AttentionGeneration(nn.Module): 71 | def __init__(self, in_channels, reduced_channels, out_channels, norm_layer): 72 | super(_AttentionGeneration, self).__init__() 73 | self.conv_reduce = _ConvBNReLU(in_channels, reduced_channels, 1, norm_layer=norm_layer) 74 | self.attention = nn.Sequential( 75 | _ConvBNReLU(reduced_channels, reduced_channels, 1, norm_layer=norm_layer), 76 | nn.Conv2d(reduced_channels, out_channels, 1, bias=False)) 77 | 78 | self.reduced_channels = reduced_channels 79 | 80 | def forward(self, x): 81 | reduce_x = self.conv_reduce(x) 82 | attention = self.attention(reduce_x) 83 | n, c, h, w = attention.size() 84 | attention = attention.view(n, c, -1) 85 | reduce_x = reduce_x.view(n, self.reduced_channels, -1) 86 | fm = torch.bmm(reduce_x, torch.softmax(attention, dim=1)) 87 | fm = fm.view(n, self.reduced_channels, h, w) 88 | 89 | return fm 90 | -------------------------------------------------------------------------------- /models/head/psp.py: -------------------------------------------------------------------------------- 1 | ########################################################################### 2 | # Created by: Hang Zhang 3 | # Email: zhang.hang@rutgers.edu 4 | # Copyright (c) 2017 5 | ########################################################################### 6 | from __future__ import division 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | up_kwargs = {'mode': 'bilinear', 'align_corners': True} 13 | norm_layer = nn.BatchNorm2d 14 | 15 | 16 | class PyramidPooling(nn.Module): 17 | """ 18 | Reference: 19 | Zhao, Hengshuang, et al. *"Pyramid scene parsing network."* 20 | """ 21 | def __init__(self, in_channels, norm_layer, up_kwargs): 22 | super(PyramidPooling, self).__init__() 23 | self.pool1 = nn.AdaptiveAvgPool2d(1) 24 | self.pool2 = nn.AdaptiveAvgPool2d(2) 25 | self.pool3 = nn.AdaptiveAvgPool2d(3) 26 | self.pool4 = nn.AdaptiveAvgPool2d(6) 27 | 28 | out_channels = int(in_channels/4) 29 | self.conv1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False), 30 | norm_layer(out_channels), 31 | nn.ReLU(True)) 32 | self.conv2 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False), 33 | norm_layer(out_channels), 34 | nn.ReLU(True)) 35 | self.conv3 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False), 36 | norm_layer(out_channels), 37 | nn.ReLU(True)) 38 | self.conv4 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False), 39 | norm_layer(out_channels), 40 | nn.ReLU(True)) 41 | # bilinear interpolate options 42 | self._up_kwargs = up_kwargs 43 | 44 | def forward(self, x): 45 | _, _, h, w = x.size() 46 | feat1 = F.interpolate(self.conv1(self.pool1(x)), (h, w), **self._up_kwargs) 47 | feat2 = F.interpolate(self.conv2(self.pool2(x)), (h, w), **self._up_kwargs) 48 | feat3 = F.interpolate(self.conv3(self.pool3(x)), (h, w), **self._up_kwargs) 49 | feat4 = F.interpolate(self.conv4(self.pool4(x)), (h, w), **self._up_kwargs) 50 | return torch.cat((x, feat1, feat2, feat3, feat4), 1) 51 | 52 | 53 | class PSPHead(nn.Module): 54 | def __init__(self, in_channels, num_classes, norm_layer=norm_layer, up_kwargs=up_kwargs, in_index=3): 55 | super(PSPHead, self).__init__() 56 | inter_channels = in_channels // 4 57 | self.in_index = in_index 58 | self.conv5 = nn.Sequential(PyramidPooling(in_channels, norm_layer, up_kwargs), 59 | nn.Conv2d(in_channels * 2, inter_channels, 3, padding=1, bias=False), 60 | norm_layer(inter_channels), 61 | nn.ReLU(True), 62 | nn.Dropout(0.1, False), 63 | nn.Conv2d(inter_channels, num_classes, 1)) 64 | 65 | def _transform_inputs(self, inputs): 66 | if isinstance(self.in_index, (list, tuple)): 67 | inputs = [inputs[i] for i in self.in_index] 68 | elif isinstance(self.in_index, int): 69 | inputs = inputs[self.in_index] 70 | return inputs 71 | 72 | def forward(self, inputs): 73 | x = self._transform_inputs(inputs) 74 | return self.conv5(x) 75 | -------------------------------------------------------------------------------- /models/head/seg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | up_kwargs = {'mode': 'bilinear', 'align_corners': False} 6 | 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | """3x3 convolution with padding""" 10 | conv3x3 = nn.Sequential( 11 | nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 12 | padding=1, bias=False), 13 | nn.BatchNorm2d(out_planes), 14 | nn.ReLU(inplace=True), 15 | ) 16 | return conv3x3 17 | 18 | 19 | class SegHead(nn.Module): 20 | def __init__(self, in_channels=[96, 192, 384, 768], num_classes=6, in_index=[0, 1, 2, 3]): 21 | super(SegHead, self).__init__() 22 | self.in_index = in_index 23 | 24 | self.conv1 = conv3x3(in_channels[0], in_channels[0]) 25 | self.conv2 = conv3x3(in_channels[1], in_channels[0]) 26 | self.conv3 = conv3x3(in_channels[2], in_channels[0]) 27 | self.conv4 = conv3x3(in_channels[3], in_channels[0]) 28 | self.final_layer = nn.Sequential( 29 | nn.Conv2d( 30 | in_channels=in_channels[0] * 4, 31 | out_channels=in_channels[0] * 4, 32 | kernel_size=1, 33 | stride=1, 34 | padding=0), 35 | nn.BatchNorm2d(in_channels[0] * 4), 36 | nn.ReLU(inplace=True), 37 | nn.Conv2d( 38 | in_channels=in_channels[0] * 4, 39 | out_channels=num_classes, 40 | kernel_size=1, 41 | stride=1, 42 | padding=0) 43 | ) 44 | 45 | def _transform_inputs(self, inputs): 46 | if isinstance(self.in_index, (list, tuple)): 47 | inputs = [inputs[i] for i in self.in_index] 48 | elif isinstance(self.in_index, int): 49 | inputs = inputs[self.in_index] 50 | return inputs 51 | 52 | def forward(self, inputs): 53 | inputs = self._transform_inputs(inputs) 54 | p2, p3, p4, p5 = inputs 55 | h, w = p2.shape[-2:] 56 | x2 = self.conv1(p2) 57 | x3 = F.interpolate(self.conv2(p3), size=(h, w), **up_kwargs) 58 | x4 = F.interpolate(self.conv3(p4), size=(h, w), **up_kwargs) 59 | x5 = F.interpolate(self.conv4(p5), size=(h, w), **up_kwargs) 60 | x = torch.cat((x2, x3, x4, x5), dim=1) 61 | x = self.final_layer(x) 62 | 63 | return x 64 | -------------------------------------------------------------------------------- /models/head/unet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | up_kwargs = {'mode': 'bilinear', 'align_corners': True} 7 | norm_layer = nn.BatchNorm2d 8 | 9 | 10 | class Conv2dReLU(nn.Sequential): 11 | def __init__( 12 | self, 13 | in_channels, 14 | out_channels, 15 | kernel_size, 16 | padding=0, 17 | stride=1, 18 | use_batchnorm=True, 19 | ): 20 | 21 | conv = nn.Conv2d( 22 | in_channels, 23 | out_channels, 24 | kernel_size, 25 | stride=stride, 26 | padding=padding, 27 | bias=not (use_batchnorm), 28 | ) 29 | relu = nn.ReLU(inplace=True) 30 | 31 | if use_batchnorm: 32 | bn = nn.BatchNorm2d(out_channels) 33 | else: 34 | bn = nn.Identity() 35 | 36 | super(Conv2dReLU, self).__init__(conv, bn, relu) 37 | 38 | 39 | class SCSEAttention(nn.Module): 40 | def __init__(self, in_channels, reduction=16): 41 | super().__init__() 42 | self.cSE = nn.Sequential( 43 | nn.AdaptiveAvgPool2d(1), 44 | nn.Conv2d(in_channels, in_channels // reduction, 1), 45 | nn.ReLU(inplace=True), 46 | nn.Conv2d(in_channels // reduction, in_channels, 1), 47 | nn.Sigmoid(), 48 | ) 49 | self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid()) 50 | 51 | def forward(self, x): 52 | return x * self.cSE(x) + x * self.sSE(x) 53 | 54 | 55 | class DecoderBlock(nn.Module): 56 | def __init__( 57 | self, 58 | in_channels, 59 | skip_channels, 60 | out_channels, 61 | use_batchnorm=True, 62 | use_attention=False, 63 | ): 64 | super().__init__() 65 | self.conv1 = Conv2dReLU( 66 | in_channels + skip_channels, 67 | out_channels, 68 | kernel_size=3, 69 | padding=1, 70 | use_batchnorm=use_batchnorm, 71 | ) 72 | 73 | self.conv2 = Conv2dReLU( 74 | out_channels, 75 | out_channels, 76 | kernel_size=3, 77 | padding=1, 78 | use_batchnorm=use_batchnorm, 79 | ) 80 | self.use_attention = use_attention 81 | if self.use_attention: 82 | self.attention1 = SCSEAttention(in_channels=in_channels + skip_channels) 83 | self.attention2 = SCSEAttention(in_channels=out_channels) 84 | 85 | def forward(self, x, skip=None): 86 | x = F.interpolate(x, scale_factor=2, **up_kwargs) 87 | if skip is not None: 88 | x = torch.cat([x, skip], dim=1) 89 | if self.use_attention: 90 | x = self.attention1(x) 91 | x = self.conv1(x) 92 | x = self.conv2(x) 93 | if self.use_attention: 94 | x = self.attention2(x) 95 | 96 | return x 97 | 98 | 99 | class CenterBlock(nn.Sequential): 100 | def __init__(self, in_channels, out_channels, use_batchnorm=True): 101 | conv1 = Conv2dReLU( 102 | in_channels, 103 | out_channels, 104 | kernel_size=3, 105 | padding=1, 106 | use_batchnorm=use_batchnorm, 107 | ) 108 | conv2 = Conv2dReLU( 109 | out_channels, 110 | out_channels, 111 | kernel_size=3, 112 | padding=1, 113 | use_batchnorm=use_batchnorm, 114 | ) 115 | super().__init__(conv1, conv2) 116 | 117 | 118 | class UNetHead(nn.Module): 119 | def __init__( 120 | self, 121 | in_channels, 122 | num_classes=6, 123 | n_blocks=4, 124 | use_batchnorm=True, 125 | use_attention=False, 126 | center=False, 127 | in_index=[0, 1, 2, 3], 128 | ): 129 | super(UNetHead, self).__init__() 130 | self.in_index = in_index 131 | decoder_channels = [in_channels[i] // 4 for i in self.in_index] 132 | if n_blocks != len(decoder_channels): 133 | raise ValueError( 134 | "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format( 135 | n_blocks, len(decoder_channels) 136 | ) 137 | ) 138 | encoder_channels = in_channels[::-1] # reverse channels to start from head of encoder 139 | 140 | # computing blocks input and output channels 141 | head_channels = encoder_channels[0] 142 | in_channels = [head_channels] + list(decoder_channels[:-1]) 143 | skip_channels = list(encoder_channels[1:]) + [0] 144 | out_channels = decoder_channels 145 | 146 | if center: 147 | self.center = CenterBlock( 148 | head_channels, head_channels, use_batchnorm=use_batchnorm 149 | ) 150 | else: 151 | self.center = nn.Identity() 152 | # combine decoder keyword arguments 153 | kwargs = dict(use_batchnorm=use_batchnorm, use_attention=use_attention) 154 | blocks = [ 155 | DecoderBlock(in_ch, skip_ch, out_ch, **kwargs) 156 | for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels) 157 | ] 158 | self.blocks = nn.ModuleList(blocks) 159 | self.head = nn.Conv2d(out_channels[-1], num_classes, kernel_size=1) 160 | 161 | def _transform_inputs(self, inputs): 162 | if isinstance(self.in_index, (list, tuple)): 163 | inputs = [inputs[i] for i in self.in_index] 164 | elif isinstance(self.in_index, int): 165 | inputs = inputs[self.in_index] 166 | return inputs 167 | 168 | def forward(self, features): 169 | 170 | features = self._transform_inputs(features) 171 | features = features[::-1] # reverse channels to start from head of encoder 172 | 173 | head = features[0] 174 | skips = features[1:] 175 | 176 | x = self.center(head) 177 | for i, decoder_block in enumerate(self.blocks): 178 | skip = skips[i] if i < len(skips) else None 179 | x = decoder_block(x, skip) 180 | x = self.head(x) 181 | return x 182 | -------------------------------------------------------------------------------- /models/head/uper.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch.nn as nn 3 | import torch 4 | from .base_decoder import BaseDecodeHead, resize 5 | from mmcv.cnn import ConvModule 6 | 7 | norm_cfg = dict(type='BN', requires_grad=True) 8 | 9 | 10 | class PPM(nn.ModuleList): 11 | """Pooling Pyramid Module used in PSPNet. 12 | Args: 13 | pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid 14 | Module. 15 | in_channels (int): Input channels. 16 | channels (int): Channels after modules, before conv_seg. 17 | conv_cfg (dict|None): Config of conv layers. 18 | norm_cfg (dict|None): Config of norm layers. 19 | act_cfg (dict): Config of activation layers. 20 | align_corners (bool): align_corners argument of F.interpolate. 21 | """ 22 | 23 | def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg, 24 | act_cfg, align_corners): 25 | super(PPM, self).__init__() 26 | self.pool_scales = pool_scales 27 | self.align_corners = align_corners 28 | self.in_channels = in_channels 29 | self.channels = channels 30 | self.conv_cfg = conv_cfg 31 | self.norm_cfg = norm_cfg 32 | self.act_cfg = act_cfg 33 | for pool_scale in pool_scales: 34 | self.append( 35 | nn.Sequential( 36 | nn.AdaptiveAvgPool2d(pool_scale), 37 | ConvModule( 38 | self.in_channels, 39 | self.channels, 40 | 1, 41 | conv_cfg=self.conv_cfg, 42 | norm_cfg=self.norm_cfg, 43 | act_cfg=self.act_cfg))) 44 | 45 | def forward(self, x): 46 | """Forward function.""" 47 | ppm_outs = [] 48 | for ppm in self: 49 | """ppm work on batch > 1 when training""" 50 | ppm_out = ppm(x) 51 | upsampled_ppm_out = resize( 52 | ppm_out, 53 | size=x.size()[2:], 54 | mode='bilinear', 55 | align_corners=self.align_corners) 56 | ppm_outs.append(upsampled_ppm_out) 57 | return ppm_outs 58 | 59 | 60 | class UPerHead(BaseDecodeHead): 61 | """Unified Perceptual Parsing for Scene Understanding. 62 | This head is the implementation of `UPerNet 63 | `_. 64 | Args: 65 | pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid 66 | Module applied on the last feature. Default: (1, 2, 3, 6). 67 | """ 68 | 69 | def __init__(self, pool_scales=(1, 2, 3, 6), in_channels=[96, 192, 384, 768], num_classes=6): 70 | super(UPerHead, self).__init__( 71 | input_transform='multiple_select', in_index=[0, 1, 2, 3], in_channels=in_channels, num_classes=num_classes, 72 | channels=512, dropout_ratio=0.1, norm_cfg=norm_cfg, align_corners=False) 73 | # PSP Module 74 | self.psp_modules = PPM( 75 | pool_scales, 76 | self.in_channels[-1], 77 | self.channels, 78 | conv_cfg=self.conv_cfg, 79 | norm_cfg=self.norm_cfg, 80 | act_cfg=self.act_cfg, 81 | align_corners=self.align_corners) 82 | self.bottleneck = ConvModule( 83 | self.in_channels[-1] + len(pool_scales) * self.channels, 84 | self.channels, 85 | 3, 86 | padding=1, 87 | conv_cfg=self.conv_cfg, 88 | norm_cfg=self.norm_cfg, 89 | act_cfg=self.act_cfg) 90 | # FPN Module 91 | self.lateral_convs = nn.ModuleList() 92 | self.fpn_convs = nn.ModuleList() 93 | for in_channels in self.in_channels[:-1]: # skip the top layer 94 | l_conv = ConvModule( 95 | in_channels, 96 | self.channels, 97 | 1, 98 | conv_cfg=self.conv_cfg, 99 | norm_cfg=self.norm_cfg, 100 | act_cfg=self.act_cfg, 101 | inplace=False) 102 | fpn_conv = ConvModule( 103 | self.channels, 104 | self.channels, 105 | 3, 106 | padding=1, 107 | conv_cfg=self.conv_cfg, 108 | norm_cfg=self.norm_cfg, 109 | act_cfg=self.act_cfg, 110 | inplace=False) 111 | self.lateral_convs.append(l_conv) 112 | self.fpn_convs.append(fpn_conv) 113 | 114 | self.fpn_bottleneck = ConvModule( 115 | len(self.in_channels) * self.channels, 116 | self.channels, 117 | 3, 118 | padding=1, 119 | conv_cfg=self.conv_cfg, 120 | norm_cfg=self.norm_cfg, 121 | act_cfg=self.act_cfg) 122 | 123 | def psp_forward(self, inputs): 124 | """Forward function of PSP module.""" 125 | x = inputs[-1] 126 | psp_outs = [x] 127 | psp_outs.extend(self.psp_modules(x)) 128 | psp_outs = torch.cat(psp_outs, dim=1) 129 | output = self.bottleneck(psp_outs) 130 | 131 | return output 132 | 133 | def forward(self, inputs): 134 | """Forward function.""" 135 | 136 | inputs = self._transform_inputs(inputs) 137 | 138 | # build laterals 139 | laterals = [ 140 | lateral_conv(inputs[i]) 141 | for i, lateral_conv in enumerate(self.lateral_convs) 142 | ] 143 | 144 | laterals.append(self.psp_forward(inputs)) 145 | 146 | # build top-down path 147 | used_backbone_levels = len(laterals) 148 | for i in range(used_backbone_levels - 1, 0, -1): 149 | prev_shape = laterals[i - 1].shape[2:] 150 | laterals[i - 1] += resize( 151 | laterals[i], 152 | size=prev_shape, 153 | mode='bilinear', 154 | align_corners=self.align_corners) 155 | 156 | # build outputs 157 | fpn_outs = [ 158 | self.fpn_convs[i](laterals[i]) 159 | for i in range(used_backbone_levels - 1) 160 | ] 161 | # append psp feature 162 | fpn_outs.append(laterals[-1]) 163 | 164 | for i in range(used_backbone_levels - 1, 0, -1): 165 | fpn_outs[i] = resize( 166 | fpn_outs[i], 167 | size=fpn_outs[0].shape[2:], 168 | mode='bilinear', 169 | align_corners=self.align_corners) 170 | fpn_outs = torch.cat(fpn_outs, dim=1) 171 | output = self.fpn_bottleneck(fpn_outs) 172 | output = self.cls_seg(output) 173 | return output 174 | 175 | 176 | 177 | -------------------------------------------------------------------------------- /models/model_store.py: -------------------------------------------------------------------------------- 1 | """Model store which provides pretrained models.""" 2 | from __future__ import print_function 3 | __all__ = ['get_model_file', 'purge'] 4 | import os 5 | import zipfile 6 | import portalocker 7 | 8 | from .utils import download, check_sha1 9 | 10 | _model_sha1 = {name: checksum for checksum, name in [ 11 | # resnest 12 | ('fb9de5b360976e3e8bd3679d3e93c5409a5eff3c', 'resnest50'), 13 | ('966fb78c22323b0c68097c5c1242bd16d3e07fd5', 'resnest101'), 14 | ('d7fd712f5a1fcee5b3ce176026fbb6d0d278454a', 'resnest200'), 15 | ('51ae5f19032e22af4ec08e695496547acdba5ce5', 'resnest269'), 16 | # rectified 17 | #('9b5dc32b3b36ca1a6b41ecd4906830fc84dae8ed', 'resnet101_rt'), 18 | # resnet other variants 19 | ('a75c83cfc89a56a4e8ba71b14f1ec67e923787b3', 'resnet50s'), 20 | ('03a0f310d6447880f1b22a83bd7d1aa7fc702c6e', 'resnet101s'), 21 | ('36670e8bc2428ecd5b7db1578538e2dd23872813', 'resnet152s'), 22 | # other segmentation backbones 23 | ('da4785cfc837bf00ef95b52fb218feefe703011f', 'wideresnet38'), 24 | ('b41562160173ee2e979b795c551d3c7143b1e5b5', 'wideresnet50'), 25 | # deepten paper 26 | ('1225f149519c7a0113c43a056153c1bb15468ac0', 'deepten_resnet50_minc'), 27 | # segmentation resnet models 28 | ('662e979de25a389f11c65e9f1df7e06c2c356381', 'fcn_resnet50s_ade'), 29 | ('4de91d5922d4d3264f678b663f874da72e82db00', 'encnet_resnet50s_pcontext'), 30 | ('9f27ea13d514d7010e59988341bcbd4140fcc33d', 'encnet_resnet101s_pcontext'), 31 | ('07ac287cd77e53ea583f37454e17d30ce1509a4a', 'encnet_resnet50s_ade'), 32 | ('3f54fa3b67bac7619cd9b3673f5c8227cf8f4718', 'encnet_resnet101s_ade'), 33 | # resnest segmentation models 34 | ('4aba491aaf8e4866a9c9981b210e3e3266ac1f2a', 'fcn_resnest50_ade'), 35 | ('2225f09d0f40b9a168d9091652194bc35ec2a5a9', 'deeplab_resnest50_ade'), 36 | ('06ca799c8cc148fe0fafb5b6d052052935aa3cc8', 'deeplab_resnest101_ade'), 37 | ('7b9e7d3e6f0e2c763c7d77cad14d306c0a31fe05', 'deeplab_resnest200_ade'), 38 | ('0074dd10a6e6696f6f521653fb98224e75955496', 'deeplab_resnest269_ade'), 39 | ('77a2161deeb1564e8b9c41a4bb7a3f33998b00ad', 'fcn_resnest50_pcontext'), 40 | ('08dccbc4f4694baab631e037a374d76d8108c61f', 'deeplab_resnest50_pcontext'), 41 | ('faf5841853aae64bd965a7bdc2cdc6e7a2b5d898', 'deeplab_resnest101_pcontext'), 42 | ('fe76a26551dd5dcf2d474fd37cba99d43f6e984e', 'deeplab_resnest200_pcontext'), 43 | ('b661fd26c49656e01e9487cd9245babb12f37449', 'deeplab_resnest269_pcontext'), 44 | ]} 45 | 46 | encoding_repo_url = 'https://s3.us-west-1.wasabisys.com/encoding' 47 | _url_format = '{repo_url}models/{file_name}.zip' 48 | 49 | def short_hash(name): 50 | if name not in _model_sha1: 51 | raise ValueError('Pretrained model for {name} is not available.'.format(name=name)) 52 | return _model_sha1[name][:8] 53 | 54 | def get_model_file(name, root=os.path.join('~', '.encoding', 'models')): 55 | r"""Return location for the pretrained on local file system. 56 | This function will download from online model zoo when model cannot be found or has mismatch. 57 | The root directory will be created if it doesn't exist. 58 | Parameters 59 | ---------- 60 | name : str 61 | Name of the model. 62 | root : str, default '~/.encoding/models' 63 | Location for keeping the model parameters. 64 | Returns 65 | ------- 66 | file_path 67 | Path to the requested pretrained model file. 68 | """ 69 | if name not in _model_sha1: 70 | from torchvision.models.resnet import model_urls 71 | if name not in model_urls: 72 | raise ValueError('Pretrained model for {name} is not available.'.format(name=name)) 73 | root = os.path.expanduser(root) 74 | return download(model_urls[name], 75 | path=root, 76 | overwrite=True) 77 | file_name = '{name}-{short_hash}'.format(name=name, short_hash=short_hash(name)) 78 | root = os.path.expanduser(root) 79 | if not os.path.exists(root): 80 | os.makedirs(root) 81 | 82 | file_path = os.path.join(root, file_name+'.pth') 83 | sha1_hash = _model_sha1[name] 84 | 85 | lockfile = os.path.join(root, file_name + '.lock') 86 | with portalocker.Lock(lockfile, timeout=300): 87 | if os.path.exists(file_path): 88 | if check_sha1(file_path, sha1_hash): 89 | return file_path 90 | else: 91 | print('Mismatch in the content of model file {} detected.' + 92 | ' Downloading again.'.format(file_path)) 93 | else: 94 | print('Model file {} is not found. Downloading.'.format(file_path)) 95 | 96 | zip_file_path = os.path.join(root, file_name+'.zip') 97 | repo_url = os.environ.get('ENCODING_REPO', encoding_repo_url) 98 | if repo_url[-1] != '/': 99 | repo_url = repo_url + '/' 100 | download(_url_format.format(repo_url=repo_url, file_name=file_name), 101 | path=zip_file_path, 102 | overwrite=True) 103 | with zipfile.ZipFile(zip_file_path) as zf: 104 | zf.extractall(root) 105 | os.remove(zip_file_path) 106 | 107 | if check_sha1(file_path, sha1_hash): 108 | return file_path 109 | else: 110 | raise ValueError('Downloaded file has different hash. Please try again.') 111 | 112 | def purge(root=os.path.join('~', '.encoding', 'models')): 113 | r"""Purge all pretrained model files in local file store. 114 | Parameters 115 | ---------- 116 | root : str, default '~/.encoding/models' 117 | Location for keeping the model parameters. 118 | """ 119 | root = os.path.expanduser(root) 120 | files = os.listdir(root) 121 | for f in files: 122 | if f.endswith(".pth"): 123 | os.remove(os.path.join(root, f)) 124 | 125 | def pretrained_model_list(): 126 | return list(_model_sha1.keys()) -------------------------------------------------------------------------------- /models/pspnet.py: -------------------------------------------------------------------------------- 1 | """Pyramid Scene Parsing Network""" 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | try: 6 | from .resnet import resnet50_v1b 7 | except: 8 | from resnet import resnet50_v1b 9 | 10 | 11 | class SeparableConv2d(nn.Module): 12 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, padding=1, 13 | dilation=1, bias=False, norm_layer=nn.BatchNorm2d): 14 | super(SeparableConv2d, self).__init__() 15 | self.conv = nn.Conv2d(inplanes, inplanes, kernel_size, stride, padding, dilation, groups=inplanes, bias=bias) 16 | self.bn = norm_layer(inplanes) 17 | self.pointwise = nn.Conv2d(inplanes, planes, 1, bias=bias) 18 | 19 | def forward(self, x): 20 | x = self.conv(x) 21 | x = self.bn(x) 22 | x = self.pointwise(x) 23 | return x 24 | 25 | 26 | # copy from: https://github.com/wuhuikai/FastFCN/blob/master/encoding/nn/customize.py 27 | class JPU(nn.Module): 28 | def __init__(self, in_channels, width=512, norm_layer=nn.BatchNorm2d, **kwargs): 29 | super(JPU, self).__init__() 30 | 31 | self.conv5 = nn.Sequential( 32 | nn.Conv2d(in_channels[-1], width, 3, padding=1, bias=False), 33 | norm_layer(width), 34 | nn.ReLU(True)) 35 | self.conv4 = nn.Sequential( 36 | nn.Conv2d(in_channels[-2], width, 3, padding=1, bias=False), 37 | norm_layer(width), 38 | nn.ReLU(True)) 39 | self.conv3 = nn.Sequential( 40 | nn.Conv2d(in_channels[-3], width, 3, padding=1, bias=False), 41 | norm_layer(width), 42 | nn.ReLU(True)) 43 | 44 | self.dilation1 = nn.Sequential( 45 | SeparableConv2d(3 * width, width, 3, padding=1, dilation=1, bias=False), 46 | norm_layer(width), 47 | nn.ReLU(True)) 48 | self.dilation2 = nn.Sequential( 49 | SeparableConv2d(3 * width, width, 3, padding=2, dilation=2, bias=False), 50 | norm_layer(width), 51 | nn.ReLU(True)) 52 | self.dilation3 = nn.Sequential( 53 | SeparableConv2d(3 * width, width, 3, padding=4, dilation=4, bias=False), 54 | norm_layer(width), 55 | nn.ReLU(True)) 56 | self.dilation4 = nn.Sequential( 57 | SeparableConv2d(3 * width, width, 3, padding=8, dilation=8, bias=False), 58 | norm_layer(width), 59 | nn.ReLU(True)) 60 | 61 | def forward(self, *inputs): 62 | feats = [self.conv5(inputs[-1]), self.conv4(inputs[-2]), self.conv3(inputs[-3])] 63 | size = feats[-1].size()[2:] 64 | feats[-2] = F.interpolate(feats[-2], size, mode='bilinear', align_corners=True) 65 | feats[-3] = F.interpolate(feats[-3], size, mode='bilinear', align_corners=True) 66 | feat = torch.cat(feats, dim=1) 67 | feat = torch.cat([self.dilation1(feat), self.dilation2(feat), self.dilation3(feat), self.dilation4(feat)], 68 | dim=1) 69 | 70 | return inputs[0], inputs[1], inputs[2], feat 71 | 72 | 73 | class SegBaseModel(nn.Module): 74 | r"""Base Model for Semantic Segmentation 75 | 76 | Parameters 77 | ---------- 78 | backbone : string 79 | Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50', 80 | 'resnet101' or 'resnet152'). 81 | """ 82 | 83 | def __init__(self, nclass, aux, backbone='resnet50', jpu=False, pretrained_base=False, **kwargs): 84 | super(SegBaseModel, self).__init__() 85 | dilated = False if jpu else True 86 | self.aux = aux 87 | self.nclass = nclass 88 | if backbone == 'resnet50': 89 | self.pretrained = resnet50_v1b(pretrained=pretrained_base, dilated=dilated, **kwargs) 90 | 91 | else: 92 | raise RuntimeError('unknown backbone: {}'.format(backbone)) 93 | 94 | self.jpu = JPU([512, 1024, 2048], width=512, **kwargs) if jpu else None 95 | 96 | def base_forward(self, x): 97 | """forwarding pre-trained network""" 98 | x = self.pretrained.conv1(x) 99 | x = self.pretrained.bn1(x) 100 | x = self.pretrained.relu(x) 101 | x = self.pretrained.maxpool(x) 102 | c1 = self.pretrained.layer1(x) 103 | c2 = self.pretrained.layer2(c1) 104 | c3 = self.pretrained.layer3(c2) 105 | c4 = self.pretrained.layer4(c3) 106 | 107 | if self.jpu: 108 | return self.jpu(c1, c2, c3, c4) 109 | else: 110 | return c1, c2, c3, c4 111 | 112 | def evaluate(self, x): 113 | """evaluating network with inputs and targets""" 114 | return self.forward(x)[0] 115 | 116 | def demo(self, x): 117 | pred = self.forward(x) 118 | if self.aux: 119 | pred = pred[0] 120 | return pred 121 | 122 | 123 | class _FCNHead(nn.Module): 124 | def __init__(self, in_channels, channels, norm_layer=nn.BatchNorm2d, **kwargs): 125 | super(_FCNHead, self).__init__() 126 | inter_channels = in_channels // 4 127 | self.block = nn.Sequential( 128 | nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), 129 | norm_layer(inter_channels), 130 | nn.ReLU(inplace=True), 131 | nn.Dropout(0.1), 132 | nn.Conv2d(inter_channels, channels, 1) 133 | ) 134 | 135 | def forward(self, x): 136 | return self.block(x) 137 | 138 | 139 | class PSPNet(SegBaseModel): 140 | r"""Pyramid Scene Parsing Network 141 | 142 | Parameters 143 | ---------- 144 | nclass : int 145 | Number of categories for the training dataset. 146 | backbone : string 147 | Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50', 148 | 'resnet101' or 'resnet152'). 149 | norm_layer : object 150 | Normalization layer used in backbone network (default: :class:`nn.BatchNorm`; 151 | for Synchronized Cross-GPU BachNormalization). 152 | aux : bool 153 | Auxiliary loss. 154 | 155 | Reference: 156 | Zhao, Hengshuang, Jianping Shi, Xiaojuan Qi, Xiaogang Wang, and Jiaya Jia. 157 | "Pyramid scene parsing network." *CVPR*, 2017 158 | """ 159 | 160 | def __init__(self, nclass, backbone='resnet50', aux=False, pretrained_base=False, **kwargs): 161 | super(PSPNet, self).__init__(nclass, aux, backbone, pretrained_base=pretrained_base, **kwargs) 162 | self.head = _PSPHead(nclass, **kwargs) 163 | if self.aux: 164 | self.auxlayer = _FCNHead(1024, nclass, **kwargs) 165 | 166 | self.__setattr__('exclusive', ['head', 'auxlayer'] if aux else ['head']) 167 | 168 | def forward(self, x): 169 | size = x.size()[2:] 170 | _, _, c3, c4 = self.base_forward(x) 171 | outputs = [] 172 | x = self.head(c4) 173 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 174 | outputs = x 175 | 176 | if self.aux: 177 | auxout = self.auxlayer(c3) 178 | auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True) 179 | outputs.append(auxout) 180 | return outputs 181 | 182 | 183 | def _PSP1x1Conv(in_channels, out_channels, norm_layer, norm_kwargs): 184 | return nn.Sequential( 185 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 186 | norm_layer(out_channels, **({} if norm_kwargs is None else norm_kwargs)), 187 | nn.ReLU(True) 188 | ) 189 | 190 | 191 | class _PyramidPooling(nn.Module): 192 | def __init__(self, in_channels, **kwargs): 193 | super(_PyramidPooling, self).__init__() 194 | out_channels = int(in_channels / 4) 195 | self.avgpool1 = nn.AdaptiveAvgPool2d(1) 196 | self.avgpool2 = nn.AdaptiveAvgPool2d(2) 197 | self.avgpool3 = nn.AdaptiveAvgPool2d(3) 198 | self.avgpool4 = nn.AdaptiveAvgPool2d(6) 199 | self.conv1 = _PSP1x1Conv(in_channels, out_channels, **kwargs) 200 | self.conv2 = _PSP1x1Conv(in_channels, out_channels, **kwargs) 201 | self.conv3 = _PSP1x1Conv(in_channels, out_channels, **kwargs) 202 | self.conv4 = _PSP1x1Conv(in_channels, out_channels, **kwargs) 203 | 204 | def forward(self, x): 205 | size = x.size()[2:] 206 | feat1 = F.interpolate(self.conv1(self.avgpool1(x)), size, mode='bilinear', align_corners=True) 207 | feat2 = F.interpolate(self.conv2(self.avgpool2(x)), size, mode='bilinear', align_corners=True) 208 | feat3 = F.interpolate(self.conv3(self.avgpool3(x)), size, mode='bilinear', align_corners=True) 209 | feat4 = F.interpolate(self.conv4(self.avgpool4(x)), size, mode='bilinear', align_corners=True) 210 | return torch.cat([x, feat1, feat2, feat3, feat4], dim=1) 211 | 212 | 213 | class _PSPHead(nn.Module): 214 | def __init__(self, nclass, norm_layer=nn.BatchNorm2d, norm_kwargs=None, **kwargs): 215 | super(_PSPHead, self).__init__() 216 | self.psp = _PyramidPooling(2048, norm_layer=norm_layer, norm_kwargs=norm_kwargs) 217 | self.block = nn.Sequential( 218 | nn.Conv2d(4096, 512, 3, padding=1, bias=False), 219 | norm_layer(512, **({} if norm_kwargs is None else norm_kwargs)), 220 | nn.ReLU(True), 221 | nn.Dropout(0.1), 222 | nn.Conv2d(512, nclass, 1) 223 | ) 224 | 225 | def forward(self, x): 226 | x = self.psp(x) 227 | return self.block(x) 228 | 229 | 230 | if __name__ == '__main__': 231 | from tools.flops_params_fps_count import flops_params_fps 232 | model = PSPNet(nclass=6) 233 | flops_params_fps(model) -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | __all__ = ['ResNetV1b', 'resnet18_v1b', 'resnet34_v1b', 'resnet50_v1b', 6 | 'resnet101_v1b', 'resnet152_v1b', 'resnet152_v1s', 'resnet101_v1s', 'resnet50_v1s'] 7 | 8 | model_urls = { 9 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 10 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 11 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 12 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 13 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 14 | } 15 | 16 | pretrained_save_dir = './pretrained_weights' 17 | 18 | class BasicBlockV1b(nn.Module): 19 | expansion = 1 20 | 21 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, 22 | previous_dilation=1, norm_layer=nn.BatchNorm2d): 23 | super(BasicBlockV1b, self).__init__() 24 | self.conv1 = nn.Conv2d(inplanes, planes, 3, stride, 25 | dilation, dilation, bias=False) 26 | self.bn1 = norm_layer(planes) 27 | self.relu = nn.ReLU(True) 28 | self.conv2 = nn.Conv2d(planes, planes, 3, 1, previous_dilation, 29 | dilation=previous_dilation, bias=False) 30 | self.bn2 = norm_layer(planes) 31 | self.downsample = downsample 32 | self.stride = stride 33 | 34 | def forward(self, x): 35 | identity = x 36 | 37 | out = self.conv1(x) 38 | out = self.bn1(out) 39 | out = self.relu(out) 40 | 41 | out = self.conv2(out) 42 | out = self.bn2(out) 43 | 44 | if self.downsample is not None: 45 | identity = self.downsample(x) 46 | 47 | out += identity 48 | out = self.relu(out) 49 | 50 | return out 51 | 52 | 53 | class BottleneckV1b(nn.Module): 54 | expansion = 4 55 | 56 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, 57 | previous_dilation=1, norm_layer=nn.BatchNorm2d): 58 | super(BottleneckV1b, self).__init__() 59 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 60 | self.bn1 = norm_layer(planes) 61 | self.conv2 = nn.Conv2d(planes, planes, 3, stride, 62 | dilation, dilation, bias=False) 63 | self.bn2 = norm_layer(planes) 64 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 65 | self.bn3 = norm_layer(planes * self.expansion) 66 | self.relu = nn.ReLU(True) 67 | self.downsample = downsample 68 | self.stride = stride 69 | 70 | def forward(self, x): 71 | identity = x 72 | 73 | out = self.conv1(x) 74 | out = self.bn1(out) 75 | out = self.relu(out) 76 | 77 | out = self.conv2(out) 78 | out = self.bn2(out) 79 | out = self.relu(out) 80 | 81 | out = self.conv3(out) 82 | out = self.bn3(out) 83 | 84 | if self.downsample is not None: 85 | identity = self.downsample(x) 86 | 87 | out += identity 88 | out = self.relu(out) 89 | 90 | return out 91 | 92 | 93 | class ResNetV1b(nn.Module): 94 | 95 | def __init__(self, block, layers, num_classes=1000, dilated=True, deep_stem=False, 96 | zero_init_residual=False, norm_layer=nn.BatchNorm2d): 97 | self.inplanes = 128 if deep_stem else 64 98 | super(ResNetV1b, self).__init__() 99 | if deep_stem: 100 | self.conv1 = nn.Sequential( 101 | nn.Conv2d(3, 64, 3, 2, 1, bias=False), 102 | norm_layer(64), 103 | nn.ReLU(True), 104 | nn.Conv2d(64, 64, 3, 1, 1, bias=False), 105 | norm_layer(64), 106 | nn.ReLU(True), 107 | nn.Conv2d(64, 128, 3, 1, 1, bias=False) 108 | ) 109 | else: 110 | self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False) 111 | self.bn1 = norm_layer(self.inplanes) 112 | self.relu = nn.ReLU(True) 113 | self.maxpool = nn.MaxPool2d(3, 2, 1) 114 | self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer) 115 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer) 116 | if dilated: 117 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2, norm_layer=norm_layer) 118 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4, norm_layer=norm_layer) 119 | else: 120 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer) 121 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer) 122 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 123 | self.fc = nn.Linear(512 * block.expansion, num_classes) 124 | 125 | for m in self.modules(): 126 | if isinstance(m, nn.Conv2d): 127 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 128 | elif isinstance(m, nn.BatchNorm2d): 129 | nn.init.constant_(m.weight, 1) 130 | nn.init.constant_(m.bias, 0) 131 | 132 | if zero_init_residual: 133 | for m in self.modules(): 134 | if isinstance(m, BottleneckV1b): 135 | nn.init.constant_(m.bn3.weight, 0) 136 | elif isinstance(m, BasicBlockV1b): 137 | nn.init.constant_(m.bn2.weight, 0) 138 | 139 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=nn.BatchNorm2d): 140 | downsample = None 141 | if stride != 1 or self.inplanes != planes * block.expansion: 142 | downsample = nn.Sequential( 143 | nn.Conv2d(self.inplanes, planes * block.expansion, 1, stride, bias=False), 144 | norm_layer(planes * block.expansion), 145 | ) 146 | 147 | layers = [] 148 | if dilation in (1, 2): 149 | layers.append(block(self.inplanes, planes, stride, dilation=1, downsample=downsample, 150 | previous_dilation=dilation, norm_layer=norm_layer)) 151 | elif dilation == 4: 152 | layers.append(block(self.inplanes, planes, stride, dilation=2, downsample=downsample, 153 | previous_dilation=dilation, norm_layer=norm_layer)) 154 | else: 155 | raise RuntimeError("=> unknown dilation size: {}".format(dilation)) 156 | self.inplanes = planes * block.expansion 157 | for _ in range(1, blocks): 158 | layers.append(block(self.inplanes, planes, dilation=dilation, 159 | previous_dilation=dilation, norm_layer=norm_layer)) 160 | 161 | return nn.Sequential(*layers) 162 | 163 | def forward(self, x): 164 | x = self.conv1(x) 165 | x = self.bn1(x) 166 | x = self.relu(x) 167 | x = self.maxpool(x) 168 | 169 | x = self.layer1(x) 170 | x = self.layer2(x) 171 | x = self.layer3(x) 172 | x = self.layer4(x) 173 | 174 | x = self.avgpool(x) 175 | x = x.view(x.size(0), -1) 176 | x = self.fc(x) 177 | 178 | return x 179 | 180 | 181 | def resnet18_v1b(pretrained=False, **kwargs): 182 | model = ResNetV1b(BasicBlockV1b, [2, 2, 2, 2], **kwargs) 183 | if pretrained: 184 | old_dict = model_zoo.load_url(model_urls['resnet18'], model_dir=pretrained_save_dir) 185 | model_dict = model.state_dict() 186 | old_dict = {k: v for k, v in old_dict.items() if (k in model_dict)} 187 | model_dict.update(old_dict) 188 | model.load_state_dict(model_dict) 189 | return model 190 | 191 | 192 | def resnet34_v1b(pretrained=False, **kwargs): 193 | model = ResNetV1b(BasicBlockV1b, [3, 4, 6, 3], **kwargs) 194 | if pretrained: 195 | old_dict = model_zoo.load_url(model_urls['resnet34'], model_dir=pretrained_save_dir) 196 | model_dict = model.state_dict() 197 | old_dict = {k: v for k, v in old_dict.items() if (k in model_dict)} 198 | model_dict.update(old_dict) 199 | model.load_state_dict(model_dict) 200 | return model 201 | 202 | 203 | def resnet50_v1b(pretrained=False, **kwargs): 204 | model = ResNetV1b(BottleneckV1b, [3, 4, 6, 3], **kwargs) 205 | if pretrained: 206 | print('load pretrain resnet50_v1b') 207 | old_dict = model_zoo.load_url(model_urls['resnet50'], model_dir=pretrained_save_dir) 208 | model_dict = model.state_dict() 209 | old_dict = {k: v for k, v in old_dict.items() if (k in model_dict)} 210 | model_dict.update(old_dict) 211 | model.load_state_dict(model_dict) 212 | return model 213 | 214 | 215 | def resnet101_v1b(pretrained=False, **kwargs): 216 | model = ResNetV1b(BottleneckV1b, [3, 4, 23, 3], **kwargs) 217 | if pretrained: 218 | print('load pretrain resnet101_v1b') 219 | old_dict = model_zoo.load_url(model_urls['resnet101'], model_dir=pretrained_save_dir) 220 | model_dict = model.state_dict() 221 | old_dict = {k: v for k, v in old_dict.items() if (k in model_dict)} 222 | model_dict.update(old_dict) 223 | model.load_state_dict(model_dict) 224 | return model 225 | 226 | 227 | def resnet152_v1b(pretrained=False, **kwargs): 228 | model = ResNetV1b(BottleneckV1b, [3, 8, 36, 3], **kwargs) 229 | if pretrained: 230 | print('load pretrain resnet152_v1b') 231 | old_dict = model_zoo.load_url(model_urls['resnet152'], model_dir=pretrained_save_dir) 232 | model_dict = model.state_dict() 233 | old_dict = {k: v for k, v in old_dict.items() if (k in model_dict)} 234 | model_dict.update(old_dict) 235 | model.load_state_dict(model_dict) 236 | return model 237 | 238 | 239 | def resnet50_v1s(pretrained=False, root=pretrained_save_dir, **kwargs): 240 | model = ResNetV1b(BottleneckV1b, [3, 4, 6, 3], deep_stem=True, **kwargs) 241 | if pretrained: 242 | print('load pretrain resnet50_v1s') 243 | from models.model_store import get_model_file 244 | model.load_state_dict(torch.load(get_model_file('resnet50s', root=root)), strict=False) 245 | return model 246 | 247 | 248 | def resnet101_v1s(pretrained=False, root=pretrained_save_dir, **kwargs): 249 | model = ResNetV1b(BottleneckV1b, [3, 4, 23, 3], deep_stem=True, **kwargs) 250 | if pretrained: 251 | print('load pretrain resnet101_v1s') 252 | from .model_store import get_model_file 253 | model.load_state_dict(torch.load(get_model_file('resnet101s', root=root)), strict=False) 254 | return model 255 | 256 | 257 | def resnet152_v1s(pretrained=False, root=pretrained_save_dir, **kwargs): 258 | model = ResNetV1b(BottleneckV1b, [3, 8, 36, 3], deep_stem=True, **kwargs) 259 | if pretrained: 260 | print('load pretrain resnet152_v1s') 261 | from .model_store import get_model_file 262 | model.load_state_dict(torch.load(get_model_file('resnet152s', root=root)), strict=False) 263 | return model 264 | 265 | 266 | if __name__ == '__main__': 267 | import torch 268 | 269 | img = torch.randn(4, 3, 224, 224) 270 | model = resnet50_v1b(True) 271 | output = model(img) -------------------------------------------------------------------------------- /models/segbase.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | try: 3 | from .resnet import resnet50_v1b 4 | except: 5 | from resnet import resnet50_v1b 6 | import torch.nn.functional as F 7 | from models.head.seg import SegHead 8 | 9 | 10 | class SegBaseModel(nn.Module): 11 | r"""Base Model for Semantic Segmentation 12 | 13 | Parameters 14 | ---------- 15 | backbone : string 16 | Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50', 17 | 'resnet101' or 'resnet152'). 18 | """ 19 | 20 | def __init__(self, nclass, backbone='resnet50', dilated=True, pretrained_base=False, **kwargs): 21 | super(SegBaseModel, self).__init__() 22 | self.nclass = nclass 23 | if backbone == 'resnet50': 24 | self.pretrained = resnet50_v1b(pretrained=pretrained_base, dilated=dilated, **kwargs) 25 | 26 | def base_forward(self, x): 27 | """forwarding pre-trained network""" 28 | x = self.pretrained.conv1(x) 29 | x = self.pretrained.bn1(x) 30 | x = self.pretrained.relu(x) 31 | x = self.pretrained.maxpool(x) 32 | c1 = self.pretrained.layer1(x) 33 | c2 = self.pretrained.layer2(c1) 34 | c3 = self.pretrained.layer3(c2) 35 | c4 = self.pretrained.layer4(c3) 36 | 37 | return c1, c2, c3, c4 38 | 39 | 40 | class _FCNHead(nn.Module): 41 | def __init__(self, in_channels, channels, norm_layer=nn.BatchNorm2d, **kwargs): 42 | super(_FCNHead, self).__init__() 43 | inter_channels = in_channels // 4 44 | self.block = nn.Sequential( 45 | nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), 46 | norm_layer(inter_channels), 47 | nn.ReLU(inplace=True), 48 | nn.Dropout(0.1), 49 | nn.Conv2d(inter_channels, channels, 1) 50 | ) 51 | 52 | def forward(self, x): 53 | return self.block(x) 54 | 55 | 56 | # class SegBase(SegBaseModel): 57 | # 58 | # def __init__(self, nclass, backbone='resnet50', pretrained_base=False, **kwargs): 59 | # super(SegBase, self).__init__(nclass, backbone, pretrained_base=pretrained_base, **kwargs) 60 | # self.head = _FCNHead(2048, nclass, **kwargs) 61 | # 62 | # def forward(self, x): 63 | # size = x.size()[2:] 64 | # _, _, c3, c4 = self.base_forward(x) 65 | # x = self.head(c4) 66 | # x = F.interpolate(x, size, mode='bilinear', align_corners=True) 67 | # 68 | # return x 69 | 70 | 71 | class SegBase(SegBaseModel): 72 | 73 | def __init__(self, nclass, backbone='resnet50', pretrained_base=False, **kwargs): 74 | super(SegBase, self).__init__(nclass, backbone, pretrained_base=pretrained_base, **kwargs) 75 | cnn_dict = { 76 | 'resnet18_v1b': 'resnet18_v1b', 'resnet18': 'resnet18_v1b', 77 | 'resnet34_v1b': 'resnet34_v1b', 'resnet34': 'resnet34_v1b', 78 | 'resnet50_v1b': 'resnet50_v1b', 'resnet50': 'resnet50_v1b', 79 | 'resnet101_v1b': 'resnet101_v1b', 'resnet101': 'resnet101_v1b', 80 | 'hrnet18': 'hrnet18', 'HRNet18': 'hrnet18', 81 | 'hrnet32': 'hrnet32', 'HRNet32': 'hrnet32', 82 | 'hrnet48': 'hrnet48', 'HRNet48': 'hrnet48', 83 | } 84 | cnn_name = backbone 85 | if 'resnet18' in cnn_dict[cnn_name] or 'resnet34' in cnn_dict[cnn_name]: 86 | self.cnn_head_dim = [64, 128, 256, 512] 87 | if 'resnet50' in cnn_dict[cnn_name] or 'resnet101' in cnn_dict[cnn_name]: 88 | self.cnn_head_dim = [256, 512, 1024, 2048] 89 | self.head = SegHead(in_channels=self.cnn_head_dim, num_classes=nclass, in_index=[0, 1, 2, 3]) 90 | 91 | def forward(self, x): 92 | size = x.size()[2:] 93 | out_backbone = self.base_forward(x) 94 | x = self.head(out_backbone) 95 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 96 | 97 | return x 98 | 99 | 100 | if __name__ == '__main__': 101 | from tools.flops_params_fps_count import flops_params_fps 102 | model = SegBase(nclass=6) 103 | flops_params_fps(model) 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | -------------------------------------------------------------------------------- /models/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from models.head import * 5 | from models.resT import rest_tiny, rest_small, rest_base, rest_large 6 | from models.swinT import swin_tiny, swin_small, swin_base, swin_large 7 | from models.volo import volo_d1, volo_d2, volo_d3, volo_d4, volo_d5 8 | from models.cswin import cswin_tiny, cswin_base, cswin_small, cswin_large 9 | from models.beit import beit_base, beit_large 10 | #from tools.heatmap_fun import draw_features 11 | 12 | up_kwargs = {'mode': 'bilinear', 'align_corners': False} 13 | 14 | 15 | class Transformer(nn.Module): 16 | 17 | def __init__(self, transformer_name, nclass, img_size, aux=False, pretrained=False, head='seghead', edge_aux=False): 18 | super(Transformer, self).__init__() 19 | self.aux = aux 20 | self.edge_aux = edge_aux 21 | self.head_name = head 22 | 23 | self.model = eval(transformer_name)(nclass=nclass, img_size=img_size, aux=aux, pretrained=pretrained) 24 | self.backbone = self.model.backbone 25 | 26 | head_dim = self.model.head_dim 27 | if self.head_name == 'apchead': 28 | self.decode_head = APCHead(in_channels=head_dim[3], num_classes=nclass, in_index=3, channels=512) 29 | 30 | if self.head_name == 'aspphead': 31 | self.decode_head = ASPPHead(in_channels=head_dim[3], num_classes=nclass, in_index=3) 32 | 33 | if self.head_name == 'asppplushead': 34 | self.decode_head = ASPPPlusHead(in_channels=head_dim[3], num_classes=nclass, in_index=[0, 3]) 35 | 36 | if self.head_name == 'dahead': 37 | self.decode_head = DAHead(in_channels=head_dim[3], num_classes=nclass, in_index=3) 38 | 39 | if self.head_name == 'dnlhead': 40 | self.decode_head = DNLHead(in_channels=head_dim[3], num_classes=nclass, in_index=3, channels=512) 41 | 42 | if self.head_name == 'fcfpnhead': 43 | self.decode_head = FCFPNHead(in_channels=head_dim, num_classes=nclass, in_index=[0, 1, 2, 3], channels=256) 44 | 45 | if self.head_name == 'cefpnhead': 46 | self.decode_head = CEFPNHead(in_channels=head_dim, num_classes=nclass, in_index=[0, 1, 2, 3], channels=256) 47 | 48 | if self.head_name == 'fcnhead': 49 | self.decode_head = FCNHead(in_channels=head_dim[3], num_classes=nclass, in_index=3, channels=512) 50 | 51 | if self.head_name == 'gchead': 52 | self.decode_head = GCHead(in_channels=head_dim[3], num_classes=nclass, in_index=3, channels=512) 53 | 54 | if self.head_name == 'psahead': 55 | self.decode_head = PSAHead(in_channels=head_dim[3], num_classes=nclass, in_index=3) 56 | 57 | if self.head_name == 'psphead': 58 | self.decode_head = PSPHead(in_channels=head_dim[3], num_classes=nclass, in_index=3) 59 | 60 | if self.head_name == 'seghead': 61 | self.decode_head = SegHead(in_channels=head_dim, num_classes=nclass, in_index=[0, 1, 2, 3]) 62 | 63 | if self.head_name == 'unethead': 64 | self.decode_head = UNetHead(in_channels=head_dim, num_classes=nclass, in_index=[0, 1, 2, 3]) 65 | 66 | if self.head_name == 'uperhead': 67 | self.decode_head = UPerHead(in_channels=head_dim, num_classes=nclass) 68 | 69 | if self.head_name == 'annhead': 70 | self.decode_head = ANNHead(in_channels=head_dim[2:], num_classes=nclass, in_index=[2, 3], channels=512) 71 | 72 | if self.head_name == 'mlphead': 73 | self.decode_head = MLPHead(in_channels=head_dim, num_classes=nclass, in_index=[0, 1, 2, 3], channels=256) 74 | 75 | if self.aux: 76 | self.auxiliary_head = FCNHead(num_convs=1, in_channels=head_dim[2], num_classes=nclass, in_index=2, channels=256) 77 | 78 | if self.edge_aux: 79 | self.edge_head = EdgeHead(in_channels=head_dim[0:2], in_index=[0, 1], channels=head_dim[0]) 80 | 81 | def forward(self, x): 82 | size = x.size()[2:] 83 | outputs = [] 84 | 85 | out_backbone = self.backbone(x) 86 | 87 | # for i, out in enumerate(out_backbone): 88 | # draw_features(out, f'C{i}') 89 | 90 | x0 = self.decode_head(out_backbone) 91 | if isinstance(x0, (list, tuple)): 92 | for out in x0: 93 | out = F.interpolate(out, size, **up_kwargs) 94 | outputs.append(out) 95 | else: 96 | x0 = F.interpolate(x0, size, **up_kwargs) 97 | outputs.append(x0) 98 | 99 | if self.aux: 100 | x1 = self.auxiliary_head(out_backbone) 101 | x1 = F.interpolate(x1, size, **up_kwargs) 102 | outputs.append(x1) 103 | 104 | if self.edge_aux: 105 | edge = self.edge_head(out_backbone) 106 | edge = F.interpolate(edge, size, **up_kwargs) 107 | outputs.append(edge) 108 | 109 | return outputs 110 | 111 | 112 | if __name__ == '__main__': 113 | """Notice if torch1.6, try to replace a / b with torch.true_divide(a, b)""" 114 | from tools.flops_params_fps_count import flops_params_fps 115 | 116 | model = Transformer(transformer_name='cswin_tiny', nclass=6, img_size=512, aux=True, edge_aux=False, 117 | head='uperhead', pretrained=False) 118 | flops_params_fps(model) 119 | -------------------------------------------------------------------------------- /models/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class DoubleConv(nn.Module): 7 | """(convolution => [BN] => ReLU) * 2""" 8 | 9 | def __init__(self, in_channels, out_channels): 10 | super().__init__() 11 | self.double_conv = nn.Sequential( 12 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 13 | nn.BatchNorm2d(out_channels), 14 | nn.ReLU(inplace=True), 15 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 16 | nn.BatchNorm2d(out_channels), 17 | nn.ReLU(inplace=True) 18 | ) 19 | 20 | def forward(self, x): 21 | return self.double_conv(x) 22 | 23 | 24 | class Down(nn.Module): 25 | """Downscaling with maxpool then double conv""" 26 | 27 | def __init__(self, in_channels, out_channels): 28 | super().__init__() 29 | self.maxpool_conv = nn.Sequential( 30 | nn.MaxPool2d(2), 31 | DoubleConv(in_channels, out_channels) 32 | ) 33 | 34 | def forward(self, x): 35 | return self.maxpool_conv(x) 36 | 37 | 38 | class Up(nn.Module): 39 | """Upscaling then double conv""" 40 | 41 | def __init__(self, in_channels, out_channels, bilinear=True): 42 | super().__init__() 43 | 44 | # if bilinear, use the normal convolutions to reduce the number of channels 45 | if bilinear: 46 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 47 | else: 48 | self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2) 49 | 50 | self.conv = DoubleConv(in_channels, out_channels) 51 | 52 | def forward(self, x1, x2): 53 | x1 = self.up(x1) 54 | # input is CHW 55 | diffY = x2.size()[2] - x1.size()[2] 56 | diffX = x2.size()[3] - x1.size()[3] 57 | 58 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 59 | diffY // 2, diffY - diffY // 2]) 60 | # if you have padding issues, see 61 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 62 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 63 | x = torch.cat([x2, x1], dim=1) 64 | return self.conv(x) 65 | 66 | 67 | class OutConv(nn.Module): 68 | def __init__(self, in_channels, out_channels): 69 | super(OutConv, self).__init__() 70 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 71 | 72 | def forward(self, x): 73 | return self.conv(x) 74 | 75 | class UNet(nn.Module): 76 | def __init__(self, nclass, bilinear=True): 77 | super(UNet, self).__init__() 78 | self.n_channels = 3 79 | self.n_classes = nclass 80 | self.bilinear = bilinear 81 | 82 | self.inc = DoubleConv(self.n_channels, 64) 83 | self.down1 = Down(64, 128) 84 | self.down2 = Down(128, 256) 85 | self.down3 = Down(256, 512) 86 | self.down4 = Down(512, 512) 87 | self.up1 = Up(1024, 256, bilinear) 88 | self.up2 = Up(512, 128, bilinear) 89 | self.up3 = Up(256, 64, bilinear) 90 | self.up4 = Up(128, 64, bilinear) 91 | self.outc = OutConv(64, self.n_classes) 92 | 93 | def forward(self, x): 94 | x1 = self.inc(x) 95 | x2 = self.down1(x1) 96 | x3 = self.down2(x2) 97 | x4 = self.down3(x3) 98 | x5 = self.down4(x4) 99 | x = self.up1(x5, x4) 100 | x = self.up2(x, x3) 101 | x = self.up3(x, x2) 102 | x = self.up4(x, x1) 103 | logits = self.outc(x) 104 | return logits 105 | 106 | 107 | if __name__ == '__main__': 108 | from tools.flops_params_fps_count import flops_params_fps 109 | model = UNet(nclass=6) 110 | flops_params_fps(model) 111 | 112 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | import errno 4 | import shutil 5 | import hashlib 6 | from tqdm import tqdm 7 | import torch 8 | 9 | __all__ = ['save_checkpoint', 'download', 'mkdir', 'check_sha1'] 10 | 11 | def save_checkpoint(state, args, is_best, filename='checkpoint.pth.tar'): 12 | """Saves checkpoint to disk""" 13 | if hasattr(args, 'backbone'): 14 | directory = "runs/%s/%s/%s/%s/"%(args.dataset, args.model, args.backbone, args.checkname) 15 | else: 16 | directory = "runs/%s/%s/%s/"%(args.dataset, args.model, args.checkname) 17 | if not os.path.exists(directory): 18 | os.makedirs(directory) 19 | filename = directory + filename 20 | torch.save(state, filename) 21 | if is_best: 22 | shutil.copyfile(filename, directory + 'model_best.pth.tar') 23 | 24 | 25 | def download(url, path=None, overwrite=False, sha1_hash=None): 26 | """Download an given URL 27 | Parameters 28 | ---------- 29 | url : str 30 | URL to download 31 | path : str, optional 32 | Destination path to store downloaded file. By default stores to the 33 | current directory with same name as in url. 34 | overwrite : bool, optional 35 | Whether to overwrite destination file if already exists. 36 | sha1_hash : str, optional 37 | Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified 38 | but doesn't match. 39 | Returns 40 | ------- 41 | str 42 | The file path of the downloaded file. 43 | """ 44 | if path is None: 45 | fname = url.split('/')[-1] 46 | else: 47 | path = os.path.expanduser(path) 48 | if os.path.isdir(path): 49 | fname = os.path.join(path, url.split('/')[-1]) 50 | else: 51 | fname = path 52 | 53 | if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)): 54 | dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) 55 | if not os.path.exists(dirname): 56 | os.makedirs(dirname) 57 | 58 | print('Downloading %s from %s...'%(fname, url)) 59 | r = requests.get(url, stream=True) 60 | if r.status_code != 200: 61 | raise RuntimeError("Failed downloading url %s"%url) 62 | total_length = r.headers.get('content-length') 63 | with open(fname, 'wb') as f: 64 | if total_length is None: # no content length header 65 | for chunk in r.iter_content(chunk_size=1024): 66 | if chunk: # filter out keep-alive new chunks 67 | f.write(chunk) 68 | else: 69 | total_length = int(total_length) 70 | for chunk in tqdm(r.iter_content(chunk_size=1024), 71 | total=int(total_length / 1024. + 0.5), 72 | unit='KB', unit_scale=False, dynamic_ncols=True): 73 | f.write(chunk) 74 | 75 | if sha1_hash and not check_sha1(fname, sha1_hash): 76 | raise UserWarning('File {} is downloaded but the content hash does not match. ' \ 77 | 'The repo may be outdated or download may be incomplete. ' \ 78 | 'If the "repo_url" is overridden, consider switching to ' \ 79 | 'the default repo.'.format(fname)) 80 | 81 | return fname 82 | 83 | 84 | def check_sha1(filename, sha1_hash): 85 | """Check whether the sha1 hash of the file content matches the expected hash. 86 | Parameters 87 | ---------- 88 | filename : str 89 | Path to the file. 90 | sha1_hash : str 91 | Expected sha1 hash in hexadecimal digits. 92 | Returns 93 | ------- 94 | bool 95 | Whether the file content matches the expected hash. 96 | """ 97 | sha1 = hashlib.sha1() 98 | with open(filename, 'rb') as f: 99 | while True: 100 | data = f.read(1048576) 101 | if not data: 102 | break 103 | sha1.update(data) 104 | 105 | return sha1.hexdigest() == sha1_hash 106 | 107 | 108 | def mkdir(path): 109 | """make dir exists okay""" 110 | try: 111 | os.makedirs(path) 112 | except OSError as exc: # Python >2.5 113 | if exc.errno == errno.EEXIST and os.path.isdir(path): 114 | pass 115 | else: 116 | raise -------------------------------------------------------------------------------- /mutil_scale_test.py: -------------------------------------------------------------------------------- 1 | ########################################################################### 2 | # Created by: Hang Zhang 3 | # Email: zhang.hang@rutgers.edu 4 | # Copyright (c) 2017 5 | ########################################################################### 6 | 7 | import math 8 | import torch 9 | import torch.nn.functional as F 10 | import numpy as np 11 | import torch.nn as nn 12 | from torch.nn.parallel.data_parallel import DataParallel 13 | 14 | 15 | up_kwargs = {'mode': 'bilinear', 'align_corners': False} 16 | 17 | 18 | def module_inference(module, image, flip=True): 19 | if flip: 20 | h_img = h_flip_image(image) 21 | v_img = v_flip_image(image) 22 | img = torch.cat([image, h_img, v_img], dim=0) 23 | cat_output = module(img) 24 | if isinstance(cat_output, (list, tuple)): 25 | cat_output = cat_output[0] 26 | output, h_output, v_output = cat_output.chunk(3, dim=0) 27 | output = output + h_flip_image(h_output) + v_flip_image(v_output) 28 | else: 29 | output = module(image) 30 | if isinstance(output, (list, tuple)): 31 | output = output[0] 32 | 33 | return output 34 | 35 | 36 | def resize_image(img, h, w, **up_kwargs): 37 | return F.upsample(img, (h, w), **up_kwargs) 38 | 39 | 40 | def pad_image(img, crop_size): 41 | """crop_size could be list:[h, w] or int""" 42 | b,c,h,w = img.size() 43 | # assert(c==3) 44 | if len(crop_size) > 1: 45 | padh = crop_size[0] - h if h < crop_size[0] else 0 46 | padw = crop_size[1] - w if w < crop_size[1] else 0 47 | else: 48 | padh = crop_size - h if h < crop_size else 0 49 | padw = crop_size - w if w < crop_size else 0 50 | # pad_values = -np.array(mean) / np.array(std) 51 | img_pad = img.new().resize_(b,c,h+padh,w+padw) 52 | # for i in range(c): 53 | # note that pytorch pad params is in reversed orders 54 | min_padh = min(padh, h) 55 | min_padw = min(padw, w) 56 | if padw < w and padh < h: 57 | img_pad[:, :, :, :] = F.pad(img[:, :, :, :], (0, padw, 0, padh), mode='reflect') 58 | else: 59 | img_pad[:, :, 0:h + min_padh - 1, 0:w + min_padw - 1] = \ 60 | F.pad(img[:, :, :, :], (0, min_padw - 1, 0, min_padh - 1), mode='reflect') 61 | 62 | img_pad[:, :, :, :] = F.pad(img_pad[:, :, 0:h + min_padh - 1, 0:w + min_padw - 1], 63 | (0, padw - min_padw + 1, 0, padh - min_padh + 1), mode='constant', value=0) 64 | if len(crop_size) > 1: 65 | assert (img_pad.size(2) >= crop_size[0] and img_pad.size(3) >= crop_size[1]) 66 | else: 67 | assert(img_pad.size(2)>=crop_size and img_pad.size(3)>=crop_size) 68 | return img_pad 69 | 70 | 71 | def crop_image(img, h0, h1, w0, w1): 72 | return img[:,:,h0:h1,w0:w1] 73 | 74 | 75 | def h_flip_image(img): 76 | assert(img.dim()==4) 77 | with torch.cuda.device_of(img): 78 | idx = torch.arange(img.size(3)-1, -1, -1).type_as(img).long() 79 | return img.index_select(3, idx) 80 | 81 | 82 | def v_flip_image(img): 83 | assert(img.dim()==4) 84 | with torch.cuda.device_of(img): 85 | idx = torch.arange(img.size(3)-1, -1, -1).type_as(img).long() 86 | return img.index_select(2, idx) 87 | 88 | 89 | def hv_flip_image(img): 90 | assert(img.dim()==4) 91 | with torch.cuda.device_of(img): 92 | idx = torch.arange(img.size(3)-1, -1, -1).type_as(img).long() 93 | img = img.index_select(3, idx) 94 | return img.index_select(2, idx) 95 | 96 | 97 | class MultiEvalModule_Fullimg(DataParallel): 98 | """Multi-size Segmentation Eavluator""" 99 | def __init__(self, module, nclass, device_ids=None, flip=True, 100 | # scales=[1.0]): 101 | # scales=[1.0,1.25]): 102 | # scales=[0.5, 0.75,1.0,1.25,1.5]): 103 | scales=[1.0]): 104 | super(MultiEvalModule_Fullimg, self).__init__(module, device_ids) 105 | self.nclass = nclass 106 | self.base_size = 256 107 | self.crop_size = 256 108 | self.scales = scales 109 | self.flip = flip 110 | print('MultiEvalModule_Fullimg: base_size {}, crop_size {}'. \ 111 | format(self.base_size, self.crop_size)) 112 | 113 | def forward(self, image): 114 | """Mult-size Evaluation""" 115 | batch, _, h, w = image.size() 116 | 117 | with torch.cuda.device_of(image): 118 | scores = image.new().resize_(batch,self.nclass,h,w).zero_().cuda() 119 | for scale in self.scales: 120 | crop_size = int(math.ceil(self.crop_size * scale)) 121 | 122 | cur_img = resize_image(image, crop_size, crop_size, **up_kwargs) 123 | outputs = module_inference(self.module, cur_img, self.flip) 124 | score = resize_image(outputs, h, w, **up_kwargs) 125 | scores += score 126 | 127 | return scores 128 | 129 | 130 | class MultiEvalModule(nn.Module): 131 | """Multi-size Segmentation Eavluator""" 132 | def __init__(self, module, nclass, device_ids=None, flip=True, save_gpu_memory=False, 133 | scales=[1.0], get_batch=1, crop_size=[512, 512], stride_rate=1/2): 134 | #scales=[0.5,0.75,1,1.25]): 135 | #scales=[0.5,0.75,1.0,1.25,1.4,1.6,1.8]): 136 | #scales=[1]): 137 | # super(MultiEvalModule, self).__init__(module, device_ids) 138 | super(MultiEvalModule, self).__init__() 139 | self.module = module 140 | self.devices_ids = device_ids 141 | self.nclass = nclass 142 | self.crop_size = np.array(crop_size) 143 | self.scales = scales 144 | self.flip = flip 145 | self.get_batch = get_batch 146 | self.stride_rate = stride_rate 147 | self.save_gpu_memory = save_gpu_memory # if over memory, can try this 148 | 149 | def forward(self, image): 150 | """Mult-size Evaluation""" 151 | # only single image is supported for evaluation 152 | batch, _, h, w = image.size() 153 | # assert(batch == 1) 154 | stride_rate = self.stride_rate 155 | with torch.cuda.device_of(image): 156 | if self.save_gpu_memory: 157 | scores = image.new().resize_(batch, self.nclass, h, w).zero_().cpu() 158 | else: 159 | scores = image.new().resize_(batch,self.nclass,h,w).zero_().cuda() 160 | 161 | for scale in self.scales: 162 | crop_size = self.crop_size 163 | stride = (crop_size * stride_rate).astype(np.int) 164 | 165 | if h > w: 166 | long_size = int(math.ceil(h * scale)) 167 | height = long_size 168 | width = int(1.0 * w * long_size / h + 0.5) 169 | short_size = width 170 | else: 171 | long_size = int(math.ceil(w * scale)) 172 | width = long_size 173 | height = int(1.0 * h * long_size / w + 0.5) 174 | short_size = height 175 | 176 | # resize image to current size 177 | cur_img = resize_image(image, height, width, **up_kwargs) 178 | if long_size <= np.max(crop_size): 179 | pad_img = pad_image(cur_img, crop_size) 180 | outputs = module_inference(self.module, pad_img, self.flip) 181 | outputs = crop_image(outputs, 0, height, 0, width) 182 | 183 | else: 184 | if short_size < np.min(crop_size): 185 | # pad if needed 186 | pad_img = pad_image(cur_img, crop_size) 187 | else: 188 | pad_img = cur_img 189 | _,_,ph,pw = pad_img.size() 190 | # assert(ph >= height and pw >= width) 191 | # grid forward and normalize 192 | h_grids = int(math.ceil(1.0 * (ph-crop_size[0])/stride[0])) + 1 193 | w_grids = int(math.ceil(1.0 * (pw-crop_size[1])/stride[1])) + 1 194 | with torch.cuda.device_of(image): 195 | if self.save_gpu_memory: 196 | outputs = image.new().resize_(batch, self.nclass, ph, pw).zero_().cpu() 197 | count_norm = image.new().resize_(batch, 1, ph, pw).zero_().cpu() 198 | else: 199 | outputs = image.new().resize_(batch,self.nclass,ph,pw).zero_().cuda() 200 | count_norm = image.new().resize_(batch,1,ph,pw).zero_().cuda() 201 | # grid evaluation 202 | location = [] 203 | batch_size = [] 204 | pad_img = pad_image(pad_img, [ph + crop_size[0], pw + crop_size[1]]) # expand pad_image 205 | 206 | for idh in range(h_grids): 207 | for idw in range(w_grids): 208 | h0 = idh * stride[0] 209 | w0 = idw * stride[1] 210 | h1 = min(h0 + crop_size[0], ph) 211 | w1 = min(w0 + crop_size[1], pw) 212 | 213 | crop_img = crop_image(pad_img, h0, h0 + crop_size[0], w0, w0 + crop_size[1]) 214 | # pad if needed 215 | pad_crop_img = pad_image(crop_img, crop_size) 216 | size_h, size_w = pad_crop_img.shape[-2:] 217 | pad_crop_img = resize_image(pad_crop_img, crop_size[0], crop_size[1], **up_kwargs) 218 | if self.get_batch > 1: 219 | location.append([h0, w0, h1, w1]) 220 | batch_size.append(pad_crop_img) 221 | if len(location) == self.get_batch or (idh + idw + 2) == (h_grids + w_grids): 222 | batch_size = torch.cat(batch_size, dim=0).cuda() 223 | location = np.array(location) 224 | output = module_inference(self.module, batch_size, self.flip) 225 | output = output.detach() 226 | output = resize_image(output, size_h, size_w, **up_kwargs) 227 | if self.save_gpu_memory: 228 | output = output.detach().cpu() # to save gpu memory 229 | else: 230 | output = output.detach() 231 | for i in range(batch_size.shape[0]): 232 | outputs[:, :, location[i][0]:location[i][2], location[i][1]:location[i][3]] += \ 233 | crop_image(output[i, ...].unsqueeze(dim=0), 0, location[i][2]-location[i][0], 0, location[i][3]-location[i][1]) 234 | count_norm[:, :, location[i][0]:location[i][2], location[i][1]:location[i][3]] += 1 235 | location = [] 236 | batch_size = [] 237 | else: 238 | output = module_inference(self.module, pad_crop_img, self.flip) 239 | if self.save_gpu_memory: 240 | output = output.detach().cpu() # to save gpu memory 241 | else: 242 | output = output.detach() 243 | output = resize_image(output, size_h, size_w, **up_kwargs) 244 | outputs[:,:,h0:h1,w0:w1] += crop_image(output, 245 | 0, h1-h0, 0, w1-w0) 246 | count_norm[:,:,h0:h1,w0:w1] += 1 247 | assert((count_norm==0).sum()==0) 248 | outputs = outputs / count_norm 249 | outputs = outputs[:,:,:height,:width] 250 | score = resize_image(outputs, h, w, **up_kwargs) 251 | scores += score 252 | return scores 253 | 254 | -------------------------------------------------------------------------------- /post_process.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author : now more 3 | Connect : lin.honghui@qq.com 4 | LastEditors: Please set LastEditors 5 | Description : 6 | LastEditTime: 2020-11-27 03:42:46 7 | ''' 8 | import os 9 | import threading 10 | import cv2 as cv 11 | import numpy as np 12 | from skimage.morphology import remove_small_holes, remove_small_objects 13 | from argparse import ArgumentParser 14 | from PIL import Image 15 | 16 | Image.MAX_IMAGE_PIXELS = None 17 | 18 | 19 | def to_categorical(y, num_classes=None, dtype='float32'): 20 | """Converts a class vector (integers) to binary class matrix. 21 | 22 | E.g. for use with categorical_crossentropy. 23 | 24 | # Arguments 25 | y: class vector to be converted into a matrix 26 | (integers from 0 to num_classes). 27 | num_classes: total number of classes. 28 | dtype: The data type expected by the input, as a string 29 | (`float32`, `float64`, `int32`...) 30 | 31 | # Returns 32 | A binary matrix representation of the input. The classes axis 33 | is placed last. 34 | """ 35 | y = np.array(y, dtype='int') 36 | input_shape = y.shape 37 | if input_shape and input_shape[-1] == 1 and len(input_shape) > 1: 38 | input_shape = tuple(input_shape[:-1]) 39 | y = y.ravel() 40 | if not num_classes: 41 | num_classes = np.max(y) + 1 42 | n = y.shape[0] 43 | categorical = np.zeros((n, num_classes), dtype=dtype) 44 | categorical[np.arange(n), y] = 1 45 | output_shape = input_shape + (num_classes,) 46 | categorical = np.reshape(categorical, output_shape) 47 | return categorical 48 | 49 | 50 | class MyThread(threading.Thread): 51 | 52 | def __init__(self, func, args=()): 53 | super(MyThread, self).__init__() 54 | self.func = func 55 | self.args = args 56 | 57 | def run(self): 58 | self.result = self.func(*self.args) 59 | 60 | def get_result(self): 61 | try: 62 | return self.result # 如果子线程不使用join方法,此处可能会报没有self.result的错误 63 | except Exception: 64 | return None 65 | 66 | 67 | def label_resize_vis(label, img=None, alpha=0.5): 68 | ''' 69 | :param label:原始标签 70 | :param img: 原始图像 71 | :param alpha: 透明度 72 | :return: 可视化标签 73 | ''' 74 | def label_to_RGB(image, classes=6): 75 | RGB = np.zeros(shape=[image.shape[0], image.shape[1], 3], dtype=np.uint8) 76 | if classes == 6: # potsdam and vaihingen 77 | palette = [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], [255, 255, 0], [255, 0, 0]] 78 | if classes == 4: # barley 79 | palette = [[255, 255, 255], [0, 255, 0], [255, 255, 0], [255, 0, 0]] 80 | for i in range(classes): 81 | index = image == i 82 | RGB[index] = np.array(palette[i]) 83 | return RGB 84 | 85 | # label = cv.resize(label.copy(), None, fx=0.1, fy=0.1) 86 | anno_vis = label_to_RGB(label, classes=4) 87 | if img is None: 88 | return anno_vis 89 | else: 90 | overlapping = cv.addWeighted(img, alpha, anno_vis, 1 - alpha, 0) 91 | return overlapping 92 | 93 | 94 | def remove_small_objects_and_holes(class_type, label, min_size, area_threshold, in_place=True): 95 | print("------------- class_n : {} start ------------".format(class_type)) 96 | if class_type == 3: 97 | # kernel = cv.getStructuringElement(cv.MORPH_RECT,(500,500)) 98 | # label = cv.dilate(label,kernel) 99 | # kernel = cv.getStructuringElement(cv.MORPH_RECT,(10,10)) 100 | # label = cv.erode(label,kernel) 101 | label = remove_small_objects(label == 1, min_size=min_size, connectivity=1, in_place=in_place) 102 | label = remove_small_holes(label == 1, area_threshold=area_threshold, connectivity=1, in_place=in_place) 103 | else: 104 | label = remove_small_objects(label == 1, min_size=min_size, connectivity=1, in_place=in_place) 105 | label = remove_small_holes(label == 1, area_threshold=area_threshold, connectivity=1, in_place=in_place) 106 | print("------------- class_n : {} finished ------------".format(class_type)) 107 | return label 108 | 109 | 110 | def RGB_to_label(image=None, classes=6): 111 | if classes == 6: # potsdam and vaihingen 112 | palette = [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], [255, 255, 0], [255, 0, 0]] 113 | if classes == 4: # barley 114 | palette = [[255, 255, 255], [0, 255, 0], [255, 255, 0], [255, 0, 0]] 115 | label = np.zeros(shape=[image.shape[0], image.shape[1]], dtype=np.uint8) 116 | for i in range(len(palette)): 117 | index = image == np.array(palette[i]) 118 | index[..., 0][index[..., 1] == False] = False 119 | index[..., 0][index[..., 2] == False] = False 120 | label[index[..., 0]] = i 121 | return label 122 | 123 | 124 | 125 | if __name__ == "__main__": 126 | parser = ArgumentParser() 127 | parser.add_argument("--image_n", type=int, default=2, help="传入1或2,指定") 128 | parser.add_argument("--image_path", type=str, default='./outputs', help="传入image_n_predict所在路径") 129 | parser.add_argument("--threshold", type=int, default=2000) 130 | arg = parser.parse_args() 131 | image_n = arg.image_n 132 | image_path = arg.image_path 133 | threshold = arg.threshold 134 | 135 | if image_n == 1: 136 | source_image = cv.imread("../../data/barley/images_size0.1/image_1.png") 137 | elif image_n == 2: 138 | source_image = cv.imread("../../data/barley/images_size0.1/image_2.png") 139 | else: 140 | raise ValueError("image_n should be 1 or 2, Got {} ".format(image_n)) 141 | 142 | img_mask_dir = os.path.join(image_path, f'image_{image_n}_mask.png') 143 | img_dir = os.path.join(image_path, f'image_{image_n}.png') 144 | if os.path.exists(img_mask_dir): 145 | image = np.asarray(Image.open(img_mask_dir)) 146 | elif os.path.exists(img_dir): 147 | image = np.asarray(Image.open(img_dir)) 148 | else: 149 | raise ValueError(f"Not found image_{image_n}_mask.png or image_{image_n}.png") 150 | 151 | if len(image.shape) == 3: 152 | image = RGB_to_label(image, classes=4) 153 | image_save = Image.fromarray(image) 154 | image_save.save(os.path.join(image_path, f'image_{image_n}_mask.png')) 155 | 156 | image = cv.resize(image, None, fx=0.1, fy=0.1, interpolation=cv.INTER_NEAREST) # because over memory 157 | 158 | label = to_categorical(image, num_classes=4, dtype='uint8') 159 | 160 | threading_list = [] 161 | for i in range(4): 162 | t = MyThread(remove_small_objects_and_holes, args=(i, label[:, :, i], threshold, threshold, True)) 163 | threading_list.append(t) 164 | t.start() 165 | 166 | # 等待所有线程运行完毕 167 | result = [] 168 | for t in threading_list: 169 | t.join() 170 | result.append(t.get_result()[:, :, None]) 171 | 172 | label = np.concatenate(result, axis=2) 173 | 174 | label = np.argmax(label, axis=2).astype(np.uint8) 175 | cv.imwrite('./outputs/image_' + str(image_n) + "_predict.png", label) 176 | mask = label_resize_vis(label, source_image) 177 | cv.imwrite('./outputs/vis_image_' + str(image_n) + "_predict.jpg", mask[..., ::-1]) -------------------------------------------------------------------------------- /pre_process.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os 3 | import numpy as np 4 | from PIL import ImageFile 5 | import math 6 | import cv2 7 | ImageFile.LOAD_TRUNCATED_IMAGES = True 8 | Image.MAX_IMAGE_PIXELS = None 9 | 10 | 11 | def label_to_RGB(image, classes=4): 12 | RGB = np.zeros(shape=[image.shape[0], image.shape[1], 3], dtype=np.uint8) 13 | if classes == 6: 14 | palette = [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], [255, 255, 0], [255, 0, 0]] 15 | if classes == 4: 16 | palette = [[255, 255, 255], [0, 255, 0], [255, 255, 0], [255, 0, 0]] 17 | for i in range(classes): 18 | index = image == i 19 | RGB[index] = np.array(palette[i]) 20 | return RGB 21 | 22 | 23 | def divide_img(image, oh, ow, filename, save_dir, write_txt_dir=None): 24 | """No overlap, and the last square maybe small than oh ow""" 25 | if write_txt_dir is not None: 26 | txt = open(write_txt_dir, 'w') 27 | h, w = image.shape[0:2] 28 | num_h, num_w = h // oh + 1, w // ow + 1 29 | for i in range(num_w): 30 | for j in range(num_h): 31 | h1 = min((j + 1) * oh, h) 32 | w1 = min((i + 1) * ow, w) 33 | if len(image.shape) == 2: 34 | image_part = image[j * oh:h1, i * ow:w1] 35 | else: 36 | image_part = image[j * oh:h1, i * ow:w1, :] 37 | image_part = Image.fromarray(image_part) 38 | image_part.save(os.path.join(save_dir, f'{filename}_{j}_{i}.png')) # j:h, i:w 39 | if write_txt_dir is not None: 40 | txt.write(f'{filename}_{j}_{i}' + '\n') 41 | 42 | 43 | def divide_img_overlap(image, oh, ow, filename, save_dir, write_txt_dir=None, overlap=1024): 44 | """Divide img with an overlap, the last square is back trace to oh ow""" 45 | if write_txt_dir is not None: 46 | txt = open(write_txt_dir, 'w') 47 | path, name = os.path.split(write_txt_dir) 48 | txt_clean = open(os.path.join(path, os.path.splitext(name)[0] + '_clean.txt'), 'w') 49 | h, w = image.shape[0:2] 50 | if len(image.shape) == 2: 51 | image = np.expand_dims(image, axis=2) 52 | num_h, num_w = math.ceil((h - oh) / (oh - overlap)) + 1, math.ceil((w - ow) / (ow - overlap)) + 1 53 | for i in range(num_w): 54 | for j in range(num_h): 55 | if i < num_w - 1 and j < num_h - 1: 56 | image_part = image[(oh - overlap) * j:(oh - overlap) * j + oh, (ow - overlap) * i:(ow - overlap) * i + ow, :] 57 | if i < num_w - 1 and j == num_h - 1: 58 | image_part = image[h - oh:h, (ow - overlap) * i:(ow - overlap) * i + ow, :] 59 | if i == num_w - 1 and j < num_h - 1: 60 | image_part = image[(oh - overlap) * j:(oh - overlap) * j + oh, w - ow:w, :] 61 | if i == num_w - 1 and j == num_h - 1: 62 | image_part = image[h - oh:h, w - ow:w, :] 63 | image_part = image_part.squeeze() 64 | if write_txt_dir is not None: 65 | if np.any(image_part[..., 3] > 0): 66 | txt_clean.write(f'{filename}_{j}_{i}' + '\n') 67 | txt.write(f'{filename}_{j}_{i}' + '\n') 68 | image_part = Image.fromarray(image_part) 69 | image_part.save(os.path.join(save_dir, f'{filename}_{j}_{i}.png')) # j:h, i:w 70 | 71 | 72 | def restore_part_img(oh=6000, ow=6000, overlap=1024, filename='image_1'): 73 | """restore patches of the image, the last square is back traced to oh ow""" 74 | root = '/data/xzy/datasets/' 75 | dataset = f'barley_hw6000_s{overlap}' 76 | if filename == 'image_1': 77 | h, w = 50141, 47161 78 | if filename == 'image_2': 79 | h, w = 46050, 77470 80 | num_h, num_w = math.ceil((h - oh) / (oh - overlap)) + 1, math.ceil((w - ow) / (ow - overlap)) + 1 81 | for i in range(num_w): 82 | for j in range(num_h): 83 | part_img = Image.open(os.path.join(root, dataset, f'images/{filename}_{j}_{i}.png')) 84 | part_img = np.array(part_img) 85 | if j == 0: 86 | w_patch = part_img[0:oh - overlap // 2, ...] 87 | elif j < num_h - 1: 88 | w_patch = np.concatenate((w_patch, part_img[overlap // 2:oh - overlap // 2, ...]), 0) 89 | else: 90 | end_h = w_patch.shape[0] 91 | w_patch = np.concatenate((w_patch, part_img[oh - (h - end_h):oh, ...]), 0) 92 | if i == 0: 93 | h_patch = w_patch[:, 0:ow - overlap // 2, :] 94 | elif i < num_w - 1: 95 | h_patch = np.concatenate((h_patch, w_patch[:, overlap // 2:ow - overlap // 2, :]), 1) 96 | else: 97 | end_w = h_patch.shape[1] 98 | h_patch = np.concatenate((h_patch, w_patch[:, ow - (w - end_w):ow, :]), 1) 99 | print(h_patch.shape) 100 | h_patch = cv2.resize(h_patch, None, fx=0.1, fy=0.1, interpolation=cv2.INTER_NEAREST) 101 | full_img = Image.fromarray(h_patch) 102 | full_img.save(os.path.join(root, dataset, f'{filename}.png')) 103 | 104 | 105 | def apply_divide_img_label(overlap=1024): 106 | """read images and divide them with an overlap""" 107 | root = '/data/zyxu/dataset/barley/' 108 | save = f'barley_hw6000_s{overlap}' 109 | img_dir = 'images_complete/' 110 | label_dir = 'labels_complete/' 111 | img_save_dir = os.path.join(root, f'{save}/images/') 112 | label_save_dir = os.path.join(root, f'{save}/labels/') 113 | file = [f'image_{1}', f'image_{2}'] 114 | 115 | for filename in file: 116 | if filename == f'image_{1}': 117 | txt_name = os.path.join(root, f'{save}/annotations/train_full.txt') 118 | if filename == f'image_{2}': 119 | txt_name = os.path.join(root, f'{save}/annotations/test_full.txt') 120 | 121 | img = Image.open(os.path.join(root, img_dir, filename + '.png')) 122 | img = np.array(img) 123 | divide_img_overlap(img, 6000, 6000, filename, img_save_dir, txt_name, overlap=overlap) 124 | 125 | label = Image.open(os.path.join(root, label_dir, filename + '.png')) 126 | label = np.array(label) 127 | divide_img_overlap(label, 6000, 6000, filename, label_save_dir, overlap=overlap) 128 | 129 | 130 | def clean_white_background(): 131 | """remove transparent patches from the train and test txt """ 132 | img_dir = './data/barley/images/' 133 | img_list = os.listdir(img_dir) 134 | train_no_alpha = open('./data/barley/annotations/train_no_alpha.txt', 'w') 135 | test_no_alpha = open('./data/barley/annotations/test_no_alpha.txt', 'w') 136 | for file in img_list: 137 | file = file.strip() 138 | img = Image.open(os.path.join(img_dir, file)) 139 | img = np.array(img) 140 | alpha = img[..., 3] 141 | if np.any(alpha > 0): 142 | if 'image_1' in file: 143 | train_no_alpha.write(file[:-4] + '\n') 144 | if 'image_2' in file: 145 | test_no_alpha.write(file[:-4] + '\n') 146 | 147 | 148 | def count_nums(): 149 | """count pixels of each classes in images""" 150 | label_train_dir = '/data/xzy/datasets/barley/labels_full/image_1_label.png' 151 | label_test_dir = '/data/xzy/datasets/barley/labels_full/image_2_label.png' 152 | label_train = np.array(Image.open(label_train_dir)) 153 | label_test = np.array(Image.open(label_test_dir)) 154 | h0, w0, h1, w1 = label_train.shape[0], label_train.shape[1], label_test.shape[0], label_test.shape[1] 155 | for i in range(4): 156 | print('train pixel{}:{:.6f}'.format(i, np.sum(label_train == i) / (h0 * w0))) 157 | # 0.870185 0.066412 0.006110 0.057294 158 | # class012: [0.51158563 0.04706662 0.44134775] 159 | for i in range(4): 160 | print('test pixel{}:{:.6f}'.format(i, np.sum(label_test == i) / (h1 * w1))) 161 | # 0.926607 0.002005 0.033732 0.037655 162 | # class012: [0.02731905 0.45961413 0.51306682] 163 | 164 | 165 | def rearrange_dataset(oh=6000, ow=6000, overlap=1024): 166 | """fuse image_1 and image_2 to generate new train and test file""" 167 | root = '/data/zyxu/dataset/barley/' 168 | dataset = f'barley_hw6000_s{overlap}' 169 | train_txt = open(os.path.join(root, dataset, f'annotations/train.txt'), 'w') 170 | test_txt = open(os.path.join(root, dataset, f'annotations/test.txt'), 'w') 171 | for filename in ['image_1', 'image_2']: 172 | if filename == 'image_1': 173 | h, w = 50141, 47161 174 | if filename == 'image_2': 175 | h, w = 46050, 77470 176 | num_h, num_w = math.ceil((h - oh) / (oh - overlap)) + 1, math.ceil((w - ow) / (ow - overlap)) + 1 177 | for i in range(num_w): 178 | for j in range(num_h): 179 | part_img = Image.open(os.path.join(root, dataset, f'images/{filename}_{j}_{i}.png')) 180 | part_img = np.array(part_img) 181 | if (i + j) % 2 == 0 and np.any(part_img[..., 3] > 0): 182 | train_txt.write(f'{filename}_{j}_{i}' + '\n') 183 | if (i + j) % 2 == 1 and np.any(part_img[..., 3] > 0): 184 | test_txt.write(f'{filename}_{j}_{i}' + '\n') 185 | 186 | 187 | def resize(): 188 | """resize image to 1/10""" 189 | filename = 'image_2' 190 | img1_dir = f'./data/barley/labels_view/{filename}.png' 191 | img1 = np.array(Image.open(img1_dir)) 192 | img = cv2.resize(img1, None, fx=0.1, fy=0.1, interpolation=cv2.INTER_NEAREST) 193 | img = Image.fromarray(img) 194 | img.save(f'./data/barley/images_size0.1/{filename}_labels_view.png') 195 | 196 | 197 | def label_view(): 198 | """view label to rgb""" 199 | root = '/data/xzy/datasets/barley/' 200 | dataset = f'' 201 | train_txt = os.path.join(root, dataset, 'annotations/train.txt') 202 | test_txt = os.path.join(root, dataset, 'annotations/test.txt') 203 | file = open(train_txt, 'r').readlines() + open(test_txt, 'r').readlines() 204 | for name in file: 205 | name = name.strip() 206 | label = np.array(Image.open(os.path.join(root, dataset, 'labels', name + '.png'))) 207 | label = label_to_RGB(label, 4) 208 | label = Image.fromarray(label) 209 | label.save(os.path.join(root, dataset, f'labels_view/{name}.png')) 210 | 211 | 212 | def get_alpha(): 213 | """get alpha channel and save it""" 214 | name = 'image_2.png' 215 | image_dir = './data/barley/images_complete/' 216 | image = np.array(Image.open(os.path.join(image_dir, name))) 217 | alpha = image[..., 3] 218 | alpha = Image.fromarray(alpha) 219 | alpha.save(f'./data/barley/alphas_complete/{name}') 220 | 221 | 222 | if __name__ == '__main__': 223 | # apply_divide_img_label() 224 | # restore_part_img() 225 | 226 | # label1 = np.array(Image.open('./data/barley/barley_hw6000_s1024/image_2.png')) 227 | # label2 = np.array(Image.open('./data/barley/labels_complete/image_2.png')) 228 | # print(label1.shape, label2.shape) 229 | # print(np.all(label1 == label2)) 230 | # print(np.sum(label1), np.sum(label2)) 231 | 232 | # apply_divide_img_label(overlap=0) 233 | # restore_part_img(overlap=0) 234 | # rearrange_dataset(overlap=0) 235 | label_view() 236 | 237 | 238 | 239 | -------------------------------------------------------------------------------- /pretrained_weights/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyxu1996/CCTNet/5a5db40d2e38bd478b404583050049eedca90844/pretrained_weights/.gitignore -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.3.0 2 | fvcore==0.1.5.post20210604 3 | matplotlib==3.3.2 4 | mmcv==1.3.5 5 | numpy==1.18.5 6 | opencv_python==4.5.2.54 7 | pavi==0.0.1 8 | Pillow==9.1.0 9 | portalocker==2.0.0 10 | requests==2.24.0 11 | scikit_image==0.17.2 12 | scipy==1.5.2 13 | setproctitle==1.2.3 14 | skimage==0.0 15 | tifffile==2020.9.3 16 | timm==0.3.2 17 | torch==1.6.0 18 | torchvision==0.7.0 19 | tqdm==4.50.2 20 | -------------------------------------------------------------------------------- /seg_metric.py: -------------------------------------------------------------------------------- 1 | """ 2 | refer to https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/utils/metrics.py 3 | """ 4 | import numpy as np 5 | 6 | __all__ = ['SegmentationMetric'] 7 | 8 | """ 9 | confusionMetric 10 | P\L P N 11 | P TP FP 12 | N FN TN 13 | """ 14 | 15 | 16 | class SegmentationMetric(object): 17 | def __init__(self, numClass): 18 | self.numClass = numClass 19 | self.confusionMatrix = np.zeros((self.numClass,) * 2) 20 | 21 | def Accuracy(self): 22 | # return all class overall pixel accuracy 23 | # acc = (TP + TN) / (TP + TN + FP + TN) 24 | acc = np.diag(self.confusionMatrix).sum() / self.confusionMatrix.sum() 25 | return acc 26 | 27 | def Precision(self): 28 | # return each category pixel accuracy(A more accurate way to call it precision) 29 | # acc = (TP) / TP + FP 30 | precision = np.diag(self.confusionMatrix) / self.confusionMatrix.sum(axis=0) 31 | return precision 32 | 33 | def meanPrecision(self): 34 | precision = self.Precision() 35 | mPrecision = np.nanmean(precision) 36 | return mPrecision 37 | 38 | def Recall(self): 39 | # Recall = (TP) / (TP + FN) 40 | recall = np.diag(self.confusionMatrix) / self.confusionMatrix.sum(axis=1) 41 | return recall 42 | 43 | def meanRecall(self): 44 | recall = self.Recall() 45 | mRecall = np.nanmean(recall) 46 | return mRecall 47 | 48 | def F1(self): 49 | # 2*precision*recall / (precision + recall) 50 | f1 = 2 * self.Precision() * self.Recall() / (self.Precision() + self.Recall()) 51 | return f1 52 | 53 | def meanF1(self): 54 | f1 = self.F1() 55 | mF1 = np.nanmean(f1) 56 | return mF1 57 | 58 | def IntersectionOverUnion(self): 59 | # Intersection = TP Union = TP + FP + FN 60 | # IoU = TP / (TP + FP + FN) 61 | intersection = np.diag(self.confusionMatrix) 62 | union = np.sum(self.confusionMatrix, axis=1) + np.sum(self.confusionMatrix, axis=0) - np.diag( 63 | self.confusionMatrix) 64 | IoU = intersection / union 65 | return IoU 66 | 67 | def meanIntersectionOverUnion(self): 68 | # Intersection = TP Union = TP + FP + FN 69 | # IoU = TP / (TP + FP + FN) 70 | IoU = self.IntersectionOverUnion() 71 | mIoU = np.nanmean(IoU) 72 | return mIoU 73 | 74 | def genConfusionMatrix(self, imgPredict, imgLabel): 75 | # remove classes from unlabeled pixels in gt image and predict 76 | mask = (imgLabel >= 0) & (imgLabel < self.numClass) 77 | label = self.numClass * imgLabel[mask] + imgPredict[mask] 78 | count = np.bincount(label, minlength=self.numClass ** 2) 79 | confusionMatrix = count.reshape(self.numClass, self.numClass) 80 | return confusionMatrix 81 | 82 | def Frequency_Weighted_Intersection_over_Union(self): 83 | # FWIOU = [(TP+FN)/(TP+FP+TN+FN)] *[TP / (TP + FP + FN)] 84 | freq = np.sum(self.confusionMatrix, axis=1) / np.sum(self.confusionMatrix) 85 | iu = np.diag(self.confusionMatrix) / ( 86 | np.sum(self.confusionMatrix, axis=1) + np.sum(self.confusionMatrix, axis=0) - 87 | np.diag(self.confusionMatrix)) 88 | iu = [i if not np.isnan(i) else 0.0 for i in iu] 89 | iu = np.array(iu) 90 | FWIoU = (freq[freq > 0] * iu[freq > 0]).sum() 91 | return FWIoU 92 | 93 | def Frequency_Weighted(self): 94 | # FWIOU = [(TP+FN)/(TP+FP+TN+FN)] *[TP / (TP + FP + FN)] 95 | freq = np.sum(self.confusionMatrix, axis=1) / np.sum(self.confusionMatrix) 96 | 97 | return freq 98 | 99 | def addBatch(self, imgPredict, imgLabel): 100 | assert imgPredict.shape == imgLabel.shape 101 | self.confusionMatrix += self.genConfusionMatrix(imgPredict, imgLabel) 102 | def reset(self): 103 | self.confusionMatrix = np.zeros((self.numClass, self.numClass)) 104 | 105 | 106 | if __name__ == '__main__': 107 | imgPredict = np.array([0, 0, 1, 1, 2, 2]) 108 | imgLabel = np.array([0, 0, 1, 1, 2, 2]) 109 | metric = SegmentationMetric(3) 110 | metric.addBatch(imgPredict, imgLabel) 111 | acc = metric.pixelAccuracy() 112 | mIoU = metric.meanIntersectionOverUnion() 113 | print(acc, mIoU) -------------------------------------------------------------------------------- /tools/edge/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyxu1996/CCTNet/5a5db40d2e38bd478b404583050049eedca90844/tools/edge/.gitignore -------------------------------------------------------------------------------- /tools/flops_params_fps_count.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | from fvcore.nn import FlopCountAnalysis, parameter_count 4 | from tqdm import tqdm 5 | import torch 6 | 7 | 8 | def flops_params_fps(model, input_shape=(1, 3, 512, 512)): 9 | """count flops:G params:M fps:img/s 10 | input shape tensor[1, c, h, w] 11 | """ 12 | total_time = [] 13 | with torch.no_grad(): 14 | model = model.cuda().eval() 15 | input = torch.randn(size=input_shape, dtype=torch.float32).cuda() 16 | flops = FlopCountAnalysis(model, input) 17 | params = parameter_count(model) 18 | 19 | for i in tqdm(range(100)): 20 | torch.cuda.synchronize() 21 | start = time.time() 22 | output = model(input) 23 | torch.cuda.synchronize() 24 | end = time.time() 25 | total_time.append(end - start) 26 | mean_time = np.mean(np.array(total_time)) 27 | print(model.__class__.__name__) 28 | print('img/s:{:.2f}'.format(1 / mean_time)) 29 | print('flops:{:.2f}G params:{:.2f}M'.format(flops.total() / 1e9, params[''] / 1e6)) 30 | -------------------------------------------------------------------------------- /tools/generate_edge.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import tifffile as tiff 3 | import os 4 | import numpy as np 5 | from tools.utils import label_to_RGB 6 | from models.resT import rest_tiny 7 | import torch 8 | from torchvision import transforms 9 | 10 | 11 | def to_tensor(image): 12 | image = torch.from_numpy(image).permute(2, 0, 1).float().div(255) 13 | normalize = transforms.Normalize((.485, .456, .406), (.229, .224, .225)) 14 | image = normalize(image).unsqueeze(0) 15 | 16 | return image 17 | 18 | 19 | def init_model(): 20 | model = rest_tiny(nclass=6, aux=False, head='mlphead', edge_aux=True) 21 | weight_dir = '../work_dir/' \ 22 | 'resT_lr0.0003_epoch100_batchsize16_vaihingen_resT_tiny_mlphead_imagenetpretrain_noaux_edge_edgeup_AdamW_num128' \ 23 | '/weights/best_weight.pkl' 24 | checkpoint = torch.load(weight_dir, map_location=lambda storage, loc: storage) 25 | if 'state_dict' in checkpoint: 26 | checkpoint = checkpoint['state_dict'] 27 | checkpoint = {k.replace('module.model.', ''): v for k, v in checkpoint.items()} 28 | model.load_state_dict(checkpoint) 29 | return model 30 | 31 | 32 | def read_img_label(save_dir): 33 | img_dir = '../data/vaihingen/images/top_mosaic_09cm_area10.tif' 34 | label_dir = '../data/vaihingen/annotations/labels/top_mosaic_09cm_area10.png' 35 | image = tiff.imread(img_dir) 36 | image = image[1000:1000 + 512, 1000:1000+512, :] 37 | label = cv2.imread(label_dir, cv2.IMREAD_UNCHANGED) 38 | label = label[1000:1000 + 512, 1000:1000 + 512] 39 | cv2.imwrite(os.path.join(save_dir, 'ori_img.png'), image[..., ::-1]) 40 | cv2.imwrite(os.path.join(save_dir, 'ori_label.png'), label_to_RGB(label)[..., ::-1]) 41 | 42 | return image, label 43 | 44 | 45 | def canny_edge(img, edge_width=3): 46 | gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 47 | gray = cv2.GaussianBlur(gray, (11, 11), 0) 48 | edge = cv2.Canny(gray, 30, 150) 49 | kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (edge_width, edge_width)) 50 | edge = cv2.dilate(edge, kernel) 51 | 52 | return edge 53 | 54 | 55 | def groundtruth_edge(label, edge_width=3): 56 | if len(label.shape) == 2: 57 | label = label[np.newaxis, ...] 58 | label = label.astype(np.int) 59 | b, h, w = label.shape 60 | edge = np.zeros(label.shape) 61 | 62 | # right 63 | edge_right = edge[:, 1:h, :] 64 | edge_right[(label[:, 1:h, :] != label[:, :h - 1, :]) & (label[:, 1:h, :] != 255) 65 | & (label[:, :h - 1, :] != 255)] = 1 66 | 67 | # up 68 | edge_up = edge[:, :, :w - 1] 69 | edge_up[(label[:, :, :w - 1] != label[:, :, 1:w]) 70 | & (label[:, :, :w - 1] != 255) 71 | & (label[:, :, 1:w] != 255)] = 1 72 | 73 | # upright 74 | edge_upright = edge[:, :h - 1, :w - 1] 75 | edge_upright[(label[:, :h - 1, :w - 1] != label[:, 1:h, 1:w]) 76 | & (label[:, :h - 1, :w - 1] != 255) 77 | & (label[:, 1:h, 1:w] != 255)] = 1 78 | 79 | # bottomright 80 | edge_bottomright = edge[:, :h - 1, 1:w] 81 | edge_bottomright[(label[:, :h - 1, 1:w] != label[:, 1:h, :w - 1]) 82 | & (label[:, :h - 1, 1:w] != 255) 83 | & (label[:, 1:h, :w - 1] != 255)] = 1 84 | 85 | kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (edge_width, edge_width)) 86 | for i in range(edge.shape[0]): 87 | edge[i] = cv2.dilate(edge[i], kernel) 88 | edge = edge.squeeze(axis=0) 89 | return edge 90 | 91 | 92 | def get_edge_predict(img): 93 | img = to_tensor(img).cuda() 94 | model = init_model().cuda().eval() 95 | with torch.no_grad(): 96 | output = model(img) 97 | edge_predict = torch.argmax(output[1], dim=1) 98 | edge_predict = edge_predict.squeeze().cpu().numpy().astype(np.uint8) 99 | edge_predict = edge_predict * 255 100 | # kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (4, 4)) 101 | # edge_predict = cv2.erode(edge_predict, kernel) 102 | 103 | return edge_predict 104 | 105 | 106 | def main(): 107 | img, label = read_img_label(save_path) 108 | # canny_ = canny_edge(img) 109 | # cv2.imwrite(os.path.join(save_path, 'canny_edge.png'), canny_) 110 | # groundtruth_ = groundtruth_edge(label) * 255 111 | # cv2.imwrite(os.path.join(save_path, 'groundtruth_edge.png'), groundtruth_) 112 | edge_predict = get_edge_predict(img) 113 | cv2.imwrite(os.path.join(save_path, 'predict_edge.png'), edge_predict) 114 | 115 | 116 | if __name__ == '__main__': 117 | save_path = './edge/' 118 | if not os.path.exists(save_path): 119 | os.mkdir(save_path) 120 | main() 121 | 122 | -------------------------------------------------------------------------------- /tools/generate_heatmap.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import sys 4 | sys.path.append('./') 5 | import numpy as np 6 | from models.swinT import swin_base 7 | import torch 8 | import tifffile as tiff 9 | from PIL import Image 10 | from tools.utils import label_to_RGB 11 | from torchvision import transforms 12 | 13 | 14 | def init_model(): 15 | model = swin_base(nclass=6, aux=True, head='uperhead') 16 | weight_dir = 'work_dir/' \ 17 | 'swinT_lr0.0003_epoch100_batchsize16_swinT_upernet_base_imagenetpretrain_aux_AdamW_num73' \ 18 | '/weights/best_weight.pkl' 19 | checkpoint = torch.load(weight_dir, map_location=lambda storage, loc: storage) 20 | if 'state_dict' in checkpoint: 21 | checkpoint = checkpoint['state_dict'] 22 | checkpoint = {k.replace('module.model.', ''): v for k, v in checkpoint.items()} 23 | model.load_state_dict(checkpoint) 24 | return model 25 | 26 | 27 | def read_img(save_dir): 28 | img_dir = 'data/barley/images/image_2_4_10.png' 29 | # image = tiff.imread(img_dir) 30 | image = Image.open(img_dir) 31 | image = np.array(image) 32 | image = image[1000:1000 + 512, 0:0+512, 0:3] 33 | cv2.imwrite(os.path.join(save_dir, 'ori_image.png'), image[..., ::-1]) 34 | 35 | return image 36 | 37 | 38 | def read_label(save_dir): 39 | img_dir = 'data/barley/labels_view/image_2_4_10.png' 40 | image = Image.open(img_dir) 41 | image = np.array(image) 42 | image = image[1000:1000 + 512, 0:0+512, 0:3] 43 | cv2.imwrite(os.path.join(save_dir, 'ori_label.png'), image[..., ::-1]) 44 | 45 | return image 46 | 47 | 48 | def to_tensor(image): 49 | image = torch.from_numpy(image).permute(2, 0, 1).float().div(255) 50 | normalize = transforms.Normalize((.485, .456, .406), (.229, .224, .225)) 51 | image = normalize(image).unsqueeze(0) 52 | 53 | return image 54 | 55 | 56 | def main(): 57 | save_img_dir = os.path.join(save_path, 'origin_img') 58 | if not os.path.exists(save_img_dir): 59 | os.mkdir(save_img_dir) 60 | save_out_dir = os.path.join(save_path, 'output') 61 | if not os.path.exists(save_out_dir): 62 | os.mkdir(save_out_dir) 63 | 64 | image = read_img(save_img_dir) 65 | image = to_tensor(image).cuda() 66 | model = init_model().cuda().eval() 67 | with torch.no_grad(): 68 | output = model(image) 69 | output = torch.argmax(output[0], dim=1) 70 | output = output.squeeze() 71 | output = output.cpu().numpy() 72 | output = output.astype(np.uint8) 73 | output = label_to_RGB(output) 74 | cv2.imwrite(os.path.join(save_out_dir, 'out.png'), output[..., ::-1]) 75 | 76 | 77 | if __name__ == '__main__': 78 | save_path = 'tools/heatmap/outputs/' 79 | if not os.path.exists(save_path): 80 | os.mkdir(save_path) 81 | # main() 82 | read_img(save_path) 83 | read_label(save_path) 84 | 85 | -------------------------------------------------------------------------------- /tools/heat_map.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import time 3 | import os 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | 7 | 8 | save_path='./heatmap/' 9 | if not os.path.exists(save_path): 10 | os.mkdir(save_path) 11 | 12 | def draw_features(x,savename): 13 | tic = time.time() 14 | fig = plt.figure(figsize=(16, 16)) 15 | fig.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.95, wspace=0.05, hspace=0.05) 16 | b, c, h, w = x.shape 17 | for i in range(int(c)): 18 | plt.subplot(h, w, i + 1) 19 | plt.axis('off') 20 | img = x[0, i, :, :].cpu().numpy() 21 | print('img_shape', img.shape) 22 | # print('img', img) 23 | # print(width,height) 24 | pmin = np.min(img) 25 | pmax = np.max(img) 26 | img = ((img - pmin) / (pmax - pmin + 0.000001))*255 #float在[0,1]之间,转换成0-255 27 | img=img.astype(np.uint8) #转成unit8 28 | img=cv2.applyColorMap(img, cv2.COLORMAP_JET) #生成heat map 29 | # img = img[:, :, ::-1] #注意cv2(BGR)和matplotlib(RGB)通道是相反的 30 | plt.imshow(img) 31 | print("{}/{}".format(i, c)) 32 | print(img.shape) 33 | img = cv2.resize(img, (768, 768), interpolation=cv2.INTER_LINEAR) 34 | cv2.imwrite(save_path + savename + str(i) + '.png', img) 35 | fig.clf() 36 | plt.close() 37 | print("time:{}".format(time.time()-tic)) 38 | 39 | -------------------------------------------------------------------------------- /tools/heatmap/outputs/ori_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyxu1996/CCTNet/5a5db40d2e38bd478b404583050049eedca90844/tools/heatmap/outputs/ori_image.png -------------------------------------------------------------------------------- /tools/heatmap/outputs/ori_label.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyxu1996/CCTNet/5a5db40d2e38bd478b404583050049eedca90844/tools/heatmap/outputs/ori_label.png -------------------------------------------------------------------------------- /tools/heatmap_fun.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import cv2 4 | import time 5 | import os 6 | from tqdm import tqdm 7 | 8 | save_path = './heatmap/uperhead/' 9 | if not os.path.exists(save_path): 10 | os.mkdir(save_path) 11 | 12 | 13 | def draw_features(x, savename): 14 | tic = time.time() 15 | b, c, h, w = x.shape 16 | for i in tqdm(range(int(c))): 17 | img = x[0, i, :, :].cpu().numpy() 18 | pmin = np.min(img) 19 | pmax = np.max(img) 20 | img = ((img - pmin) / (pmax - pmin + 0.000001))*255 # change value [0, 1] to [0, 255] 21 | img = img.astype(np.uint8) 22 | img = cv2.applyColorMap(img, cv2.COLORMAP_JET) # generate heat map 23 | img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR) 24 | if not os.path.exists(os.path.join(save_path, savename)): 25 | os.mkdir(os.path.join(save_path, savename)) 26 | cv2.imwrite(os.path.join(save_path, savename, savename + '_' + str(i) + '.png'), img) 27 | plt.close() 28 | print("{} time:{}".format(savename, time.time()-tic)) 29 | -------------------------------------------------------------------------------- /tools/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def label_to_RGB(image): 5 | RGB = np.zeros(shape=[image.shape[0], image.shape[1], 3], dtype=np.uint8) 6 | index = image == 0 7 | RGB[index] = np.array([255, 255, 255]) 8 | index = image == 1 9 | RGB[index] = np.array([0, 0, 255]) 10 | index = image == 2 11 | RGB[index] = np.array([0, 255, 255]) 12 | index = image == 3 13 | RGB[index] = np.array([0, 255, 0]) 14 | index = image == 4 15 | RGB[index] = np.array([255, 255, 0]) 16 | index = image == 5 17 | RGB[index] = np.array([255, 0, 0]) 18 | return RGB -------------------------------------------------------------------------------- /work_dir/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyxu1996/CCTNet/5a5db40d2e38bd478b404583050049eedca90844/work_dir/.gitignore --------------------------------------------------------------------------------