├── CCTNet.png
├── LICENSE
├── README.md
├── __init__.py
├── autorun.sh
├── custom_transforms.py
├── data
└── .gitignore
├── dataset.py
├── loss.py
├── models
├── __init__.py
├── banet.py
├── beit.py
├── bisenetv2.py
├── cctnet.py
├── checkpoint.py
├── cswin.py
├── danet.py
├── deeplabv3.py
├── edgenet.py
├── fcn.py
├── fpn.py
├── head
│ ├── __init__.py
│ ├── ann.py
│ ├── apc.py
│ ├── aspp.py
│ ├── aspp_plus.py
│ ├── base_decoder.py
│ ├── cefpn.py
│ ├── da.py
│ ├── dnl.py
│ ├── edge.py
│ ├── fcfpn.py
│ ├── fcn.py
│ ├── gc.py
│ ├── mlp.py
│ ├── psa.py
│ ├── psp.py
│ ├── seg.py
│ ├── unet.py
│ └── uper.py
├── hrnet.py
├── model_store.py
├── pspnet.py
├── resT.py
├── resnet.py
├── segbase.py
├── swinT.py
├── transformer.py
├── unet.py
├── utils.py
└── volo.py
├── mutil_scale_test.py
├── post_process.py
├── pre_process.py
├── pretrained_weights
└── .gitignore
├── requirements.txt
├── seg_metric.py
├── test.py
├── tools
├── edge
│ └── .gitignore
├── flops_params_fps_count.py
├── generate_edge.py
├── generate_heatmap.py
├── heat_map.py
├── heatmap
│ └── outputs
│ │ ├── ori_image.png
│ │ └── ori_label.png
├── heatmap_fun.py
└── utils.py
├── train.py
└── work_dir
└── .gitignore
/CCTNet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zyxu1996/CCTNet/5a5db40d2e38bd478b404583050049eedca90844/CCTNet.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CCTNet: Coupled CNN and Transformer Network for Crop Segmentation of Remote Sensing Images, [RemoteSensing](https://www.mdpi.com/2072-4292/14/9/1956/htm)
2 | ## Introduction
3 | We propose a Coupled CNN and Transformer Network to combine the local modeling advantage of the CNN and the global modeling advantage of Transformer to achieve SOTA performance on the [Barley Remote Sensing Dataset](https://tianchi.aliyun.com/dataset/dataDetail?spm=5176.12281978.0.0.76944054ZQD0l2&dataId=74952). By applying our code base, you can easily deal with ultra-high-resolution remote sensing images. If our work is helpful to you, please star us.
4 |
5 | 
6 | ## Usage
7 | * Install packages
8 |
9 | This repository is based on `python 3.6.12` and `torch 1.6.0`.
10 |
11 | ```
12 | git clone https://github.com/zyxu1996/CCTNet.git
13 | cd CCTNet
14 | ```
15 | ```
16 | pip install -r requirements.txt
17 | ```
18 | * Prepare datasets and pretrained weights
19 |
20 | * The code base has supported three high-resolution datasets, are respective Barley, Potsdam and Vaihingen.
21 | * Download `Barley, Potsdam and Vaihingen` datasets form BaiduYun, and put them on `./data `
22 | `BaiduYun`: [https://pan.baidu.com/s/1MyDw1qncPKYJFK_zjFxFBA](https://pan.baidu.com/s/1MyDw1qncPKYJFK_zjFxFBA)
23 | `Password`: s7f2
24 |
25 | Data file structure of the above three datasets is as followed.
26 | ```
27 | ├── data ├── data ├── data
28 | ├──barley ├──potsdam ├──vaihingen
29 | ├──images ├──images ├──images
30 | ├──image_1_0_0.png ├──top_potsdam_2_10.tif ├──top_mosaic_09cm_area1.tif
31 | ├──image_1_0_1.png ├──top_potsdam_2_11.tif ├──top_mosaic_09cm_area2.tif
32 | ... ... ...
33 | ├──labels ├──labels ├──labels
34 | ├──image_1_0_0.png ├──top_potsdam_2_10.png ├──top_mosaic_09cm_area1.png
35 | ├──image_1_0_1.png ├──top_potsdam_2_11.png ├──top_mosaic_09cm_area2.png
36 | ... ... ...
37 | ├──annotations ├──annotations ├──annotations
38 | ├──train.txt ├──train.txt ├──train.txt
39 | ├──test.txt ├──test.txt ├──test.txt
40 |
41 | ```
42 |
43 | * Download the pretained weights from [CSwin-Transformer](https://github.com/microsoft/CSWin-Transformer), and put them on `./pretrained_weights`
44 | CSwin: `CSwin Tiny, Small, Base and Large` pretrained on `ImageNet-1K` and `ImageNet-22K` are used.
45 | ResNet: `ResNet 18, 34, 50 and 101` pretrained models are used, the download link is contained in the our code.
46 |
47 | * Training
48 |
49 | * The training and testing settings are written in the script, including the selection of datasets and models.
50 | ```
51 | sh autorun.sh
52 | ```
53 | * If directly run train.py, please undo the following code.
54 | ```
55 | if __name__ == '__main__':
56 | os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
57 | os.environ.setdefault('RANK', '0')
58 | os.environ.setdefault('WORLD_SIZE', '1')
59 | os.environ.setdefault('MASTER_ADDR', '127.0.0.1')
60 | os.environ.setdefault('MASTER_PORT', '29556')
61 | ```
62 | * Testing
63 | * Generating the final results and visulizing the prediction.
64 | ```
65 | cd ./work_dir/your_work
66 | ```
67 | * Do remember undo the test command in `sh autorun.sh`. And keep the `--information num1` in testing command is same as the information in training command.
68 | `CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --nproc_per_node=1 --master_port 29505 test.py --dataset barley --val_batchsize 8 --models cctnet --head seghead --crop_size 512 512 --trans_cnn cswin_tiny resnet50 --save_dir work_dir --base_dir ../../ --information num1
69 | `
70 | * Then run the script autorun.sh.
71 | ```
72 | sh autorun.sh
73 | ```
74 | ## Acknowledgments
75 | Thanks Guangzhou Jingwei Information Technology Co., Ltd., and the Xingren City government for providing the Barley Remote Sensing Dataset.
76 | Thanks the ISPRS for providing the Potsdam and Vaihingen datasets.
77 | ## Citation
78 | ```
79 | @article{wang2022cctnet,
80 | title={CCTNet: Coupled CNN and Transformer Network for Crop Segmentation of Remote Sensing Images},
81 | author={Wang, Hong and Chen, Xianzhong and Zhang, Tianxiang and Xu, Zhiyong and Li, Jiangyun},
82 | journal={Remote Sensing},
83 | volume={14},
84 | number={9},
85 | pages={1956},
86 | year={2022},
87 | publisher={MDPI}
88 | }
89 | ```
90 | ## Other Links
91 | * [HRCNet: High-Resolution Context Extraction Network for Semantic Segmentation of Remote Sensing Images](https://github.com/zyxu1996/HRCNet-High-Resolution-Context-Extraction-Network)
92 | * [Efficient Transformer for Remote Sensing Image Segmentation](https://github.com/zyxu1996/Efficient-Transformer)
93 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | from .dataset import *
--------------------------------------------------------------------------------
/autorun.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | ################### Test #################
3 |
4 | #CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --nproc_per_node=1 --master_port 29505 test.py --dataset barley --val_batchsize 8 --models cctnet --head seghead --crop_size 512 512 --trans_cnn cswin_tiny resnet50 --save_dir work_dir --base_dir ../../ --information num1
5 |
6 |
7 | ################### Train #################
8 |
9 | # barley
10 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port 29506 train.py --dataset barley --end_epoch 50 --lr 0.0001 --train_batchsize 4 --models cctnet --head seghead --crop_size 512 512 --trans_cnn cswin_tiny resnet50 --use_mixup 0 --information num1
11 |
12 | # vaihingen
13 | #CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port 29507 train.py --dataset vaihingen --end_epoch 100 --lr 0.0003 --train_batchsize 4 --models cctnet --head seghead --crop_size 512 512 --trans_cnn cswin_tiny resnet50 --use_mixup 0 --information num2
14 |
15 |
16 | # potsdam
17 | #CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port 29508 train.py --dataset potsdam --end_epoch 50 --lr 0.0001 --train_batchsize 4 --models cctnet --head seghead --crop_size 512 512 --trans_cnn cswin_tiny resnet50 --use_mixup 0 --information num3
--------------------------------------------------------------------------------
/custom_transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | import numpy as np
4 | import cv2
5 | import os
6 | import torch.nn as nn
7 | from torchvision import transforms
8 |
9 | class RandomHorizontalFlip(object):
10 | def __call__(self, sample):
11 | image = sample['image']
12 | label = sample['label']
13 | if random.random() < 0.5:
14 | image = cv2.flip(image, 1)
15 | label = cv2.flip(label, 1)
16 |
17 | return {'image': image, 'label': label}
18 |
19 |
20 | class RandomVerticalFlip(object):
21 | def __call__(self, sample):
22 | image = sample['image']
23 | label = sample['label']
24 | if random.random() < 0.5:
25 | image = cv2.flip(image, 0)
26 | label = cv2.flip(label, 0)
27 |
28 | return {'image': image, 'label': label}
29 |
30 |
31 | class RandomScaleCrop(object):
32 | def __init__(self, base_size=None, crop_size=None, fill=0):
33 | """shape [H, W]"""
34 | if base_size is None:
35 | base_size = [512, 512]
36 | if crop_size is None:
37 | crop_size = [512, 512]
38 | self.base_size = np.array(base_size)
39 | self.crop_size = np.array(crop_size)
40 | self.fill = fill
41 |
42 | def __call__(self, sample):
43 | img = sample['image']
44 | mask = sample['label']
45 | # random scale (short edge)
46 | short_size = random.choice([self.base_size * 0.5, self.base_size * 0.75, self.base_size,
47 | self.base_size * 1.25, self.base_size * 1.5])
48 | short_size = short_size.astype(np.int)
49 | h, w = img.shape[0:2]
50 | if h > w:
51 | ow = short_size[1]
52 | oh = int(1.0 * h * ow / w)
53 | else:
54 | oh = short_size[0]
55 | ow = int(1.0 * w * oh / h)
56 | #img = img.resize((ow, oh), Image.BILINEAR)
57 | #mask = mask.resize((ow, oh), Image.NEAREST)
58 | img = cv2.resize(img, (ow, oh), interpolation=cv2.INTER_LINEAR)
59 | mask = cv2.resize(mask, (ow, oh), interpolation=cv2.INTER_NEAREST)
60 | # pad crop
61 | if short_size[0] < self.crop_size[0] or short_size[1] < self.crop_size[1]:
62 | padh = self.crop_size[0] - oh if oh < self.crop_size[0] else 0
63 | padw = self.crop_size[1] - ow if ow < self.crop_size[1] else 0
64 | #img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
65 | #mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.fill)
66 | img = cv2.copyMakeBorder(img, 0, padh, 0, padw, borderType=cv2.BORDER_DEFAULT)
67 | mask = cv2.copyMakeBorder(mask, 0, padh, 0, padw, borderType=cv2.BORDER_DEFAULT)
68 | # random crop crop_size
69 | h, w = img.shape[0:2]
70 | x1 = random.randint(0, w - self.crop_size[1])
71 | y1 = random.randint(0, h - self.crop_size[0])
72 | img = img[y1:y1+self.crop_size[0], x1:x1+self.crop_size[1], :]
73 | mask = mask[y1:y1+self.crop_size[0], x1:x1+self.crop_size[1]]
74 | return {'image': img, 'label': mask}
75 |
76 |
77 | class ImageSplit(nn.Module):
78 | def __init__(self, numbers=None):
79 | super(ImageSplit, self).__init__()
80 | """numbers [H, W]
81 | split from left to right, top to bottom"""
82 | if numbers is None:
83 | numbers = [2, 2]
84 | self.num = numbers
85 |
86 | def forward(self, x):
87 | flag = None
88 | if len(x.shape) == 3:
89 | x = x.unsqueeze(dim=1)
90 | flag = 1
91 | b, c, h, w = x.shape
92 | num_h, num_w = self.num[0], self.num[1]
93 | assert h % num_h == 0 and w % num_w == 0
94 | split_h, split_w = h // num_h, w // num_w
95 |
96 | outputs = []
97 | outputss = []
98 | for i in range(b):
99 | for h_i in range(num_h):
100 | for w_i in range(num_w):
101 | output = x[i][:, split_h * h_i: split_h * (h_i + 1),
102 | split_w * w_i: split_w * (w_i + 1)].unsqueeze(dim=0)
103 | outputs.append(output)
104 | outputs = torch.cat(outputs, dim=0).unsqueeze(dim=0)
105 | outputss.append(outputs)
106 | outputs = []
107 | outputss = torch.cat(outputss, dim=0).contiguous()
108 | if flag is not None:
109 | outputss = outputss.squeeze(dim=2)
110 | return outputss
111 |
112 |
113 | class ToTensor(object):
114 | """Convert ndarrays in sample to Tensors."""
115 | def __init__(self, add_edge=True):
116 | """imagenet normalize"""
117 | self.normalize = transforms.Normalize((.485, .456, .406), (.229, .224, .225))
118 | self.add_edge = add_edge
119 |
120 | def get_edge(self, img, edge_width=3):
121 | gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
122 | gray = cv2.GaussianBlur(gray, (11, 11), 0)
123 | edge = cv2.Canny(gray, 50, 150)
124 | # cv2.imshow('edge', edge)
125 | # cv2.waitKey(0)
126 | kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (edge_width, edge_width))
127 | edge = cv2.dilate(edge, kernel)
128 | edge = edge / 255
129 | edge = torch.from_numpy(edge).unsqueeze(dim=0).float()
130 |
131 | return edge
132 |
133 | def __call__(self, sample):
134 | # swap color axis because
135 | # numpy image: H x W x C
136 | # torch image: C X H X W
137 | img = sample['image']
138 | mask = sample['label']
139 |
140 | mask = np.expand_dims(mask, axis=2)
141 | img = np.array(img).astype(np.float32).transpose((2, 0, 1))
142 | mask = np.array(mask).astype(np.int64).transpose((2, 0, 1))
143 |
144 | img = torch.from_numpy(img).float().div(255)
145 | img = self.normalize(img)
146 | mask = torch.from_numpy(mask).float()
147 |
148 | if self.add_edge:
149 | edge = self.get_edge(sample['image'])
150 | img = img + edge
151 |
152 | return {'image': img, 'label': mask}
153 |
154 |
155 | class RGBGrayExchange():
156 | def __init__(self, path=None, palette=None):
157 | self.palette = palette
158 | """RGB format"""
159 | if palette is None:
160 | self.palette = [[255, 255, 255], [0, 0, 255], [0, 255, 255],
161 | [0, 255, 0], [255, 255, 0], [255, 0, 0]]
162 | self.path = path
163 |
164 | def read_img(self):
165 | img = cv2.imread(self.path, cv2.IMREAD_UNCHANGED)
166 | if len(img.shape) == 3:
167 | img = img[:, :, ::-1]
168 | return img
169 |
170 | def RGB_to_Gray(self, image=None):
171 | if not self.path is None:
172 | image = self.read_img()
173 | Gray = np.zeros(shape=[image.shape[0], image.shape[1]], dtype=np.uint8)
174 | for i in range(len(self.palette)):
175 | index = image == np.array(self.palette[i])
176 | index[..., 0][index[..., 1] == False] = False
177 | index[..., 0][index[..., 2] == False] = False
178 | Gray[index[..., 0]] = i
179 | print('unique pixels:{}'.format(np.unique(Gray)))
180 | return Gray
181 |
182 | def Gray_to_RGB(self, image=None):
183 | if not self.path is None:
184 | image = self.read_img()
185 | RGB = np.zeros(shape=[image.shape[0], image.shape[1], 3], dtype=np.uint8)
186 | for i in range(len(self.palette)):
187 | index = image == i
188 | RGB[index] = np.array(self.palette[i])
189 | print('unique pixels:{}'.format(np.unique(RGB)))
190 | return RGB
191 |
192 |
193 | class Mixup(nn.Module):
194 | def __init__(self, alpha=1.0, use_edge=False):
195 | super(Mixup, self).__init__()
196 | self.alpha = alpha
197 | self.use_edge = use_edge
198 |
199 | def criterion(self, lam, outputs, targets_a, targets_b, criterion):
200 | return lam * criterion(outputs, targets_a) + (1 - lam) * criterion(outputs, targets_b)
201 |
202 | def forward(self, inputs, targets, criterion, model):
203 | if self.alpha > 0:
204 | lam = np.random.beta(self.alpha, self.alpha)
205 | else:
206 | lam = 1
207 | batch_size = inputs.size(0)
208 | index = torch.randperm(batch_size).cuda()
209 | mix_inputs = lam*inputs + (1-lam)*inputs[index, :]
210 | targets_a, targets_b = targets, targets[index]
211 | outputs = model(mix_inputs)
212 |
213 | losses = 0
214 | if isinstance(outputs, (list, tuple)):
215 | if self.use_edge:
216 | for i in range(len(outputs) - 1):
217 | loss = self.criterion(lam, outputs[i], targets_a, targets_b, criterion[0])
218 | losses += loss
219 | edge_targets_a = edge_contour(targets).long()
220 | edge_targets_b = edge_targets_a[index]
221 | loss2 = self.criterion(lam, outputs[-1], edge_targets_a, edge_targets_b, criterion[1])
222 | losses += loss2
223 | else:
224 | for i in range(len(outputs)):
225 | loss = self.criterion(lam, outputs[i], targets_a, targets_b, criterion)
226 | losses += loss
227 | else:
228 | losses = self.criterion(lam, outputs, targets_a, targets_b, criterion)
229 | return losses
230 |
231 |
232 | def edge_contour(label, edge_width=3):
233 | import cv2
234 | cuda_type = label.is_cuda
235 | label = label.cpu().numpy().astype(np.int)
236 | b, h, w = label.shape
237 | edge = np.zeros(label.shape)
238 |
239 | # right
240 | edge_right = edge[:, 1:h, :]
241 | edge_right[(label[:, 1:h, :] != label[:, :h - 1, :]) & (label[:, 1:h, :] != 255)
242 | & (label[:, :h - 1, :] != 255)] = 1
243 |
244 | # up
245 | edge_up = edge[:, :, :w - 1]
246 | edge_up[(label[:, :, :w - 1] != label[:, :, 1:w])
247 | & (label[:, :, :w - 1] != 255)
248 | & (label[:, :, 1:w] != 255)] = 1
249 |
250 | # upright
251 | edge_upright = edge[:, :h - 1, :w - 1]
252 | edge_upright[(label[:, :h - 1, :w - 1] != label[:, 1:h, 1:w])
253 | & (label[:, :h - 1, :w - 1] != 255)
254 | & (label[:, 1:h, 1:w] != 255)] = 1
255 |
256 | # bottomright
257 | edge_bottomright = edge[:, :h - 1, 1:w]
258 | edge_bottomright[(label[:, :h - 1, 1:w] != label[:, 1:h, :w - 1])
259 | & (label[:, :h - 1, 1:w] != 255)
260 | & (label[:, 1:h, :w - 1] != 255)] = 1
261 |
262 | kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (edge_width, edge_width))
263 | for i in range(edge.shape[0]):
264 | edge[i] = cv2.dilate(edge[i], kernel)
265 |
266 | # edge[edge == 1] = 255 # view edge
267 | # import random
268 | # cv2.imwrite(os.path.join('./edge', '{}.png'.format(random.random())), edge[0])
269 | if cuda_type:
270 | edge = torch.from_numpy(edge).cuda()
271 | else:
272 | edge = torch.from_numpy(edge)
273 |
274 | return edge
275 |
276 |
277 | if __name__ == '__main__':
278 | path = './data/vaihingen/annotations/labels'
279 | filelist = os.listdir(path)
280 | for file in filelist:
281 | print(file)
282 | img = cv2.imread(os.path.join(path, file), cv2.IMREAD_UNCHANGED)
283 | img = torch.from_numpy(img).unsqueeze(dim=0).repeat(2, 1, 1)
284 | img = edge_contour(img)
285 | # cv2.imwrite(os.path.join(save_path, os.path.splitext(file)[0] + '.png'), gray)
286 |
--------------------------------------------------------------------------------
/data/.gitignore:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zyxu1996/CCTNet/5a5db40d2e38bd478b404583050049eedca90844/data/.gitignore
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from PIL import Image
4 | import torch.utils.data as data
5 | from torchvision import transforms
6 | import custom_transforms as tr
7 | import tifffile as tiff
8 | import math
9 |
10 |
11 | class RemoteData(data.Dataset):
12 | def __init__(self, base_dir='./data/', train=True, dataset='vaihingen', crop_szie=None, val_full_img=False):
13 | super(RemoteData, self).__init__()
14 | self.dataset_dir = base_dir
15 | self.train = train
16 | self.dataset = dataset
17 | self.val_full_img = val_full_img
18 | self.images = []
19 | self.labels = []
20 | self.names = []
21 | self.alphas = []
22 | alpha = None
23 | if crop_szie is None:
24 | crop_szie = [512, 512]
25 | self.crop_size = crop_szie
26 | if train:
27 | self.image_dir = os.path.join(self.dataset_dir, self.dataset + '/images')
28 | self.label_dir = os.path.join(self.dataset_dir, self.dataset + '/labels')
29 | txt = os.path.join(self.dataset_dir, self.dataset + '/annotations' + '/train.txt')
30 | else:
31 | self.image_dir = os.path.join(self.dataset_dir, self.dataset + '/images')
32 | self.label_dir = os.path.join(self.dataset_dir, self.dataset + '/labels')
33 | txt = os.path.join(self.dataset_dir, self.dataset + '/annotations' + '/test.txt')
34 |
35 | with open(txt, "r") as f:
36 | self.filename_list = f.readlines()
37 | for filename in self.filename_list:
38 | if self.dataset in ['barley']:
39 | image = os.path.join(self.image_dir, filename.strip() + '.png')
40 | image = Image.open(image)
41 | image = np.array(image)
42 | if image.shape[2] == 4:
43 | alpha = image[..., 3]
44 | image = image[..., 0:3]
45 | else:
46 | image = os.path.join(self.image_dir, filename.strip() + '.tif')
47 | image = tiff.imread(image)
48 | label = os.path.join(self.label_dir, filename.strip() + '.png')
49 | label = Image.open(label)
50 | label = np.array(label)
51 | if self.val_full_img:
52 | self.images.append(image)
53 | self.labels.append(label)
54 | self.names.append(filename.strip())
55 | if alpha is not None:
56 | self.alphas.append(alpha)
57 | else:
58 | if alpha is not None:
59 | slide_crop(image, label, self.crop_size, self.images, self.labels, self.dataset,
60 | alpha=alpha, alpha_patches=self.alphas, stride_rate=2/3)
61 | else:
62 | slide_crop(image, label, self.crop_size, self.images, self.labels, self.dataset, stride_rate=2/3)
63 | assert(len(self.images) == len(self.labels))
64 |
65 | def __len__(self):
66 | return len(self.images)
67 |
68 | def __getitem__(self, index):
69 | sample = {'image': self.images[index], 'label': self.labels[index]}
70 | sample = self.transform(sample)
71 | if self.val_full_img:
72 | sample['name'] = self.names[index]
73 | if self.alphas != [] and self.train == False:
74 | sample['alpha'] = self.alphas[index]
75 | return sample
76 |
77 | def transform(self, sample):
78 | if self.train:
79 | if self.dataset in ['barley']:
80 | composed_transforms = transforms.Compose([
81 | tr.RandomHorizontalFlip(),
82 | tr.RandomVerticalFlip(),
83 | tr.ToTensor(add_edge=False),
84 | ])
85 | else:
86 | composed_transforms = transforms.Compose([
87 | tr.RandomHorizontalFlip(),
88 | tr.RandomVerticalFlip(),
89 | tr.RandomScaleCrop(base_size=self.crop_size, crop_size=self.crop_size),
90 | tr.ToTensor(add_edge=False),
91 | ])
92 | else:
93 | composed_transforms = transforms.Compose([
94 | tr.ToTensor(add_edge=False),
95 | ])
96 | return composed_transforms(sample)
97 |
98 | def __str__(self):
99 | return 'dataset:{} train:{}'.format(self.dataset, self.train)
100 |
101 |
102 | def slide_crop(image, label, crop_size, image_patches, label_patches, dataset,
103 | stride_rate=1.0/2.0, alpha=None, alpha_patches=None):
104 | """images shape [h, w, c]"""
105 | if len(image.shape) == 2:
106 | image = np.expand_dims(image, axis=2)
107 | if len(label.shape) == 2:
108 | label = np.expand_dims(label, axis=2)
109 | if alpha is not None:
110 | alpha = np.expand_dims(alpha, axis=2)
111 | stride_rate = stride_rate
112 | h, w, c = image.shape
113 | H, W = crop_size
114 | stride_h = int(H * stride_rate)
115 | stride_w = int(W * stride_rate)
116 | assert h >= crop_size[0] and w >= crop_size[1]
117 | h_grids = int(math.ceil(1.0 * (h - H) / stride_h)) + 1
118 | w_grids = int(math.ceil(1.0 * (w - W) / stride_w)) + 1
119 | for idh in range(h_grids):
120 | for idw in range(w_grids):
121 | h0 = idh * stride_h
122 | w0 = idw * stride_w
123 | h1 = min(h0 + H, h)
124 | w1 = min(w0 + W, w)
125 | if h1 == h and w1 != w:
126 | crop_img = image[h - H:h, w0:w0 + W, :]
127 | crop_label = label[h - H:h, w0:w0 + W, :]
128 | if alpha is not None:
129 | crop_alpha = alpha[h - H:h, w0:w0 + W, :]
130 | if w1 == w and h1 != h:
131 | crop_img = image[h0:h0 + H, w - W:w, :]
132 | crop_label = label[h0:h0 + H, w - W:w, :]
133 | if alpha is not None:
134 | crop_alpha = alpha[h0:h0 + H, w - W:w, :]
135 | if h1 == h and w1 == w:
136 | crop_img = image[h - H:h, w - W:w, :]
137 | crop_label = label[h - H:h, w - W:w, :]
138 | if alpha is not None:
139 | crop_alpha = alpha[h - H:h, w - W:w, :]
140 | if w1 != w and h1 != h:
141 | crop_img = image[h0:h0 + H, w0:w0 + W, :]
142 | crop_label = label[h0:h0 + H, w0:w0 + W, :]
143 | if alpha is not None:
144 | crop_alpha = alpha[h0:h0 + H, w0:w0 + W, :]
145 | crop_img = crop_img.squeeze()
146 | crop_label = crop_label.squeeze()
147 | if alpha is not None:
148 | crop_alpha = crop_alpha.squeeze()
149 | if (dataset in ['barley'] and np.any(crop_alpha > 0)) or dataset not in ['barley']:
150 | image_patches.append(crop_img)
151 | label_patches.append(crop_label)
152 | if alpha is not None:
153 | alpha_patches.append(crop_alpha)
154 |
155 |
156 | def label_to_RGB(image, classes=6):
157 | RGB = np.zeros(shape=[image.shape[0], image.shape[1], 3], dtype=np.uint8)
158 | if classes == 6: # potsdam and vaihingen
159 | palette = [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], [255, 255, 0], [255, 0, 0]]
160 | if classes == 4: # barley
161 | palette = [[255, 255, 255], [0, 255, 0], [255, 255, 0], [255, 0, 0]]
162 | for i in range(classes):
163 | index = image == i
164 | RGB[index] = np.array(palette[i])
165 | return RGB
166 |
167 |
168 | def RGB_to_label(image=None, classes=6):
169 | if classes == 6: # potsdam and vaihingen
170 | palette = [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], [255, 255, 0], [255, 0, 0]]
171 | if classes == 4: # barley
172 | palette = [[255, 255, 255], [0, 255, 0], [255, 255, 0], [255, 0, 0]]
173 | label = np.zeros(shape=[image.shape[0], image.shape[1]], dtype=np.uint8)
174 | for i in range(len(palette)):
175 | index = image == np.array(palette[i])
176 | index[..., 0][index[..., 1] == False] = False
177 | index[..., 0][index[..., 2] == False] = False
178 | label[index[..., 0]] = i
179 | return label
180 |
181 |
182 | if __name__ == '__main__':
183 | from torch.utils.data import DataLoader
184 | import matplotlib.pyplot as plt
185 |
186 | remotedata_train = RemoteData(train=True, dataset='vaihingen')
187 | dataloader = DataLoader(remotedata_train, batch_size=1, shuffle=False, num_workers=1)
188 | # print(dataloader)
189 |
190 | for ii, sample in enumerate(dataloader):
191 | im = sample['label'].numpy().astype(np.uint8)
192 | pic = sample['image'].numpy().astype(np.uint8)
193 | print(im.shape)
194 | im = np.squeeze(im, axis=0)
195 | pic = np.squeeze(pic, axis=0)
196 | print(im.shape)
197 | im = np.transpose(im, axes=[1, 2, 0])[:, :, 0:3]
198 | pic = np.transpose(pic, axes=[1, 2, 0])[:, :, 0:3]
199 | print(im.shape)
200 | im = np.squeeze(im, axis=2)
201 | # print(im)
202 | im = label_to_RGB(im)
203 | plt.imshow(pic)
204 | plt.show()
205 | plt.imshow(im)
206 | plt.show()
207 | if ii == 10:
208 | break
209 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zyxu1996/CCTNet/5a5db40d2e38bd478b404583050049eedca90844/models/__init__.py
--------------------------------------------------------------------------------
/models/danet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | try:
3 | from .resnet import resnet50_v1b
4 | except:
5 | from resnet import resnet50_v1b
6 | import torch.nn.functional as F
7 | import torch
8 |
9 |
10 | class SegBaseModel(nn.Module):
11 | r"""Base Model for Semantic Segmentation
12 |
13 | Parameters
14 | ----------
15 | backbone : string
16 | Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50',
17 | 'resnet101' or 'resnet152').
18 | """
19 |
20 | def __init__(self, nclass, aux, backbone='resnet50', dilated=True, pretrained_base=False, **kwargs):
21 | super(SegBaseModel, self).__init__()
22 | self.aux = aux
23 | self.nclass = nclass
24 | if backbone == 'resnet50':
25 | self.pretrained = resnet50_v1b(pretrained=pretrained_base, dilated=dilated, **kwargs)
26 |
27 | def base_forward(self, x):
28 | """forwarding pre-trained network"""
29 | x = self.pretrained.conv1(x)
30 | x = self.pretrained.bn1(x)
31 | x = self.pretrained.relu(x)
32 | x = self.pretrained.maxpool(x)
33 | c1 = self.pretrained.layer1(x)
34 | c2 = self.pretrained.layer2(c1)
35 | c3 = self.pretrained.layer3(c2)
36 | c4 = self.pretrained.layer4(c3)
37 |
38 | return c1, c2, c3, c4
39 |
40 |
41 | class _FCNHead(nn.Module):
42 | def __init__(self, in_channels, channels, norm_layer=nn.BatchNorm2d, **kwargs):
43 | super(_FCNHead, self).__init__()
44 | inter_channels = in_channels // 4
45 | self.block = nn.Sequential(
46 | nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
47 | norm_layer(inter_channels),
48 | nn.ReLU(inplace=True),
49 | nn.Dropout(0.1),
50 | nn.Conv2d(inter_channels, channels, 1)
51 | )
52 |
53 | def forward(self, x):
54 | return self.block(x)
55 |
56 |
57 | class _PositionAttentionModule(nn.Module):
58 | """ Position attention module"""
59 |
60 | def __init__(self, in_channels, **kwargs):
61 | super(_PositionAttentionModule, self).__init__()
62 | self.conv_b = nn.Conv2d(in_channels, in_channels // 8, 1)
63 | self.conv_c = nn.Conv2d(in_channels, in_channels // 8, 1)
64 | self.conv_d = nn.Conv2d(in_channels, in_channels, 1)
65 | self.alpha = nn.Parameter(torch.zeros(1))
66 | self.softmax = nn.Softmax(dim=-1)
67 |
68 | def forward(self, x):
69 | batch_size, _, height, width = x.size()
70 | feat_b = self.conv_b(x).view(batch_size, -1, height * width).permute(0, 2, 1)
71 | feat_c = self.conv_c(x).view(batch_size, -1, height * width)
72 | attention_s = self.softmax(torch.bmm(feat_b, feat_c))
73 | feat_d = self.conv_d(x).view(batch_size, -1, height * width)
74 | feat_e = torch.bmm(feat_d, attention_s.permute(0, 2, 1)).view(batch_size, -1, height, width)
75 | out = self.alpha * feat_e + x
76 |
77 | return out
78 |
79 |
80 | class _ChannelAttentionModule(nn.Module):
81 | """Channel attention module"""
82 |
83 | def __init__(self, **kwargs):
84 | super(_ChannelAttentionModule, self).__init__()
85 | self.beta = nn.Parameter(torch.zeros(1))
86 | self.softmax = nn.Softmax(dim=-1)
87 |
88 | def forward(self, x):
89 | batch_size, _, height, width = x.size()
90 | feat_a = x.view(batch_size, -1, height * width)
91 | feat_a_transpose = x.view(batch_size, -1, height * width).permute(0, 2, 1)
92 | attention = torch.bmm(feat_a, feat_a_transpose)
93 | attention_new = torch.max(attention, dim=-1, keepdim=True)[0].expand_as(attention) - attention
94 | attention = self.softmax(attention_new)
95 |
96 | feat_e = torch.bmm(attention, feat_a).view(batch_size, -1, height, width)
97 | out = self.beta * feat_e + x
98 |
99 | return out
100 |
101 |
102 | class _DAHead(nn.Module):
103 | def __init__(self, in_channels, nclass, aux=True, norm_layer=nn.BatchNorm2d, norm_kwargs=None, **kwargs):
104 | super(_DAHead, self).__init__()
105 | self.aux = aux
106 | inter_channels = in_channels // 4
107 | self.conv_p1 = nn.Sequential(
108 | nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
109 | norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)),
110 | nn.ReLU(True)
111 | )
112 | self.conv_c1 = nn.Sequential(
113 | nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
114 | norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)),
115 | nn.ReLU(True)
116 | )
117 | self.pam = _PositionAttentionModule(inter_channels, **kwargs)
118 | self.cam = _ChannelAttentionModule(**kwargs)
119 | self.conv_p2 = nn.Sequential(
120 | nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
121 | norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)),
122 | nn.ReLU(True)
123 | )
124 | self.conv_c2 = nn.Sequential(
125 | nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
126 | norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)),
127 | nn.ReLU(True)
128 | )
129 | self.out = nn.Sequential(
130 | nn.Dropout(0.1),
131 | nn.Conv2d(inter_channels, nclass, 1)
132 | )
133 | if aux:
134 | self.conv_p3 = nn.Sequential(
135 | nn.Dropout(0.1),
136 | nn.Conv2d(inter_channels, nclass, 1)
137 | )
138 | self.conv_c3 = nn.Sequential(
139 | nn.Dropout(0.1),
140 | nn.Conv2d(inter_channels, nclass, 1)
141 | )
142 |
143 | def forward(self, x):
144 | feat_p = self.conv_p1(x)
145 | feat_p = self.pam(feat_p)
146 | feat_p = self.conv_p2(feat_p)
147 |
148 | feat_c = self.conv_c1(x)
149 | feat_c = self.cam(feat_c)
150 | feat_c = self.conv_c2(feat_c)
151 |
152 | feat_fusion = feat_p + feat_c
153 |
154 | outputs = []
155 | fusion_out = self.out(feat_fusion)
156 | outputs.append(fusion_out)
157 | if self.aux:
158 | p_out = self.conv_p3(feat_p)
159 | c_out = self.conv_c3(feat_c)
160 | outputs.append(p_out)
161 | outputs.append(c_out)
162 |
163 | return tuple(outputs)
164 |
165 |
166 | class DANet(SegBaseModel):
167 | r"""Pyramid Scene Parsing Network
168 |
169 | Parameters
170 | ----------
171 | nclass : int
172 | Number of categories for the training dataset.
173 | backbone : string
174 | Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50',
175 | 'resnet101' or 'resnet152').
176 | norm_layer : object
177 | Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`;
178 | for Synchronized Cross-GPU BachNormalization).
179 | aux : bool
180 | Auxiliary loss.
181 | Reference:
182 | Jun Fu, Jing Liu, Haijie Tian, Yong Li, Yongjun Bao, Zhiwei Fang,and Hanqing Lu.
183 | "Dual Attention Network for Scene Segmentation." *CVPR*, 2019
184 | """
185 |
186 | def __init__(self, nclass, backbone='resnet50', aux=False, pretrained_base=False, **kwargs):
187 | super(DANet, self).__init__(nclass, aux, backbone, pretrained_base=pretrained_base, **kwargs)
188 | self.head = _DAHead(2048, nclass, aux, **kwargs)
189 |
190 | def forward(self, x):
191 | size = x.size()[2:]
192 | _, _, c3, c4 = self.base_forward(x)
193 | outputs = []
194 | x = self.head(c4)
195 | x0 = F.interpolate(x[0], size, mode='bilinear', align_corners=True)
196 | if self.aux:
197 | x1 = F.interpolate(x[1], size, mode='bilinear', align_corners=True)
198 | x2 = F.interpolate(x[2], size, mode='bilinear', align_corners=True)
199 | outputs.append(x0)
200 | outputs.append(x1)
201 | outputs.append(x2)
202 | return outputs
203 | return x0
204 |
205 |
206 | if __name__ == '__main__':
207 | from tools.flops_params_fps_count import flops_params_fps
208 | model = DANet(nclass=6)
209 | flops_params_fps(model)
210 |
--------------------------------------------------------------------------------
/models/deeplabv3.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | try:
4 | from .resnet import resnet50_v1b
5 | except:
6 | from resnet import resnet50_v1b
7 | import torch.nn.functional as F
8 |
9 |
10 | class SegBaseModel(nn.Module):
11 | r"""Base Model for Semantic Segmentation
12 |
13 | Parameters
14 | ----------
15 | backbone : string
16 | Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50',
17 | 'resnet101' or 'resnet152').
18 | """
19 |
20 | def __init__(self, nclass, aux, backbone='resnet50', jpu=False, pretrained_base=True, **kwargs):
21 | super(SegBaseModel, self).__init__()
22 | dilated = False if jpu else True
23 | self.aux = aux
24 | self.nclass = nclass
25 | if backbone == 'resnet50':
26 | self.pretrained = resnet50_v1b(pretrained=pretrained_base, dilated=dilated, **kwargs)
27 |
28 | # self.jpu = JPU([512, 1024, 2048], width=512, **kwargs) if jpu else None
29 |
30 | def base_forward(self, x):
31 | """forwarding pre-trained network"""
32 | x = self.pretrained.conv1(x)
33 | x = self.pretrained.bn1(x)
34 | x = self.pretrained.relu(x)
35 | x = self.pretrained.maxpool(x)
36 | c1 = self.pretrained.layer1(x)
37 | c2 = self.pretrained.layer2(c1)
38 | c3 = self.pretrained.layer3(c2)
39 | c4 = self.pretrained.layer4(c3)
40 |
41 | return c1, c2, c3, c4
42 |
43 | def evaluate(self, x):
44 | """evaluating network with inputs and targets"""
45 | return self.forward(x)[0]
46 |
47 | def demo(self, x):
48 | pred = self.forward(x)
49 | if self.aux:
50 | pred = pred[0]
51 | return pred
52 |
53 |
54 | class _FCNHead(nn.Module):
55 | def __init__(self, in_channels, channels, norm_layer=nn.BatchNorm2d, **kwargs):
56 | super(_FCNHead, self).__init__()
57 | inter_channels = in_channels // 4
58 | self.block = nn.Sequential(
59 | nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
60 | norm_layer(inter_channels),
61 | nn.ReLU(inplace=True),
62 | nn.Dropout(0.1),
63 | nn.Conv2d(inter_channels, channels, 1)
64 | )
65 |
66 | def forward(self, x):
67 | return self.block(x)
68 |
69 |
70 | class DeepLabV3(SegBaseModel):
71 | r"""DeepLabV3
72 |
73 | Parameters
74 | ----------
75 | nclass : int
76 | Number of categories for the training dataset.
77 | backbone : string
78 | Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50',
79 | 'resnet101' or 'resnet152').
80 | norm_layer : object
81 | Normalization layer used in backbone network (default: :class:`nn.BatchNorm`;
82 | for Synchronized Cross-GPU BachNormalization).
83 | aux : bool
84 | Auxiliary loss.
85 |
86 | Reference:
87 | Chen, Liang-Chieh, et al. "Rethinking atrous convolution for semantic image segmentation."
88 | arXiv preprint arXiv:1706.05587 (2017).
89 | """
90 |
91 | def __init__(self, nclass, backbone='resnet50', aux=False, pretrained_base=False, **kwargs):
92 | super(DeepLabV3, self).__init__(nclass, aux, backbone, pretrained_base=pretrained_base, **kwargs)
93 | self.head = _DeepLabHead(nclass, **kwargs)
94 | if self.aux:
95 | self.auxlayer = _FCNHead(1024, nclass, **kwargs)
96 |
97 | self.__setattr__('exclusive', ['head', 'auxlayer'] if aux else ['head'])
98 |
99 | def forward(self, x):
100 | size = x.size()[2:]
101 | _, _, c3, c4 = self.base_forward(x)
102 | outputs = []
103 | x = self.head(c4)
104 | x = F.interpolate(x, size, mode='bilinear', align_corners=True)
105 |
106 |
107 | if self.aux:
108 | auxout = self.auxlayer(c3)
109 | auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True)
110 | outputs.append(auxout)
111 | return x
112 |
113 |
114 | class _DeepLabHead(nn.Module):
115 | def __init__(self, nclass, norm_layer=nn.BatchNorm2d, norm_kwargs=None, **kwargs):
116 | super(_DeepLabHead, self).__init__()
117 | self.aspp = _ASPP(2048, [12, 24, 36], norm_layer=norm_layer, norm_kwargs=norm_kwargs, **kwargs)
118 | self.block = nn.Sequential(
119 | nn.Conv2d(256, 256, 3, padding=1, bias=False),
120 | norm_layer(256, **({} if norm_kwargs is None else norm_kwargs)),
121 | nn.ReLU(True),
122 | nn.Dropout(0.1),
123 | nn.Conv2d(256, nclass, 1)
124 | )
125 |
126 | def forward(self, x):
127 | x = self.aspp(x)
128 | return self.block(x)
129 |
130 |
131 | class _ASPPConv(nn.Module):
132 | def __init__(self, in_channels, out_channels, atrous_rate, norm_layer, norm_kwargs):
133 | super(_ASPPConv, self).__init__()
134 | self.block = nn.Sequential(
135 | nn.Conv2d(in_channels, out_channels, 3, padding=atrous_rate, dilation=atrous_rate, bias=False),
136 | norm_layer(out_channels, **({} if norm_kwargs is None else norm_kwargs)),
137 | nn.ReLU(True)
138 | )
139 |
140 | def forward(self, x):
141 | return self.block(x)
142 |
143 |
144 | class _AsppPooling(nn.Module):
145 | def __init__(self, in_channels, out_channels, norm_layer, norm_kwargs, **kwargs):
146 | super(_AsppPooling, self).__init__()
147 | self.gap = nn.Sequential(
148 | nn.AdaptiveAvgPool2d(1),
149 | nn.Conv2d(in_channels, out_channels, 1, bias=False),
150 | norm_layer(out_channels, **({} if norm_kwargs is None else norm_kwargs)),
151 | nn.ReLU(True)
152 | )
153 |
154 | def forward(self, x):
155 | size = x.size()[2:]
156 | pool = self.gap(x)
157 | out = F.interpolate(pool, size, mode='bilinear', align_corners=True)
158 | return out
159 |
160 |
161 | class _ASPP(nn.Module):
162 | def __init__(self, in_channels, atrous_rates, norm_layer, norm_kwargs, **kwargs):
163 | super(_ASPP, self).__init__()
164 | out_channels = 256
165 | self.b0 = nn.Sequential(
166 | nn.Conv2d(in_channels, out_channels, 1, bias=False),
167 | norm_layer(out_channels, **({} if norm_kwargs is None else norm_kwargs)),
168 | nn.ReLU(True)
169 | )
170 |
171 | rate1, rate2, rate3 = tuple(atrous_rates)
172 | self.b1 = _ASPPConv(in_channels, out_channels, rate1, norm_layer, norm_kwargs)
173 | self.b2 = _ASPPConv(in_channels, out_channels, rate2, norm_layer, norm_kwargs)
174 | self.b3 = _ASPPConv(in_channels, out_channels, rate3, norm_layer, norm_kwargs)
175 | self.b4 = _AsppPooling(in_channels, out_channels, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
176 |
177 | self.project = nn.Sequential(
178 | nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
179 | norm_layer(out_channels, **({} if norm_kwargs is None else norm_kwargs)),
180 | nn.ReLU(True),
181 | nn.Dropout(0.5)
182 | )
183 |
184 | def forward(self, x):
185 | feat1 = self.b0(x)
186 | feat2 = self.b1(x)
187 | feat3 = self.b2(x)
188 | feat4 = self.b3(x)
189 | feat5 = self.b4(x)
190 | x = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1)
191 | x = self.project(x)
192 | return x
193 |
194 |
195 | if __name__ == '__main__':
196 | from tools.flops_params_fps_count import flops_params_fps
197 | model = DeepLabV3(nclass=6)
198 | flops_params_fps(model)
199 |
200 |
201 |
202 |
203 |
204 |
--------------------------------------------------------------------------------
/models/edgenet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from models.resT import rest_tiny
4 |
5 |
6 | def EdgeNet(nclass=6):
7 | model = rest_tiny(nclass=nclass, pretrained=True, aux=True, edge_aux=False, head='mlphead')
8 | return model
9 |
10 |
11 | def edgenet_init(weight_dir):
12 | with torch.no_grad():
13 | model = rest_tiny(nclass=6, pretrained=False, aux=True, edge_aux=False, head='mlphead').eval()
14 | if os.path.isfile(weight_dir):
15 | print('loaded edge model successfully')
16 | checkpoint = torch.load(weight_dir, map_location=lambda storage, loc: storage)
17 | checkpoint = {k: v for k, v in checkpoint.items() if not 'loss' in k}
18 | checkpoint = {k.replace('module.model.', ''): v for k, v in checkpoint.items()}
19 | model.load_state_dict(checkpoint)
20 | return model
21 |
22 |
23 | if __name__ == '__main__':
24 | from tools.flops_params_fps_count import flops_params_fps
25 | model = EdgeNet(nclass=6)
26 | flops_params_fps(model)
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
--------------------------------------------------------------------------------
/models/fcn.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.utils.model_zoo as model_zoo
6 |
7 |
8 | class VGG(nn.Module):
9 | def __init__(self, features, num_classes=1000, init_weights=True):
10 | super(VGG, self).__init__()
11 | self.features = features
12 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
13 | self.classifier = nn.Sequential(
14 | nn.Linear(512 * 7 * 7, 4096),
15 | nn.ReLU(True),
16 | nn.Dropout(),
17 | nn.Linear(4096, 4096),
18 | nn.ReLU(True),
19 | nn.Dropout(),
20 | nn.Linear(4096, num_classes)
21 | )
22 | if init_weights:
23 | self._initialize_weights()
24 |
25 | def forward(self, x):
26 | x = self.features(x)
27 | x = self.avgpool(x)
28 | x = x.view(x.size(0), -1)
29 | x = self.classifier(x)
30 | return x
31 |
32 | def _initialize_weights(self):
33 | for m in self.modules():
34 | if isinstance(m, nn.Conv2d):
35 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
36 | if m.bias is not None:
37 | nn.init.constant_(m.bias, 0)
38 | elif isinstance(m, nn.BatchNorm2d):
39 | nn.init.constant_(m.weight, 1)
40 | nn.init.constant_(m.bias, 0)
41 | elif isinstance(m, nn.Linear):
42 | nn.init.normal_(m.weight, 0, 0.01)
43 | nn.init.constant_(m.bias, 0)
44 |
45 |
46 | def make_layers(cfg, batch_norm=False):
47 | layers = []
48 | in_channels = 3
49 | for v in cfg:
50 | if v == 'M':
51 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
52 | else:
53 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
54 | if batch_norm:
55 | layers += (conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True))
56 | else:
57 | layers += [conv2d, nn.ReLU(inplace=True)]
58 | in_channels = v
59 | return nn.Sequential(*layers)
60 |
61 |
62 | cfg = {
63 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
64 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
65 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
66 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
67 | }
68 |
69 |
70 | def vgg16(**kwargs):
71 |
72 | model = VGG(make_layers(cfg['D']), **kwargs)
73 |
74 | return model
75 |
76 |
77 | class FCN16s(nn.Module):
78 | def __init__(self, nclass, backbone='vgg16', aux=False, norm_layer=nn.BatchNorm2d, **kwargs):
79 | super(FCN16s, self).__init__()
80 | self.aux = aux
81 | if backbone == 'vgg16':
82 | self.pretrained = vgg16().features
83 | else:
84 | raise RuntimeError('unknown backbone: {}'.format(backbone))
85 | self.pool4 = nn.Sequential(*self.pretrained[:24])
86 | self.pool5 = nn.Sequential(*self.pretrained[24:])
87 | self.head = _FCNHead(512, nclass, norm_layer)
88 | self.score_pool4 = nn.Conv2d(512, nclass, 1)
89 | if aux:
90 | self.auxlayer = _FCNHead(512, nclass, norm_layer)
91 |
92 | self.__setattr__('exclusive', ['head', 'score_pool4', 'auxlayer'] if aux else ['head', 'score_pool4'])
93 |
94 | def forward(self, x):
95 | pool4 = self.pool4(x)
96 | pool5 = self.pool5(pool4)
97 |
98 | outputs = []
99 | score_fr = self.head(pool5)
100 |
101 | score_pool4 = self.score_pool4(pool4)
102 |
103 | upscore2 = F.interpolate(score_fr, score_pool4.size()[2:], mode='bilinear', align_corners=True)
104 | fuse_pool4 = upscore2 + score_pool4
105 |
106 | out = F.interpolate(fuse_pool4, x.size()[2:], mode='bilinear', align_corners=True)
107 | outputs = out
108 |
109 | if self.aux:
110 | auxout = self.auxlayer(pool5)
111 | auxout = F.interpolate(auxout, x.size()[2:], mode='bilinear', align_corners=True)
112 | outputs.append(auxout)
113 |
114 | return outputs
115 |
116 |
117 | class _FCNHead(nn.Module):
118 | def __init__(self, in_channels, channels, norm_layer=nn.BatchNorm2d, **kwargs):
119 | super(_FCNHead, self).__init__()
120 | inter_channels = in_channels // 4
121 | self.block = nn.Sequential(
122 | nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
123 | norm_layer(inter_channels),
124 | nn.ReLU(inplace=True),
125 | nn.Dropout(0.1),
126 | nn.Conv2d(inter_channels, channels, 1)
127 | )
128 |
129 | def forward(self, x):
130 | return self.block(x)
131 |
132 |
133 | if __name__ == '__main__':
134 | from tools.flops_params_fps_count import flops_params_fps
135 | model = FCN16s(nclass=6)
136 | flops_params_fps(model)
137 |
--------------------------------------------------------------------------------
/models/fpn.py:
--------------------------------------------------------------------------------
1 | '''FPN in PyTorch.
2 |
3 | See the paper "Feature Pyramid Networks for Object Detection" for more details.
4 | '''
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | from torch.autograd import Variable
10 |
11 |
12 | def conv3x3(in_planes, out_planes, stride=1):
13 | """3x3 convolution with padding"""
14 | conv3x3 = nn.Sequential(
15 | nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
16 | padding=1, bias=False),
17 | nn.BatchNorm2d(out_planes),
18 | nn.ReLU(inplace=True),
19 | )
20 | return conv3x3
21 |
22 |
23 | class Seg_head(nn.Module):
24 | def __init__(self, in_planes=256, out_planes=128, n_class=6):
25 | super(Seg_head, self).__init__()
26 | self.conv1 = conv3x3(in_planes, out_planes)
27 | self.conv2 = conv3x3(in_planes, out_planes)
28 | self.conv3 = conv3x3(in_planes, out_planes)
29 | self.conv4 = conv3x3(in_planes, out_planes)
30 | self.final_layer = nn.Sequential(
31 | nn.Conv2d(
32 | in_channels=out_planes * 4,
33 | out_channels=out_planes * 4,
34 | kernel_size=1,
35 | stride=1,
36 | padding=0),
37 | nn.BatchNorm2d(out_planes * 4),
38 | nn.ReLU(inplace=True),
39 | nn.Conv2d(
40 | in_channels=out_planes * 4,
41 | out_channels=n_class,
42 | kernel_size=1,
43 | stride=1,
44 | padding=0)
45 | )
46 |
47 | def forward(self, p2, p3, p4, p5):
48 | x2 = self.conv1(p2)
49 | x3 = F.interpolate(self.conv2(p3), scale_factor=2, mode='bilinear')
50 | x4 = F.interpolate(self.conv2(p4), scale_factor=4, mode='bilinear')
51 | x5 = F.interpolate(self.conv2(p5), scale_factor=8, mode='bilinear')
52 | x = torch.cat((x2, x3, x4, x5), dim=1)
53 | x = self.final_layer(x)
54 | output = F.interpolate(x, scale_factor=4, mode='bilinear')
55 |
56 | return output
57 |
58 |
59 | class Bottleneck(nn.Module):
60 | expansion = 4
61 |
62 | def __init__(self, in_planes, planes, stride=1):
63 | super(Bottleneck, self).__init__()
64 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
65 | self.bn1 = nn.BatchNorm2d(planes)
66 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
67 | self.bn2 = nn.BatchNorm2d(planes)
68 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
69 | self.bn3 = nn.BatchNorm2d(self.expansion*planes)
70 |
71 | self.shortcut = nn.Sequential()
72 | if stride != 1 or in_planes != self.expansion*planes:
73 | self.shortcut = nn.Sequential(
74 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
75 | nn.BatchNorm2d(self.expansion*planes)
76 | )
77 |
78 | def forward(self, x):
79 | out = F.relu(self.bn1(self.conv1(x)))
80 | out = F.relu(self.bn2(self.conv2(out)))
81 | out = self.bn3(self.conv3(out))
82 | out += self.shortcut(x)
83 | out = F.relu(out)
84 | return out
85 |
86 |
87 | class FPN(nn.Module):
88 | def __init__(self, block=Bottleneck, num_blocks=[3, 4, 6, 3], nclass=6):
89 | super(FPN, self).__init__()
90 | self.in_planes = 64
91 |
92 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
93 | self.bn1 = nn.BatchNorm2d(64)
94 |
95 | # Bottom-up layers
96 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
97 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
98 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
99 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
100 |
101 | # Top layer
102 | self.toplayer = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=0) # Reduce channels
103 |
104 | # Smooth layers
105 | self.smooth1 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
106 | self.smooth2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
107 | self.smooth3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
108 |
109 | # Lateral layers
110 | self.latlayer1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)
111 | self.latlayer2 = nn.Conv2d( 512, 256, kernel_size=1, stride=1, padding=0)
112 | self.latlayer3 = nn.Conv2d( 256, 256, kernel_size=1, stride=1, padding=0)
113 |
114 | self.seg_head = Seg_head(in_planes=256, out_planes=128, n_class=nclass)
115 |
116 | def _make_layer(self, block, planes, num_blocks, stride):
117 | strides = [stride] + [1]*(num_blocks-1)
118 | layers = []
119 | for stride in strides:
120 | layers.append(block(self.in_planes, planes, stride))
121 | self.in_planes = planes * block.expansion
122 | return nn.Sequential(*layers)
123 |
124 | def _upsample_add(self, x, y):
125 | '''Upsample and add two feature maps.
126 |
127 | Args:
128 | x: (Variable) top feature map to be upsampled.
129 | y: (Variable) lateral feature map.
130 |
131 | Returns:
132 | (Variable) added feature map.
133 |
134 | Note in PyTorch, when input size is odd, the upsampled feature map
135 | with `F.upsample(..., scale_factor=2, mode='nearest')`
136 | maybe not equal to the lateral feature map size.
137 |
138 | e.g.
139 | original input size: [N,_,15,15] ->
140 | conv2d feature map size: [N,_,8,8] ->
141 | upsampled feature map size: [N,_,16,16]
142 |
143 | So we choose bilinear upsample which supports arbitrary output sizes.
144 | '''
145 | _,_,H,W = y.size()
146 | return F.upsample(x, size=(H,W), mode='bilinear') + y
147 |
148 | def forward(self, x):
149 | # Bottom-up
150 | c1 = F.relu(self.bn1(self.conv1(x)))
151 | c1 = F.max_pool2d(c1, kernel_size=3, stride=2, padding=1)
152 | c2 = self.layer1(c1)
153 | c3 = self.layer2(c2)
154 | c4 = self.layer3(c3)
155 | c5 = self.layer4(c4)
156 | # Top-down
157 | p5 = self.toplayer(c5)
158 | p4 = self._upsample_add(p5, self.latlayer1(c4))
159 | p3 = self._upsample_add(p4, self.latlayer2(c3))
160 | p2 = self._upsample_add(p3, self.latlayer3(c2))
161 | # Smooth
162 | p4 = self.smooth1(p4)
163 | p3 = self.smooth2(p3)
164 | p2 = self.smooth3(p2)
165 |
166 | output = self.seg_head(p2, p3, p4, p5)
167 | return output
168 |
169 |
170 | if __name__ == '__main__':
171 | from tools.flops_params_fps_count import flops_params_fps
172 | model = FPN(nclass=6)
173 | flops_params_fps(model)
174 |
175 |
176 |
--------------------------------------------------------------------------------
/models/head/__init__.py:
--------------------------------------------------------------------------------
1 | from .ann import ANNHead
2 | from .apc import APCHead
3 | from .aspp import ASPPHead
4 | from .aspp_plus import ASPPPlusHead
5 | from .da import DAHead
6 | from .dnl import DNLHead
7 | from .fcfpn import FCFPNHead
8 | from .fcn import FCNHead
9 | from .gc import GCHead
10 | from .psa import PSAHead
11 | from .psp import PSPHead
12 | from .unet import UNetHead
13 | from .uper import UPerHead
14 | from .seg import SegHead
15 | from .cefpn import CEFPNHead
16 | from .mlp import MLPHead
17 | from .edge import EdgeHead
18 |
19 | __all__ = [
20 | 'ANNHead', 'APCHead', 'ASPPHead', 'ASPPPlusHead', 'DAHead', 'DNLHead', 'FCFPNHead', 'FCNHead',
21 | 'GCHead', 'PSAHead', 'PSPHead', 'UNetHead', 'UPerHead', 'SegHead', 'CEFPNHead', 'MLPHead', 'EdgeHead'
22 | ]
--------------------------------------------------------------------------------
/models/head/apc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from mmcv.cnn import ConvModule
5 | from .base_decoder import BaseDecodeHead, resize
6 |
7 | norm_cfg = dict(type='BN', requires_grad=True)
8 |
9 |
10 | class ACM(nn.Module):
11 | """Adaptive Context Module used in APCNet.
12 | Args:
13 | pool_scale (int): Pooling scale used in Adaptive Context
14 | Module to extract region features.
15 | fusion (bool): Add one conv to fuse residual feature.
16 | in_channels (int): Input channels.
17 | channels (int): Channels after modules, before conv_seg.
18 | conv_cfg (dict | None): Config of conv layers.
19 | norm_cfg (dict | None): Config of norm layers.
20 | act_cfg (dict): Config of activation layers.
21 | """
22 |
23 | def __init__(self, pool_scale, fusion, in_channels, channels, conv_cfg,
24 | norm_cfg, act_cfg):
25 | super(ACM, self).__init__()
26 | self.pool_scale = pool_scale
27 | self.fusion = fusion
28 | self.in_channels = in_channels
29 | self.channels = channels
30 | self.conv_cfg = conv_cfg
31 | self.norm_cfg = norm_cfg
32 | self.act_cfg = act_cfg
33 | self.pooled_redu_conv = ConvModule(
34 | self.in_channels,
35 | self.channels,
36 | 1,
37 | conv_cfg=self.conv_cfg,
38 | norm_cfg=self.norm_cfg,
39 | act_cfg=self.act_cfg)
40 |
41 | self.input_redu_conv = ConvModule(
42 | self.in_channels,
43 | self.channels,
44 | 1,
45 | conv_cfg=self.conv_cfg,
46 | norm_cfg=self.norm_cfg,
47 | act_cfg=self.act_cfg)
48 |
49 | self.global_info = ConvModule(
50 | self.channels,
51 | self.channels,
52 | 1,
53 | conv_cfg=self.conv_cfg,
54 | norm_cfg=self.norm_cfg,
55 | act_cfg=self.act_cfg)
56 |
57 | self.gla = nn.Conv2d(self.channels, self.pool_scale**2, 1, 1, 0)
58 |
59 | self.residual_conv = ConvModule(
60 | self.channels,
61 | self.channels,
62 | 1,
63 | conv_cfg=self.conv_cfg,
64 | norm_cfg=self.norm_cfg,
65 | act_cfg=self.act_cfg)
66 |
67 | if self.fusion:
68 | self.fusion_conv = ConvModule(
69 | self.channels,
70 | self.channels,
71 | 1,
72 | conv_cfg=self.conv_cfg,
73 | norm_cfg=self.norm_cfg,
74 | act_cfg=self.act_cfg)
75 |
76 | def forward(self, x):
77 | """Forward function."""
78 | pooled_x = F.adaptive_avg_pool2d(x, self.pool_scale)
79 | # [batch_size, channels, h, w]
80 | x = self.input_redu_conv(x)
81 | # [batch_size, channels, pool_scale, pool_scale]
82 | pooled_x = self.pooled_redu_conv(pooled_x)
83 | batch_size = x.size(0)
84 | # [batch_size, pool_scale * pool_scale, channels]
85 | pooled_x = pooled_x.view(batch_size, self.channels,
86 | -1).permute(0, 2, 1).contiguous()
87 | # [batch_size, h * w, pool_scale * pool_scale]
88 | affinity_matrix = self.gla(x + resize(
89 | self.global_info(F.adaptive_avg_pool2d(x, 1)), size=x.shape[2:])
90 | ).permute(0, 2, 3, 1).reshape(
91 | batch_size, -1, self.pool_scale**2)
92 | affinity_matrix = F.sigmoid(affinity_matrix)
93 | # [batch_size, h * w, channels]
94 | z_out = torch.matmul(affinity_matrix, pooled_x)
95 | # [batch_size, channels, h * w]
96 | z_out = z_out.permute(0, 2, 1).contiguous()
97 | # [batch_size, channels, h, w]
98 | z_out = z_out.view(batch_size, self.channels, x.size(2), x.size(3))
99 | z_out = self.residual_conv(z_out)
100 | z_out = F.relu(z_out + x)
101 | if self.fusion:
102 | z_out = self.fusion_conv(z_out)
103 |
104 | return z_out
105 |
106 |
107 | class APCHead(BaseDecodeHead):
108 | """Adaptive Pyramid Context Network for Semantic Segmentation.
109 | This head is the implementation of
110 | `APCNet `_.
113 | Args:
114 | pool_scales (tuple[int]): Pooling scales used in Adaptive Context
115 | Module. Default: (1, 2, 3, 6).
116 | fusion (bool): Add one conv to fuse residual feature.
117 | """
118 |
119 | def __init__(self, pool_scales=(1, 2, 3, 6), fusion=True, in_channels=768, num_classes=6, channels=512, in_index=3):
120 | super(APCHead, self).__init__(in_index=in_index, in_channels=in_channels,
121 | num_classes=num_classes, channels=channels, dropout_ratio=0.1, norm_cfg=norm_cfg, align_corners=False)
122 | assert isinstance(pool_scales, (list, tuple))
123 | self.pool_scales = pool_scales
124 | self.fusion = fusion
125 | acm_modules = []
126 | for pool_scale in self.pool_scales:
127 | acm_modules.append(
128 | ACM(pool_scale,
129 | self.fusion,
130 | self.in_channels,
131 | self.channels,
132 | conv_cfg=self.conv_cfg,
133 | norm_cfg=self.norm_cfg,
134 | act_cfg=self.act_cfg))
135 | self.acm_modules = nn.ModuleList(acm_modules)
136 | self.bottleneck = ConvModule(
137 | self.in_channels + len(pool_scales) * self.channels,
138 | self.channels,
139 | 3,
140 | padding=1,
141 | conv_cfg=self.conv_cfg,
142 | norm_cfg=self.norm_cfg,
143 | act_cfg=self.act_cfg)
144 |
145 | def forward(self, inputs):
146 | """Forward function."""
147 | x = self._transform_inputs(inputs)
148 | acm_outs = [x]
149 | for acm_module in self.acm_modules:
150 | acm_outs.append(acm_module(x))
151 | acm_outs = torch.cat(acm_outs, dim=1)
152 | output = self.bottleneck(acm_outs)
153 | output = self.cls_seg(output)
154 | return output
--------------------------------------------------------------------------------
/models/head/aspp.py:
--------------------------------------------------------------------------------
1 | ###########################################################################
2 | # Created by: Hang Zhang
3 | # Email: zhang.hang@rutgers.edu
4 | # Copyright (c) 2017
5 | ###########################################################################
6 | from __future__ import division
7 | import torch
8 | import torch.nn as nn
9 | from torch.nn.functional import interpolate
10 |
11 |
12 | up_kwargs = {'mode': 'bilinear', 'align_corners': False}
13 | norm_layer = nn.BatchNorm2d
14 |
15 |
16 | def ASPPConv(in_channels, out_channels, atrous_rate, norm_layer):
17 | block = nn.Sequential(
18 | nn.Conv2d(in_channels, out_channels, 3, padding=atrous_rate,
19 | dilation=atrous_rate, bias=False),
20 | norm_layer(out_channels),
21 | nn.ReLU(True))
22 | return block
23 |
24 |
25 | class AsppPooling(nn.Module):
26 | def __init__(self, in_channels, out_channels, norm_layer, up_kwargs):
27 | super(AsppPooling, self).__init__()
28 | self._up_kwargs = up_kwargs
29 | self.gap = nn.Sequential(nn.AdaptiveAvgPool2d(1),
30 | nn.Conv2d(in_channels, out_channels, 1, bias=False),
31 | norm_layer(out_channels),
32 | nn.ReLU(True))
33 |
34 | def forward(self, x):
35 | _, _, h, w = x.size()
36 | pool = self.gap(x)
37 | return interpolate(pool, (h,w), **self._up_kwargs)
38 |
39 |
40 | class ASPP_Module(nn.Module):
41 | def __init__(self, in_channels, atrous_rates, norm_layer, up_kwargs):
42 | super(ASPP_Module, self).__init__()
43 | out_channels = in_channels // 8
44 | rate1, rate2, rate3 = tuple(atrous_rates)
45 | self.b0 = nn.Sequential(
46 | nn.Conv2d(in_channels, out_channels, 1, bias=False),
47 | norm_layer(out_channels),
48 | nn.ReLU(True))
49 | self.b1 = ASPPConv(in_channels, out_channels, rate1, norm_layer)
50 | self.b2 = ASPPConv(in_channels, out_channels, rate2, norm_layer)
51 | self.b3 = ASPPConv(in_channels, out_channels, rate3, norm_layer)
52 | self.b4 = AsppPooling(in_channels, out_channels, norm_layer, up_kwargs)
53 |
54 | self.project = nn.Sequential(
55 | nn.Conv2d(5*out_channels, out_channels, 1, bias=False),
56 | norm_layer(out_channels),
57 | nn.ReLU(True),
58 | nn.Dropout2d(0.5, False))
59 |
60 | def forward(self, x):
61 | feat0 = self.b0(x)
62 | feat1 = self.b1(x)
63 | feat2 = self.b2(x)
64 | feat3 = self.b3(x)
65 | feat4 = self.b4(x)
66 | y = torch.cat((feat0, feat1, feat2, feat3, feat4), 1)
67 | return self.project(y)
68 |
69 |
70 | class ASPPHead(nn.Module):
71 | def __init__(self, in_channels, num_classes, norm_layer=norm_layer, up_kwargs=up_kwargs, atrous_rates=[12, 24, 36], in_index=3):
72 | super(ASPPHead, self).__init__()
73 | inter_channels = in_channels // 8
74 | self.in_index = in_index
75 | self.aspp = ASPP_Module(in_channels, atrous_rates, norm_layer, up_kwargs)
76 | self.block = nn.Sequential(
77 | nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
78 | norm_layer(inter_channels),
79 | nn.ReLU(True),
80 | nn.Dropout(0.1, False),
81 | nn.Conv2d(inter_channels, num_classes, 1))
82 |
83 | def _transform_inputs(self, inputs):
84 | if isinstance(self.in_index, (list, tuple)):
85 | inputs = [inputs[i] for i in self.in_index]
86 | elif isinstance(self.in_index, int):
87 | inputs = inputs[self.in_index]
88 | return inputs
89 |
90 | def forward(self, inputs):
91 | x = self._transform_inputs(inputs)
92 | x = self.aspp(x)
93 | x = self.block(x)
94 | return x
95 |
96 |
--------------------------------------------------------------------------------
/models/head/aspp_plus.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .aspp import ASPP_Module
6 |
7 |
8 | up_kwargs = {'mode': 'bilinear', 'align_corners': False}
9 | norm_layer = nn.BatchNorm2d
10 |
11 |
12 | class _ConvBNReLU(nn.Module):
13 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
14 | dilation=1, groups=1, relu6=False, norm_layer=norm_layer):
15 | super(_ConvBNReLU, self).__init__()
16 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False)
17 | self.bn = norm_layer(out_channels)
18 | self.relu = nn.ReLU6(True) if relu6 else nn.ReLU(True)
19 |
20 | def forward(self, x):
21 | x = self.conv(x)
22 | x = self.bn(x)
23 | x = self.relu(x)
24 | return x
25 |
26 |
27 | class ASPPPlusHead(nn.Module):
28 | def __init__(self, num_classes, in_channels, norm_layer=norm_layer, up_kwargs=up_kwargs, in_index=[0, 3]):
29 | super(ASPPPlusHead, self).__init__()
30 | self._up_kwargs = up_kwargs
31 | self.in_index = in_index
32 | self.channels = in_channels // 2 ** in_index[1]
33 | self.aspp = ASPP_Module(in_channels, [12, 24, 36], norm_layer=norm_layer, up_kwargs=up_kwargs)
34 | self.c1_block = _ConvBNReLU(self.channels, self.channels, 3, padding=1, norm_layer=norm_layer)
35 | self.block = nn.Sequential(
36 | _ConvBNReLU(self.channels + in_channels // 8, self.channels + in_channels // 8, 3, padding=1, norm_layer=norm_layer),
37 | nn.Dropout(0.5),
38 | _ConvBNReLU(self.channels + in_channels // 8, self.channels + in_channels // 8, 3, padding=1, norm_layer=norm_layer),
39 | nn.Dropout(0.1),
40 | nn.Conv2d(self.channels + in_channels // 8, num_classes, 1))
41 |
42 | def _transform_inputs(self, inputs):
43 | if isinstance(self.in_index, (list, tuple)):
44 | inputs = [inputs[i] for i in self.in_index]
45 | elif isinstance(self.in_index, int):
46 | inputs = inputs[self.in_index]
47 | return inputs
48 |
49 | def forward(self, inputs):
50 | inputs = self._transform_inputs(inputs)
51 | c1, x = inputs
52 | size = c1.size()[2:]
53 | c1 = self.c1_block(c1)
54 | x = self.aspp(x)
55 | x = F.interpolate(x, size, **self._up_kwargs)
56 | return self.block(torch.cat([x, c1], dim=1))
57 |
--------------------------------------------------------------------------------
/models/head/base_decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from mmcv.cnn import normal_init
4 | import warnings
5 | import torch.nn.functional as F
6 |
7 | norm_cfg = dict(type='BN', requires_grad=True)
8 |
9 |
10 | def resize(input,
11 | size=None,
12 | scale_factor=None,
13 | mode='nearest',
14 | align_corners=None,
15 | warning=True):
16 | if warning:
17 | if size is not None and align_corners:
18 | input_h, input_w = tuple(int(x) for x in input.shape[2:])
19 | output_h, output_w = tuple(int(x) for x in size)
20 | if output_h > input_h or output_w > output_h:
21 | if ((output_h > 1 and output_w > 1 and input_h > 1
22 | and input_w > 1) and (output_h - 1) % (input_h - 1)
23 | and (output_w - 1) % (input_w - 1)):
24 | warnings.warn(
25 | f'When align_corners={align_corners}, '
26 | 'the output would more aligned if '
27 | f'input size {(input_h, input_w)} is `x+1` and '
28 | f'out size {(output_h, output_w)} is `nx+1`')
29 | if isinstance(size, torch.Size):
30 | size = tuple(int(x) for x in size)
31 | return F.interpolate(input, size, scale_factor, mode, align_corners)
32 |
33 |
34 | class BaseDecodeHead(nn.Module):
35 | """Base class for BaseDecodeHead.
36 |
37 | Args:
38 | in_channels (int|Sequence[int]): Input channels.
39 | channels (int): Channels after modules, before conv_seg.
40 | num_classes (int): Number of classes.
41 | dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
42 | conv_cfg (dict|None): Config of conv layers. Default: None.
43 | norm_cfg (dict|None): Config of norm layers. Default: None.
44 | act_cfg (dict): Config of activation layers.
45 | Default: dict(type='ReLU')
46 | in_index (int|Sequence[int]): Input feature index. Default: -1
47 | input_transform (str|None): Transformation type of input features.
48 | Options: 'resize_concat', 'multiple_select', None.
49 | 'resize_concat': Multiple feature maps will be resize to the
50 | same size as first one and than concat together.
51 | Usually used in FCN head of HRNet.
52 | 'multiple_select': Multiple feature maps will be bundle into
53 | a list and passed into decode head.
54 | None: Only one select feature map is allowed.
55 | Default: None.
56 | loss_decode (dict): Config of decode loss.
57 | Default: dict(type='CrossEntropyLoss').
58 | ignore_index (int | None): The label index to be ignored. When using
59 | masked BCE loss, ignore_index should be set to None. Default: 255
60 | sampler (dict|None): The config of segmentation map sampler.
61 | Default: None.
62 | align_corners (bool): align_corners argument of F.interpolate.
63 | Default: False.
64 | """
65 |
66 | def __init__(self,
67 | in_channels,
68 | channels,
69 | *,
70 | num_classes,
71 | dropout_ratio=0.1,
72 | conv_cfg=None,
73 | norm_cfg=None,
74 | act_cfg=dict(type='ReLU'),
75 | in_index=-1,
76 | input_transform=None,
77 | ignore_index=255,
78 | sampler=None,
79 | align_corners=False):
80 | super(BaseDecodeHead, self).__init__()
81 | self._init_inputs(in_channels, in_index, input_transform)
82 | self.channels = channels
83 | self.num_classes = num_classes
84 | self.dropout_ratio = dropout_ratio
85 | self.conv_cfg = conv_cfg
86 | self.norm_cfg = norm_cfg
87 | self.act_cfg = act_cfg
88 | self.in_index = in_index
89 | self.ignore_index = ignore_index
90 | self.align_corners = align_corners
91 | self.sampler = None
92 |
93 | self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
94 | if dropout_ratio > 0:
95 | self.dropout = nn.Dropout2d(dropout_ratio)
96 | else:
97 | self.dropout = None
98 | self.fp16_enabled = False
99 |
100 | def extra_repr(self):
101 | """Extra repr."""
102 | s = f'input_transform={self.input_transform}, ' \
103 | f'ignore_index={self.ignore_index}, ' \
104 | f'align_corners={self.align_corners}'
105 | return s
106 |
107 | def _init_inputs(self, in_channels, in_index, input_transform):
108 | """Check and initialize input transforms.
109 |
110 | The in_channels, in_index and input_transform must match.
111 | Specifically, when input_transform is None, only single feature map
112 | will be selected. So in_channels and in_index must be of type int.
113 | When input_transform
114 |
115 | Args:
116 | in_channels (int|Sequence[int]): Input channels.
117 | in_index (int|Sequence[int]): Input feature index.
118 | input_transform (str|None): Transformation type of input features.
119 | Options: 'resize_concat', 'multiple_select', None.
120 | 'resize_concat': Multiple feature maps will be resize to the
121 | same size as first one and than concat together.
122 | Usually used in FCN head of HRNet.
123 | 'multiple_select': Multiple feature maps will be bundle into
124 | a list and passed into decode head.
125 | None: Only one select feature map is allowed.
126 | """
127 |
128 | if input_transform is not None:
129 | assert input_transform in ['resize_concat', 'multiple_select']
130 | self.input_transform = input_transform
131 | self.in_index = in_index
132 | if input_transform is not None:
133 | assert isinstance(in_channels, (list, tuple))
134 | assert isinstance(in_index, (list, tuple))
135 | assert len(in_channels) == len(in_index)
136 | if input_transform == 'resize_concat':
137 | self.in_channels = sum(in_channels)
138 | else:
139 | self.in_channels = in_channels
140 | else:
141 | assert isinstance(in_channels, int)
142 | assert isinstance(in_index, int)
143 | self.in_channels = in_channels
144 |
145 | def init_weights(self):
146 | """Initialize weights of classification layer."""
147 | normal_init(self.conv_seg, mean=0, std=0.01)
148 |
149 | def _transform_inputs(self, inputs):
150 | """Transform inputs for decoder.
151 |
152 | Args:
153 | inputs (list[Tensor]): List of multi-level img features.
154 |
155 | Returns:
156 | Tensor: The transformed inputs
157 | """
158 |
159 | if self.input_transform == 'resize_concat':
160 | inputs = [inputs[i] for i in self.in_index]
161 | upsampled_inputs = [
162 | resize(
163 | input=x,
164 | size=inputs[0].shape[2:],
165 | mode='bilinear',
166 | align_corners=self.align_corners) for x in inputs
167 | ]
168 | inputs = torch.cat(upsampled_inputs, dim=1)
169 | elif self.input_transform == 'multiple_select':
170 | inputs = [inputs[i] for i in self.in_index]
171 | else:
172 | inputs = inputs[self.in_index]
173 |
174 | return inputs
175 |
176 | def forward(self, inputs):
177 | """Placeholder of forward function."""
178 | pass
179 |
180 | def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg):
181 | """Forward function for training.
182 | Args:
183 | inputs (list[Tensor]): List of multi-level img features.
184 | img_metas (list[dict]): List of image info dict where each dict
185 | has: 'img_shape', 'scale_factor', 'flip', and may also contain
186 | 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
187 | For details on the values of these keys see
188 | `mmseg/datasets/pipelines/formatting.py:Collect`.
189 | gt_semantic_seg (Tensor): Semantic segmentation masks
190 | used if the architecture supports semantic segmentation task.
191 | train_cfg (dict): The training config.
192 |
193 | Returns:
194 | dict[str, Tensor]: a dictionary of loss components
195 | """
196 | seg_logits = self.forward(inputs)
197 | losses = self.losses(seg_logits, gt_semantic_seg)
198 | return losses
199 |
200 | def forward_test(self, inputs, img_metas, test_cfg):
201 | """Forward function for testing.
202 |
203 | Args:
204 | inputs (list[Tensor]): List of multi-level img features.
205 | img_metas (list[dict]): List of image info dict where each dict
206 | has: 'img_shape', 'scale_factor', 'flip', and may also contain
207 | 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
208 | For details on the values of these keys see
209 | `mmseg/datasets/pipelines/formatting.py:Collect`.
210 | test_cfg (dict): The testing config.
211 |
212 | Returns:
213 | Tensor: Output segmentation map.
214 | """
215 | return self.forward(inputs)
216 |
217 | def cls_seg(self, feat):
218 | """Classify each pixel."""
219 | if self.dropout is not None:
220 | feat = self.dropout(feat)
221 | output = self.conv_seg(feat)
222 | return output
223 |
--------------------------------------------------------------------------------
/models/head/cefpn.py:
--------------------------------------------------------------------------------
1 | ###########################################################################
2 | # Created by: Hang Zhang
3 | # Email: zhang.hang@rutgers.edu
4 | # Copyright (c) 2017
5 | ###########################################################################
6 | from __future__ import division
7 | import torch
8 | import torch.nn as nn
9 | from torch.nn.functional import upsample
10 |
11 | up_kwargs = {'mode': 'bilinear', 'align_corners': True}
12 | norm_layer = nn.BatchNorm2d
13 |
14 |
15 | class CEFPNHead(nn.Module):
16 | def __init__(self, in_channels=[256, 512, 1024, 2048], num_classes=6, channels=256,
17 | norm_layer=norm_layer, up_kwargs=up_kwargs, in_index=[0, 1, 2, 3]):
18 | super(CEFPNHead, self).__init__()
19 | assert up_kwargs is not None
20 | self._up_kwargs = up_kwargs
21 | self.in_index = in_index
22 | self.C5_2_F4 = nn.Sequential(
23 | nn.Conv2d(in_channels[3], in_channels[2], kernel_size=1, bias=False),
24 | norm_layer(in_channels[2]),
25 | nn.ReLU(inplace=True))
26 | self.C4_2_F4 = nn.Sequential(
27 | nn.Conv2d(in_channels[2], channels, kernel_size=1, bias=False),
28 | norm_layer(channels),
29 | nn.ReLU(inplace=True))
30 | self.C3_2_F3 = nn.Sequential(
31 | nn.Conv2d(in_channels[1], channels, kernel_size=1, bias=False),
32 | norm_layer(channels),
33 | nn.ReLU(inplace=True))
34 | self.C2_2_F2 = nn.Sequential(
35 | nn.Conv2d(in_channels[0], channels, kernel_size=1, bias=False),
36 | norm_layer(channels),
37 | nn.ReLU(inplace=True))
38 |
39 | fpn_out = []
40 | for _ in range(len(in_channels)):
41 | fpn_out.append(nn.Sequential(
42 | nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
43 | norm_layer(channels),
44 | nn.ReLU(inplace=True),
45 | ))
46 | self.fpn_out = nn.ModuleList(fpn_out)
47 | inter_channels = len(in_channels) * channels
48 | self.conv5 = nn.Sequential(nn.Conv2d(inter_channels, 512, 3, padding=1, bias=False),
49 | norm_layer(512),
50 | nn.ReLU(),
51 | nn.Dropout(0.1, False),
52 | nn.Conv2d(512, num_classes, 1))
53 | # channel_attention_guide
54 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
55 | self.max_pool = nn.AdaptiveMaxPool2d(1)
56 | self.shared_MLP = nn.Sequential(
57 | nn.Linear(in_features=channels, out_features=channels // 16),
58 | nn.ReLU(inplace=True),
59 | nn.Linear(in_features=channels // 16, out_features=channels))
60 | self.sigmoid = nn.Sigmoid()
61 |
62 | # sub_pixel_context_enhancement
63 | self.conv1 = nn.Sequential(nn.Conv2d(in_channels[-1], in_channels[-1] // 2, kernel_size=3, padding=1, bias=False),
64 | norm_layer(in_channels[-1] // 2),
65 | nn.ReLU())
66 | self.max_pool2 = nn.MaxPool2d(3, stride=2, padding=1)
67 | self.conv2 = nn.Sequential(nn.Conv2d(in_channels[-1], in_channels[-1] * 2, 1, bias=False),
68 | norm_layer(in_channels[-1] * 2),
69 | nn.ReLU())
70 | self.global_pool = nn.AdaptiveAvgPool2d(1)
71 | self.conv3 = nn.Sequential(nn.Conv2d(in_channels[-1], in_channels[-1] // 8, 1, bias=False),
72 | norm_layer(in_channels[-1] // 8),
73 | nn.ReLU())
74 | # inchannels to channels
75 | self.smooth1 = nn.Sequential(
76 | nn.Conv2d(in_channels[0], channels, kernel_size=1, bias=False),
77 | norm_layer(channels),
78 | nn.ReLU(inplace=True))
79 | self.smooth2 = nn.Sequential(
80 | nn.Conv2d(in_channels[0], channels, kernel_size=1, bias=False),
81 | norm_layer(channels),
82 | nn.ReLU(inplace=True))
83 | self.smooth3 = nn.Sequential(
84 | nn.Conv2d(in_channels[0], channels, kernel_size=1, bias=False),
85 | norm_layer(channels),
86 | nn.ReLU(inplace=True))
87 |
88 | def sub_pixel_conv(self, inputs, up_factor=2):
89 | b, c, h, w = inputs.shape
90 | assert c % (up_factor * up_factor) == 0
91 | inputs = inputs.permute(0, 2, 3, 1) # b h w c
92 | inputs = inputs.view(b, h, w, c // (up_factor * up_factor), up_factor, up_factor)
93 | inputs = inputs.permute(0, 1, 4, 2, 5, 3).contiguous()
94 | inputs = inputs.view(b, h * up_factor, w * up_factor, c // (up_factor * up_factor)).permute(0, 3, 1, 2)
95 | inputs = inputs.contiguous()
96 | return inputs
97 |
98 | def channel_attention_guide(self, inputs):
99 | avgout = self.shared_MLP(self.avg_pool(inputs).view(inputs.size(0), -1)).unsqueeze(2).unsqueeze(3)
100 | maxout = self.shared_MLP(self.max_pool(inputs).view(inputs.size(0), -1)).unsqueeze(2).unsqueeze(3)
101 | weights = self.sigmoid(avgout + maxout)
102 | output = weights * inputs
103 | return output
104 |
105 | def sub_pixel_context_enhancement(self, inputs):
106 | h, w = inputs.size()[2:]
107 | input1 = self.sub_pixel_conv(self.conv1(inputs))
108 | input2 = self.sub_pixel_conv(self.conv2(self.max_pool2(inputs)), up_factor=4)
109 | input3 = upsample(self.conv3(inputs), (h * 2, w * 2), **self._up_kwargs)
110 | output = input1 + input2 + input3
111 | output = self.smooth3(output)
112 | return output
113 |
114 | def _transform_inputs(self, inputs):
115 | if isinstance(self.in_index, (list, tuple)):
116 | inputs = [inputs[i] for i in self.in_index]
117 | elif isinstance(self.in_index, int):
118 | inputs = inputs[self.in_index]
119 | return inputs
120 |
121 | def forward(self, inputs):
122 | inputs = self._transform_inputs(inputs)
123 | c5 = inputs[-1]
124 | c1_size = inputs[0].size()[2:]
125 | if hasattr(self, 'extramodule'):
126 | c5 = self.extramodule(c5)
127 |
128 | feat = self.sub_pixel_context_enhancement(c5)
129 | feat_up = upsample(self.channel_attention_guide(self.fpn_out[3](feat)), c1_size, **self._up_kwargs)
130 | fpn_features = [feat_up]
131 |
132 | feat = self.smooth1(self.sub_pixel_conv(self.C5_2_F4(c5))) + self.C4_2_F4(inputs[2])
133 | feat_up = upsample(self.channel_attention_guide(self.fpn_out[2](feat)), c1_size, **self._up_kwargs)
134 | fpn_features.append(feat_up)
135 |
136 | feats = []
137 | feats.append(self.C2_2_F2(inputs[0]))
138 | feats.append(self.smooth2(self.sub_pixel_conv(inputs[2])) + self.C3_2_F3(inputs[1]))
139 |
140 | for i in reversed(range(len(inputs) - 2)):
141 | feat_i = feats[i]
142 | feat = upsample(feat, feat_i.size()[2:], **self._up_kwargs)
143 | feat = feat + feat_i
144 | feat_up = upsample(self.channel_attention_guide(self.fpn_out[i](feat)), c1_size, **self._up_kwargs)
145 | fpn_features.append(feat_up)
146 | fpn_features = torch.cat(fpn_features, 1)
147 |
148 | return self.conv5(fpn_features)
149 |
--------------------------------------------------------------------------------
/models/head/da.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 |
4 | up_kwargs = {'mode': 'bilinear', 'align_corners': False}
5 | norm_layer = nn.BatchNorm2d
6 |
7 |
8 | class _PositionAttentionModule(nn.Module):
9 | """ Position attention module"""
10 |
11 | def __init__(self, in_channels):
12 | super(_PositionAttentionModule, self).__init__()
13 | self.conv_b = nn.Conv2d(in_channels, in_channels // 8, 1)
14 | self.conv_c = nn.Conv2d(in_channels, in_channels // 8, 1)
15 | self.conv_d = nn.Conv2d(in_channels, in_channels, 1)
16 | self.alpha = nn.Parameter(torch.zeros(1))
17 | self.softmax = nn.Softmax(dim=-1)
18 |
19 | def forward(self, x):
20 | batch_size, _, height, width = x.size()
21 | feat_b = self.conv_b(x).view(batch_size, -1, height * width).permute(0, 2, 1)
22 | feat_c = self.conv_c(x).view(batch_size, -1, height * width)
23 | attention_s = self.softmax(torch.bmm(feat_b, feat_c))
24 | feat_d = self.conv_d(x).view(batch_size, -1, height * width)
25 | feat_e = torch.bmm(feat_d, attention_s.permute(0, 2, 1)).view(batch_size, -1, height, width)
26 | out = self.alpha * feat_e + x
27 |
28 | return out
29 |
30 |
31 | class _ChannelAttentionModule(nn.Module):
32 | """Channel attention module"""
33 |
34 | def __init__(self):
35 | super(_ChannelAttentionModule, self).__init__()
36 | self.beta = nn.Parameter(torch.zeros(1))
37 | self.softmax = nn.Softmax(dim=-1)
38 |
39 | def forward(self, x):
40 | batch_size, _, height, width = x.size()
41 | feat_a = x.view(batch_size, -1, height * width)
42 | feat_a_transpose = x.view(batch_size, -1, height * width).permute(0, 2, 1)
43 | attention = torch.bmm(feat_a, feat_a_transpose)
44 | attention_new = torch.max(attention, dim=-1, keepdim=True)[0].expand_as(attention) - attention
45 | attention = self.softmax(attention_new)
46 |
47 | feat_e = torch.bmm(attention, feat_a).view(batch_size, -1, height, width)
48 | out = self.beta * feat_e + x
49 |
50 | return out
51 |
52 |
53 | class DAHead(nn.Module):
54 | def __init__(self, in_channels, num_classes, aux=False, norm_layer=norm_layer, norm_kwargs=None, in_index=3):
55 | super(DAHead, self).__init__()
56 | self.aux = aux
57 | self.in_index = in_index
58 | inter_channels = in_channels // 4
59 | self.conv_p1 = nn.Sequential(
60 | nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
61 | norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)),
62 | nn.ReLU(True)
63 | )
64 | self.conv_c1 = nn.Sequential(
65 | nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
66 | norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)),
67 | nn.ReLU(True)
68 | )
69 | self.pam = _PositionAttentionModule(inter_channels)
70 | self.cam = _ChannelAttentionModule()
71 | self.conv_p2 = nn.Sequential(
72 | nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
73 | norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)),
74 | nn.ReLU(True)
75 | )
76 | self.conv_c2 = nn.Sequential(
77 | nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
78 | norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)),
79 | nn.ReLU(True)
80 | )
81 | self.out = nn.Sequential(
82 | nn.Dropout(0.1),
83 | nn.Conv2d(inter_channels, num_classes, 1)
84 | )
85 | if aux:
86 | self.conv_p3 = nn.Sequential(
87 | nn.Dropout(0.1),
88 | nn.Conv2d(inter_channels, num_classes, 1)
89 | )
90 | self.conv_c3 = nn.Sequential(
91 | nn.Dropout(0.1),
92 | nn.Conv2d(inter_channels, num_classes, 1)
93 | )
94 |
95 | def _transform_inputs(self, inputs):
96 | if isinstance(self.in_index, (list, tuple)):
97 | inputs = [inputs[i] for i in self.in_index]
98 | elif isinstance(self.in_index, int):
99 | inputs = inputs[self.in_index]
100 | return inputs
101 |
102 | def forward(self, inputs):
103 | x = self._transform_inputs(inputs)
104 | feat_p = self.conv_p1(x)
105 | feat_p = self.pam(feat_p)
106 | feat_p = self.conv_p2(feat_p)
107 |
108 | feat_c = self.conv_c1(x)
109 | feat_c = self.cam(feat_c)
110 | feat_c = self.conv_c2(feat_c)
111 |
112 | feat_fusion = feat_p + feat_c
113 |
114 | outputs = []
115 | fusion_out = self.out(feat_fusion)
116 | outputs.append(fusion_out)
117 | if self.aux:
118 | p_out = self.conv_p3(feat_p)
119 | c_out = self.conv_c3(feat_c)
120 | outputs.append(p_out)
121 | outputs.append(c_out)
122 |
123 | return outputs
124 |
--------------------------------------------------------------------------------
/models/head/dnl.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from mmcv.cnn import NonLocal2d
3 | from torch import nn
4 | from .fcn import FCNHead
5 |
6 | norm_cfg = dict(type='BN', requires_grad=True)
7 |
8 |
9 | class DisentangledNonLocal2d(NonLocal2d):
10 | """Disentangled Non-Local Blocks.
11 | Args:
12 | temperature (float): Temperature to adjust attention. Default: 0.05
13 | """
14 |
15 | def __init__(self, *arg, temperature, **kwargs):
16 | super().__init__(*arg, **kwargs)
17 | self.temperature = temperature
18 | self.conv_mask = nn.Conv2d(self.in_channels, 1, kernel_size=1)
19 |
20 | def embedded_gaussian(self, theta_x, phi_x):
21 | """Embedded gaussian with temperature."""
22 |
23 | # NonLocal2d pairwise_weight: [N, HxW, HxW]
24 | pairwise_weight = torch.matmul(theta_x, phi_x)
25 | if self.use_scale:
26 | # theta_x.shape[-1] is `self.inter_channels`
27 | pairwise_weight /= theta_x.shape[-1]**0.5
28 | pairwise_weight /= self.temperature
29 | pairwise_weight = pairwise_weight.softmax(dim=-1)
30 | return pairwise_weight
31 |
32 | def forward(self, x):
33 | # x: [N, C, H, W]
34 | n = x.size(0)
35 |
36 | # g_x: [N, HxW, C]
37 | g_x = self.g(x).view(n, self.inter_channels, -1)
38 | g_x = g_x.permute(0, 2, 1)
39 |
40 | # theta_x: [N, HxW, C], phi_x: [N, C, HxW]
41 | if self.mode == 'gaussian':
42 | theta_x = x.view(n, self.in_channels, -1)
43 | theta_x = theta_x.permute(0, 2, 1)
44 | if self.sub_sample:
45 | phi_x = self.phi(x).view(n, self.in_channels, -1)
46 | else:
47 | phi_x = x.view(n, self.in_channels, -1)
48 | elif self.mode == 'concatenation':
49 | theta_x = self.theta(x).view(n, self.inter_channels, -1, 1)
50 | phi_x = self.phi(x).view(n, self.inter_channels, 1, -1)
51 | else:
52 | theta_x = self.theta(x).view(n, self.inter_channels, -1)
53 | theta_x = theta_x.permute(0, 2, 1)
54 | phi_x = self.phi(x).view(n, self.inter_channels, -1)
55 |
56 | # subtract mean
57 | theta_x -= theta_x.mean(dim=-2, keepdim=True)
58 | phi_x -= phi_x.mean(dim=-1, keepdim=True)
59 |
60 | pairwise_func = getattr(self, self.mode)
61 | # pairwise_weight: [N, HxW, HxW]
62 | pairwise_weight = pairwise_func(theta_x, phi_x)
63 |
64 | # y: [N, HxW, C]
65 | y = torch.matmul(pairwise_weight, g_x)
66 | # y: [N, C, H, W]
67 | y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels,
68 | *x.size()[2:])
69 |
70 | # unary_mask: [N, 1, HxW]
71 | unary_mask = self.conv_mask(x)
72 | unary_mask = unary_mask.view(n, 1, -1)
73 | unary_mask = unary_mask.softmax(dim=-1)
74 | # unary_x: [N, 1, C]
75 | unary_x = torch.matmul(unary_mask, g_x)
76 | # unary_x: [N, C, 1, 1]
77 | unary_x = unary_x.permute(0, 2, 1).contiguous().reshape(
78 | n, self.inter_channels, 1, 1)
79 |
80 | output = x + self.conv_out(y + unary_x)
81 |
82 | return output
83 |
84 |
85 | class DNLHead(FCNHead):
86 | """Disentangled Non-Local Neural Networks.
87 | This head is the implementation of `DNLNet
88 | `_.
89 | Args:
90 | reduction (int): Reduction factor of projection transform. Default: 2.
91 | use_scale (bool): Whether to scale pairwise_weight by
92 | sqrt(1/inter_channels). Default: False.
93 | mode (str): The nonlocal mode. Options are 'embedded_gaussian',
94 | 'dot_product'. Default: 'embedded_gaussian.'.
95 | temperature (float): Temperature to adjust attention. Default: 0.05
96 | """
97 |
98 | def __init__(self,
99 | reduction=2,
100 | use_scale=True,
101 | mode='embedded_gaussian',
102 | temperature=0.05,
103 | in_channels=768,
104 | num_classes=6,
105 | in_index=3,
106 | channels=512,
107 | ):
108 | super(DNLHead, self).__init__(num_convs=2, in_channels=in_channels, num_classes=num_classes, in_index=in_index, channels=channels)
109 | self.reduction = reduction
110 | self.use_scale = use_scale
111 | self.mode = mode
112 | self.temperature = temperature
113 | self.dnl_block = DisentangledNonLocal2d(
114 | in_channels=self.channels,
115 | reduction=self.reduction,
116 | use_scale=self.use_scale,
117 | conv_cfg=self.conv_cfg,
118 | norm_cfg=self.norm_cfg,
119 | mode=self.mode,
120 | temperature=self.temperature)
121 |
122 | def forward(self, inputs):
123 | """Forward function."""
124 | x = self._transform_inputs(inputs)
125 | output = self.convs[0](x)
126 | output = self.dnl_block(output)
127 | output = self.convs[1](output)
128 | if self.concat_input:
129 | output = self.conv_cat(torch.cat([x, output], dim=1))
130 | output = self.cls_seg(output)
131 | return output
--------------------------------------------------------------------------------
/models/head/edge.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torch.nn.functional as F
4 |
5 |
6 | up_kwargs = {'mode': 'bilinear', 'align_corners': False}
7 |
8 |
9 | class EdgeHead(nn.Module):
10 | """Edge awareness module"""
11 |
12 | def __init__(self, in_channels=[96, 192], channels=96, out_fea=2, in_index=[0, 1]):
13 | super(EdgeHead, self).__init__()
14 | self.in_index = in_index
15 | self.conv1 = nn.Sequential(
16 | nn.Conv2d(in_channels[0], in_channels[0], 1, 1, 0),
17 | nn.BatchNorm2d(in_channels[0]),
18 | nn.ReLU(True),
19 | nn.Conv2d(in_channels[0], channels, 1, 1, 0),
20 | nn.BatchNorm2d(channels),
21 | nn.ReLU(True),
22 | )
23 | # self.conv2 = nn.Sequential(
24 | # nn.Conv2d(in_channels[1], in_channels[1], 1, 1, 0),
25 | # nn.BatchNorm2d(in_channels[1]),
26 | # nn.ReLU(True),
27 | # nn.Conv2d(in_channels[1], channels, 1, 1, 0),
28 | # nn.BatchNorm2d(channels),
29 | # nn.ReLU(True),
30 | # )
31 | self.conv3 = nn.Conv2d(channels, out_fea, 1, 1, 0)
32 |
33 | def _transform_inputs(self, inputs):
34 | if isinstance(self.in_index, (list, tuple)):
35 | inputs = [inputs[i] for i in self.in_index]
36 | elif isinstance(self.in_index, int):
37 | inputs = inputs[self.in_index]
38 | return inputs
39 |
40 | def forward(self, inputs):
41 | inputs = self._transform_inputs(inputs)
42 | x1, x2 = inputs
43 | _, _, h, w = x1.size()
44 |
45 | edge1_fea = self.conv1(x1)
46 | # edge2_fea = self.conv2(x2)
47 |
48 | edge1_fea = F.interpolate(edge1_fea, size=(h, w), **up_kwargs)
49 | # edge2_fea = F.interpolate(edge2_fea, size=(h, w), **up_kwargs)
50 |
51 | # edge_fea = torch.cat([edge1_fea, edge2_fea], dim=1)
52 |
53 | edge = self.conv3(edge1_fea)
54 |
55 | return edge
56 |
--------------------------------------------------------------------------------
/models/head/fcfpn.py:
--------------------------------------------------------------------------------
1 | ###########################################################################
2 | # Created by: Hang Zhang
3 | # Email: zhang.hang@rutgers.edu
4 | # Copyright (c) 2017
5 | ###########################################################################
6 | from __future__ import division
7 | import torch
8 | import torch.nn as nn
9 | from torch.nn.functional import upsample
10 |
11 | up_kwargs = {'mode': 'bilinear', 'align_corners': True}
12 | norm_layer = nn.BatchNorm2d
13 |
14 |
15 | class FCFPNHead(nn.Module):
16 | def __init__(self, in_channels=[256, 512, 1024, 2048], num_classes=6, channels=256,
17 | norm_layer=norm_layer, up_kwargs=up_kwargs, in_index=[0, 1, 2, 3]):
18 | super(FCFPNHead, self).__init__()
19 | assert up_kwargs is not None
20 | self._up_kwargs = up_kwargs
21 | self.in_index = in_index
22 | fpn_lateral = []
23 | for inchannel in in_channels[:-1]:
24 | fpn_lateral.append(nn.Sequential(
25 | nn.Conv2d(inchannel, channels, kernel_size=1, bias=False),
26 | norm_layer(channels),
27 | nn.ReLU(inplace=True),
28 | ))
29 | self.fpn_lateral = nn.ModuleList(fpn_lateral)
30 | fpn_out = []
31 | for _ in range(len(in_channels) - 1):
32 | fpn_out.append(nn.Sequential(
33 | nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
34 | norm_layer(channels),
35 | nn.ReLU(inplace=True),
36 | ))
37 | self.fpn_out = nn.ModuleList(fpn_out)
38 | self.c4conv = nn.Sequential(nn.Conv2d(in_channels[-1], channels, 3, padding=1, bias=False),
39 | norm_layer(channels),
40 | nn.ReLU())
41 | inter_channels = len(in_channels) * channels
42 | self.conv5 = nn.Sequential(nn.Conv2d(inter_channels, 512, 3, padding=1, bias=False),
43 | norm_layer(512),
44 | nn.ReLU(),
45 | nn.Dropout(0.1, False),
46 | nn.Conv2d(512, num_classes, 1))
47 |
48 | def _transform_inputs(self, inputs):
49 | if isinstance(self.in_index, (list, tuple)):
50 | inputs = [inputs[i] for i in self.in_index]
51 | elif isinstance(self.in_index, int):
52 | inputs = inputs[self.in_index]
53 | return inputs
54 |
55 | def forward(self, inputs):
56 | inputs = self._transform_inputs(inputs)
57 | c4 = inputs[-1]
58 | if hasattr(self, 'extramodule'):
59 | c4 = self.extramodule(c4)
60 | feat = self.c4conv(c4)
61 | c1_size = inputs[0].size()[2:]
62 | feat_up = upsample(feat, c1_size, **self._up_kwargs)
63 | fpn_features = [feat_up]
64 |
65 | for i in reversed(range(len(inputs) - 1)):
66 | feat_i = self.fpn_lateral[i](inputs[i])
67 | feat = upsample(feat, feat_i.size()[2:], **self._up_kwargs)
68 | feat = feat + feat_i
69 | feat_up = upsample(self.fpn_out[i](feat), c1_size, **self._up_kwargs)
70 | fpn_features.append(feat_up)
71 | fpn_features = torch.cat(fpn_features, 1)
72 |
73 | return self.conv5(fpn_features)
74 |
--------------------------------------------------------------------------------
/models/head/fcn.py:
--------------------------------------------------------------------------------
1 | ###########################################################################
2 | # Created by: Hang Zhang
3 | # Email: zhang.hang@rutgers.edu
4 | # Copyright (c) 2017
5 | ###########################################################################
6 | from __future__ import division
7 | import torch.nn as nn
8 | import torch
9 | from .base_decoder import BaseDecodeHead
10 | from mmcv.cnn import ConvModule
11 |
12 | norm_cfg = dict(type='BN', requires_grad=True)
13 |
14 |
15 | class FCNHead(BaseDecodeHead):
16 | """Fully Convolution Networks for Semantic Segmentation.
17 | This head is implemented of `FCNNet `_.
18 | Args:
19 | num_convs (int): Number of convs in the head. Default: 2.
20 | kernel_size (int): The kernel size for convs in the head. Default: 3.
21 | concat_input (bool): Whether concat the input and output of convs
22 | before classification layer.
23 | """
24 |
25 | def __init__(self,
26 | num_convs=2,
27 | kernel_size=3,
28 | concat_input=False,
29 | in_channels=768,
30 | num_classes=6,
31 | in_index=3,
32 | channels=512
33 | ):
34 | assert num_convs >= 0
35 | self.num_convs = num_convs
36 | self.concat_input = concat_input
37 | self.kernel_size = kernel_size
38 | super(FCNHead, self).__init__(in_channels=in_channels, in_index=in_index, channels=channels, dropout_ratio=0.1,
39 | num_classes=num_classes, norm_cfg=norm_cfg, align_corners=False)
40 | if num_convs == 0:
41 | assert self.in_channels == self.channels
42 | convs = []
43 | convs.append(
44 | ConvModule(
45 | self.in_channels,
46 | self.channels,
47 | kernel_size=kernel_size,
48 | padding=kernel_size // 2,
49 | conv_cfg=self.conv_cfg,
50 | norm_cfg=self.norm_cfg,
51 | act_cfg=self.act_cfg))
52 | for i in range(num_convs - 1):
53 | convs.append(
54 | ConvModule(
55 | self.channels,
56 | self.channels,
57 | kernel_size=kernel_size,
58 | padding=kernel_size // 2,
59 | conv_cfg=self.conv_cfg,
60 | norm_cfg=self.norm_cfg,
61 | act_cfg=self.act_cfg))
62 | if num_convs == 0:
63 | self.convs = nn.Identity()
64 | else:
65 | self.convs = nn.Sequential(*convs)
66 | if self.concat_input:
67 | self.conv_cat = ConvModule(
68 | self.in_channels + self.channels,
69 | self.channels,
70 | kernel_size=kernel_size,
71 | padding=kernel_size // 2,
72 | conv_cfg=self.conv_cfg,
73 | norm_cfg=self.norm_cfg,
74 | act_cfg=self.act_cfg)
75 |
76 | def forward(self, inputs):
77 | """Forward function."""
78 | x = self._transform_inputs(inputs)
79 | output = self.convs(x)
80 | if self.concat_input:
81 | output = self.conv_cat(torch.cat([x, output], dim=1))
82 | output = self.cls_seg(output)
83 | return output
84 |
--------------------------------------------------------------------------------
/models/head/gc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from mmcv.cnn import ContextBlock
3 | from .fcn import FCNHead
4 |
5 |
6 | class GCHead(FCNHead):
7 | """GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond.
8 | This head is the implementation of `GCNet
9 | `_.
10 | Args:
11 | ratio (float): Multiplier of channels ratio. Default: 1/4.
12 | pooling_type (str): The pooling type of context aggregation.
13 | Options are 'att', 'avg'. Default: 'avg'.
14 | fusion_types (tuple[str]): The fusion type for feature fusion.
15 | Options are 'channel_add', 'channel_mul'. Default: ('channel_add',)
16 | """
17 |
18 | def __init__(self,
19 | ratio=1 / 4.,
20 | pooling_type='att',
21 | fusion_types=('channel_add', ),
22 | in_channels=768,
23 | num_classes=6,
24 | in_index=3,
25 | channels=512,
26 | ):
27 | super(GCHead, self).__init__(num_convs=2, in_channels=in_channels, num_classes=num_classes, in_index=in_index, channels=channels)
28 | self.ratio = ratio
29 | self.pooling_type = pooling_type
30 | self.fusion_types = fusion_types
31 | self.gc_block = ContextBlock(
32 | in_channels=self.channels,
33 | ratio=self.ratio,
34 | pooling_type=self.pooling_type,
35 | fusion_types=self.fusion_types)
36 |
37 | def forward(self, inputs):
38 | """Forward function."""
39 | x = self._transform_inputs(inputs)
40 | output = self.convs[0](x)
41 | output = self.gc_block(output)
42 | output = self.convs[1](output)
43 | if self.concat_input:
44 | output = self.conv_cat(torch.cat([x, output], dim=1))
45 | output = self.cls_seg(output)
46 | return output
--------------------------------------------------------------------------------
/models/head/mlp.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | from .base_decoder import BaseDecodeHead, resize
4 |
5 | up_kwargs = {'mode': 'bilinear', 'align_corners': False}
6 |
7 |
8 | class MLP(nn.Module):
9 | """
10 | Linear Embedding
11 | """
12 |
13 | def __init__(self, input_dim=2048, embed_dim=768, norm_act=True):
14 | super().__init__()
15 | self.proj = nn.Linear(input_dim, embed_dim)
16 | self.norm_act = norm_act
17 | if self.norm_act:
18 | self.norm = nn.LayerNorm(input_dim)
19 | self.act = nn.GELU()
20 |
21 | def forward(self, x):
22 | x = x.flatten(2).transpose(1, 2)
23 | if self.norm_act:
24 | x = self.norm(x)
25 | x = self.proj(x)
26 | if self.norm_act:
27 | x = self.act(x)
28 | return x
29 |
30 |
31 | class MLPHead(BaseDecodeHead):
32 | """
33 | SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers
34 | """
35 |
36 | def __init__(self, in_channels=[96, 192, 384, 768], channels=512, num_classes=6, in_index=[0, 1, 2, 3]):
37 | super(MLPHead, self).__init__(input_transform='multiple_select', in_index=in_index,
38 | in_channels=in_channels, num_classes=num_classes, channels=channels)
39 | c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels
40 |
41 | self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=channels)
42 | self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=channels)
43 | self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=channels)
44 | self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=channels)
45 |
46 | self.linear_c3_out = MLP(input_dim=channels, embed_dim=channels)
47 | self.linear_c2_out = MLP(input_dim=channels, embed_dim=channels)
48 | self.linear_c1_out = MLP(input_dim=channels, embed_dim=channels)
49 |
50 | self.linear_fuse = MLP(input_dim=channels * 4, embed_dim=channels)
51 | self.linear_pred = MLP(input_dim=channels, embed_dim=num_classes, norm_act=False)
52 |
53 | def forward(self, inputs):
54 | x = self._transform_inputs(inputs) # len=4, 1/4,1/8,1/16,1/32
55 | c1, c2, c3, c4 = x
56 | out = []
57 | ############## MLP decoder on C1-C4 ###########
58 | n, _, h, w = c4.shape
59 |
60 | _c4 = self.linear_c4(c4).permute(0, 2, 1).contiguous().reshape(n, -1, c4.shape[2], c4.shape[3])
61 | _c4 = resize(_c4, size=c3.size()[2:], **up_kwargs)
62 |
63 | out.append(resize(_c4, size=c1.size()[2:], **up_kwargs))
64 |
65 | _c3 = self.linear_c3(c3).permute(0, 2, 1).contiguous().reshape(n, -1, c3.shape[2], c3.shape[3])
66 | _c3 = _c4 + _c3
67 |
68 | _c3_out = self.linear_c3_out(_c3).permute(0, 2, 1).contiguous().reshape(n, -1, c3.shape[2], c3.shape[3])
69 | out.append(resize(_c3_out, size=c1.size()[2:], **up_kwargs))
70 |
71 | _c2 = self.linear_c2(c2).permute(0, 2, 1).contiguous().reshape(n, -1, c2.shape[2], c2.shape[3])
72 | _c3 = resize(_c3, size=c2.size()[2:], **up_kwargs)
73 | _c2 = _c3 + _c2
74 |
75 | _c2_out = self.linear_c2_out(_c2).permute(0, 2, 1).contiguous().reshape(n, -1, c2.shape[2], c2.shape[3])
76 | out.append(resize(_c2_out, size=c1.size()[2:], **up_kwargs))
77 |
78 | _c1 = self.linear_c1(c1).permute(0, 2, 1).contiguous().reshape(n, -1, c1.shape[2], c1.shape[3])
79 | _c2 = resize(_c2, size=c1.size()[2:], **up_kwargs)
80 | _c1 = _c2 + _c1
81 |
82 | _c1_out = self.linear_c1_out(_c1).permute(0, 2, 1).contiguous().reshape(n, -1, c1.shape[2], c1.shape[3])
83 | out.append(_c1_out)
84 |
85 | _c = self.linear_fuse(torch.cat(out, dim=1)).permute(0, 2, 1).contiguous().reshape(n, -1, c1.shape[2], c1.shape[3])
86 | _c = self.dropout(_c)
87 | x = self.linear_pred(_c).permute(0, 2, 1).contiguous().reshape(n, -1, c1.shape[2], c1.shape[3])
88 |
89 | return x
90 |
--------------------------------------------------------------------------------
/models/head/psa.py:
--------------------------------------------------------------------------------
1 | """Point-wise Spatial Attention Network"""
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | up_kwargs = {'mode': 'bilinear', 'align_corners': True}
7 | norm_layer = nn.BatchNorm2d
8 |
9 |
10 | class _ConvBNReLU(nn.Module):
11 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
12 | dilation=1, groups=1, relu6=False, norm_layer=norm_layer):
13 | super(_ConvBNReLU, self).__init__()
14 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False)
15 | self.bn = norm_layer(out_channels)
16 | self.relu = nn.ReLU6(True) if relu6 else nn.ReLU(True)
17 |
18 | def forward(self, x):
19 | x = self.conv(x)
20 | x = self.bn(x)
21 | x = self.relu(x)
22 | return x
23 |
24 |
25 | class PSAHead(nn.Module):
26 | def __init__(self, in_channels=768, num_classes=6, norm_layer=norm_layer, in_index=3):
27 | super(PSAHead, self).__init__()
28 | self.in_index = in_index
29 | # psa_out_channels = crop_size // stride_rate ** 2
30 | psa_out_channels = (512 // 32) ** 2
31 | self.psa = _PointwiseSpatialAttention(in_channels, psa_out_channels, norm_layer)
32 |
33 | self.conv_post = _ConvBNReLU(psa_out_channels, in_channels, 1, norm_layer=norm_layer)
34 | self.project = nn.Sequential(
35 | _ConvBNReLU(in_channels * 2, in_channels // 2, 3, padding=1, norm_layer=norm_layer),
36 | nn.Dropout2d(0.1, False),
37 | nn.Conv2d(in_channels // 2, num_classes, 1))
38 |
39 | def _transform_inputs(self, inputs):
40 | if isinstance(self.in_index, (list, tuple)):
41 | inputs = [inputs[i] for i in self.in_index]
42 | elif isinstance(self.in_index, int):
43 | inputs = inputs[self.in_index]
44 | return inputs
45 |
46 | def forward(self, inputs):
47 | x = self._transform_inputs(inputs)
48 | global_feature = self.psa(x)
49 | out = self.conv_post(global_feature)
50 | out = torch.cat([x, out], dim=1)
51 | out = self.project(out)
52 |
53 | return out
54 |
55 |
56 | class _PointwiseSpatialAttention(nn.Module):
57 | def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d):
58 | super(_PointwiseSpatialAttention, self).__init__()
59 | reduced_channels = out_channels // 2
60 | self.collect_attention = _AttentionGeneration(in_channels, reduced_channels, out_channels, norm_layer)
61 | self.distribute_attention = _AttentionGeneration(in_channels, reduced_channels, out_channels, norm_layer)
62 |
63 | def forward(self, x):
64 | collect_fm = self.collect_attention(x)
65 | distribute_fm = self.distribute_attention(x)
66 | psa_fm = torch.cat([collect_fm, distribute_fm], dim=1)
67 | return psa_fm
68 |
69 |
70 | class _AttentionGeneration(nn.Module):
71 | def __init__(self, in_channels, reduced_channels, out_channels, norm_layer):
72 | super(_AttentionGeneration, self).__init__()
73 | self.conv_reduce = _ConvBNReLU(in_channels, reduced_channels, 1, norm_layer=norm_layer)
74 | self.attention = nn.Sequential(
75 | _ConvBNReLU(reduced_channels, reduced_channels, 1, norm_layer=norm_layer),
76 | nn.Conv2d(reduced_channels, out_channels, 1, bias=False))
77 |
78 | self.reduced_channels = reduced_channels
79 |
80 | def forward(self, x):
81 | reduce_x = self.conv_reduce(x)
82 | attention = self.attention(reduce_x)
83 | n, c, h, w = attention.size()
84 | attention = attention.view(n, c, -1)
85 | reduce_x = reduce_x.view(n, self.reduced_channels, -1)
86 | fm = torch.bmm(reduce_x, torch.softmax(attention, dim=1))
87 | fm = fm.view(n, self.reduced_channels, h, w)
88 |
89 | return fm
90 |
--------------------------------------------------------------------------------
/models/head/psp.py:
--------------------------------------------------------------------------------
1 | ###########################################################################
2 | # Created by: Hang Zhang
3 | # Email: zhang.hang@rutgers.edu
4 | # Copyright (c) 2017
5 | ###########################################################################
6 | from __future__ import division
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 |
12 | up_kwargs = {'mode': 'bilinear', 'align_corners': True}
13 | norm_layer = nn.BatchNorm2d
14 |
15 |
16 | class PyramidPooling(nn.Module):
17 | """
18 | Reference:
19 | Zhao, Hengshuang, et al. *"Pyramid scene parsing network."*
20 | """
21 | def __init__(self, in_channels, norm_layer, up_kwargs):
22 | super(PyramidPooling, self).__init__()
23 | self.pool1 = nn.AdaptiveAvgPool2d(1)
24 | self.pool2 = nn.AdaptiveAvgPool2d(2)
25 | self.pool3 = nn.AdaptiveAvgPool2d(3)
26 | self.pool4 = nn.AdaptiveAvgPool2d(6)
27 |
28 | out_channels = int(in_channels/4)
29 | self.conv1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False),
30 | norm_layer(out_channels),
31 | nn.ReLU(True))
32 | self.conv2 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False),
33 | norm_layer(out_channels),
34 | nn.ReLU(True))
35 | self.conv3 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False),
36 | norm_layer(out_channels),
37 | nn.ReLU(True))
38 | self.conv4 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False),
39 | norm_layer(out_channels),
40 | nn.ReLU(True))
41 | # bilinear interpolate options
42 | self._up_kwargs = up_kwargs
43 |
44 | def forward(self, x):
45 | _, _, h, w = x.size()
46 | feat1 = F.interpolate(self.conv1(self.pool1(x)), (h, w), **self._up_kwargs)
47 | feat2 = F.interpolate(self.conv2(self.pool2(x)), (h, w), **self._up_kwargs)
48 | feat3 = F.interpolate(self.conv3(self.pool3(x)), (h, w), **self._up_kwargs)
49 | feat4 = F.interpolate(self.conv4(self.pool4(x)), (h, w), **self._up_kwargs)
50 | return torch.cat((x, feat1, feat2, feat3, feat4), 1)
51 |
52 |
53 | class PSPHead(nn.Module):
54 | def __init__(self, in_channels, num_classes, norm_layer=norm_layer, up_kwargs=up_kwargs, in_index=3):
55 | super(PSPHead, self).__init__()
56 | inter_channels = in_channels // 4
57 | self.in_index = in_index
58 | self.conv5 = nn.Sequential(PyramidPooling(in_channels, norm_layer, up_kwargs),
59 | nn.Conv2d(in_channels * 2, inter_channels, 3, padding=1, bias=False),
60 | norm_layer(inter_channels),
61 | nn.ReLU(True),
62 | nn.Dropout(0.1, False),
63 | nn.Conv2d(inter_channels, num_classes, 1))
64 |
65 | def _transform_inputs(self, inputs):
66 | if isinstance(self.in_index, (list, tuple)):
67 | inputs = [inputs[i] for i in self.in_index]
68 | elif isinstance(self.in_index, int):
69 | inputs = inputs[self.in_index]
70 | return inputs
71 |
72 | def forward(self, inputs):
73 | x = self._transform_inputs(inputs)
74 | return self.conv5(x)
75 |
--------------------------------------------------------------------------------
/models/head/seg.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | up_kwargs = {'mode': 'bilinear', 'align_corners': False}
6 |
7 |
8 | def conv3x3(in_planes, out_planes, stride=1):
9 | """3x3 convolution with padding"""
10 | conv3x3 = nn.Sequential(
11 | nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
12 | padding=1, bias=False),
13 | nn.BatchNorm2d(out_planes),
14 | nn.ReLU(inplace=True),
15 | )
16 | return conv3x3
17 |
18 |
19 | class SegHead(nn.Module):
20 | def __init__(self, in_channels=[96, 192, 384, 768], num_classes=6, in_index=[0, 1, 2, 3]):
21 | super(SegHead, self).__init__()
22 | self.in_index = in_index
23 |
24 | self.conv1 = conv3x3(in_channels[0], in_channels[0])
25 | self.conv2 = conv3x3(in_channels[1], in_channels[0])
26 | self.conv3 = conv3x3(in_channels[2], in_channels[0])
27 | self.conv4 = conv3x3(in_channels[3], in_channels[0])
28 | self.final_layer = nn.Sequential(
29 | nn.Conv2d(
30 | in_channels=in_channels[0] * 4,
31 | out_channels=in_channels[0] * 4,
32 | kernel_size=1,
33 | stride=1,
34 | padding=0),
35 | nn.BatchNorm2d(in_channels[0] * 4),
36 | nn.ReLU(inplace=True),
37 | nn.Conv2d(
38 | in_channels=in_channels[0] * 4,
39 | out_channels=num_classes,
40 | kernel_size=1,
41 | stride=1,
42 | padding=0)
43 | )
44 |
45 | def _transform_inputs(self, inputs):
46 | if isinstance(self.in_index, (list, tuple)):
47 | inputs = [inputs[i] for i in self.in_index]
48 | elif isinstance(self.in_index, int):
49 | inputs = inputs[self.in_index]
50 | return inputs
51 |
52 | def forward(self, inputs):
53 | inputs = self._transform_inputs(inputs)
54 | p2, p3, p4, p5 = inputs
55 | h, w = p2.shape[-2:]
56 | x2 = self.conv1(p2)
57 | x3 = F.interpolate(self.conv2(p3), size=(h, w), **up_kwargs)
58 | x4 = F.interpolate(self.conv3(p4), size=(h, w), **up_kwargs)
59 | x5 = F.interpolate(self.conv4(p5), size=(h, w), **up_kwargs)
60 | x = torch.cat((x2, x3, x4, x5), dim=1)
61 | x = self.final_layer(x)
62 |
63 | return x
64 |
--------------------------------------------------------------------------------
/models/head/unet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torch.nn.functional as F
4 |
5 |
6 | up_kwargs = {'mode': 'bilinear', 'align_corners': True}
7 | norm_layer = nn.BatchNorm2d
8 |
9 |
10 | class Conv2dReLU(nn.Sequential):
11 | def __init__(
12 | self,
13 | in_channels,
14 | out_channels,
15 | kernel_size,
16 | padding=0,
17 | stride=1,
18 | use_batchnorm=True,
19 | ):
20 |
21 | conv = nn.Conv2d(
22 | in_channels,
23 | out_channels,
24 | kernel_size,
25 | stride=stride,
26 | padding=padding,
27 | bias=not (use_batchnorm),
28 | )
29 | relu = nn.ReLU(inplace=True)
30 |
31 | if use_batchnorm:
32 | bn = nn.BatchNorm2d(out_channels)
33 | else:
34 | bn = nn.Identity()
35 |
36 | super(Conv2dReLU, self).__init__(conv, bn, relu)
37 |
38 |
39 | class SCSEAttention(nn.Module):
40 | def __init__(self, in_channels, reduction=16):
41 | super().__init__()
42 | self.cSE = nn.Sequential(
43 | nn.AdaptiveAvgPool2d(1),
44 | nn.Conv2d(in_channels, in_channels // reduction, 1),
45 | nn.ReLU(inplace=True),
46 | nn.Conv2d(in_channels // reduction, in_channels, 1),
47 | nn.Sigmoid(),
48 | )
49 | self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid())
50 |
51 | def forward(self, x):
52 | return x * self.cSE(x) + x * self.sSE(x)
53 |
54 |
55 | class DecoderBlock(nn.Module):
56 | def __init__(
57 | self,
58 | in_channels,
59 | skip_channels,
60 | out_channels,
61 | use_batchnorm=True,
62 | use_attention=False,
63 | ):
64 | super().__init__()
65 | self.conv1 = Conv2dReLU(
66 | in_channels + skip_channels,
67 | out_channels,
68 | kernel_size=3,
69 | padding=1,
70 | use_batchnorm=use_batchnorm,
71 | )
72 |
73 | self.conv2 = Conv2dReLU(
74 | out_channels,
75 | out_channels,
76 | kernel_size=3,
77 | padding=1,
78 | use_batchnorm=use_batchnorm,
79 | )
80 | self.use_attention = use_attention
81 | if self.use_attention:
82 | self.attention1 = SCSEAttention(in_channels=in_channels + skip_channels)
83 | self.attention2 = SCSEAttention(in_channels=out_channels)
84 |
85 | def forward(self, x, skip=None):
86 | x = F.interpolate(x, scale_factor=2, **up_kwargs)
87 | if skip is not None:
88 | x = torch.cat([x, skip], dim=1)
89 | if self.use_attention:
90 | x = self.attention1(x)
91 | x = self.conv1(x)
92 | x = self.conv2(x)
93 | if self.use_attention:
94 | x = self.attention2(x)
95 |
96 | return x
97 |
98 |
99 | class CenterBlock(nn.Sequential):
100 | def __init__(self, in_channels, out_channels, use_batchnorm=True):
101 | conv1 = Conv2dReLU(
102 | in_channels,
103 | out_channels,
104 | kernel_size=3,
105 | padding=1,
106 | use_batchnorm=use_batchnorm,
107 | )
108 | conv2 = Conv2dReLU(
109 | out_channels,
110 | out_channels,
111 | kernel_size=3,
112 | padding=1,
113 | use_batchnorm=use_batchnorm,
114 | )
115 | super().__init__(conv1, conv2)
116 |
117 |
118 | class UNetHead(nn.Module):
119 | def __init__(
120 | self,
121 | in_channels,
122 | num_classes=6,
123 | n_blocks=4,
124 | use_batchnorm=True,
125 | use_attention=False,
126 | center=False,
127 | in_index=[0, 1, 2, 3],
128 | ):
129 | super(UNetHead, self).__init__()
130 | self.in_index = in_index
131 | decoder_channels = [in_channels[i] // 4 for i in self.in_index]
132 | if n_blocks != len(decoder_channels):
133 | raise ValueError(
134 | "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
135 | n_blocks, len(decoder_channels)
136 | )
137 | )
138 | encoder_channels = in_channels[::-1] # reverse channels to start from head of encoder
139 |
140 | # computing blocks input and output channels
141 | head_channels = encoder_channels[0]
142 | in_channels = [head_channels] + list(decoder_channels[:-1])
143 | skip_channels = list(encoder_channels[1:]) + [0]
144 | out_channels = decoder_channels
145 |
146 | if center:
147 | self.center = CenterBlock(
148 | head_channels, head_channels, use_batchnorm=use_batchnorm
149 | )
150 | else:
151 | self.center = nn.Identity()
152 | # combine decoder keyword arguments
153 | kwargs = dict(use_batchnorm=use_batchnorm, use_attention=use_attention)
154 | blocks = [
155 | DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)
156 | for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels)
157 | ]
158 | self.blocks = nn.ModuleList(blocks)
159 | self.head = nn.Conv2d(out_channels[-1], num_classes, kernel_size=1)
160 |
161 | def _transform_inputs(self, inputs):
162 | if isinstance(self.in_index, (list, tuple)):
163 | inputs = [inputs[i] for i in self.in_index]
164 | elif isinstance(self.in_index, int):
165 | inputs = inputs[self.in_index]
166 | return inputs
167 |
168 | def forward(self, features):
169 |
170 | features = self._transform_inputs(features)
171 | features = features[::-1] # reverse channels to start from head of encoder
172 |
173 | head = features[0]
174 | skips = features[1:]
175 |
176 | x = self.center(head)
177 | for i, decoder_block in enumerate(self.blocks):
178 | skip = skips[i] if i < len(skips) else None
179 | x = decoder_block(x, skip)
180 | x = self.head(x)
181 | return x
182 |
--------------------------------------------------------------------------------
/models/head/uper.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import torch.nn as nn
3 | import torch
4 | from .base_decoder import BaseDecodeHead, resize
5 | from mmcv.cnn import ConvModule
6 |
7 | norm_cfg = dict(type='BN', requires_grad=True)
8 |
9 |
10 | class PPM(nn.ModuleList):
11 | """Pooling Pyramid Module used in PSPNet.
12 | Args:
13 | pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
14 | Module.
15 | in_channels (int): Input channels.
16 | channels (int): Channels after modules, before conv_seg.
17 | conv_cfg (dict|None): Config of conv layers.
18 | norm_cfg (dict|None): Config of norm layers.
19 | act_cfg (dict): Config of activation layers.
20 | align_corners (bool): align_corners argument of F.interpolate.
21 | """
22 |
23 | def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg,
24 | act_cfg, align_corners):
25 | super(PPM, self).__init__()
26 | self.pool_scales = pool_scales
27 | self.align_corners = align_corners
28 | self.in_channels = in_channels
29 | self.channels = channels
30 | self.conv_cfg = conv_cfg
31 | self.norm_cfg = norm_cfg
32 | self.act_cfg = act_cfg
33 | for pool_scale in pool_scales:
34 | self.append(
35 | nn.Sequential(
36 | nn.AdaptiveAvgPool2d(pool_scale),
37 | ConvModule(
38 | self.in_channels,
39 | self.channels,
40 | 1,
41 | conv_cfg=self.conv_cfg,
42 | norm_cfg=self.norm_cfg,
43 | act_cfg=self.act_cfg)))
44 |
45 | def forward(self, x):
46 | """Forward function."""
47 | ppm_outs = []
48 | for ppm in self:
49 | """ppm work on batch > 1 when training"""
50 | ppm_out = ppm(x)
51 | upsampled_ppm_out = resize(
52 | ppm_out,
53 | size=x.size()[2:],
54 | mode='bilinear',
55 | align_corners=self.align_corners)
56 | ppm_outs.append(upsampled_ppm_out)
57 | return ppm_outs
58 |
59 |
60 | class UPerHead(BaseDecodeHead):
61 | """Unified Perceptual Parsing for Scene Understanding.
62 | This head is the implementation of `UPerNet
63 | `_.
64 | Args:
65 | pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
66 | Module applied on the last feature. Default: (1, 2, 3, 6).
67 | """
68 |
69 | def __init__(self, pool_scales=(1, 2, 3, 6), in_channels=[96, 192, 384, 768], num_classes=6):
70 | super(UPerHead, self).__init__(
71 | input_transform='multiple_select', in_index=[0, 1, 2, 3], in_channels=in_channels, num_classes=num_classes,
72 | channels=512, dropout_ratio=0.1, norm_cfg=norm_cfg, align_corners=False)
73 | # PSP Module
74 | self.psp_modules = PPM(
75 | pool_scales,
76 | self.in_channels[-1],
77 | self.channels,
78 | conv_cfg=self.conv_cfg,
79 | norm_cfg=self.norm_cfg,
80 | act_cfg=self.act_cfg,
81 | align_corners=self.align_corners)
82 | self.bottleneck = ConvModule(
83 | self.in_channels[-1] + len(pool_scales) * self.channels,
84 | self.channels,
85 | 3,
86 | padding=1,
87 | conv_cfg=self.conv_cfg,
88 | norm_cfg=self.norm_cfg,
89 | act_cfg=self.act_cfg)
90 | # FPN Module
91 | self.lateral_convs = nn.ModuleList()
92 | self.fpn_convs = nn.ModuleList()
93 | for in_channels in self.in_channels[:-1]: # skip the top layer
94 | l_conv = ConvModule(
95 | in_channels,
96 | self.channels,
97 | 1,
98 | conv_cfg=self.conv_cfg,
99 | norm_cfg=self.norm_cfg,
100 | act_cfg=self.act_cfg,
101 | inplace=False)
102 | fpn_conv = ConvModule(
103 | self.channels,
104 | self.channels,
105 | 3,
106 | padding=1,
107 | conv_cfg=self.conv_cfg,
108 | norm_cfg=self.norm_cfg,
109 | act_cfg=self.act_cfg,
110 | inplace=False)
111 | self.lateral_convs.append(l_conv)
112 | self.fpn_convs.append(fpn_conv)
113 |
114 | self.fpn_bottleneck = ConvModule(
115 | len(self.in_channels) * self.channels,
116 | self.channels,
117 | 3,
118 | padding=1,
119 | conv_cfg=self.conv_cfg,
120 | norm_cfg=self.norm_cfg,
121 | act_cfg=self.act_cfg)
122 |
123 | def psp_forward(self, inputs):
124 | """Forward function of PSP module."""
125 | x = inputs[-1]
126 | psp_outs = [x]
127 | psp_outs.extend(self.psp_modules(x))
128 | psp_outs = torch.cat(psp_outs, dim=1)
129 | output = self.bottleneck(psp_outs)
130 |
131 | return output
132 |
133 | def forward(self, inputs):
134 | """Forward function."""
135 |
136 | inputs = self._transform_inputs(inputs)
137 |
138 | # build laterals
139 | laterals = [
140 | lateral_conv(inputs[i])
141 | for i, lateral_conv in enumerate(self.lateral_convs)
142 | ]
143 |
144 | laterals.append(self.psp_forward(inputs))
145 |
146 | # build top-down path
147 | used_backbone_levels = len(laterals)
148 | for i in range(used_backbone_levels - 1, 0, -1):
149 | prev_shape = laterals[i - 1].shape[2:]
150 | laterals[i - 1] += resize(
151 | laterals[i],
152 | size=prev_shape,
153 | mode='bilinear',
154 | align_corners=self.align_corners)
155 |
156 | # build outputs
157 | fpn_outs = [
158 | self.fpn_convs[i](laterals[i])
159 | for i in range(used_backbone_levels - 1)
160 | ]
161 | # append psp feature
162 | fpn_outs.append(laterals[-1])
163 |
164 | for i in range(used_backbone_levels - 1, 0, -1):
165 | fpn_outs[i] = resize(
166 | fpn_outs[i],
167 | size=fpn_outs[0].shape[2:],
168 | mode='bilinear',
169 | align_corners=self.align_corners)
170 | fpn_outs = torch.cat(fpn_outs, dim=1)
171 | output = self.fpn_bottleneck(fpn_outs)
172 | output = self.cls_seg(output)
173 | return output
174 |
175 |
176 |
177 |
--------------------------------------------------------------------------------
/models/model_store.py:
--------------------------------------------------------------------------------
1 | """Model store which provides pretrained models."""
2 | from __future__ import print_function
3 | __all__ = ['get_model_file', 'purge']
4 | import os
5 | import zipfile
6 | import portalocker
7 |
8 | from .utils import download, check_sha1
9 |
10 | _model_sha1 = {name: checksum for checksum, name in [
11 | # resnest
12 | ('fb9de5b360976e3e8bd3679d3e93c5409a5eff3c', 'resnest50'),
13 | ('966fb78c22323b0c68097c5c1242bd16d3e07fd5', 'resnest101'),
14 | ('d7fd712f5a1fcee5b3ce176026fbb6d0d278454a', 'resnest200'),
15 | ('51ae5f19032e22af4ec08e695496547acdba5ce5', 'resnest269'),
16 | # rectified
17 | #('9b5dc32b3b36ca1a6b41ecd4906830fc84dae8ed', 'resnet101_rt'),
18 | # resnet other variants
19 | ('a75c83cfc89a56a4e8ba71b14f1ec67e923787b3', 'resnet50s'),
20 | ('03a0f310d6447880f1b22a83bd7d1aa7fc702c6e', 'resnet101s'),
21 | ('36670e8bc2428ecd5b7db1578538e2dd23872813', 'resnet152s'),
22 | # other segmentation backbones
23 | ('da4785cfc837bf00ef95b52fb218feefe703011f', 'wideresnet38'),
24 | ('b41562160173ee2e979b795c551d3c7143b1e5b5', 'wideresnet50'),
25 | # deepten paper
26 | ('1225f149519c7a0113c43a056153c1bb15468ac0', 'deepten_resnet50_minc'),
27 | # segmentation resnet models
28 | ('662e979de25a389f11c65e9f1df7e06c2c356381', 'fcn_resnet50s_ade'),
29 | ('4de91d5922d4d3264f678b663f874da72e82db00', 'encnet_resnet50s_pcontext'),
30 | ('9f27ea13d514d7010e59988341bcbd4140fcc33d', 'encnet_resnet101s_pcontext'),
31 | ('07ac287cd77e53ea583f37454e17d30ce1509a4a', 'encnet_resnet50s_ade'),
32 | ('3f54fa3b67bac7619cd9b3673f5c8227cf8f4718', 'encnet_resnet101s_ade'),
33 | # resnest segmentation models
34 | ('4aba491aaf8e4866a9c9981b210e3e3266ac1f2a', 'fcn_resnest50_ade'),
35 | ('2225f09d0f40b9a168d9091652194bc35ec2a5a9', 'deeplab_resnest50_ade'),
36 | ('06ca799c8cc148fe0fafb5b6d052052935aa3cc8', 'deeplab_resnest101_ade'),
37 | ('7b9e7d3e6f0e2c763c7d77cad14d306c0a31fe05', 'deeplab_resnest200_ade'),
38 | ('0074dd10a6e6696f6f521653fb98224e75955496', 'deeplab_resnest269_ade'),
39 | ('77a2161deeb1564e8b9c41a4bb7a3f33998b00ad', 'fcn_resnest50_pcontext'),
40 | ('08dccbc4f4694baab631e037a374d76d8108c61f', 'deeplab_resnest50_pcontext'),
41 | ('faf5841853aae64bd965a7bdc2cdc6e7a2b5d898', 'deeplab_resnest101_pcontext'),
42 | ('fe76a26551dd5dcf2d474fd37cba99d43f6e984e', 'deeplab_resnest200_pcontext'),
43 | ('b661fd26c49656e01e9487cd9245babb12f37449', 'deeplab_resnest269_pcontext'),
44 | ]}
45 |
46 | encoding_repo_url = 'https://s3.us-west-1.wasabisys.com/encoding'
47 | _url_format = '{repo_url}models/{file_name}.zip'
48 |
49 | def short_hash(name):
50 | if name not in _model_sha1:
51 | raise ValueError('Pretrained model for {name} is not available.'.format(name=name))
52 | return _model_sha1[name][:8]
53 |
54 | def get_model_file(name, root=os.path.join('~', '.encoding', 'models')):
55 | r"""Return location for the pretrained on local file system.
56 | This function will download from online model zoo when model cannot be found or has mismatch.
57 | The root directory will be created if it doesn't exist.
58 | Parameters
59 | ----------
60 | name : str
61 | Name of the model.
62 | root : str, default '~/.encoding/models'
63 | Location for keeping the model parameters.
64 | Returns
65 | -------
66 | file_path
67 | Path to the requested pretrained model file.
68 | """
69 | if name not in _model_sha1:
70 | from torchvision.models.resnet import model_urls
71 | if name not in model_urls:
72 | raise ValueError('Pretrained model for {name} is not available.'.format(name=name))
73 | root = os.path.expanduser(root)
74 | return download(model_urls[name],
75 | path=root,
76 | overwrite=True)
77 | file_name = '{name}-{short_hash}'.format(name=name, short_hash=short_hash(name))
78 | root = os.path.expanduser(root)
79 | if not os.path.exists(root):
80 | os.makedirs(root)
81 |
82 | file_path = os.path.join(root, file_name+'.pth')
83 | sha1_hash = _model_sha1[name]
84 |
85 | lockfile = os.path.join(root, file_name + '.lock')
86 | with portalocker.Lock(lockfile, timeout=300):
87 | if os.path.exists(file_path):
88 | if check_sha1(file_path, sha1_hash):
89 | return file_path
90 | else:
91 | print('Mismatch in the content of model file {} detected.' +
92 | ' Downloading again.'.format(file_path))
93 | else:
94 | print('Model file {} is not found. Downloading.'.format(file_path))
95 |
96 | zip_file_path = os.path.join(root, file_name+'.zip')
97 | repo_url = os.environ.get('ENCODING_REPO', encoding_repo_url)
98 | if repo_url[-1] != '/':
99 | repo_url = repo_url + '/'
100 | download(_url_format.format(repo_url=repo_url, file_name=file_name),
101 | path=zip_file_path,
102 | overwrite=True)
103 | with zipfile.ZipFile(zip_file_path) as zf:
104 | zf.extractall(root)
105 | os.remove(zip_file_path)
106 |
107 | if check_sha1(file_path, sha1_hash):
108 | return file_path
109 | else:
110 | raise ValueError('Downloaded file has different hash. Please try again.')
111 |
112 | def purge(root=os.path.join('~', '.encoding', 'models')):
113 | r"""Purge all pretrained model files in local file store.
114 | Parameters
115 | ----------
116 | root : str, default '~/.encoding/models'
117 | Location for keeping the model parameters.
118 | """
119 | root = os.path.expanduser(root)
120 | files = os.listdir(root)
121 | for f in files:
122 | if f.endswith(".pth"):
123 | os.remove(os.path.join(root, f))
124 |
125 | def pretrained_model_list():
126 | return list(_model_sha1.keys())
--------------------------------------------------------------------------------
/models/pspnet.py:
--------------------------------------------------------------------------------
1 | """Pyramid Scene Parsing Network"""
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | try:
6 | from .resnet import resnet50_v1b
7 | except:
8 | from resnet import resnet50_v1b
9 |
10 |
11 | class SeparableConv2d(nn.Module):
12 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, padding=1,
13 | dilation=1, bias=False, norm_layer=nn.BatchNorm2d):
14 | super(SeparableConv2d, self).__init__()
15 | self.conv = nn.Conv2d(inplanes, inplanes, kernel_size, stride, padding, dilation, groups=inplanes, bias=bias)
16 | self.bn = norm_layer(inplanes)
17 | self.pointwise = nn.Conv2d(inplanes, planes, 1, bias=bias)
18 |
19 | def forward(self, x):
20 | x = self.conv(x)
21 | x = self.bn(x)
22 | x = self.pointwise(x)
23 | return x
24 |
25 |
26 | # copy from: https://github.com/wuhuikai/FastFCN/blob/master/encoding/nn/customize.py
27 | class JPU(nn.Module):
28 | def __init__(self, in_channels, width=512, norm_layer=nn.BatchNorm2d, **kwargs):
29 | super(JPU, self).__init__()
30 |
31 | self.conv5 = nn.Sequential(
32 | nn.Conv2d(in_channels[-1], width, 3, padding=1, bias=False),
33 | norm_layer(width),
34 | nn.ReLU(True))
35 | self.conv4 = nn.Sequential(
36 | nn.Conv2d(in_channels[-2], width, 3, padding=1, bias=False),
37 | norm_layer(width),
38 | nn.ReLU(True))
39 | self.conv3 = nn.Sequential(
40 | nn.Conv2d(in_channels[-3], width, 3, padding=1, bias=False),
41 | norm_layer(width),
42 | nn.ReLU(True))
43 |
44 | self.dilation1 = nn.Sequential(
45 | SeparableConv2d(3 * width, width, 3, padding=1, dilation=1, bias=False),
46 | norm_layer(width),
47 | nn.ReLU(True))
48 | self.dilation2 = nn.Sequential(
49 | SeparableConv2d(3 * width, width, 3, padding=2, dilation=2, bias=False),
50 | norm_layer(width),
51 | nn.ReLU(True))
52 | self.dilation3 = nn.Sequential(
53 | SeparableConv2d(3 * width, width, 3, padding=4, dilation=4, bias=False),
54 | norm_layer(width),
55 | nn.ReLU(True))
56 | self.dilation4 = nn.Sequential(
57 | SeparableConv2d(3 * width, width, 3, padding=8, dilation=8, bias=False),
58 | norm_layer(width),
59 | nn.ReLU(True))
60 |
61 | def forward(self, *inputs):
62 | feats = [self.conv5(inputs[-1]), self.conv4(inputs[-2]), self.conv3(inputs[-3])]
63 | size = feats[-1].size()[2:]
64 | feats[-2] = F.interpolate(feats[-2], size, mode='bilinear', align_corners=True)
65 | feats[-3] = F.interpolate(feats[-3], size, mode='bilinear', align_corners=True)
66 | feat = torch.cat(feats, dim=1)
67 | feat = torch.cat([self.dilation1(feat), self.dilation2(feat), self.dilation3(feat), self.dilation4(feat)],
68 | dim=1)
69 |
70 | return inputs[0], inputs[1], inputs[2], feat
71 |
72 |
73 | class SegBaseModel(nn.Module):
74 | r"""Base Model for Semantic Segmentation
75 |
76 | Parameters
77 | ----------
78 | backbone : string
79 | Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50',
80 | 'resnet101' or 'resnet152').
81 | """
82 |
83 | def __init__(self, nclass, aux, backbone='resnet50', jpu=False, pretrained_base=False, **kwargs):
84 | super(SegBaseModel, self).__init__()
85 | dilated = False if jpu else True
86 | self.aux = aux
87 | self.nclass = nclass
88 | if backbone == 'resnet50':
89 | self.pretrained = resnet50_v1b(pretrained=pretrained_base, dilated=dilated, **kwargs)
90 |
91 | else:
92 | raise RuntimeError('unknown backbone: {}'.format(backbone))
93 |
94 | self.jpu = JPU([512, 1024, 2048], width=512, **kwargs) if jpu else None
95 |
96 | def base_forward(self, x):
97 | """forwarding pre-trained network"""
98 | x = self.pretrained.conv1(x)
99 | x = self.pretrained.bn1(x)
100 | x = self.pretrained.relu(x)
101 | x = self.pretrained.maxpool(x)
102 | c1 = self.pretrained.layer1(x)
103 | c2 = self.pretrained.layer2(c1)
104 | c3 = self.pretrained.layer3(c2)
105 | c4 = self.pretrained.layer4(c3)
106 |
107 | if self.jpu:
108 | return self.jpu(c1, c2, c3, c4)
109 | else:
110 | return c1, c2, c3, c4
111 |
112 | def evaluate(self, x):
113 | """evaluating network with inputs and targets"""
114 | return self.forward(x)[0]
115 |
116 | def demo(self, x):
117 | pred = self.forward(x)
118 | if self.aux:
119 | pred = pred[0]
120 | return pred
121 |
122 |
123 | class _FCNHead(nn.Module):
124 | def __init__(self, in_channels, channels, norm_layer=nn.BatchNorm2d, **kwargs):
125 | super(_FCNHead, self).__init__()
126 | inter_channels = in_channels // 4
127 | self.block = nn.Sequential(
128 | nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
129 | norm_layer(inter_channels),
130 | nn.ReLU(inplace=True),
131 | nn.Dropout(0.1),
132 | nn.Conv2d(inter_channels, channels, 1)
133 | )
134 |
135 | def forward(self, x):
136 | return self.block(x)
137 |
138 |
139 | class PSPNet(SegBaseModel):
140 | r"""Pyramid Scene Parsing Network
141 |
142 | Parameters
143 | ----------
144 | nclass : int
145 | Number of categories for the training dataset.
146 | backbone : string
147 | Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50',
148 | 'resnet101' or 'resnet152').
149 | norm_layer : object
150 | Normalization layer used in backbone network (default: :class:`nn.BatchNorm`;
151 | for Synchronized Cross-GPU BachNormalization).
152 | aux : bool
153 | Auxiliary loss.
154 |
155 | Reference:
156 | Zhao, Hengshuang, Jianping Shi, Xiaojuan Qi, Xiaogang Wang, and Jiaya Jia.
157 | "Pyramid scene parsing network." *CVPR*, 2017
158 | """
159 |
160 | def __init__(self, nclass, backbone='resnet50', aux=False, pretrained_base=False, **kwargs):
161 | super(PSPNet, self).__init__(nclass, aux, backbone, pretrained_base=pretrained_base, **kwargs)
162 | self.head = _PSPHead(nclass, **kwargs)
163 | if self.aux:
164 | self.auxlayer = _FCNHead(1024, nclass, **kwargs)
165 |
166 | self.__setattr__('exclusive', ['head', 'auxlayer'] if aux else ['head'])
167 |
168 | def forward(self, x):
169 | size = x.size()[2:]
170 | _, _, c3, c4 = self.base_forward(x)
171 | outputs = []
172 | x = self.head(c4)
173 | x = F.interpolate(x, size, mode='bilinear', align_corners=True)
174 | outputs = x
175 |
176 | if self.aux:
177 | auxout = self.auxlayer(c3)
178 | auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True)
179 | outputs.append(auxout)
180 | return outputs
181 |
182 |
183 | def _PSP1x1Conv(in_channels, out_channels, norm_layer, norm_kwargs):
184 | return nn.Sequential(
185 | nn.Conv2d(in_channels, out_channels, 1, bias=False),
186 | norm_layer(out_channels, **({} if norm_kwargs is None else norm_kwargs)),
187 | nn.ReLU(True)
188 | )
189 |
190 |
191 | class _PyramidPooling(nn.Module):
192 | def __init__(self, in_channels, **kwargs):
193 | super(_PyramidPooling, self).__init__()
194 | out_channels = int(in_channels / 4)
195 | self.avgpool1 = nn.AdaptiveAvgPool2d(1)
196 | self.avgpool2 = nn.AdaptiveAvgPool2d(2)
197 | self.avgpool3 = nn.AdaptiveAvgPool2d(3)
198 | self.avgpool4 = nn.AdaptiveAvgPool2d(6)
199 | self.conv1 = _PSP1x1Conv(in_channels, out_channels, **kwargs)
200 | self.conv2 = _PSP1x1Conv(in_channels, out_channels, **kwargs)
201 | self.conv3 = _PSP1x1Conv(in_channels, out_channels, **kwargs)
202 | self.conv4 = _PSP1x1Conv(in_channels, out_channels, **kwargs)
203 |
204 | def forward(self, x):
205 | size = x.size()[2:]
206 | feat1 = F.interpolate(self.conv1(self.avgpool1(x)), size, mode='bilinear', align_corners=True)
207 | feat2 = F.interpolate(self.conv2(self.avgpool2(x)), size, mode='bilinear', align_corners=True)
208 | feat3 = F.interpolate(self.conv3(self.avgpool3(x)), size, mode='bilinear', align_corners=True)
209 | feat4 = F.interpolate(self.conv4(self.avgpool4(x)), size, mode='bilinear', align_corners=True)
210 | return torch.cat([x, feat1, feat2, feat3, feat4], dim=1)
211 |
212 |
213 | class _PSPHead(nn.Module):
214 | def __init__(self, nclass, norm_layer=nn.BatchNorm2d, norm_kwargs=None, **kwargs):
215 | super(_PSPHead, self).__init__()
216 | self.psp = _PyramidPooling(2048, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
217 | self.block = nn.Sequential(
218 | nn.Conv2d(4096, 512, 3, padding=1, bias=False),
219 | norm_layer(512, **({} if norm_kwargs is None else norm_kwargs)),
220 | nn.ReLU(True),
221 | nn.Dropout(0.1),
222 | nn.Conv2d(512, nclass, 1)
223 | )
224 |
225 | def forward(self, x):
226 | x = self.psp(x)
227 | return self.block(x)
228 |
229 |
230 | if __name__ == '__main__':
231 | from tools.flops_params_fps_count import flops_params_fps
232 | model = PSPNet(nclass=6)
233 | flops_params_fps(model)
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.utils.model_zoo as model_zoo
4 |
5 | __all__ = ['ResNetV1b', 'resnet18_v1b', 'resnet34_v1b', 'resnet50_v1b',
6 | 'resnet101_v1b', 'resnet152_v1b', 'resnet152_v1s', 'resnet101_v1s', 'resnet50_v1s']
7 |
8 | model_urls = {
9 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
10 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
11 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
12 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
13 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
14 | }
15 |
16 | pretrained_save_dir = './pretrained_weights'
17 |
18 | class BasicBlockV1b(nn.Module):
19 | expansion = 1
20 |
21 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None,
22 | previous_dilation=1, norm_layer=nn.BatchNorm2d):
23 | super(BasicBlockV1b, self).__init__()
24 | self.conv1 = nn.Conv2d(inplanes, planes, 3, stride,
25 | dilation, dilation, bias=False)
26 | self.bn1 = norm_layer(planes)
27 | self.relu = nn.ReLU(True)
28 | self.conv2 = nn.Conv2d(planes, planes, 3, 1, previous_dilation,
29 | dilation=previous_dilation, bias=False)
30 | self.bn2 = norm_layer(planes)
31 | self.downsample = downsample
32 | self.stride = stride
33 |
34 | def forward(self, x):
35 | identity = x
36 |
37 | out = self.conv1(x)
38 | out = self.bn1(out)
39 | out = self.relu(out)
40 |
41 | out = self.conv2(out)
42 | out = self.bn2(out)
43 |
44 | if self.downsample is not None:
45 | identity = self.downsample(x)
46 |
47 | out += identity
48 | out = self.relu(out)
49 |
50 | return out
51 |
52 |
53 | class BottleneckV1b(nn.Module):
54 | expansion = 4
55 |
56 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None,
57 | previous_dilation=1, norm_layer=nn.BatchNorm2d):
58 | super(BottleneckV1b, self).__init__()
59 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
60 | self.bn1 = norm_layer(planes)
61 | self.conv2 = nn.Conv2d(planes, planes, 3, stride,
62 | dilation, dilation, bias=False)
63 | self.bn2 = norm_layer(planes)
64 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
65 | self.bn3 = norm_layer(planes * self.expansion)
66 | self.relu = nn.ReLU(True)
67 | self.downsample = downsample
68 | self.stride = stride
69 |
70 | def forward(self, x):
71 | identity = x
72 |
73 | out = self.conv1(x)
74 | out = self.bn1(out)
75 | out = self.relu(out)
76 |
77 | out = self.conv2(out)
78 | out = self.bn2(out)
79 | out = self.relu(out)
80 |
81 | out = self.conv3(out)
82 | out = self.bn3(out)
83 |
84 | if self.downsample is not None:
85 | identity = self.downsample(x)
86 |
87 | out += identity
88 | out = self.relu(out)
89 |
90 | return out
91 |
92 |
93 | class ResNetV1b(nn.Module):
94 |
95 | def __init__(self, block, layers, num_classes=1000, dilated=True, deep_stem=False,
96 | zero_init_residual=False, norm_layer=nn.BatchNorm2d):
97 | self.inplanes = 128 if deep_stem else 64
98 | super(ResNetV1b, self).__init__()
99 | if deep_stem:
100 | self.conv1 = nn.Sequential(
101 | nn.Conv2d(3, 64, 3, 2, 1, bias=False),
102 | norm_layer(64),
103 | nn.ReLU(True),
104 | nn.Conv2d(64, 64, 3, 1, 1, bias=False),
105 | norm_layer(64),
106 | nn.ReLU(True),
107 | nn.Conv2d(64, 128, 3, 1, 1, bias=False)
108 | )
109 | else:
110 | self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False)
111 | self.bn1 = norm_layer(self.inplanes)
112 | self.relu = nn.ReLU(True)
113 | self.maxpool = nn.MaxPool2d(3, 2, 1)
114 | self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer)
115 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer)
116 | if dilated:
117 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2, norm_layer=norm_layer)
118 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4, norm_layer=norm_layer)
119 | else:
120 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer)
121 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer)
122 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
123 | self.fc = nn.Linear(512 * block.expansion, num_classes)
124 |
125 | for m in self.modules():
126 | if isinstance(m, nn.Conv2d):
127 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
128 | elif isinstance(m, nn.BatchNorm2d):
129 | nn.init.constant_(m.weight, 1)
130 | nn.init.constant_(m.bias, 0)
131 |
132 | if zero_init_residual:
133 | for m in self.modules():
134 | if isinstance(m, BottleneckV1b):
135 | nn.init.constant_(m.bn3.weight, 0)
136 | elif isinstance(m, BasicBlockV1b):
137 | nn.init.constant_(m.bn2.weight, 0)
138 |
139 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=nn.BatchNorm2d):
140 | downsample = None
141 | if stride != 1 or self.inplanes != planes * block.expansion:
142 | downsample = nn.Sequential(
143 | nn.Conv2d(self.inplanes, planes * block.expansion, 1, stride, bias=False),
144 | norm_layer(planes * block.expansion),
145 | )
146 |
147 | layers = []
148 | if dilation in (1, 2):
149 | layers.append(block(self.inplanes, planes, stride, dilation=1, downsample=downsample,
150 | previous_dilation=dilation, norm_layer=norm_layer))
151 | elif dilation == 4:
152 | layers.append(block(self.inplanes, planes, stride, dilation=2, downsample=downsample,
153 | previous_dilation=dilation, norm_layer=norm_layer))
154 | else:
155 | raise RuntimeError("=> unknown dilation size: {}".format(dilation))
156 | self.inplanes = planes * block.expansion
157 | for _ in range(1, blocks):
158 | layers.append(block(self.inplanes, planes, dilation=dilation,
159 | previous_dilation=dilation, norm_layer=norm_layer))
160 |
161 | return nn.Sequential(*layers)
162 |
163 | def forward(self, x):
164 | x = self.conv1(x)
165 | x = self.bn1(x)
166 | x = self.relu(x)
167 | x = self.maxpool(x)
168 |
169 | x = self.layer1(x)
170 | x = self.layer2(x)
171 | x = self.layer3(x)
172 | x = self.layer4(x)
173 |
174 | x = self.avgpool(x)
175 | x = x.view(x.size(0), -1)
176 | x = self.fc(x)
177 |
178 | return x
179 |
180 |
181 | def resnet18_v1b(pretrained=False, **kwargs):
182 | model = ResNetV1b(BasicBlockV1b, [2, 2, 2, 2], **kwargs)
183 | if pretrained:
184 | old_dict = model_zoo.load_url(model_urls['resnet18'], model_dir=pretrained_save_dir)
185 | model_dict = model.state_dict()
186 | old_dict = {k: v for k, v in old_dict.items() if (k in model_dict)}
187 | model_dict.update(old_dict)
188 | model.load_state_dict(model_dict)
189 | return model
190 |
191 |
192 | def resnet34_v1b(pretrained=False, **kwargs):
193 | model = ResNetV1b(BasicBlockV1b, [3, 4, 6, 3], **kwargs)
194 | if pretrained:
195 | old_dict = model_zoo.load_url(model_urls['resnet34'], model_dir=pretrained_save_dir)
196 | model_dict = model.state_dict()
197 | old_dict = {k: v for k, v in old_dict.items() if (k in model_dict)}
198 | model_dict.update(old_dict)
199 | model.load_state_dict(model_dict)
200 | return model
201 |
202 |
203 | def resnet50_v1b(pretrained=False, **kwargs):
204 | model = ResNetV1b(BottleneckV1b, [3, 4, 6, 3], **kwargs)
205 | if pretrained:
206 | print('load pretrain resnet50_v1b')
207 | old_dict = model_zoo.load_url(model_urls['resnet50'], model_dir=pretrained_save_dir)
208 | model_dict = model.state_dict()
209 | old_dict = {k: v for k, v in old_dict.items() if (k in model_dict)}
210 | model_dict.update(old_dict)
211 | model.load_state_dict(model_dict)
212 | return model
213 |
214 |
215 | def resnet101_v1b(pretrained=False, **kwargs):
216 | model = ResNetV1b(BottleneckV1b, [3, 4, 23, 3], **kwargs)
217 | if pretrained:
218 | print('load pretrain resnet101_v1b')
219 | old_dict = model_zoo.load_url(model_urls['resnet101'], model_dir=pretrained_save_dir)
220 | model_dict = model.state_dict()
221 | old_dict = {k: v for k, v in old_dict.items() if (k in model_dict)}
222 | model_dict.update(old_dict)
223 | model.load_state_dict(model_dict)
224 | return model
225 |
226 |
227 | def resnet152_v1b(pretrained=False, **kwargs):
228 | model = ResNetV1b(BottleneckV1b, [3, 8, 36, 3], **kwargs)
229 | if pretrained:
230 | print('load pretrain resnet152_v1b')
231 | old_dict = model_zoo.load_url(model_urls['resnet152'], model_dir=pretrained_save_dir)
232 | model_dict = model.state_dict()
233 | old_dict = {k: v for k, v in old_dict.items() if (k in model_dict)}
234 | model_dict.update(old_dict)
235 | model.load_state_dict(model_dict)
236 | return model
237 |
238 |
239 | def resnet50_v1s(pretrained=False, root=pretrained_save_dir, **kwargs):
240 | model = ResNetV1b(BottleneckV1b, [3, 4, 6, 3], deep_stem=True, **kwargs)
241 | if pretrained:
242 | print('load pretrain resnet50_v1s')
243 | from models.model_store import get_model_file
244 | model.load_state_dict(torch.load(get_model_file('resnet50s', root=root)), strict=False)
245 | return model
246 |
247 |
248 | def resnet101_v1s(pretrained=False, root=pretrained_save_dir, **kwargs):
249 | model = ResNetV1b(BottleneckV1b, [3, 4, 23, 3], deep_stem=True, **kwargs)
250 | if pretrained:
251 | print('load pretrain resnet101_v1s')
252 | from .model_store import get_model_file
253 | model.load_state_dict(torch.load(get_model_file('resnet101s', root=root)), strict=False)
254 | return model
255 |
256 |
257 | def resnet152_v1s(pretrained=False, root=pretrained_save_dir, **kwargs):
258 | model = ResNetV1b(BottleneckV1b, [3, 8, 36, 3], deep_stem=True, **kwargs)
259 | if pretrained:
260 | print('load pretrain resnet152_v1s')
261 | from .model_store import get_model_file
262 | model.load_state_dict(torch.load(get_model_file('resnet152s', root=root)), strict=False)
263 | return model
264 |
265 |
266 | if __name__ == '__main__':
267 | import torch
268 |
269 | img = torch.randn(4, 3, 224, 224)
270 | model = resnet50_v1b(True)
271 | output = model(img)
--------------------------------------------------------------------------------
/models/segbase.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | try:
3 | from .resnet import resnet50_v1b
4 | except:
5 | from resnet import resnet50_v1b
6 | import torch.nn.functional as F
7 | from models.head.seg import SegHead
8 |
9 |
10 | class SegBaseModel(nn.Module):
11 | r"""Base Model for Semantic Segmentation
12 |
13 | Parameters
14 | ----------
15 | backbone : string
16 | Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50',
17 | 'resnet101' or 'resnet152').
18 | """
19 |
20 | def __init__(self, nclass, backbone='resnet50', dilated=True, pretrained_base=False, **kwargs):
21 | super(SegBaseModel, self).__init__()
22 | self.nclass = nclass
23 | if backbone == 'resnet50':
24 | self.pretrained = resnet50_v1b(pretrained=pretrained_base, dilated=dilated, **kwargs)
25 |
26 | def base_forward(self, x):
27 | """forwarding pre-trained network"""
28 | x = self.pretrained.conv1(x)
29 | x = self.pretrained.bn1(x)
30 | x = self.pretrained.relu(x)
31 | x = self.pretrained.maxpool(x)
32 | c1 = self.pretrained.layer1(x)
33 | c2 = self.pretrained.layer2(c1)
34 | c3 = self.pretrained.layer3(c2)
35 | c4 = self.pretrained.layer4(c3)
36 |
37 | return c1, c2, c3, c4
38 |
39 |
40 | class _FCNHead(nn.Module):
41 | def __init__(self, in_channels, channels, norm_layer=nn.BatchNorm2d, **kwargs):
42 | super(_FCNHead, self).__init__()
43 | inter_channels = in_channels // 4
44 | self.block = nn.Sequential(
45 | nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
46 | norm_layer(inter_channels),
47 | nn.ReLU(inplace=True),
48 | nn.Dropout(0.1),
49 | nn.Conv2d(inter_channels, channels, 1)
50 | )
51 |
52 | def forward(self, x):
53 | return self.block(x)
54 |
55 |
56 | # class SegBase(SegBaseModel):
57 | #
58 | # def __init__(self, nclass, backbone='resnet50', pretrained_base=False, **kwargs):
59 | # super(SegBase, self).__init__(nclass, backbone, pretrained_base=pretrained_base, **kwargs)
60 | # self.head = _FCNHead(2048, nclass, **kwargs)
61 | #
62 | # def forward(self, x):
63 | # size = x.size()[2:]
64 | # _, _, c3, c4 = self.base_forward(x)
65 | # x = self.head(c4)
66 | # x = F.interpolate(x, size, mode='bilinear', align_corners=True)
67 | #
68 | # return x
69 |
70 |
71 | class SegBase(SegBaseModel):
72 |
73 | def __init__(self, nclass, backbone='resnet50', pretrained_base=False, **kwargs):
74 | super(SegBase, self).__init__(nclass, backbone, pretrained_base=pretrained_base, **kwargs)
75 | cnn_dict = {
76 | 'resnet18_v1b': 'resnet18_v1b', 'resnet18': 'resnet18_v1b',
77 | 'resnet34_v1b': 'resnet34_v1b', 'resnet34': 'resnet34_v1b',
78 | 'resnet50_v1b': 'resnet50_v1b', 'resnet50': 'resnet50_v1b',
79 | 'resnet101_v1b': 'resnet101_v1b', 'resnet101': 'resnet101_v1b',
80 | 'hrnet18': 'hrnet18', 'HRNet18': 'hrnet18',
81 | 'hrnet32': 'hrnet32', 'HRNet32': 'hrnet32',
82 | 'hrnet48': 'hrnet48', 'HRNet48': 'hrnet48',
83 | }
84 | cnn_name = backbone
85 | if 'resnet18' in cnn_dict[cnn_name] or 'resnet34' in cnn_dict[cnn_name]:
86 | self.cnn_head_dim = [64, 128, 256, 512]
87 | if 'resnet50' in cnn_dict[cnn_name] or 'resnet101' in cnn_dict[cnn_name]:
88 | self.cnn_head_dim = [256, 512, 1024, 2048]
89 | self.head = SegHead(in_channels=self.cnn_head_dim, num_classes=nclass, in_index=[0, 1, 2, 3])
90 |
91 | def forward(self, x):
92 | size = x.size()[2:]
93 | out_backbone = self.base_forward(x)
94 | x = self.head(out_backbone)
95 | x = F.interpolate(x, size, mode='bilinear', align_corners=True)
96 |
97 | return x
98 |
99 |
100 | if __name__ == '__main__':
101 | from tools.flops_params_fps_count import flops_params_fps
102 | model = SegBase(nclass=6)
103 | flops_params_fps(model)
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
--------------------------------------------------------------------------------
/models/transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from models.head import *
5 | from models.resT import rest_tiny, rest_small, rest_base, rest_large
6 | from models.swinT import swin_tiny, swin_small, swin_base, swin_large
7 | from models.volo import volo_d1, volo_d2, volo_d3, volo_d4, volo_d5
8 | from models.cswin import cswin_tiny, cswin_base, cswin_small, cswin_large
9 | from models.beit import beit_base, beit_large
10 | #from tools.heatmap_fun import draw_features
11 |
12 | up_kwargs = {'mode': 'bilinear', 'align_corners': False}
13 |
14 |
15 | class Transformer(nn.Module):
16 |
17 | def __init__(self, transformer_name, nclass, img_size, aux=False, pretrained=False, head='seghead', edge_aux=False):
18 | super(Transformer, self).__init__()
19 | self.aux = aux
20 | self.edge_aux = edge_aux
21 | self.head_name = head
22 |
23 | self.model = eval(transformer_name)(nclass=nclass, img_size=img_size, aux=aux, pretrained=pretrained)
24 | self.backbone = self.model.backbone
25 |
26 | head_dim = self.model.head_dim
27 | if self.head_name == 'apchead':
28 | self.decode_head = APCHead(in_channels=head_dim[3], num_classes=nclass, in_index=3, channels=512)
29 |
30 | if self.head_name == 'aspphead':
31 | self.decode_head = ASPPHead(in_channels=head_dim[3], num_classes=nclass, in_index=3)
32 |
33 | if self.head_name == 'asppplushead':
34 | self.decode_head = ASPPPlusHead(in_channels=head_dim[3], num_classes=nclass, in_index=[0, 3])
35 |
36 | if self.head_name == 'dahead':
37 | self.decode_head = DAHead(in_channels=head_dim[3], num_classes=nclass, in_index=3)
38 |
39 | if self.head_name == 'dnlhead':
40 | self.decode_head = DNLHead(in_channels=head_dim[3], num_classes=nclass, in_index=3, channels=512)
41 |
42 | if self.head_name == 'fcfpnhead':
43 | self.decode_head = FCFPNHead(in_channels=head_dim, num_classes=nclass, in_index=[0, 1, 2, 3], channels=256)
44 |
45 | if self.head_name == 'cefpnhead':
46 | self.decode_head = CEFPNHead(in_channels=head_dim, num_classes=nclass, in_index=[0, 1, 2, 3], channels=256)
47 |
48 | if self.head_name == 'fcnhead':
49 | self.decode_head = FCNHead(in_channels=head_dim[3], num_classes=nclass, in_index=3, channels=512)
50 |
51 | if self.head_name == 'gchead':
52 | self.decode_head = GCHead(in_channels=head_dim[3], num_classes=nclass, in_index=3, channels=512)
53 |
54 | if self.head_name == 'psahead':
55 | self.decode_head = PSAHead(in_channels=head_dim[3], num_classes=nclass, in_index=3)
56 |
57 | if self.head_name == 'psphead':
58 | self.decode_head = PSPHead(in_channels=head_dim[3], num_classes=nclass, in_index=3)
59 |
60 | if self.head_name == 'seghead':
61 | self.decode_head = SegHead(in_channels=head_dim, num_classes=nclass, in_index=[0, 1, 2, 3])
62 |
63 | if self.head_name == 'unethead':
64 | self.decode_head = UNetHead(in_channels=head_dim, num_classes=nclass, in_index=[0, 1, 2, 3])
65 |
66 | if self.head_name == 'uperhead':
67 | self.decode_head = UPerHead(in_channels=head_dim, num_classes=nclass)
68 |
69 | if self.head_name == 'annhead':
70 | self.decode_head = ANNHead(in_channels=head_dim[2:], num_classes=nclass, in_index=[2, 3], channels=512)
71 |
72 | if self.head_name == 'mlphead':
73 | self.decode_head = MLPHead(in_channels=head_dim, num_classes=nclass, in_index=[0, 1, 2, 3], channels=256)
74 |
75 | if self.aux:
76 | self.auxiliary_head = FCNHead(num_convs=1, in_channels=head_dim[2], num_classes=nclass, in_index=2, channels=256)
77 |
78 | if self.edge_aux:
79 | self.edge_head = EdgeHead(in_channels=head_dim[0:2], in_index=[0, 1], channels=head_dim[0])
80 |
81 | def forward(self, x):
82 | size = x.size()[2:]
83 | outputs = []
84 |
85 | out_backbone = self.backbone(x)
86 |
87 | # for i, out in enumerate(out_backbone):
88 | # draw_features(out, f'C{i}')
89 |
90 | x0 = self.decode_head(out_backbone)
91 | if isinstance(x0, (list, tuple)):
92 | for out in x0:
93 | out = F.interpolate(out, size, **up_kwargs)
94 | outputs.append(out)
95 | else:
96 | x0 = F.interpolate(x0, size, **up_kwargs)
97 | outputs.append(x0)
98 |
99 | if self.aux:
100 | x1 = self.auxiliary_head(out_backbone)
101 | x1 = F.interpolate(x1, size, **up_kwargs)
102 | outputs.append(x1)
103 |
104 | if self.edge_aux:
105 | edge = self.edge_head(out_backbone)
106 | edge = F.interpolate(edge, size, **up_kwargs)
107 | outputs.append(edge)
108 |
109 | return outputs
110 |
111 |
112 | if __name__ == '__main__':
113 | """Notice if torch1.6, try to replace a / b with torch.true_divide(a, b)"""
114 | from tools.flops_params_fps_count import flops_params_fps
115 |
116 | model = Transformer(transformer_name='cswin_tiny', nclass=6, img_size=512, aux=True, edge_aux=False,
117 | head='uperhead', pretrained=False)
118 | flops_params_fps(model)
119 |
--------------------------------------------------------------------------------
/models/unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class DoubleConv(nn.Module):
7 | """(convolution => [BN] => ReLU) * 2"""
8 |
9 | def __init__(self, in_channels, out_channels):
10 | super().__init__()
11 | self.double_conv = nn.Sequential(
12 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
13 | nn.BatchNorm2d(out_channels),
14 | nn.ReLU(inplace=True),
15 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
16 | nn.BatchNorm2d(out_channels),
17 | nn.ReLU(inplace=True)
18 | )
19 |
20 | def forward(self, x):
21 | return self.double_conv(x)
22 |
23 |
24 | class Down(nn.Module):
25 | """Downscaling with maxpool then double conv"""
26 |
27 | def __init__(self, in_channels, out_channels):
28 | super().__init__()
29 | self.maxpool_conv = nn.Sequential(
30 | nn.MaxPool2d(2),
31 | DoubleConv(in_channels, out_channels)
32 | )
33 |
34 | def forward(self, x):
35 | return self.maxpool_conv(x)
36 |
37 |
38 | class Up(nn.Module):
39 | """Upscaling then double conv"""
40 |
41 | def __init__(self, in_channels, out_channels, bilinear=True):
42 | super().__init__()
43 |
44 | # if bilinear, use the normal convolutions to reduce the number of channels
45 | if bilinear:
46 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
47 | else:
48 | self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
49 |
50 | self.conv = DoubleConv(in_channels, out_channels)
51 |
52 | def forward(self, x1, x2):
53 | x1 = self.up(x1)
54 | # input is CHW
55 | diffY = x2.size()[2] - x1.size()[2]
56 | diffX = x2.size()[3] - x1.size()[3]
57 |
58 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
59 | diffY // 2, diffY - diffY // 2])
60 | # if you have padding issues, see
61 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
62 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
63 | x = torch.cat([x2, x1], dim=1)
64 | return self.conv(x)
65 |
66 |
67 | class OutConv(nn.Module):
68 | def __init__(self, in_channels, out_channels):
69 | super(OutConv, self).__init__()
70 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
71 |
72 | def forward(self, x):
73 | return self.conv(x)
74 |
75 | class UNet(nn.Module):
76 | def __init__(self, nclass, bilinear=True):
77 | super(UNet, self).__init__()
78 | self.n_channels = 3
79 | self.n_classes = nclass
80 | self.bilinear = bilinear
81 |
82 | self.inc = DoubleConv(self.n_channels, 64)
83 | self.down1 = Down(64, 128)
84 | self.down2 = Down(128, 256)
85 | self.down3 = Down(256, 512)
86 | self.down4 = Down(512, 512)
87 | self.up1 = Up(1024, 256, bilinear)
88 | self.up2 = Up(512, 128, bilinear)
89 | self.up3 = Up(256, 64, bilinear)
90 | self.up4 = Up(128, 64, bilinear)
91 | self.outc = OutConv(64, self.n_classes)
92 |
93 | def forward(self, x):
94 | x1 = self.inc(x)
95 | x2 = self.down1(x1)
96 | x3 = self.down2(x2)
97 | x4 = self.down3(x3)
98 | x5 = self.down4(x4)
99 | x = self.up1(x5, x4)
100 | x = self.up2(x, x3)
101 | x = self.up3(x, x2)
102 | x = self.up4(x, x1)
103 | logits = self.outc(x)
104 | return logits
105 |
106 |
107 | if __name__ == '__main__':
108 | from tools.flops_params_fps_count import flops_params_fps
109 | model = UNet(nclass=6)
110 | flops_params_fps(model)
111 |
112 |
--------------------------------------------------------------------------------
/models/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import requests
3 | import errno
4 | import shutil
5 | import hashlib
6 | from tqdm import tqdm
7 | import torch
8 |
9 | __all__ = ['save_checkpoint', 'download', 'mkdir', 'check_sha1']
10 |
11 | def save_checkpoint(state, args, is_best, filename='checkpoint.pth.tar'):
12 | """Saves checkpoint to disk"""
13 | if hasattr(args, 'backbone'):
14 | directory = "runs/%s/%s/%s/%s/"%(args.dataset, args.model, args.backbone, args.checkname)
15 | else:
16 | directory = "runs/%s/%s/%s/"%(args.dataset, args.model, args.checkname)
17 | if not os.path.exists(directory):
18 | os.makedirs(directory)
19 | filename = directory + filename
20 | torch.save(state, filename)
21 | if is_best:
22 | shutil.copyfile(filename, directory + 'model_best.pth.tar')
23 |
24 |
25 | def download(url, path=None, overwrite=False, sha1_hash=None):
26 | """Download an given URL
27 | Parameters
28 | ----------
29 | url : str
30 | URL to download
31 | path : str, optional
32 | Destination path to store downloaded file. By default stores to the
33 | current directory with same name as in url.
34 | overwrite : bool, optional
35 | Whether to overwrite destination file if already exists.
36 | sha1_hash : str, optional
37 | Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified
38 | but doesn't match.
39 | Returns
40 | -------
41 | str
42 | The file path of the downloaded file.
43 | """
44 | if path is None:
45 | fname = url.split('/')[-1]
46 | else:
47 | path = os.path.expanduser(path)
48 | if os.path.isdir(path):
49 | fname = os.path.join(path, url.split('/')[-1])
50 | else:
51 | fname = path
52 |
53 | if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)):
54 | dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
55 | if not os.path.exists(dirname):
56 | os.makedirs(dirname)
57 |
58 | print('Downloading %s from %s...'%(fname, url))
59 | r = requests.get(url, stream=True)
60 | if r.status_code != 200:
61 | raise RuntimeError("Failed downloading url %s"%url)
62 | total_length = r.headers.get('content-length')
63 | with open(fname, 'wb') as f:
64 | if total_length is None: # no content length header
65 | for chunk in r.iter_content(chunk_size=1024):
66 | if chunk: # filter out keep-alive new chunks
67 | f.write(chunk)
68 | else:
69 | total_length = int(total_length)
70 | for chunk in tqdm(r.iter_content(chunk_size=1024),
71 | total=int(total_length / 1024. + 0.5),
72 | unit='KB', unit_scale=False, dynamic_ncols=True):
73 | f.write(chunk)
74 |
75 | if sha1_hash and not check_sha1(fname, sha1_hash):
76 | raise UserWarning('File {} is downloaded but the content hash does not match. ' \
77 | 'The repo may be outdated or download may be incomplete. ' \
78 | 'If the "repo_url" is overridden, consider switching to ' \
79 | 'the default repo.'.format(fname))
80 |
81 | return fname
82 |
83 |
84 | def check_sha1(filename, sha1_hash):
85 | """Check whether the sha1 hash of the file content matches the expected hash.
86 | Parameters
87 | ----------
88 | filename : str
89 | Path to the file.
90 | sha1_hash : str
91 | Expected sha1 hash in hexadecimal digits.
92 | Returns
93 | -------
94 | bool
95 | Whether the file content matches the expected hash.
96 | """
97 | sha1 = hashlib.sha1()
98 | with open(filename, 'rb') as f:
99 | while True:
100 | data = f.read(1048576)
101 | if not data:
102 | break
103 | sha1.update(data)
104 |
105 | return sha1.hexdigest() == sha1_hash
106 |
107 |
108 | def mkdir(path):
109 | """make dir exists okay"""
110 | try:
111 | os.makedirs(path)
112 | except OSError as exc: # Python >2.5
113 | if exc.errno == errno.EEXIST and os.path.isdir(path):
114 | pass
115 | else:
116 | raise
--------------------------------------------------------------------------------
/mutil_scale_test.py:
--------------------------------------------------------------------------------
1 | ###########################################################################
2 | # Created by: Hang Zhang
3 | # Email: zhang.hang@rutgers.edu
4 | # Copyright (c) 2017
5 | ###########################################################################
6 |
7 | import math
8 | import torch
9 | import torch.nn.functional as F
10 | import numpy as np
11 | import torch.nn as nn
12 | from torch.nn.parallel.data_parallel import DataParallel
13 |
14 |
15 | up_kwargs = {'mode': 'bilinear', 'align_corners': False}
16 |
17 |
18 | def module_inference(module, image, flip=True):
19 | if flip:
20 | h_img = h_flip_image(image)
21 | v_img = v_flip_image(image)
22 | img = torch.cat([image, h_img, v_img], dim=0)
23 | cat_output = module(img)
24 | if isinstance(cat_output, (list, tuple)):
25 | cat_output = cat_output[0]
26 | output, h_output, v_output = cat_output.chunk(3, dim=0)
27 | output = output + h_flip_image(h_output) + v_flip_image(v_output)
28 | else:
29 | output = module(image)
30 | if isinstance(output, (list, tuple)):
31 | output = output[0]
32 |
33 | return output
34 |
35 |
36 | def resize_image(img, h, w, **up_kwargs):
37 | return F.upsample(img, (h, w), **up_kwargs)
38 |
39 |
40 | def pad_image(img, crop_size):
41 | """crop_size could be list:[h, w] or int"""
42 | b,c,h,w = img.size()
43 | # assert(c==3)
44 | if len(crop_size) > 1:
45 | padh = crop_size[0] - h if h < crop_size[0] else 0
46 | padw = crop_size[1] - w if w < crop_size[1] else 0
47 | else:
48 | padh = crop_size - h if h < crop_size else 0
49 | padw = crop_size - w if w < crop_size else 0
50 | # pad_values = -np.array(mean) / np.array(std)
51 | img_pad = img.new().resize_(b,c,h+padh,w+padw)
52 | # for i in range(c):
53 | # note that pytorch pad params is in reversed orders
54 | min_padh = min(padh, h)
55 | min_padw = min(padw, w)
56 | if padw < w and padh < h:
57 | img_pad[:, :, :, :] = F.pad(img[:, :, :, :], (0, padw, 0, padh), mode='reflect')
58 | else:
59 | img_pad[:, :, 0:h + min_padh - 1, 0:w + min_padw - 1] = \
60 | F.pad(img[:, :, :, :], (0, min_padw - 1, 0, min_padh - 1), mode='reflect')
61 |
62 | img_pad[:, :, :, :] = F.pad(img_pad[:, :, 0:h + min_padh - 1, 0:w + min_padw - 1],
63 | (0, padw - min_padw + 1, 0, padh - min_padh + 1), mode='constant', value=0)
64 | if len(crop_size) > 1:
65 | assert (img_pad.size(2) >= crop_size[0] and img_pad.size(3) >= crop_size[1])
66 | else:
67 | assert(img_pad.size(2)>=crop_size and img_pad.size(3)>=crop_size)
68 | return img_pad
69 |
70 |
71 | def crop_image(img, h0, h1, w0, w1):
72 | return img[:,:,h0:h1,w0:w1]
73 |
74 |
75 | def h_flip_image(img):
76 | assert(img.dim()==4)
77 | with torch.cuda.device_of(img):
78 | idx = torch.arange(img.size(3)-1, -1, -1).type_as(img).long()
79 | return img.index_select(3, idx)
80 |
81 |
82 | def v_flip_image(img):
83 | assert(img.dim()==4)
84 | with torch.cuda.device_of(img):
85 | idx = torch.arange(img.size(3)-1, -1, -1).type_as(img).long()
86 | return img.index_select(2, idx)
87 |
88 |
89 | def hv_flip_image(img):
90 | assert(img.dim()==4)
91 | with torch.cuda.device_of(img):
92 | idx = torch.arange(img.size(3)-1, -1, -1).type_as(img).long()
93 | img = img.index_select(3, idx)
94 | return img.index_select(2, idx)
95 |
96 |
97 | class MultiEvalModule_Fullimg(DataParallel):
98 | """Multi-size Segmentation Eavluator"""
99 | def __init__(self, module, nclass, device_ids=None, flip=True,
100 | # scales=[1.0]):
101 | # scales=[1.0,1.25]):
102 | # scales=[0.5, 0.75,1.0,1.25,1.5]):
103 | scales=[1.0]):
104 | super(MultiEvalModule_Fullimg, self).__init__(module, device_ids)
105 | self.nclass = nclass
106 | self.base_size = 256
107 | self.crop_size = 256
108 | self.scales = scales
109 | self.flip = flip
110 | print('MultiEvalModule_Fullimg: base_size {}, crop_size {}'. \
111 | format(self.base_size, self.crop_size))
112 |
113 | def forward(self, image):
114 | """Mult-size Evaluation"""
115 | batch, _, h, w = image.size()
116 |
117 | with torch.cuda.device_of(image):
118 | scores = image.new().resize_(batch,self.nclass,h,w).zero_().cuda()
119 | for scale in self.scales:
120 | crop_size = int(math.ceil(self.crop_size * scale))
121 |
122 | cur_img = resize_image(image, crop_size, crop_size, **up_kwargs)
123 | outputs = module_inference(self.module, cur_img, self.flip)
124 | score = resize_image(outputs, h, w, **up_kwargs)
125 | scores += score
126 |
127 | return scores
128 |
129 |
130 | class MultiEvalModule(nn.Module):
131 | """Multi-size Segmentation Eavluator"""
132 | def __init__(self, module, nclass, device_ids=None, flip=True, save_gpu_memory=False,
133 | scales=[1.0], get_batch=1, crop_size=[512, 512], stride_rate=1/2):
134 | #scales=[0.5,0.75,1,1.25]):
135 | #scales=[0.5,0.75,1.0,1.25,1.4,1.6,1.8]):
136 | #scales=[1]):
137 | # super(MultiEvalModule, self).__init__(module, device_ids)
138 | super(MultiEvalModule, self).__init__()
139 | self.module = module
140 | self.devices_ids = device_ids
141 | self.nclass = nclass
142 | self.crop_size = np.array(crop_size)
143 | self.scales = scales
144 | self.flip = flip
145 | self.get_batch = get_batch
146 | self.stride_rate = stride_rate
147 | self.save_gpu_memory = save_gpu_memory # if over memory, can try this
148 |
149 | def forward(self, image):
150 | """Mult-size Evaluation"""
151 | # only single image is supported for evaluation
152 | batch, _, h, w = image.size()
153 | # assert(batch == 1)
154 | stride_rate = self.stride_rate
155 | with torch.cuda.device_of(image):
156 | if self.save_gpu_memory:
157 | scores = image.new().resize_(batch, self.nclass, h, w).zero_().cpu()
158 | else:
159 | scores = image.new().resize_(batch,self.nclass,h,w).zero_().cuda()
160 |
161 | for scale in self.scales:
162 | crop_size = self.crop_size
163 | stride = (crop_size * stride_rate).astype(np.int)
164 |
165 | if h > w:
166 | long_size = int(math.ceil(h * scale))
167 | height = long_size
168 | width = int(1.0 * w * long_size / h + 0.5)
169 | short_size = width
170 | else:
171 | long_size = int(math.ceil(w * scale))
172 | width = long_size
173 | height = int(1.0 * h * long_size / w + 0.5)
174 | short_size = height
175 |
176 | # resize image to current size
177 | cur_img = resize_image(image, height, width, **up_kwargs)
178 | if long_size <= np.max(crop_size):
179 | pad_img = pad_image(cur_img, crop_size)
180 | outputs = module_inference(self.module, pad_img, self.flip)
181 | outputs = crop_image(outputs, 0, height, 0, width)
182 |
183 | else:
184 | if short_size < np.min(crop_size):
185 | # pad if needed
186 | pad_img = pad_image(cur_img, crop_size)
187 | else:
188 | pad_img = cur_img
189 | _,_,ph,pw = pad_img.size()
190 | # assert(ph >= height and pw >= width)
191 | # grid forward and normalize
192 | h_grids = int(math.ceil(1.0 * (ph-crop_size[0])/stride[0])) + 1
193 | w_grids = int(math.ceil(1.0 * (pw-crop_size[1])/stride[1])) + 1
194 | with torch.cuda.device_of(image):
195 | if self.save_gpu_memory:
196 | outputs = image.new().resize_(batch, self.nclass, ph, pw).zero_().cpu()
197 | count_norm = image.new().resize_(batch, 1, ph, pw).zero_().cpu()
198 | else:
199 | outputs = image.new().resize_(batch,self.nclass,ph,pw).zero_().cuda()
200 | count_norm = image.new().resize_(batch,1,ph,pw).zero_().cuda()
201 | # grid evaluation
202 | location = []
203 | batch_size = []
204 | pad_img = pad_image(pad_img, [ph + crop_size[0], pw + crop_size[1]]) # expand pad_image
205 |
206 | for idh in range(h_grids):
207 | for idw in range(w_grids):
208 | h0 = idh * stride[0]
209 | w0 = idw * stride[1]
210 | h1 = min(h0 + crop_size[0], ph)
211 | w1 = min(w0 + crop_size[1], pw)
212 |
213 | crop_img = crop_image(pad_img, h0, h0 + crop_size[0], w0, w0 + crop_size[1])
214 | # pad if needed
215 | pad_crop_img = pad_image(crop_img, crop_size)
216 | size_h, size_w = pad_crop_img.shape[-2:]
217 | pad_crop_img = resize_image(pad_crop_img, crop_size[0], crop_size[1], **up_kwargs)
218 | if self.get_batch > 1:
219 | location.append([h0, w0, h1, w1])
220 | batch_size.append(pad_crop_img)
221 | if len(location) == self.get_batch or (idh + idw + 2) == (h_grids + w_grids):
222 | batch_size = torch.cat(batch_size, dim=0).cuda()
223 | location = np.array(location)
224 | output = module_inference(self.module, batch_size, self.flip)
225 | output = output.detach()
226 | output = resize_image(output, size_h, size_w, **up_kwargs)
227 | if self.save_gpu_memory:
228 | output = output.detach().cpu() # to save gpu memory
229 | else:
230 | output = output.detach()
231 | for i in range(batch_size.shape[0]):
232 | outputs[:, :, location[i][0]:location[i][2], location[i][1]:location[i][3]] += \
233 | crop_image(output[i, ...].unsqueeze(dim=0), 0, location[i][2]-location[i][0], 0, location[i][3]-location[i][1])
234 | count_norm[:, :, location[i][0]:location[i][2], location[i][1]:location[i][3]] += 1
235 | location = []
236 | batch_size = []
237 | else:
238 | output = module_inference(self.module, pad_crop_img, self.flip)
239 | if self.save_gpu_memory:
240 | output = output.detach().cpu() # to save gpu memory
241 | else:
242 | output = output.detach()
243 | output = resize_image(output, size_h, size_w, **up_kwargs)
244 | outputs[:,:,h0:h1,w0:w1] += crop_image(output,
245 | 0, h1-h0, 0, w1-w0)
246 | count_norm[:,:,h0:h1,w0:w1] += 1
247 | assert((count_norm==0).sum()==0)
248 | outputs = outputs / count_norm
249 | outputs = outputs[:,:,:height,:width]
250 | score = resize_image(outputs, h, w, **up_kwargs)
251 | scores += score
252 | return scores
253 |
254 |
--------------------------------------------------------------------------------
/post_process.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author : now more
3 | Connect : lin.honghui@qq.com
4 | LastEditors: Please set LastEditors
5 | Description :
6 | LastEditTime: 2020-11-27 03:42:46
7 | '''
8 | import os
9 | import threading
10 | import cv2 as cv
11 | import numpy as np
12 | from skimage.morphology import remove_small_holes, remove_small_objects
13 | from argparse import ArgumentParser
14 | from PIL import Image
15 |
16 | Image.MAX_IMAGE_PIXELS = None
17 |
18 |
19 | def to_categorical(y, num_classes=None, dtype='float32'):
20 | """Converts a class vector (integers) to binary class matrix.
21 |
22 | E.g. for use with categorical_crossentropy.
23 |
24 | # Arguments
25 | y: class vector to be converted into a matrix
26 | (integers from 0 to num_classes).
27 | num_classes: total number of classes.
28 | dtype: The data type expected by the input, as a string
29 | (`float32`, `float64`, `int32`...)
30 |
31 | # Returns
32 | A binary matrix representation of the input. The classes axis
33 | is placed last.
34 | """
35 | y = np.array(y, dtype='int')
36 | input_shape = y.shape
37 | if input_shape and input_shape[-1] == 1 and len(input_shape) > 1:
38 | input_shape = tuple(input_shape[:-1])
39 | y = y.ravel()
40 | if not num_classes:
41 | num_classes = np.max(y) + 1
42 | n = y.shape[0]
43 | categorical = np.zeros((n, num_classes), dtype=dtype)
44 | categorical[np.arange(n), y] = 1
45 | output_shape = input_shape + (num_classes,)
46 | categorical = np.reshape(categorical, output_shape)
47 | return categorical
48 |
49 |
50 | class MyThread(threading.Thread):
51 |
52 | def __init__(self, func, args=()):
53 | super(MyThread, self).__init__()
54 | self.func = func
55 | self.args = args
56 |
57 | def run(self):
58 | self.result = self.func(*self.args)
59 |
60 | def get_result(self):
61 | try:
62 | return self.result # 如果子线程不使用join方法,此处可能会报没有self.result的错误
63 | except Exception:
64 | return None
65 |
66 |
67 | def label_resize_vis(label, img=None, alpha=0.5):
68 | '''
69 | :param label:原始标签
70 | :param img: 原始图像
71 | :param alpha: 透明度
72 | :return: 可视化标签
73 | '''
74 | def label_to_RGB(image, classes=6):
75 | RGB = np.zeros(shape=[image.shape[0], image.shape[1], 3], dtype=np.uint8)
76 | if classes == 6: # potsdam and vaihingen
77 | palette = [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], [255, 255, 0], [255, 0, 0]]
78 | if classes == 4: # barley
79 | palette = [[255, 255, 255], [0, 255, 0], [255, 255, 0], [255, 0, 0]]
80 | for i in range(classes):
81 | index = image == i
82 | RGB[index] = np.array(palette[i])
83 | return RGB
84 |
85 | # label = cv.resize(label.copy(), None, fx=0.1, fy=0.1)
86 | anno_vis = label_to_RGB(label, classes=4)
87 | if img is None:
88 | return anno_vis
89 | else:
90 | overlapping = cv.addWeighted(img, alpha, anno_vis, 1 - alpha, 0)
91 | return overlapping
92 |
93 |
94 | def remove_small_objects_and_holes(class_type, label, min_size, area_threshold, in_place=True):
95 | print("------------- class_n : {} start ------------".format(class_type))
96 | if class_type == 3:
97 | # kernel = cv.getStructuringElement(cv.MORPH_RECT,(500,500))
98 | # label = cv.dilate(label,kernel)
99 | # kernel = cv.getStructuringElement(cv.MORPH_RECT,(10,10))
100 | # label = cv.erode(label,kernel)
101 | label = remove_small_objects(label == 1, min_size=min_size, connectivity=1, in_place=in_place)
102 | label = remove_small_holes(label == 1, area_threshold=area_threshold, connectivity=1, in_place=in_place)
103 | else:
104 | label = remove_small_objects(label == 1, min_size=min_size, connectivity=1, in_place=in_place)
105 | label = remove_small_holes(label == 1, area_threshold=area_threshold, connectivity=1, in_place=in_place)
106 | print("------------- class_n : {} finished ------------".format(class_type))
107 | return label
108 |
109 |
110 | def RGB_to_label(image=None, classes=6):
111 | if classes == 6: # potsdam and vaihingen
112 | palette = [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], [255, 255, 0], [255, 0, 0]]
113 | if classes == 4: # barley
114 | palette = [[255, 255, 255], [0, 255, 0], [255, 255, 0], [255, 0, 0]]
115 | label = np.zeros(shape=[image.shape[0], image.shape[1]], dtype=np.uint8)
116 | for i in range(len(palette)):
117 | index = image == np.array(palette[i])
118 | index[..., 0][index[..., 1] == False] = False
119 | index[..., 0][index[..., 2] == False] = False
120 | label[index[..., 0]] = i
121 | return label
122 |
123 |
124 |
125 | if __name__ == "__main__":
126 | parser = ArgumentParser()
127 | parser.add_argument("--image_n", type=int, default=2, help="传入1或2,指定")
128 | parser.add_argument("--image_path", type=str, default='./outputs', help="传入image_n_predict所在路径")
129 | parser.add_argument("--threshold", type=int, default=2000)
130 | arg = parser.parse_args()
131 | image_n = arg.image_n
132 | image_path = arg.image_path
133 | threshold = arg.threshold
134 |
135 | if image_n == 1:
136 | source_image = cv.imread("../../data/barley/images_size0.1/image_1.png")
137 | elif image_n == 2:
138 | source_image = cv.imread("../../data/barley/images_size0.1/image_2.png")
139 | else:
140 | raise ValueError("image_n should be 1 or 2, Got {} ".format(image_n))
141 |
142 | img_mask_dir = os.path.join(image_path, f'image_{image_n}_mask.png')
143 | img_dir = os.path.join(image_path, f'image_{image_n}.png')
144 | if os.path.exists(img_mask_dir):
145 | image = np.asarray(Image.open(img_mask_dir))
146 | elif os.path.exists(img_dir):
147 | image = np.asarray(Image.open(img_dir))
148 | else:
149 | raise ValueError(f"Not found image_{image_n}_mask.png or image_{image_n}.png")
150 |
151 | if len(image.shape) == 3:
152 | image = RGB_to_label(image, classes=4)
153 | image_save = Image.fromarray(image)
154 | image_save.save(os.path.join(image_path, f'image_{image_n}_mask.png'))
155 |
156 | image = cv.resize(image, None, fx=0.1, fy=0.1, interpolation=cv.INTER_NEAREST) # because over memory
157 |
158 | label = to_categorical(image, num_classes=4, dtype='uint8')
159 |
160 | threading_list = []
161 | for i in range(4):
162 | t = MyThread(remove_small_objects_and_holes, args=(i, label[:, :, i], threshold, threshold, True))
163 | threading_list.append(t)
164 | t.start()
165 |
166 | # 等待所有线程运行完毕
167 | result = []
168 | for t in threading_list:
169 | t.join()
170 | result.append(t.get_result()[:, :, None])
171 |
172 | label = np.concatenate(result, axis=2)
173 |
174 | label = np.argmax(label, axis=2).astype(np.uint8)
175 | cv.imwrite('./outputs/image_' + str(image_n) + "_predict.png", label)
176 | mask = label_resize_vis(label, source_image)
177 | cv.imwrite('./outputs/vis_image_' + str(image_n) + "_predict.jpg", mask[..., ::-1])
--------------------------------------------------------------------------------
/pre_process.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import os
3 | import numpy as np
4 | from PIL import ImageFile
5 | import math
6 | import cv2
7 | ImageFile.LOAD_TRUNCATED_IMAGES = True
8 | Image.MAX_IMAGE_PIXELS = None
9 |
10 |
11 | def label_to_RGB(image, classes=4):
12 | RGB = np.zeros(shape=[image.shape[0], image.shape[1], 3], dtype=np.uint8)
13 | if classes == 6:
14 | palette = [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], [255, 255, 0], [255, 0, 0]]
15 | if classes == 4:
16 | palette = [[255, 255, 255], [0, 255, 0], [255, 255, 0], [255, 0, 0]]
17 | for i in range(classes):
18 | index = image == i
19 | RGB[index] = np.array(palette[i])
20 | return RGB
21 |
22 |
23 | def divide_img(image, oh, ow, filename, save_dir, write_txt_dir=None):
24 | """No overlap, and the last square maybe small than oh ow"""
25 | if write_txt_dir is not None:
26 | txt = open(write_txt_dir, 'w')
27 | h, w = image.shape[0:2]
28 | num_h, num_w = h // oh + 1, w // ow + 1
29 | for i in range(num_w):
30 | for j in range(num_h):
31 | h1 = min((j + 1) * oh, h)
32 | w1 = min((i + 1) * ow, w)
33 | if len(image.shape) == 2:
34 | image_part = image[j * oh:h1, i * ow:w1]
35 | else:
36 | image_part = image[j * oh:h1, i * ow:w1, :]
37 | image_part = Image.fromarray(image_part)
38 | image_part.save(os.path.join(save_dir, f'{filename}_{j}_{i}.png')) # j:h, i:w
39 | if write_txt_dir is not None:
40 | txt.write(f'{filename}_{j}_{i}' + '\n')
41 |
42 |
43 | def divide_img_overlap(image, oh, ow, filename, save_dir, write_txt_dir=None, overlap=1024):
44 | """Divide img with an overlap, the last square is back trace to oh ow"""
45 | if write_txt_dir is not None:
46 | txt = open(write_txt_dir, 'w')
47 | path, name = os.path.split(write_txt_dir)
48 | txt_clean = open(os.path.join(path, os.path.splitext(name)[0] + '_clean.txt'), 'w')
49 | h, w = image.shape[0:2]
50 | if len(image.shape) == 2:
51 | image = np.expand_dims(image, axis=2)
52 | num_h, num_w = math.ceil((h - oh) / (oh - overlap)) + 1, math.ceil((w - ow) / (ow - overlap)) + 1
53 | for i in range(num_w):
54 | for j in range(num_h):
55 | if i < num_w - 1 and j < num_h - 1:
56 | image_part = image[(oh - overlap) * j:(oh - overlap) * j + oh, (ow - overlap) * i:(ow - overlap) * i + ow, :]
57 | if i < num_w - 1 and j == num_h - 1:
58 | image_part = image[h - oh:h, (ow - overlap) * i:(ow - overlap) * i + ow, :]
59 | if i == num_w - 1 and j < num_h - 1:
60 | image_part = image[(oh - overlap) * j:(oh - overlap) * j + oh, w - ow:w, :]
61 | if i == num_w - 1 and j == num_h - 1:
62 | image_part = image[h - oh:h, w - ow:w, :]
63 | image_part = image_part.squeeze()
64 | if write_txt_dir is not None:
65 | if np.any(image_part[..., 3] > 0):
66 | txt_clean.write(f'{filename}_{j}_{i}' + '\n')
67 | txt.write(f'{filename}_{j}_{i}' + '\n')
68 | image_part = Image.fromarray(image_part)
69 | image_part.save(os.path.join(save_dir, f'{filename}_{j}_{i}.png')) # j:h, i:w
70 |
71 |
72 | def restore_part_img(oh=6000, ow=6000, overlap=1024, filename='image_1'):
73 | """restore patches of the image, the last square is back traced to oh ow"""
74 | root = '/data/xzy/datasets/'
75 | dataset = f'barley_hw6000_s{overlap}'
76 | if filename == 'image_1':
77 | h, w = 50141, 47161
78 | if filename == 'image_2':
79 | h, w = 46050, 77470
80 | num_h, num_w = math.ceil((h - oh) / (oh - overlap)) + 1, math.ceil((w - ow) / (ow - overlap)) + 1
81 | for i in range(num_w):
82 | for j in range(num_h):
83 | part_img = Image.open(os.path.join(root, dataset, f'images/{filename}_{j}_{i}.png'))
84 | part_img = np.array(part_img)
85 | if j == 0:
86 | w_patch = part_img[0:oh - overlap // 2, ...]
87 | elif j < num_h - 1:
88 | w_patch = np.concatenate((w_patch, part_img[overlap // 2:oh - overlap // 2, ...]), 0)
89 | else:
90 | end_h = w_patch.shape[0]
91 | w_patch = np.concatenate((w_patch, part_img[oh - (h - end_h):oh, ...]), 0)
92 | if i == 0:
93 | h_patch = w_patch[:, 0:ow - overlap // 2, :]
94 | elif i < num_w - 1:
95 | h_patch = np.concatenate((h_patch, w_patch[:, overlap // 2:ow - overlap // 2, :]), 1)
96 | else:
97 | end_w = h_patch.shape[1]
98 | h_patch = np.concatenate((h_patch, w_patch[:, ow - (w - end_w):ow, :]), 1)
99 | print(h_patch.shape)
100 | h_patch = cv2.resize(h_patch, None, fx=0.1, fy=0.1, interpolation=cv2.INTER_NEAREST)
101 | full_img = Image.fromarray(h_patch)
102 | full_img.save(os.path.join(root, dataset, f'{filename}.png'))
103 |
104 |
105 | def apply_divide_img_label(overlap=1024):
106 | """read images and divide them with an overlap"""
107 | root = '/data/zyxu/dataset/barley/'
108 | save = f'barley_hw6000_s{overlap}'
109 | img_dir = 'images_complete/'
110 | label_dir = 'labels_complete/'
111 | img_save_dir = os.path.join(root, f'{save}/images/')
112 | label_save_dir = os.path.join(root, f'{save}/labels/')
113 | file = [f'image_{1}', f'image_{2}']
114 |
115 | for filename in file:
116 | if filename == f'image_{1}':
117 | txt_name = os.path.join(root, f'{save}/annotations/train_full.txt')
118 | if filename == f'image_{2}':
119 | txt_name = os.path.join(root, f'{save}/annotations/test_full.txt')
120 |
121 | img = Image.open(os.path.join(root, img_dir, filename + '.png'))
122 | img = np.array(img)
123 | divide_img_overlap(img, 6000, 6000, filename, img_save_dir, txt_name, overlap=overlap)
124 |
125 | label = Image.open(os.path.join(root, label_dir, filename + '.png'))
126 | label = np.array(label)
127 | divide_img_overlap(label, 6000, 6000, filename, label_save_dir, overlap=overlap)
128 |
129 |
130 | def clean_white_background():
131 | """remove transparent patches from the train and test txt """
132 | img_dir = './data/barley/images/'
133 | img_list = os.listdir(img_dir)
134 | train_no_alpha = open('./data/barley/annotations/train_no_alpha.txt', 'w')
135 | test_no_alpha = open('./data/barley/annotations/test_no_alpha.txt', 'w')
136 | for file in img_list:
137 | file = file.strip()
138 | img = Image.open(os.path.join(img_dir, file))
139 | img = np.array(img)
140 | alpha = img[..., 3]
141 | if np.any(alpha > 0):
142 | if 'image_1' in file:
143 | train_no_alpha.write(file[:-4] + '\n')
144 | if 'image_2' in file:
145 | test_no_alpha.write(file[:-4] + '\n')
146 |
147 |
148 | def count_nums():
149 | """count pixels of each classes in images"""
150 | label_train_dir = '/data/xzy/datasets/barley/labels_full/image_1_label.png'
151 | label_test_dir = '/data/xzy/datasets/barley/labels_full/image_2_label.png'
152 | label_train = np.array(Image.open(label_train_dir))
153 | label_test = np.array(Image.open(label_test_dir))
154 | h0, w0, h1, w1 = label_train.shape[0], label_train.shape[1], label_test.shape[0], label_test.shape[1]
155 | for i in range(4):
156 | print('train pixel{}:{:.6f}'.format(i, np.sum(label_train == i) / (h0 * w0)))
157 | # 0.870185 0.066412 0.006110 0.057294
158 | # class012: [0.51158563 0.04706662 0.44134775]
159 | for i in range(4):
160 | print('test pixel{}:{:.6f}'.format(i, np.sum(label_test == i) / (h1 * w1)))
161 | # 0.926607 0.002005 0.033732 0.037655
162 | # class012: [0.02731905 0.45961413 0.51306682]
163 |
164 |
165 | def rearrange_dataset(oh=6000, ow=6000, overlap=1024):
166 | """fuse image_1 and image_2 to generate new train and test file"""
167 | root = '/data/zyxu/dataset/barley/'
168 | dataset = f'barley_hw6000_s{overlap}'
169 | train_txt = open(os.path.join(root, dataset, f'annotations/train.txt'), 'w')
170 | test_txt = open(os.path.join(root, dataset, f'annotations/test.txt'), 'w')
171 | for filename in ['image_1', 'image_2']:
172 | if filename == 'image_1':
173 | h, w = 50141, 47161
174 | if filename == 'image_2':
175 | h, w = 46050, 77470
176 | num_h, num_w = math.ceil((h - oh) / (oh - overlap)) + 1, math.ceil((w - ow) / (ow - overlap)) + 1
177 | for i in range(num_w):
178 | for j in range(num_h):
179 | part_img = Image.open(os.path.join(root, dataset, f'images/{filename}_{j}_{i}.png'))
180 | part_img = np.array(part_img)
181 | if (i + j) % 2 == 0 and np.any(part_img[..., 3] > 0):
182 | train_txt.write(f'{filename}_{j}_{i}' + '\n')
183 | if (i + j) % 2 == 1 and np.any(part_img[..., 3] > 0):
184 | test_txt.write(f'{filename}_{j}_{i}' + '\n')
185 |
186 |
187 | def resize():
188 | """resize image to 1/10"""
189 | filename = 'image_2'
190 | img1_dir = f'./data/barley/labels_view/{filename}.png'
191 | img1 = np.array(Image.open(img1_dir))
192 | img = cv2.resize(img1, None, fx=0.1, fy=0.1, interpolation=cv2.INTER_NEAREST)
193 | img = Image.fromarray(img)
194 | img.save(f'./data/barley/images_size0.1/{filename}_labels_view.png')
195 |
196 |
197 | def label_view():
198 | """view label to rgb"""
199 | root = '/data/xzy/datasets/barley/'
200 | dataset = f''
201 | train_txt = os.path.join(root, dataset, 'annotations/train.txt')
202 | test_txt = os.path.join(root, dataset, 'annotations/test.txt')
203 | file = open(train_txt, 'r').readlines() + open(test_txt, 'r').readlines()
204 | for name in file:
205 | name = name.strip()
206 | label = np.array(Image.open(os.path.join(root, dataset, 'labels', name + '.png')))
207 | label = label_to_RGB(label, 4)
208 | label = Image.fromarray(label)
209 | label.save(os.path.join(root, dataset, f'labels_view/{name}.png'))
210 |
211 |
212 | def get_alpha():
213 | """get alpha channel and save it"""
214 | name = 'image_2.png'
215 | image_dir = './data/barley/images_complete/'
216 | image = np.array(Image.open(os.path.join(image_dir, name)))
217 | alpha = image[..., 3]
218 | alpha = Image.fromarray(alpha)
219 | alpha.save(f'./data/barley/alphas_complete/{name}')
220 |
221 |
222 | if __name__ == '__main__':
223 | # apply_divide_img_label()
224 | # restore_part_img()
225 |
226 | # label1 = np.array(Image.open('./data/barley/barley_hw6000_s1024/image_2.png'))
227 | # label2 = np.array(Image.open('./data/barley/labels_complete/image_2.png'))
228 | # print(label1.shape, label2.shape)
229 | # print(np.all(label1 == label2))
230 | # print(np.sum(label1), np.sum(label2))
231 |
232 | # apply_divide_img_label(overlap=0)
233 | # restore_part_img(overlap=0)
234 | # rearrange_dataset(overlap=0)
235 | label_view()
236 |
237 |
238 |
239 |
--------------------------------------------------------------------------------
/pretrained_weights/.gitignore:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zyxu1996/CCTNet/5a5db40d2e38bd478b404583050049eedca90844/pretrained_weights/.gitignore
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | einops==0.3.0
2 | fvcore==0.1.5.post20210604
3 | matplotlib==3.3.2
4 | mmcv==1.3.5
5 | numpy==1.18.5
6 | opencv_python==4.5.2.54
7 | pavi==0.0.1
8 | Pillow==9.1.0
9 | portalocker==2.0.0
10 | requests==2.24.0
11 | scikit_image==0.17.2
12 | scipy==1.5.2
13 | setproctitle==1.2.3
14 | skimage==0.0
15 | tifffile==2020.9.3
16 | timm==0.3.2
17 | torch==1.6.0
18 | torchvision==0.7.0
19 | tqdm==4.50.2
20 |
--------------------------------------------------------------------------------
/seg_metric.py:
--------------------------------------------------------------------------------
1 | """
2 | refer to https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/utils/metrics.py
3 | """
4 | import numpy as np
5 |
6 | __all__ = ['SegmentationMetric']
7 |
8 | """
9 | confusionMetric
10 | P\L P N
11 | P TP FP
12 | N FN TN
13 | """
14 |
15 |
16 | class SegmentationMetric(object):
17 | def __init__(self, numClass):
18 | self.numClass = numClass
19 | self.confusionMatrix = np.zeros((self.numClass,) * 2)
20 |
21 | def Accuracy(self):
22 | # return all class overall pixel accuracy
23 | # acc = (TP + TN) / (TP + TN + FP + TN)
24 | acc = np.diag(self.confusionMatrix).sum() / self.confusionMatrix.sum()
25 | return acc
26 |
27 | def Precision(self):
28 | # return each category pixel accuracy(A more accurate way to call it precision)
29 | # acc = (TP) / TP + FP
30 | precision = np.diag(self.confusionMatrix) / self.confusionMatrix.sum(axis=0)
31 | return precision
32 |
33 | def meanPrecision(self):
34 | precision = self.Precision()
35 | mPrecision = np.nanmean(precision)
36 | return mPrecision
37 |
38 | def Recall(self):
39 | # Recall = (TP) / (TP + FN)
40 | recall = np.diag(self.confusionMatrix) / self.confusionMatrix.sum(axis=1)
41 | return recall
42 |
43 | def meanRecall(self):
44 | recall = self.Recall()
45 | mRecall = np.nanmean(recall)
46 | return mRecall
47 |
48 | def F1(self):
49 | # 2*precision*recall / (precision + recall)
50 | f1 = 2 * self.Precision() * self.Recall() / (self.Precision() + self.Recall())
51 | return f1
52 |
53 | def meanF1(self):
54 | f1 = self.F1()
55 | mF1 = np.nanmean(f1)
56 | return mF1
57 |
58 | def IntersectionOverUnion(self):
59 | # Intersection = TP Union = TP + FP + FN
60 | # IoU = TP / (TP + FP + FN)
61 | intersection = np.diag(self.confusionMatrix)
62 | union = np.sum(self.confusionMatrix, axis=1) + np.sum(self.confusionMatrix, axis=0) - np.diag(
63 | self.confusionMatrix)
64 | IoU = intersection / union
65 | return IoU
66 |
67 | def meanIntersectionOverUnion(self):
68 | # Intersection = TP Union = TP + FP + FN
69 | # IoU = TP / (TP + FP + FN)
70 | IoU = self.IntersectionOverUnion()
71 | mIoU = np.nanmean(IoU)
72 | return mIoU
73 |
74 | def genConfusionMatrix(self, imgPredict, imgLabel):
75 | # remove classes from unlabeled pixels in gt image and predict
76 | mask = (imgLabel >= 0) & (imgLabel < self.numClass)
77 | label = self.numClass * imgLabel[mask] + imgPredict[mask]
78 | count = np.bincount(label, minlength=self.numClass ** 2)
79 | confusionMatrix = count.reshape(self.numClass, self.numClass)
80 | return confusionMatrix
81 |
82 | def Frequency_Weighted_Intersection_over_Union(self):
83 | # FWIOU = [(TP+FN)/(TP+FP+TN+FN)] *[TP / (TP + FP + FN)]
84 | freq = np.sum(self.confusionMatrix, axis=1) / np.sum(self.confusionMatrix)
85 | iu = np.diag(self.confusionMatrix) / (
86 | np.sum(self.confusionMatrix, axis=1) + np.sum(self.confusionMatrix, axis=0) -
87 | np.diag(self.confusionMatrix))
88 | iu = [i if not np.isnan(i) else 0.0 for i in iu]
89 | iu = np.array(iu)
90 | FWIoU = (freq[freq > 0] * iu[freq > 0]).sum()
91 | return FWIoU
92 |
93 | def Frequency_Weighted(self):
94 | # FWIOU = [(TP+FN)/(TP+FP+TN+FN)] *[TP / (TP + FP + FN)]
95 | freq = np.sum(self.confusionMatrix, axis=1) / np.sum(self.confusionMatrix)
96 |
97 | return freq
98 |
99 | def addBatch(self, imgPredict, imgLabel):
100 | assert imgPredict.shape == imgLabel.shape
101 | self.confusionMatrix += self.genConfusionMatrix(imgPredict, imgLabel)
102 | def reset(self):
103 | self.confusionMatrix = np.zeros((self.numClass, self.numClass))
104 |
105 |
106 | if __name__ == '__main__':
107 | imgPredict = np.array([0, 0, 1, 1, 2, 2])
108 | imgLabel = np.array([0, 0, 1, 1, 2, 2])
109 | metric = SegmentationMetric(3)
110 | metric.addBatch(imgPredict, imgLabel)
111 | acc = metric.pixelAccuracy()
112 | mIoU = metric.meanIntersectionOverUnion()
113 | print(acc, mIoU)
--------------------------------------------------------------------------------
/tools/edge/.gitignore:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zyxu1996/CCTNet/5a5db40d2e38bd478b404583050049eedca90844/tools/edge/.gitignore
--------------------------------------------------------------------------------
/tools/flops_params_fps_count.py:
--------------------------------------------------------------------------------
1 | import time
2 | import numpy as np
3 | from fvcore.nn import FlopCountAnalysis, parameter_count
4 | from tqdm import tqdm
5 | import torch
6 |
7 |
8 | def flops_params_fps(model, input_shape=(1, 3, 512, 512)):
9 | """count flops:G params:M fps:img/s
10 | input shape tensor[1, c, h, w]
11 | """
12 | total_time = []
13 | with torch.no_grad():
14 | model = model.cuda().eval()
15 | input = torch.randn(size=input_shape, dtype=torch.float32).cuda()
16 | flops = FlopCountAnalysis(model, input)
17 | params = parameter_count(model)
18 |
19 | for i in tqdm(range(100)):
20 | torch.cuda.synchronize()
21 | start = time.time()
22 | output = model(input)
23 | torch.cuda.synchronize()
24 | end = time.time()
25 | total_time.append(end - start)
26 | mean_time = np.mean(np.array(total_time))
27 | print(model.__class__.__name__)
28 | print('img/s:{:.2f}'.format(1 / mean_time))
29 | print('flops:{:.2f}G params:{:.2f}M'.format(flops.total() / 1e9, params[''] / 1e6))
30 |
--------------------------------------------------------------------------------
/tools/generate_edge.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import tifffile as tiff
3 | import os
4 | import numpy as np
5 | from tools.utils import label_to_RGB
6 | from models.resT import rest_tiny
7 | import torch
8 | from torchvision import transforms
9 |
10 |
11 | def to_tensor(image):
12 | image = torch.from_numpy(image).permute(2, 0, 1).float().div(255)
13 | normalize = transforms.Normalize((.485, .456, .406), (.229, .224, .225))
14 | image = normalize(image).unsqueeze(0)
15 |
16 | return image
17 |
18 |
19 | def init_model():
20 | model = rest_tiny(nclass=6, aux=False, head='mlphead', edge_aux=True)
21 | weight_dir = '../work_dir/' \
22 | 'resT_lr0.0003_epoch100_batchsize16_vaihingen_resT_tiny_mlphead_imagenetpretrain_noaux_edge_edgeup_AdamW_num128' \
23 | '/weights/best_weight.pkl'
24 | checkpoint = torch.load(weight_dir, map_location=lambda storage, loc: storage)
25 | if 'state_dict' in checkpoint:
26 | checkpoint = checkpoint['state_dict']
27 | checkpoint = {k.replace('module.model.', ''): v for k, v in checkpoint.items()}
28 | model.load_state_dict(checkpoint)
29 | return model
30 |
31 |
32 | def read_img_label(save_dir):
33 | img_dir = '../data/vaihingen/images/top_mosaic_09cm_area10.tif'
34 | label_dir = '../data/vaihingen/annotations/labels/top_mosaic_09cm_area10.png'
35 | image = tiff.imread(img_dir)
36 | image = image[1000:1000 + 512, 1000:1000+512, :]
37 | label = cv2.imread(label_dir, cv2.IMREAD_UNCHANGED)
38 | label = label[1000:1000 + 512, 1000:1000 + 512]
39 | cv2.imwrite(os.path.join(save_dir, 'ori_img.png'), image[..., ::-1])
40 | cv2.imwrite(os.path.join(save_dir, 'ori_label.png'), label_to_RGB(label)[..., ::-1])
41 |
42 | return image, label
43 |
44 |
45 | def canny_edge(img, edge_width=3):
46 | gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
47 | gray = cv2.GaussianBlur(gray, (11, 11), 0)
48 | edge = cv2.Canny(gray, 30, 150)
49 | kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (edge_width, edge_width))
50 | edge = cv2.dilate(edge, kernel)
51 |
52 | return edge
53 |
54 |
55 | def groundtruth_edge(label, edge_width=3):
56 | if len(label.shape) == 2:
57 | label = label[np.newaxis, ...]
58 | label = label.astype(np.int)
59 | b, h, w = label.shape
60 | edge = np.zeros(label.shape)
61 |
62 | # right
63 | edge_right = edge[:, 1:h, :]
64 | edge_right[(label[:, 1:h, :] != label[:, :h - 1, :]) & (label[:, 1:h, :] != 255)
65 | & (label[:, :h - 1, :] != 255)] = 1
66 |
67 | # up
68 | edge_up = edge[:, :, :w - 1]
69 | edge_up[(label[:, :, :w - 1] != label[:, :, 1:w])
70 | & (label[:, :, :w - 1] != 255)
71 | & (label[:, :, 1:w] != 255)] = 1
72 |
73 | # upright
74 | edge_upright = edge[:, :h - 1, :w - 1]
75 | edge_upright[(label[:, :h - 1, :w - 1] != label[:, 1:h, 1:w])
76 | & (label[:, :h - 1, :w - 1] != 255)
77 | & (label[:, 1:h, 1:w] != 255)] = 1
78 |
79 | # bottomright
80 | edge_bottomright = edge[:, :h - 1, 1:w]
81 | edge_bottomright[(label[:, :h - 1, 1:w] != label[:, 1:h, :w - 1])
82 | & (label[:, :h - 1, 1:w] != 255)
83 | & (label[:, 1:h, :w - 1] != 255)] = 1
84 |
85 | kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (edge_width, edge_width))
86 | for i in range(edge.shape[0]):
87 | edge[i] = cv2.dilate(edge[i], kernel)
88 | edge = edge.squeeze(axis=0)
89 | return edge
90 |
91 |
92 | def get_edge_predict(img):
93 | img = to_tensor(img).cuda()
94 | model = init_model().cuda().eval()
95 | with torch.no_grad():
96 | output = model(img)
97 | edge_predict = torch.argmax(output[1], dim=1)
98 | edge_predict = edge_predict.squeeze().cpu().numpy().astype(np.uint8)
99 | edge_predict = edge_predict * 255
100 | # kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (4, 4))
101 | # edge_predict = cv2.erode(edge_predict, kernel)
102 |
103 | return edge_predict
104 |
105 |
106 | def main():
107 | img, label = read_img_label(save_path)
108 | # canny_ = canny_edge(img)
109 | # cv2.imwrite(os.path.join(save_path, 'canny_edge.png'), canny_)
110 | # groundtruth_ = groundtruth_edge(label) * 255
111 | # cv2.imwrite(os.path.join(save_path, 'groundtruth_edge.png'), groundtruth_)
112 | edge_predict = get_edge_predict(img)
113 | cv2.imwrite(os.path.join(save_path, 'predict_edge.png'), edge_predict)
114 |
115 |
116 | if __name__ == '__main__':
117 | save_path = './edge/'
118 | if not os.path.exists(save_path):
119 | os.mkdir(save_path)
120 | main()
121 |
122 |
--------------------------------------------------------------------------------
/tools/generate_heatmap.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import os
3 | import sys
4 | sys.path.append('./')
5 | import numpy as np
6 | from models.swinT import swin_base
7 | import torch
8 | import tifffile as tiff
9 | from PIL import Image
10 | from tools.utils import label_to_RGB
11 | from torchvision import transforms
12 |
13 |
14 | def init_model():
15 | model = swin_base(nclass=6, aux=True, head='uperhead')
16 | weight_dir = 'work_dir/' \
17 | 'swinT_lr0.0003_epoch100_batchsize16_swinT_upernet_base_imagenetpretrain_aux_AdamW_num73' \
18 | '/weights/best_weight.pkl'
19 | checkpoint = torch.load(weight_dir, map_location=lambda storage, loc: storage)
20 | if 'state_dict' in checkpoint:
21 | checkpoint = checkpoint['state_dict']
22 | checkpoint = {k.replace('module.model.', ''): v for k, v in checkpoint.items()}
23 | model.load_state_dict(checkpoint)
24 | return model
25 |
26 |
27 | def read_img(save_dir):
28 | img_dir = 'data/barley/images/image_2_4_10.png'
29 | # image = tiff.imread(img_dir)
30 | image = Image.open(img_dir)
31 | image = np.array(image)
32 | image = image[1000:1000 + 512, 0:0+512, 0:3]
33 | cv2.imwrite(os.path.join(save_dir, 'ori_image.png'), image[..., ::-1])
34 |
35 | return image
36 |
37 |
38 | def read_label(save_dir):
39 | img_dir = 'data/barley/labels_view/image_2_4_10.png'
40 | image = Image.open(img_dir)
41 | image = np.array(image)
42 | image = image[1000:1000 + 512, 0:0+512, 0:3]
43 | cv2.imwrite(os.path.join(save_dir, 'ori_label.png'), image[..., ::-1])
44 |
45 | return image
46 |
47 |
48 | def to_tensor(image):
49 | image = torch.from_numpy(image).permute(2, 0, 1).float().div(255)
50 | normalize = transforms.Normalize((.485, .456, .406), (.229, .224, .225))
51 | image = normalize(image).unsqueeze(0)
52 |
53 | return image
54 |
55 |
56 | def main():
57 | save_img_dir = os.path.join(save_path, 'origin_img')
58 | if not os.path.exists(save_img_dir):
59 | os.mkdir(save_img_dir)
60 | save_out_dir = os.path.join(save_path, 'output')
61 | if not os.path.exists(save_out_dir):
62 | os.mkdir(save_out_dir)
63 |
64 | image = read_img(save_img_dir)
65 | image = to_tensor(image).cuda()
66 | model = init_model().cuda().eval()
67 | with torch.no_grad():
68 | output = model(image)
69 | output = torch.argmax(output[0], dim=1)
70 | output = output.squeeze()
71 | output = output.cpu().numpy()
72 | output = output.astype(np.uint8)
73 | output = label_to_RGB(output)
74 | cv2.imwrite(os.path.join(save_out_dir, 'out.png'), output[..., ::-1])
75 |
76 |
77 | if __name__ == '__main__':
78 | save_path = 'tools/heatmap/outputs/'
79 | if not os.path.exists(save_path):
80 | os.mkdir(save_path)
81 | # main()
82 | read_img(save_path)
83 | read_label(save_path)
84 |
85 |
--------------------------------------------------------------------------------
/tools/heat_map.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import time
3 | import os
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 |
7 |
8 | save_path='./heatmap/'
9 | if not os.path.exists(save_path):
10 | os.mkdir(save_path)
11 |
12 | def draw_features(x,savename):
13 | tic = time.time()
14 | fig = plt.figure(figsize=(16, 16))
15 | fig.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.95, wspace=0.05, hspace=0.05)
16 | b, c, h, w = x.shape
17 | for i in range(int(c)):
18 | plt.subplot(h, w, i + 1)
19 | plt.axis('off')
20 | img = x[0, i, :, :].cpu().numpy()
21 | print('img_shape', img.shape)
22 | # print('img', img)
23 | # print(width,height)
24 | pmin = np.min(img)
25 | pmax = np.max(img)
26 | img = ((img - pmin) / (pmax - pmin + 0.000001))*255 #float在[0,1]之间,转换成0-255
27 | img=img.astype(np.uint8) #转成unit8
28 | img=cv2.applyColorMap(img, cv2.COLORMAP_JET) #生成heat map
29 | # img = img[:, :, ::-1] #注意cv2(BGR)和matplotlib(RGB)通道是相反的
30 | plt.imshow(img)
31 | print("{}/{}".format(i, c))
32 | print(img.shape)
33 | img = cv2.resize(img, (768, 768), interpolation=cv2.INTER_LINEAR)
34 | cv2.imwrite(save_path + savename + str(i) + '.png', img)
35 | fig.clf()
36 | plt.close()
37 | print("time:{}".format(time.time()-tic))
38 |
39 |
--------------------------------------------------------------------------------
/tools/heatmap/outputs/ori_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zyxu1996/CCTNet/5a5db40d2e38bd478b404583050049eedca90844/tools/heatmap/outputs/ori_image.png
--------------------------------------------------------------------------------
/tools/heatmap/outputs/ori_label.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zyxu1996/CCTNet/5a5db40d2e38bd478b404583050049eedca90844/tools/heatmap/outputs/ori_label.png
--------------------------------------------------------------------------------
/tools/heatmap_fun.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | import cv2
4 | import time
5 | import os
6 | from tqdm import tqdm
7 |
8 | save_path = './heatmap/uperhead/'
9 | if not os.path.exists(save_path):
10 | os.mkdir(save_path)
11 |
12 |
13 | def draw_features(x, savename):
14 | tic = time.time()
15 | b, c, h, w = x.shape
16 | for i in tqdm(range(int(c))):
17 | img = x[0, i, :, :].cpu().numpy()
18 | pmin = np.min(img)
19 | pmax = np.max(img)
20 | img = ((img - pmin) / (pmax - pmin + 0.000001))*255 # change value [0, 1] to [0, 255]
21 | img = img.astype(np.uint8)
22 | img = cv2.applyColorMap(img, cv2.COLORMAP_JET) # generate heat map
23 | img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
24 | if not os.path.exists(os.path.join(save_path, savename)):
25 | os.mkdir(os.path.join(save_path, savename))
26 | cv2.imwrite(os.path.join(save_path, savename, savename + '_' + str(i) + '.png'), img)
27 | plt.close()
28 | print("{} time:{}".format(savename, time.time()-tic))
29 |
--------------------------------------------------------------------------------
/tools/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def label_to_RGB(image):
5 | RGB = np.zeros(shape=[image.shape[0], image.shape[1], 3], dtype=np.uint8)
6 | index = image == 0
7 | RGB[index] = np.array([255, 255, 255])
8 | index = image == 1
9 | RGB[index] = np.array([0, 0, 255])
10 | index = image == 2
11 | RGB[index] = np.array([0, 255, 255])
12 | index = image == 3
13 | RGB[index] = np.array([0, 255, 0])
14 | index = image == 4
15 | RGB[index] = np.array([255, 255, 0])
16 | index = image == 5
17 | RGB[index] = np.array([255, 0, 0])
18 | return RGB
--------------------------------------------------------------------------------
/work_dir/.gitignore:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zyxu1996/CCTNet/5a5db40d2e38bd478b404583050049eedca90844/work_dir/.gitignore
--------------------------------------------------------------------------------