├── .gitignore ├── README.md ├── dataloaders ├── __init__.py ├── path.py ├── transforms.py └── voc_aug.py ├── libs ├── DenseCRF.py ├── __init__.py ├── criteria.py ├── lr_scheduler.py ├── metrics.py ├── schedular.py └── utils.py ├── main_epoch.py ├── main_iter.py ├── network ├── Auto_Deeplab │ ├── __init__.py │ ├── auto_deeplab.py │ └── layers.py ├── __init__.py ├── base │ ├── __init__.py │ ├── deform_conv │ │ ├── __init__.py │ │ ├── deform_conv.py │ │ └── deform_conv_v2.py │ ├── msc.py │ ├── oprations.py │ ├── resnet.py │ ├── sync_batchnorm │ │ ├── __init__.py │ │ ├── batchnorm.py │ │ ├── comm.py │ │ ├── replicate.py │ │ └── unittest.py │ └── xception.py ├── deeplab.py ├── deeplab_deform_conv.py ├── deeplabv2.py ├── deeplabv3.py ├── deeplabv3plus_resnet.py ├── deeplabv3plus_xception.py └── get_models.py ├── test.py └── validation.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | .idea/ 29 | /.idea 30 | *.iml 31 | *.ppt 32 | *.pptx 33 | *.caffemodel 34 | result/ 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # celery beat schedule file 92 | celerybeat-schedule 93 | 94 | # SageMath parsed files 95 | *.sage.py 96 | 97 | # Environments 98 | .env 99 | .venv 100 | env/ 101 | venv/ 102 | ENV/ 103 | env.bak/ 104 | venv.bak/ 105 | 106 | # Spyder project settings 107 | .spyderproject 108 | .spyproject 109 | 110 | # Rope project settings 111 | .ropeproject 112 | 113 | # mkdocs documentation 114 | /site 115 | 116 | # mypy 117 | .mypy_cache/ 118 | .dmypy.json 119 | dmypy.json 120 | 121 | # Pyre type checker 122 | .pyre/ 123 | 124 | # JPG PNG 125 | *.jpg 126 | *.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deeplab_pytorch 2 | 3 | ## Papers 4 | 5 | **Deeplab** [Semantic Image Segmentation with Deep Convolutional Nets and Fully Connected CRFs](https://arxiv.org/abs/1412.7062) 6 | 7 | 8 | **Deeplab V2** [DeepLab: Semantic Image Segmentation with Deep Convolutional Nets, Atrous Convolution, and Fully Connected CRFs](http://arxiv.org/abs/1606.00915) 9 | 10 | Note that there are still some minor differences between argmax and softmax_loss layers for DeepLabv1 and v2 11 | 12 | **Deeplab v3** [Rethinking Atrous Convolution for Semantic Image Segmentation](http://arxiv.org/abs/1706.05587) 13 | 14 | 15 | **Deeplab V3+** [Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation](http://arxiv.org/abs/1802.02611) 16 | 17 | **Auto Deeplab** [Auto-DeepLab: Hierarchical Neural Architecture Search for Semantic Image Segmentation](https://arxiv.org/abs/1901.02985) 18 | 19 | 20 | ## Reference 21 | 22 | [1] https://github.com/jfzhang95/pytorch-deeplab-xception 23 | 24 | [2] https://github.com/kazuto1011/deeplab-pytorch 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/1/30 16:59 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ -------------------------------------------------------------------------------- /dataloaders/path.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/1/30 19:30 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ 7 | 8 | class Path(object): 9 | @staticmethod 10 | def db_root_dir(database): 11 | if database == 'pascal': 12 | return '/home/data/model/wangxin/VOCdevkit/VOC2012/' # folder that contains VOCdevkit/. 13 | elif database == 'vocaug': 14 | return '/home/data/model/wangxin/VOCAug/' 15 | else: 16 | print('Database {} not available.'.format(database)) 17 | raise NotImplementedError -------------------------------------------------------------------------------- /dataloaders/transforms.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/1/30 19:32 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ 7 | 8 | import torch 9 | import math 10 | import numbers 11 | import random 12 | import numpy as np 13 | 14 | from PIL import Image, ImageOps 15 | 16 | 17 | class RandomCrop(object): 18 | def __init__(self, size, padding=0): 19 | if isinstance(size, numbers.Number): 20 | self.size = (int(size), int(size)) 21 | else: 22 | self.size = size # h, w 23 | self.padding = padding 24 | 25 | def __call__(self, sample): 26 | img, mask = sample['image'], sample['label'] 27 | 28 | if self.padding > 0: 29 | img = ImageOps.expand(img, border=self.padding, fill=0) 30 | mask = ImageOps.expand(mask, border=self.padding, fill=0) 31 | 32 | assert img.size == mask.size 33 | w, h = img.size 34 | th, tw = self.size # target size 35 | if w == tw and h == th: 36 | return {'image': img, 37 | 'label': mask} 38 | if w < tw or h < th: 39 | img = img.resize((tw, th), Image.BILINEAR) 40 | mask = mask.resize((tw, th), Image.NEAREST) 41 | return {'image': img, 42 | 'label': mask} 43 | 44 | x1 = random.randint(0, w - tw) 45 | y1 = random.randint(0, h - th) 46 | img = img.crop((x1, y1, x1 + tw, y1 + th)) 47 | mask = mask.crop((x1, y1, x1 + tw, y1 + th)) 48 | 49 | return {'image': img, 50 | 'label': mask} 51 | 52 | 53 | class CenterCrop(object): 54 | def __init__(self, size): 55 | if isinstance(size, numbers.Number): 56 | self.size = (int(size), int(size)) 57 | else: 58 | self.size = size 59 | 60 | def __call__(self, sample): 61 | img = sample['image'] 62 | mask = sample['label'] 63 | assert img.size == mask.size 64 | w, h = img.size 65 | th, tw = self.size 66 | x1 = int(round((w - tw) / 2.)) 67 | y1 = int(round((h - th) / 2.)) 68 | img = img.crop((x1, y1, x1 + tw, y1 + th)) 69 | mask = mask.crop((x1, y1, x1 + tw, y1 + th)) 70 | 71 | return {'image': img, 72 | 'label': mask} 73 | 74 | 75 | class RandomHorizontalFlip(object): 76 | def __call__(self, sample): 77 | img = sample['image'] 78 | mask = sample['label'] 79 | if random.random() < 0.5: 80 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 81 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT) 82 | 83 | return {'image': img, 84 | 'label': mask} 85 | 86 | 87 | class Normalize(object): 88 | """Normalize a tensor image with mean and standard deviation. 89 | Args: 90 | mean (tuple): means for each channel. 91 | std (tuple): standard deviations for each channel. 92 | """ 93 | 94 | def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)): 95 | self.mean = mean 96 | self.std = std 97 | 98 | def __call__(self, sample): 99 | img = np.array(sample['image']).astype(np.float32) 100 | mask = np.array(sample['label']).astype(np.float32) 101 | img /= 255.0 102 | img -= self.mean 103 | img /= self.std 104 | 105 | return {'image': img, 106 | 'label': mask} 107 | 108 | 109 | class Normalize_cityscapes(object): 110 | """Normalize a tensor image with mean and standard deviation. 111 | Args: 112 | mean (tuple): means for each channel. 113 | std (tuple): standard deviations for each channel. 114 | """ 115 | 116 | def __init__(self, mean=(0., 0., 0.)): 117 | self.mean = mean 118 | 119 | def __call__(self, sample): 120 | img = np.array(sample['image']).astype(np.float32) 121 | mask = np.array(sample['label']).astype(np.float32) 122 | img -= self.mean 123 | img /= 255.0 124 | 125 | return {'image': img, 126 | 'label': mask} 127 | 128 | 129 | class ToTensor(object): 130 | """Convert ndarrays in sample to Tensors.""" 131 | 132 | def __call__(self, sample): 133 | # swap color axis because 134 | # numpy image: H x W x C 135 | # torch image: C X H X W 136 | img = np.array(sample['image']).astype(np.float32).transpose((2, 0, 1)) 137 | mask = np.expand_dims(np.array(sample['label']).astype(np.float32), -1).transpose((2, 0, 1)) 138 | mask[mask == 255] = 0 139 | 140 | img = torch.from_numpy(img).float() 141 | mask = torch.from_numpy(mask).float() 142 | 143 | return {'image': img, 144 | 'label': mask} 145 | 146 | 147 | class FixedResize(object): 148 | def __init__(self, size): 149 | self.size = tuple(reversed(size)) # size: (h, w) 150 | 151 | def __call__(self, sample): 152 | img = sample['image'] 153 | mask = sample['label'] 154 | 155 | assert img.size == mask.size 156 | 157 | img = img.resize(self.size, Image.BILINEAR) 158 | mask = mask.resize(self.size, Image.NEAREST) 159 | 160 | return {'image': img, 161 | 'label': mask} 162 | 163 | 164 | class Scale(object): 165 | def __init__(self, size): 166 | if isinstance(size, numbers.Number): 167 | self.size = (int(size), int(size)) 168 | else: 169 | self.size = size 170 | 171 | def __call__(self, sample): 172 | img = sample['image'] 173 | mask = sample['label'] 174 | assert img.size == mask.size 175 | w, h = img.size 176 | 177 | if (w >= h and w == self.size[1]) or (h >= w and h == self.size[0]): 178 | return {'image': img, 179 | 'label': mask} 180 | oh, ow = self.size 181 | img = img.resize((ow, oh), Image.BILINEAR) 182 | mask = mask.resize((ow, oh), Image.NEAREST) 183 | 184 | return {'image': img, 185 | 'label': mask} 186 | 187 | 188 | class RandomSizedCrop(object): 189 | def __init__(self, size): 190 | self.size = size 191 | 192 | def __call__(self, sample): 193 | img = sample['image'] 194 | mask = sample['label'] 195 | assert img.size == mask.size 196 | for attempt in range(10): 197 | area = img.size[0] * img.size[1] 198 | target_area = random.uniform(0.45, 1.0) * area 199 | aspect_ratio = random.uniform(0.5, 2) 200 | 201 | w = int(round(math.sqrt(target_area * aspect_ratio))) 202 | h = int(round(math.sqrt(target_area / aspect_ratio))) 203 | 204 | if random.random() < 0.5: 205 | w, h = h, w 206 | 207 | if w <= img.size[0] and h <= img.size[1]: 208 | x1 = random.randint(0, img.size[0] - w) 209 | y1 = random.randint(0, img.size[1] - h) 210 | 211 | img = img.crop((x1, y1, x1 + w, y1 + h)) 212 | mask = mask.crop((x1, y1, x1 + w, y1 + h)) 213 | assert (img.size == (w, h)) 214 | 215 | img = img.resize((self.size, self.size), Image.BILINEAR) 216 | mask = mask.resize((self.size, self.size), Image.NEAREST) 217 | 218 | return {'image': img, 219 | 'label': mask} 220 | 221 | # Fallback 222 | scale = Scale(self.size) 223 | crop = CenterCrop(self.size) 224 | sample = crop(scale(sample)) 225 | return sample 226 | 227 | 228 | class RandomRotate(object): 229 | def __init__(self, degree): 230 | self.degree = degree 231 | 232 | def __call__(self, sample): 233 | img = sample['image'] 234 | mask = sample['label'] 235 | rotate_degree = random.random() * 2 * self.degree - self.degree 236 | img = img.rotate(rotate_degree, Image.BILINEAR) 237 | mask = mask.rotate(rotate_degree, Image.NEAREST) 238 | 239 | return {'image': img, 240 | 'label': mask} 241 | 242 | 243 | class RandomSized(object): 244 | def __init__(self, size): 245 | self.size = size 246 | self.scale = Scale(self.size) 247 | self.crop = RandomCrop(self.size) 248 | 249 | def __call__(self, sample): 250 | img = sample['image'] 251 | mask = sample['label'] 252 | assert img.size == mask.size 253 | 254 | w = int(random.uniform(0.8, 2.5) * img.size[0]) 255 | h = int(random.uniform(0.8, 2.5) * img.size[1]) 256 | 257 | img, mask = img.resize((w, h), Image.BILINEAR), mask.resize((w, h), Image.NEAREST) 258 | sample = {'image': img, 'label': mask} 259 | 260 | return self.crop(self.scale(sample)) 261 | 262 | 263 | class RandomScale(object): 264 | def __init__(self, limit): 265 | self.limit = limit 266 | 267 | def __call__(self, sample): 268 | img = sample['image'] 269 | mask = sample['label'] 270 | assert img.size == mask.size 271 | 272 | scale = random.uniform(self.limit[0], self.limit[1]) 273 | w = int(scale * img.size[0]) 274 | h = int(scale * img.size[1]) 275 | 276 | img, mask = img.resize((w, h), Image.BILINEAR), mask.resize((w, h), Image.NEAREST) 277 | 278 | return {'image': img, 'label': mask} 279 | -------------------------------------------------------------------------------- /dataloaders/voc_aug.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/1/30 19:29 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ 7 | 8 | import os 9 | from PIL import Image 10 | from torch.utils.data import Dataset 11 | 12 | from dataloaders.path import Path 13 | 14 | 15 | class VOCAug(Dataset): 16 | 17 | def __init__(self, base_dir=Path.db_root_dir('vocaug'), 18 | split='train', 19 | transform=None): 20 | super().__init__() 21 | self._base_dir = base_dir 22 | self._image_dir = os.path.join(self._base_dir, 'img') 23 | self._cat_dir = os.path.join(self._base_dir, 'gt') 24 | self._list_dir = os.path.join(self._base_dir, 'list') 25 | 26 | self.transform = transform 27 | 28 | # print(self._base_dir) 29 | 30 | if split == 'train': 31 | list_path = os.path.join(self._list_dir, 'train_aug.txt') 32 | elif split == 'val': 33 | list_path = os.path.join(self._list_dir, 'val.txt') 34 | else: 35 | print('error in split:', split) 36 | exit(-1) 37 | 38 | self.filenames = [i_id.strip() for i_id in open(list_path)] 39 | 40 | # Display stats 41 | print('Number of images in {}: {:d}'.format(split, len(self.filenames))) 42 | 43 | def __getitem__(self, index): 44 | _img, _target = self._make_img_gt_point_pair(index) 45 | sample = {'image': _img, 'label': _target} 46 | # print('test!!!!') 47 | if self.transform is not None: 48 | sample = self.transform(sample) 49 | 50 | return sample 51 | 52 | def __len__(self): 53 | return len(self.filenames) 54 | 55 | def _make_img_gt_point_pair(self, index): 56 | 57 | filename = self.filenames[index] 58 | # print('filename = ', filename) 59 | 60 | _img = Image.open(self._image_dir + "/" + str(filename) + '.jpg').convert('RGB') 61 | _target = Image.open(self._cat_dir + "/" + str(filename) + '.png') 62 | 63 | return _img, _target 64 | 65 | def __str__(self): 66 | return 'VOCAug(split=' + str(self.split) + ')' 67 | 68 | 69 | if __name__ == '__main__': 70 | from dataloaders import transforms as tr 71 | from libs.utils import decode_segmap 72 | from torch.utils.data import DataLoader 73 | from torchvision import transforms 74 | import matplotlib.pyplot as plt 75 | import numpy as np 76 | 77 | composed_transforms_tr = transforms.Compose([ 78 | tr.RandomHorizontalFlip(), 79 | tr.RandomSized(512), 80 | tr.RandomRotate(15), 81 | tr.ToTensor()]) 82 | 83 | voc_train = VOCAug(split='train', transform=composed_transforms_tr) 84 | 85 | dataloader = DataLoader(voc_train, batch_size=5, shuffle=True, num_workers=1) 86 | 87 | print(len(dataloader)) 88 | 89 | for ii, sample in enumerate(dataloader): 90 | print(sample['image'].size()) 91 | img = sample['image'].numpy() 92 | gt = sample['label'].numpy() 93 | 94 | for jj in range(sample["image"].size()[0]): 95 | tmp = np.array(gt[jj]).astype(np.uint8) 96 | print('#1 gt ', tmp.shape) 97 | tmp = np.squeeze(tmp, axis=0) 98 | print('#2 gt ', tmp.shape) 99 | segmap = decode_segmap(tmp, dataset='vocaug') 100 | print('#1 im ', img[jj].shape) 101 | img_tmp = np.transpose(img[jj], axes=[1, 2, 0]).astype(np.uint8) 102 | print('#2 im ', img_tmp.shape) 103 | # plt.figure() 104 | # plt.title('display') 105 | # plt.subplot(211) 106 | # plt.imshow(img_tmp) 107 | # plt.subplot(212) 108 | # plt.imshow(segmap) 109 | 110 | if ii == 1: 111 | break 112 | plt.show(block=True) -------------------------------------------------------------------------------- /libs/DenseCRF.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Author: Kazuto Nakashima 4 | # URL: https://kazuto1011.github.io 5 | # Date: 09 January 2019 6 | 7 | 8 | import numpy as np 9 | import pydensecrf.densecrf as dcrf 10 | import pydensecrf.utils as utils 11 | 12 | 13 | class DenseCRF(object): 14 | def __init__(self, iter_max, pos_w, pos_xy_std, bi_w, bi_xy_std, bi_rgb_std): 15 | self.iter_max = iter_max # iter num 16 | self.pos_w = pos_w # the weight of the Gaussian kernel which only depends on Pixel Position 17 | self.pos_xy_std = pos_xy_std 18 | self.bi_w = bi_w # the weight of bilateral kernel 19 | self.bi_xy_std = bi_xy_std 20 | self.bi_rgb_std = bi_rgb_std 21 | 22 | def __call__(self, image, probmap): 23 | C, H, W = probmap.shape 24 | 25 | U = utils.unary_from_softmax(probmap) 26 | U = np.ascontiguousarray(U) 27 | 28 | image = np.ascontiguousarray(image) 29 | 30 | d = dcrf.DenseCRF2D(W, H, C) 31 | d.setUnaryEnergy(U) 32 | 33 | # the gaussian kernel depends only on pixel position 34 | d.addPairwiseGaussian(sxy=self.pos_xy_std, compat=self.pos_w) 35 | 36 | # bilateral kernel depends on both position and color 37 | d.addPairwiseBilateral( 38 | sxy=self.bi_xy_std, srgb=self.bi_rgb_std, rgbim=image, compat=self.bi_w 39 | ) 40 | 41 | Q = d.inference(self.iter_max) 42 | Q = np.array(Q).reshape((C, H, W)) 43 | 44 | return Q -------------------------------------------------------------------------------- /libs/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/1/31 19:24 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ -------------------------------------------------------------------------------- /libs/criteria.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/1/30 23:17 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | def cross_entropy2d(logit, target, ignore_index=255, weight=None, size_average=True, batch_average=True): 15 | n, c, h, w = logit.size() 16 | # logit = logit.permute(0, 2, 3, 1) 17 | target = target.squeeze(1) 18 | if weight is None: 19 | criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index, size_average=False) 20 | else: 21 | criterion = nn.CrossEntropyLoss(weight=torch.from_numpy(np.array(weight)).float().cuda(), 22 | ignore_index=ignore_index, size_average=False) 23 | loss = criterion(logit, target.long()) 24 | 25 | if size_average: 26 | loss /= (h * w) 27 | 28 | if batch_average: 29 | loss /= n 30 | 31 | return loss 32 | 33 | 34 | class CrossEntropyLoss2d(nn.Module): 35 | def __init__(self, weight=None, size_average=True, ignore_index=-1): 36 | super(CrossEntropyLoss2d, self).__init__() 37 | self.nll_loss = nn.NLLLoss(weight, size_average, ignore_index) 38 | 39 | def forward(self, inputs, targets): 40 | return self.nll_loss(F.log_softmax(inputs, dim=1), targets) 41 | 42 | 43 | class _CrossEntropyLoss2d(nn.Module): 44 | def __init__(self, ignore_index=255, weight=None, size_average=True, batch_average=True): 45 | super(_CrossEntropyLoss2d, self).__init__() 46 | 47 | if weight is None: 48 | self.loss = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index, size_average=False) 49 | else: 50 | self.loss = nn.CrossEntropyLoss(weight=torch.from_numpy(np.array(weight)).float().cuda(), 51 | ignore_index=ignore_index, size_average=False) 52 | 53 | self.size_avgrage = size_average 54 | self.batch_avgrage = batch_average 55 | 56 | def forward(self, logit, target): 57 | 58 | N, C, H, W = logit.size() 59 | 60 | target = target.squeeze(1) 61 | 62 | loss = self.loss(logit, target.long()) 63 | 64 | if self.size_avgrage: 65 | loss /= (H * W) 66 | 67 | if self.batch_avgrage: 68 | loss /= N 69 | 70 | return loss 71 | -------------------------------------------------------------------------------- /libs/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/2/13 16:31 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ 7 | 8 | """ 9 | Original Author: Kazuto Nakashima 10 | URL: https://kazuto1011.github.io 11 | Date: 09 January 2019 12 | """ 13 | 14 | from torch.optim.lr_scheduler import _LRScheduler 15 | 16 | 17 | class PolynomialLR(_LRScheduler): 18 | def __init__(self, optimizer, step_size, iter_max, power, last_epoch=-1): 19 | self.step_size = step_size 20 | self.iter_max = iter_max 21 | self.power = power 22 | super(PolynomialLR, self).__init__(optimizer, last_epoch) 23 | 24 | def polynomial_decay(self, lr): 25 | return lr * (1 - float(self.last_epoch) / self.iter_max) ** self.power 26 | 27 | def get_lr(self): 28 | if ( 29 | (self.last_epoch == 0) 30 | or (self.last_epoch % self.step_size != 0) 31 | or (self.last_epoch > self.iter_max) 32 | ): 33 | return [group["lr"] for group in self.optimizer.param_groups] 34 | return [self.polynomial_decay(lr) for lr in self.base_lrs] 35 | -------------------------------------------------------------------------------- /libs/metrics.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/1/30 23:38 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ 7 | 8 | import torch 9 | import math 10 | import numpy as np 11 | 12 | from libs import utils 13 | 14 | 15 | def log10(x): 16 | """Convert a new tensor with the base-10 logarithm of the elements of x. """ 17 | return torch.log(x) / math.log(10) 18 | 19 | 20 | """ 21 | The two function named as _fast_hist, scores is Originally written by wkentaro 22 | https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py 23 | """ 24 | 25 | 26 | def _fast_hist(label_true, label_pred, n_class): 27 | mask = (label_true >= 0) & (label_true < n_class) 28 | hist = np.bincount( 29 | n_class * label_true[mask].astype(int) + label_pred[mask], 30 | minlength=n_class ** 2, 31 | ).reshape(n_class, n_class) 32 | return hist 33 | 34 | 35 | def scores(label_trues, label_preds, n_class): 36 | hist = np.zeros((n_class, n_class)) 37 | for lt, lp in zip(label_trues, label_preds): 38 | hist += _fast_hist(lt.flatten(), lp.flatten(), n_class) 39 | acc = np.diag(hist).sum() / hist.sum() 40 | acc_cls = np.diag(hist) / hist.sum(axis=1) 41 | acc_cls = np.nanmean(acc_cls) 42 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) 43 | valid = hist.sum(axis=1) > 0 # added 44 | mean_iu = np.nanmean(iu[valid]) 45 | freq = hist.sum(axis=1) / hist.sum() 46 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 47 | # cls_iu = dict(zip(range(n_class), iu)) 48 | 49 | # print('# type:', type(mean_iu), type(acc_cls), type(acc), type(fwavacc)) 50 | 51 | return mean_iu, acc_cls, acc, fwavacc 52 | 53 | 54 | class Result(object): 55 | def __init__(self): 56 | self.mean_iou = 0 57 | self.mean_acc = 0 58 | self.overall_acc = 0 59 | self.freqw_acc = 0 60 | self.data_time, self.gpu_time = 0, 0 61 | 62 | def set_to_worst(self): 63 | self.mean_iou = np.inf 64 | self.mean_acc = np.inf 65 | self.overall_acc = np.inf 66 | self.freqw_acc = np.inf 67 | self.data_time, self.gpu_time = 0, 0 68 | 69 | def update(self, mean_iou, mean_acc, overall_acc, freqw_accc, gpu_time, data_time): 70 | self.mean_iou = mean_iou 71 | self.mean_acc, self.overall_acc, self.freqw_acc = mean_acc, overall_acc, freqw_accc 72 | self.data_time, self.gpu_time = data_time, gpu_time 73 | 74 | def evaluate(self, output, target, n_class): 75 | output = np.argmax(output, axis=1) 76 | # print('output size:', output.shape) 77 | # print('target size:', target.shape) 78 | 79 | outputs = [] 80 | targets = [] 81 | outputs.append(output) 82 | targets.append(target) 83 | self.mean_iou, self.mean_acc, self.overall_acc, self.freqw_acc = scores(targets, outputs, n_class=n_class) 84 | 85 | 86 | class AverageMeter(object): 87 | def __init__(self): 88 | self.reset() 89 | 90 | def reset(self): 91 | self.count = 0.0 92 | 93 | self.sum_mean_iou = 0 94 | self.sum_mean_acc = 0 95 | self.sum_overall_acc = 0 96 | self.sum_freqw_acc = 0 97 | self.sum_data_time, self.sum_gpu_time = 0, 0 98 | 99 | def update(self, result, gpu_time, data_time, n=1): 100 | self.count += n 101 | 102 | self.sum_mean_iou += result.mean_iou * n 103 | self.sum_mean_acc += result.mean_acc * n 104 | self.sum_overall_acc += result.overall_acc * n 105 | self.sum_freqw_acc += result.freqw_acc * n 106 | 107 | self.sum_data_time += n * data_time 108 | self.sum_gpu_time += n * gpu_time 109 | 110 | def average(self): 111 | avg = Result() 112 | avg.update( 113 | self.sum_mean_iou / self.count, 114 | self.sum_mean_acc / self.count, 115 | self.sum_overall_acc / self.count, 116 | self.sum_freqw_acc / self.count, 117 | self.sum_gpu_time / self.count, 118 | self.sum_data_time / self.count) 119 | return avg 120 | -------------------------------------------------------------------------------- /libs/schedular.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Author: Kazuto Nakashima 4 | # URL: https://kazuto1011.github.io 5 | # Date: 09 January 2019 6 | 7 | 8 | from torch.optim.lr_scheduler import _LRScheduler 9 | 10 | 11 | class PolynomialLR(_LRScheduler): 12 | def __init__(self, optimizer, step_size, iter_max, power, last_epoch=-1): 13 | self.step_size = step_size 14 | self.iter_max = iter_max 15 | self.power = power 16 | super(PolynomialLR, self).__init__(optimizer, last_epoch) 17 | 18 | def polynomial_decay(self, lr): 19 | return lr * (1 - float(self.last_epoch) / self.iter_max) ** self.power 20 | 21 | def get_lr(self): 22 | if ( 23 | (self.last_epoch == 0) 24 | or (self.last_epoch % self.step_size != 0) 25 | or (self.last_epoch > self.iter_max) 26 | ): 27 | return [group["lr"] for group in self.optimizer.param_groups] 28 | return [self.polynomial_decay(lr) for lr in self.base_lrs] -------------------------------------------------------------------------------- /libs/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/1/30 19:34 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ 7 | import glob 8 | import os 9 | import shutil 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | import numpy as np 14 | import matplotlib.pyplot as plt 15 | from PIL import Image 16 | 17 | cmap = plt.cm.jet 18 | 19 | 20 | def get_output_directory(args, check=False): 21 | save_dir_root = os.getcwd() 22 | save_dir_root = os.path.join(save_dir_root, 'result', args.model, args.dataset) 23 | if args.resume: 24 | runs = sorted(glob.glob(os.path.join(save_dir_root, 'run_*'))) 25 | run_id = int(runs[-1].split('_')[-1]) if runs else 0 26 | else: 27 | runs = sorted(glob.glob(os.path.join(save_dir_root, 'run_*'))) 28 | if len(runs) > 10: 29 | print('please delete unnecessary runs, ensure run_id < 10.') 30 | if check: 31 | run_id = int(runs[-1].split('_')[-1]) if runs else 0 32 | else: 33 | run_id = int(runs[-1].split('_')[-1]) + 1 if runs else 0 34 | save_dir = os.path.join(save_dir_root, 'run_' + str(run_id)) 35 | return save_dir 36 | 37 | 38 | # save checkpoint 39 | def save_checkpoint(state, is_best, epoch, output_directory): 40 | checkpoint_filename = os.path.join(output_directory, 'checkpoint-' + str(epoch) + '.pth.tar') 41 | torch.save(state, checkpoint_filename) 42 | if is_best: 43 | best_filename = os.path.join(output_directory, 'model_best.pth.tar') 44 | shutil.copyfile(checkpoint_filename, best_filename) 45 | 46 | 47 | def resize_labels(labels, shape): 48 | # labels = labels.unsqueeze(1).float() # Add channel axis 49 | # print('#1 labels size:', labels.size()) 50 | 51 | labels = F.interpolate(labels, size=shape, mode="nearest") 52 | # labels = labels.squeeze(1).long() 53 | 54 | # print('#2 labels size:', labels.size()) 55 | return labels 56 | 57 | 58 | def get_cityscapes_labels(): 59 | return np.array([ 60 | # [ 0, 0, 0], 61 | [128, 64, 128], 62 | [244, 35, 232], 63 | [70, 70, 70], 64 | [102, 102, 156], 65 | [190, 153, 153], 66 | [153, 153, 153], 67 | [250, 170, 30], 68 | [220, 220, 0], 69 | [107, 142, 35], 70 | [152, 251, 152], 71 | [0, 130, 180], 72 | [220, 20, 60], 73 | [255, 0, 0], 74 | [0, 0, 142], 75 | [0, 0, 70], 76 | [0, 60, 100], 77 | [0, 80, 100], 78 | [0, 0, 230], 79 | [119, 11, 32]]) 80 | 81 | 82 | def get_pascal_labels(): 83 | """Load the mapping that associates pascal classes with label colors 84 | Returns: 85 | np.ndarray with dimensions (21, 3) 86 | """ 87 | return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], 88 | [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], 89 | [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], 90 | [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], 91 | [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], 92 | [0, 64, 128]]) 93 | 94 | 95 | def encode_segmap(mask): 96 | """Encode segmentation label images as pascal classes 97 | Args: 98 | mask (np.ndarray): raw segmentation label image of dimension 99 | (M, N, 3), in which the Pascal classes are encoded as colours. 100 | Returns: 101 | (np.ndarray): class map with dimensions (M,N), where the value at 102 | a given location is the integer denoting the class index. 103 | """ 104 | mask = mask.astype(int) 105 | label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16) 106 | for ii, label in enumerate(get_pascal_labels()): 107 | label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii 108 | label_mask = label_mask.astype(int) 109 | return label_mask 110 | 111 | 112 | def decode_seg_map_sequence(label_masks, dataset='pascal'): 113 | rgb_masks = [] 114 | for label_mask in label_masks: 115 | rgb_mask = decode_segmap(label_mask, dataset) 116 | rgb_masks.append(rgb_mask) 117 | rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2])) 118 | return rgb_masks 119 | 120 | 121 | def decode_segmap(label_mask, dataset, plot=False): 122 | """Decode segmentation class labels into a color image 123 | Args: 124 | label_mask (np.ndarray): an (M,N) array of integer values denoting 125 | the class label at each spatial location. 126 | plot (bool, optional): whether to show the resulting color image 127 | in a figure. 128 | Returns: 129 | (np.ndarray, optional): the resulting decoded color image. 130 | """ 131 | if dataset == 'pascal': 132 | n_classes = 21 133 | label_colours = get_pascal_labels() 134 | elif dataset == 'vocaug': 135 | n_classes = 21 136 | label_colours = get_pascal_labels() 137 | elif dataset == 'cityscapes': 138 | n_classes = 19 139 | label_colours = get_cityscapes_labels() 140 | else: 141 | raise NotImplementedError 142 | 143 | r = label_mask.copy() 144 | g = label_mask.copy() 145 | b = label_mask.copy() 146 | for ll in range(0, n_classes): 147 | r[label_mask == ll] = label_colours[ll, 0] 148 | g[label_mask == ll] = label_colours[ll, 1] 149 | b[label_mask == ll] = label_colours[ll, 2] 150 | 151 | # print('label mask:', label_mask.shape) 152 | rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) 153 | rgb[:, :, 0] = r / 255.0 154 | rgb[:, :, 1] = g / 255.0 155 | rgb[:, :, 2] = b / 255.0 156 | if plot: 157 | plt.imshow(rgb) 158 | plt.show() 159 | else: 160 | return rgb 161 | 162 | 163 | def de_normalize(rgb, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): 164 | rgb = rgb * std + mean 165 | return 255 * rgb 166 | 167 | 168 | def merge_into_row(input, target, pred): 169 | rgb = np.transpose(np.squeeze(input), (1, 2, 0)) # H, W, C 170 | rgb = de_normalize(rgb) 171 | 172 | target = np.squeeze(target) 173 | pred = np.squeeze(pred) 174 | 175 | target = 255 * decode_segmap(target, dataset='vocaug') 176 | pred = 255 * decode_segmap(pred, dataset='vocaug') 177 | 178 | img_merge = np.hstack([rgb, target, pred]) 179 | 180 | return img_merge 181 | 182 | 183 | def add_row(img_merge, row): 184 | return np.vstack([img_merge, row]) 185 | 186 | 187 | def save_image(img_merge, filename): 188 | img_merge = Image.fromarray(img_merge.astype('uint8')) 189 | img_merge.save(filename) 190 | 191 | 192 | def generate_param_report(logfile, param): 193 | log_file = open(logfile, 'w') 194 | for key, val in param.items(): 195 | log_file.write(key + ':' + str(val) + '\n') 196 | log_file.close() 197 | 198 | 199 | def get_iou(pred, gt, n_classes=21): 200 | total_iou = 0.0 201 | for i in range(len(pred)): 202 | pred_tmp = pred[i] 203 | gt_tmp = gt[i] 204 | 205 | intersect = [0] * n_classes 206 | union = [0] * n_classes 207 | for j in range(n_classes): 208 | match = (pred_tmp == j) + (gt_tmp == j) 209 | 210 | it = torch.sum(match == 2).item() 211 | un = torch.sum(match > 0).item() 212 | 213 | intersect[j] += it 214 | union[j] += un 215 | 216 | iou = [] 217 | for k in range(n_classes): 218 | if union[k] == 0: 219 | continue 220 | iou.append(intersect[k] / union[k]) 221 | 222 | img_iou = (sum(iou) / len(iou)) 223 | total_iou += img_iou 224 | 225 | return total_iou 226 | 227 | 228 | def get_miou(pred, gt, n_classes=21): 229 | m_iou = 0.0 230 | for i in range(len(pred)): 231 | pred_tmp = pred[i] 232 | gt_tmp = gt[i] 233 | 234 | intersect = [0] * n_classes 235 | union = [0] * n_classes 236 | for j in range(n_classes): 237 | match = (pred_tmp == j) + (gt_tmp == j) 238 | 239 | it = torch.sum(match == 2).item() 240 | un = torch.sum(match > 0).item() 241 | 242 | intersect[j] += it 243 | union[j] += un 244 | 245 | iou = [] 246 | for k in range(n_classes): 247 | if union[k] == 0: 248 | continue 249 | iou.append(intersect[k] / union[k]) 250 | 251 | img_iou = (sum(iou) / len(iou)) 252 | m_iou += img_iou 253 | 254 | return m_iou / len(pred) 255 | -------------------------------------------------------------------------------- /main_epoch.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/1/30 16:59 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ 7 | import os 8 | import shutil 9 | import socket 10 | import time 11 | from datetime import datetime 12 | 13 | import numpy as np 14 | 15 | from tensorboardX import SummaryWriter 16 | 17 | import torch 18 | from torch import nn 19 | from torch.optim import lr_scheduler 20 | from torch.utils.data import DataLoader 21 | import torch.nn.functional as F 22 | 23 | from torchvision.transforms import transforms 24 | from tqdm import tqdm 25 | 26 | import dataloaders.transforms as tr 27 | from libs import utils, criteria 28 | from dataloaders.voc_aug import VOCAug 29 | 30 | from libs.metrics import Result, AverageMeter 31 | from network.get_models import get_models 32 | 33 | from validation import validate 34 | 35 | 36 | def parse_command(): 37 | import argparse 38 | parser = argparse.ArgumentParser(description='DORN') 39 | parser.add_argument('--resume', default=None, type=str, metavar='PATH', 40 | help='path to latest checkpoint (default: ./run/run_1/checkpoint-5.pth.tar)') 41 | parser.add_argument('--model', default='deeplabv2', type=str, help='train which network') 42 | parser.add_argument('--crf', default=False, type=bool, help='if true, use crf as post process.') 43 | parser.add_argument('--msc', default=False, type=bool, help='if true, use multi-scale input.') 44 | parser.add_argument('--freeze', default=True, type=bool) 45 | parser.add_argument('--iter_size', default=2, type=int, help='when iter_size, opt step forward') 46 | parser.add_argument('-b', '--batch_size', default=4, type=int, help='mini-batch size (default: 4)') 47 | parser.add_argument('--epochs', default=40, type=int, metavar='N', 48 | help='number of total epochs to run (default: 15)') 49 | parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, 50 | metavar='LR', help='initial learning rate (default 0.0001)') 51 | parser.add_argument('--lr_patience', default=2, type=int, 52 | help='Patience of LR scheduler. See documentation of ReduceLROnPlateau.') 53 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 54 | help='momentum') 55 | parser.add_argument('--weight_decay', '--wd', default=0.0005, type=float, 56 | metavar='W', help='weight decay (default: 1e-4)') 57 | parser.add_argument('-j', '--workers', default=10, type=int, metavar='N', 58 | help='number of data loading workers (default: 10)') 59 | parser.add_argument('--dataset', default='vocaug', type=str, 60 | help='dataset used for training, kitti and nyu is available') 61 | parser.add_argument('--manual_seed', default=1, type=int, help='Manually set random seed') 62 | parser.add_argument('--gpu', default=None, type=str, help='if not none, use Single GPU') 63 | parser.add_argument('--print-freq', '-p', default=10, type=int, 64 | metavar='N', help='print frequency (default: 10)') 65 | args = parser.parse_args() 66 | return args 67 | 68 | 69 | def create_loader(args): 70 | if args.dataset == 'vocaug': 71 | composed_transforms_tr = transforms.Compose([ 72 | tr.RandomSized(512), 73 | tr.RandomRotate(15), 74 | tr.RandomHorizontalFlip(), 75 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 76 | tr.ToTensor()]) 77 | 78 | composed_transforms_ts = transforms.Compose([ 79 | tr.FixedResize(size=(512, 512)), 80 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 81 | tr.ToTensor()]) 82 | 83 | train_set = VOCAug(split='train', transform=composed_transforms_tr) 84 | val_set = VOCAug(split='val', transform=composed_transforms_ts) 85 | else: 86 | print('Database {} not available.'.format(args.dataset)) 87 | raise NotImplementedError 88 | 89 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, 90 | num_workers=args.workers, pin_memory=True) 91 | val_loader = DataLoader(val_set, batch_size=16, shuffle=False, num_workers=args.workers, pin_memory=True) 92 | 93 | return train_loader, val_loader 94 | 95 | 96 | def main(): 97 | args = parse_command() 98 | print(args) 99 | 100 | # if setting gpu id, the using single GPU 101 | if args.gpu: 102 | print('Single GPU Mode.') 103 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 104 | 105 | best_result = Result() 106 | best_result.set_to_worst() 107 | 108 | # set random seed 109 | torch.manual_seed(args.manual_seed) 110 | torch.cuda.manual_seed(args.manual_seed) 111 | np.random.seed(args.manual_seed) 112 | 113 | if torch.cuda.device_count() > 1: 114 | print("Let's use", torch.cuda.device_count(), "GPUs!") 115 | args.batch_size = args.batch_size * torch.cuda.device_count() 116 | else: 117 | print("Let's use GPU ", torch.cuda.current_device()) 118 | 119 | train_loader, val_loader = create_loader(args) 120 | 121 | if args.resume: 122 | assert os.path.isfile(args.resume), \ 123 | "=> no checkpoint found at '{}'".format(args.resume) 124 | print("=> loading checkpoint '{}'".format(args.resume)) 125 | checkpoint = torch.load(args.resume) 126 | 127 | start_epoch = checkpoint['epoch'] + 1 128 | best_result = checkpoint['best_result'] 129 | optimizer = checkpoint['optimizer'] 130 | 131 | # solve 'out of memory' 132 | model = checkpoint['model'] 133 | 134 | print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch'])) 135 | 136 | # clear memory 137 | del checkpoint 138 | # del model_dict 139 | torch.cuda.empty_cache() 140 | else: 141 | print("=> creating Model") 142 | model = get_models(args) 143 | print("=> model created.") 144 | start_epoch = 0 145 | 146 | # different modules have different learning rate 147 | train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr}, 148 | {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}] 149 | 150 | optimizer = torch.optim.SGD(train_params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 151 | 152 | # You can use DataParallel() whether you use Multi-GPUs or not 153 | model = nn.DataParallel(model).cuda() 154 | 155 | # when training, use reduceLROnPlateau to reduce learning rate 156 | scheduler = lr_scheduler.ReduceLROnPlateau( 157 | optimizer, 'min', patience=args.lr_patience) 158 | 159 | # loss function 160 | criterion = criteria._CrossEntropyLoss2d(size_average=True, batch_average=True) 161 | 162 | # create directory path 163 | output_directory = utils.get_output_directory(args) 164 | if not os.path.exists(output_directory): 165 | os.makedirs(output_directory) 166 | best_txt = os.path.join(output_directory, 'best.txt') 167 | config_txt = os.path.join(output_directory, 'config.txt') 168 | 169 | # write training parameters to config file 170 | if not os.path.exists(config_txt): 171 | with open(config_txt, 'w') as txtfile: 172 | args_ = vars(args) 173 | args_str = '' 174 | for k, v in args_.items(): 175 | args_str = args_str + str(k) + ':' + str(v) + ',\t\n' 176 | txtfile.write(args_str) 177 | 178 | # create log 179 | log_path = os.path.join(output_directory, 'logs', 180 | datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname()) 181 | if os.path.isdir(log_path): 182 | shutil.rmtree(log_path) 183 | os.makedirs(log_path) 184 | logger = SummaryWriter(log_path) 185 | 186 | start_iter = len(train_loader) * start_epoch + 1 187 | max_iter = len(train_loader) * (args.epochs - start_epoch + 1) + 1 188 | iter_save = len(train_loader) 189 | # iter_save = 1 190 | 191 | # train 192 | model.train() 193 | if args.freeze: 194 | model.module.freeze_backbone_bn() 195 | output_directory = utils.get_output_directory(args, check=True) 196 | 197 | average_meter = AverageMeter() 198 | train_meter = AverageMeter() 199 | 200 | for it in tqdm(range(start_iter, max_iter + 1), total=max_iter, leave=False, dynamic_ncols=True): 201 | optimizer.zero_grad() 202 | 203 | loss = 0 204 | 205 | data_time = 0 206 | gpu_time = 0 207 | 208 | for _ in range(args.iter_size): 209 | 210 | end = time.time() 211 | 212 | try: 213 | samples = next(loader_iter) 214 | except: 215 | loader_iter = iter(train_loader) 216 | samples = next(loader_iter) 217 | 218 | input = samples['image'].cuda() 219 | target = samples['label'].cuda() 220 | 221 | torch.cuda.synchronize() 222 | data_time_ = time.time() 223 | data_time += data_time_ - end 224 | 225 | with torch.autograd.detect_anomaly(): 226 | preds = model(input) # @wx 注意输出 227 | 228 | # print('#train preds size:', len(preds)) 229 | # print('#train preds[0] size:', preds[0].size()) 230 | iter_loss = 0 231 | if args.msc: 232 | for pred in preds: 233 | # Resize labels for {100%, 75%, 50%, Max} logits 234 | target_ = utils.resize_labels(target, shape=(pred.size()[-2], pred.size()[-1])) 235 | # print('#train pred size:', pred.size()) 236 | iter_loss += criterion(pred, target_) 237 | else: 238 | pred = preds 239 | target_ = utils.resize_labels(target, shape=(pred.size()[-2], pred.size()[-1])) 240 | # print('#train pred size:', pred.size()) 241 | # print('#train target size:', target.size()) 242 | iter_loss += criterion(pred, target_) 243 | 244 | # Backpropagate (just compute gradients wrt the loss) 245 | iter_loss /= args.iter_size 246 | iter_loss.backward() 247 | 248 | loss += float(iter_loss) 249 | 250 | gpu_time += time.time() - data_time_ 251 | 252 | torch.cuda.synchronize() 253 | 254 | # Update weights with accumulated gradients 255 | optimizer.step() 256 | 257 | # measure accuracy and record loss 258 | result = Result() 259 | pred = F.softmax(pred, dim=1) 260 | 261 | result.evaluate(pred.data.cpu().numpy(), target.data.cpu().numpy(), n_class=21) 262 | average_meter.update(result, gpu_time, data_time, input.size(0)) 263 | train_meter.update(result, gpu_time, data_time, input.size(0)) 264 | 265 | if it % args.print_freq == 0: 266 | print('=> output: {}'.format(output_directory)) 267 | print('Train Iter: [{0}/{1}]\t' 268 | 't_Data={data_time:.3f}({average.data_time:.3f}) ' 269 | 't_GPU={gpu_time:.3f}({average.gpu_time:.3f})\n\t' 270 | 'Loss={Loss:.5f} ' 271 | 'MeanAcc={result.mean_acc:.3f}({average.mean_acc:.3f}) ' 272 | 'MIOU={result.mean_iou:.3f}({average.mean_iou:.3f}) ' 273 | .format(it, max_iter, data_time=data_time, gpu_time=gpu_time, 274 | Loss=loss, result=result, average=average_meter.average())) 275 | logger.add_scalar('Train/Loss', loss, it) 276 | logger.add_scalar('Train/mean_acc', result.mean_iou, it) 277 | logger.add_scalar('Train/mean_iou', result.mean_acc, it) 278 | 279 | if it % iter_save == 0: 280 | epoch = it // iter_save 281 | resu1t, img_merge = validate(args, val_loader, model, epoch=epoch, logger=logger) 282 | 283 | # when rml doesn't fall, reduce learning rate 284 | scheduler.step(result.mean_iou) 285 | 286 | # save the change of learning_rate 287 | for i, param_group in enumerate(optimizer.param_groups): 288 | old_lr = float(param_group['lr']) 289 | logger.add_scalar('Lr/lr_' + str(i), old_lr, it) 290 | 291 | # vis the change between train and test 292 | train_avg = train_meter.average() 293 | logger.add_scalars('TrainVal/mean_acc', 294 | {'train_mean_acc':train_avg.mean_acc, 'test_mean_acc':result.mean_acc}, epoch) 295 | logger.add_scalars('TrainVal/mean_iou', 296 | {'train_mean_iou':train_avg.mean_iou, 'test_mean_iou':result.mean_iou}, epoch) 297 | train_meter.reset() 298 | # remember best rmse and save checkpoint 299 | is_best = result.mean_iou < best_result.mean_iou 300 | if is_best: 301 | best_result = result 302 | with open(best_txt, 'w') as txtfile: 303 | txtfile.write( 304 | "epoch={}, mean_iou={:.3f}, mean_acc={:.3f}" 305 | "t_gpu={:.4f}". 306 | format(epoch, result.mean_iou, result.mean_acc, result.gpu_time)) 307 | if img_merge is not None: 308 | img_filename = output_directory + '/comparison_best.png' 309 | utils.save_image(img_merge, img_filename) 310 | 311 | # save checkpoint for each epoch 312 | utils.save_checkpoint({ 313 | 'args': args, 314 | 'epoch': epoch, 315 | 'model': model, 316 | 'best_result': best_result, 317 | 'optimizer': optimizer, 318 | }, is_best, it, output_directory) 319 | 320 | # change to train mode 321 | model.train() 322 | if args.freeze: 323 | model.module.freeze_backbone_bn() 324 | 325 | logger.close() 326 | 327 | 328 | if __name__ == '__main__': 329 | main() 330 | -------------------------------------------------------------------------------- /main_iter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/2/13 21:12 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ 7 | 8 | import os 9 | import shutil 10 | import socket 11 | import time 12 | 13 | import numpy as np 14 | 15 | from tensorboardX import SummaryWriter 16 | 17 | import torch 18 | from torch import nn 19 | from torch.utils.data import DataLoader 20 | from datetime import datetime 21 | 22 | from torchvision.transforms import transforms 23 | from tqdm import tqdm 24 | 25 | import dataloaders.transforms as tr 26 | from libs import utils, criteria 27 | from dataloaders.voc_aug import VOCAug 28 | 29 | from libs.metrics import Result, AverageMeter 30 | from network.get_models import get_models 31 | 32 | from libs.lr_scheduler import PolynomialLR 33 | 34 | import torch.nn.functional as F 35 | 36 | from validation import validate 37 | 38 | 39 | def parse_command(): 40 | import argparse 41 | parser = argparse.ArgumentParser(description='DORN') 42 | parser.add_argument('--mode', default='train', type=str, help='train or test') 43 | parser.add_argument('--resume', default=None, type=str, metavar='PATH', 44 | help='path to latest checkpoint (default: ./run/run_1/checkpoint-5.pth.tar)') 45 | parser.add_argument('--model', default='deeplabv3plus', type=str, help='train which network') 46 | parser.add_argument('--crf', default=False, type=bool, help='if true, use crf as post process.') 47 | parser.add_argument('--msc', default=False, type=bool, help='if true, use multi-scale input.') 48 | parser.add_argument('--freeze', default=True, type=bool) 49 | parser.add_argument('--iter_size', default=2, type=int, help='when iter_size, opt step forward') 50 | parser.add_argument('-b', '--batch_size', default=8, type=int, help='mini-batch size (default: 4)') 51 | parser.add_argument('--max_iter', default=30000, type=int, metavar='N', 52 | help='number of total epochs to run (default: 15)') 53 | parser.add_argument('--lr', '--learning-rate', default=0.007, type=float, 54 | metavar='LR', help='initial learning rate (default 0.0001)') 55 | parser.add_argument('--lr_decay', default=10, type=int, 56 | help='Patience of LR scheduler. See documentation of ReduceLROnPlateau.') 57 | parser.add_argument('--power', default=0.9, type=float) 58 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 59 | help='momentum') 60 | parser.add_argument('--weight_decay', '--wd', default=0.0005, type=float, 61 | metavar='W', help='weight decay (default: 1e-4)') 62 | parser.add_argument('-j', '--workers', default=10, type=int, metavar='N', 63 | help='number of data loading workers (default: 10)') 64 | parser.add_argument('--dataset', default='vocaug', type=str, 65 | help='dataset used for training, kitti and nyu is available') 66 | parser.add_argument('--manual_seed', default=1, type=int, help='Manually set random seed') 67 | parser.add_argument('--gpu', default=None, type=str, help='if not none, use Single GPU') 68 | parser.add_argument('--print-freq', '-p', default=10, type=int, 69 | metavar='N', help='print frequency (default: 10)') 70 | parser.add_argument('--iter_save', default=500, type=int, help='every iter to save the model.') 71 | args = parser.parse_args() 72 | return args 73 | 74 | 75 | args = parse_command() 76 | print(args) 77 | 78 | # if setting gpu id, the using single GPU 79 | if args.gpu: 80 | print('Single GPU Mode.') 81 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 82 | 83 | best_result = Result() 84 | best_result.set_to_worst() 85 | 86 | 87 | def create_loader(args): 88 | if args.dataset == 'vocaug': 89 | composed_transforms_tr = transforms.Compose([ 90 | tr.RandomSized(512), 91 | tr.RandomRotate(15), 92 | tr.RandomHorizontalFlip(), 93 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 94 | tr.ToTensor()]) 95 | 96 | composed_transforms_ts = transforms.Compose([ 97 | tr.FixedResize(size=(512, 512)), 98 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 99 | tr.ToTensor()]) 100 | 101 | train_set = VOCAug(split='train', transform=composed_transforms_tr) 102 | val_set = VOCAug(split='val', transform=composed_transforms_ts) 103 | else: 104 | print('Database {} not available.'.format(args.dataset)) 105 | raise NotImplementedError 106 | 107 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, 108 | num_workers=args.workers, pin_memory=True) 109 | val_loader = DataLoader(val_set, batch_size=16, shuffle=False, num_workers=args.workers, pin_memory=True) 110 | 111 | return train_loader, val_loader 112 | 113 | 114 | def main(): 115 | global args, best_result, output_directory 116 | 117 | # set random seed 118 | torch.manual_seed(args.manual_seed) 119 | torch.cuda.manual_seed(args.manual_seed) 120 | np.random.seed(args.manual_seed) 121 | 122 | if torch.cuda.device_count() > 1: 123 | print("Let's use", torch.cuda.device_count(), "GPUs!") 124 | args.batch_size = args.batch_size * torch.cuda.device_count() 125 | else: 126 | print("Let's use GPU ", torch.cuda.current_device()) 127 | 128 | train_loader, val_loader = create_loader(args) 129 | 130 | if args.mode == 'test': 131 | if args.resume: 132 | assert os.path.isfile(args.resume), \ 133 | "=> no checkpoint found at '{}'".format(args.resume) 134 | print("=> loading checkpoint '{}'".format(args.resume)) 135 | checkpoint = torch.load(args.resume) 136 | 137 | epoch = checkpoint['epoch'] 138 | best_result = checkpoint['best_result'] 139 | 140 | # solve 'out of memory' 141 | model = checkpoint['model'] 142 | 143 | print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch'])) 144 | 145 | # clear memory 146 | del checkpoint 147 | # del model_dict 148 | torch.cuda.empty_cache() 149 | else: 150 | print("no trained model to test.") 151 | 152 | result, img_merge = validate(args, val_loader, model, epoch, logger=None) 153 | 154 | print('Test Result: mean iou={result.mean_iou:.3f}, mean acc={result.mean_acc:.3f}.'.format(result=result)) 155 | elif args.mode == 'train': 156 | if args.resume: 157 | assert os.path.isfile(args.resume), \ 158 | "=> no checkpoint found at '{}'".format(args.resume) 159 | print("=> loading checkpoint '{}'".format(args.resume)) 160 | checkpoint = torch.load(args.resume) 161 | 162 | start_iter = checkpoint['epoch'] + 1 163 | best_result = checkpoint['best_result'] 164 | optimizer = checkpoint['optimizer'] 165 | 166 | # solve 'out of memory' 167 | model = checkpoint['model'] 168 | 169 | print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch'])) 170 | 171 | # clear memory 172 | del checkpoint 173 | # del model_dict 174 | torch.cuda.empty_cache() 175 | else: 176 | print("=> creating Model") 177 | model = get_models(args) 178 | print("=> model created.") 179 | start_iter = 1 180 | 181 | # different modules have different learning rate 182 | train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr}, 183 | {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}] 184 | 185 | print(train_params) 186 | 187 | optimizer = torch.optim.SGD(train_params, lr=args.lr, momentum=args.momentum, 188 | weight_decay=args.weight_decay) 189 | 190 | # You can use DataParallel() whether you use Multi-GPUs or not 191 | model = nn.DataParallel(model).cuda() 192 | 193 | scheduler = PolynomialLR(optimizer=optimizer, 194 | step_size=args.lr_decay, 195 | iter_max=args.max_iter, 196 | power=args.power) 197 | 198 | # loss function 199 | criterion = criteria._CrossEntropyLoss2d(size_average=True, batch_average=True) 200 | 201 | # create directory path 202 | output_directory = utils.get_output_directory(args) 203 | if not os.path.exists(output_directory): 204 | os.makedirs(output_directory) 205 | best_txt = os.path.join(output_directory, 'best.txt') 206 | config_txt = os.path.join(output_directory, 'config.txt') 207 | 208 | # write training parameters to config file 209 | if not os.path.exists(config_txt): 210 | with open(config_txt, 'w') as txtfile: 211 | args_ = vars(args) 212 | args_str = '' 213 | for k, v in args_.items(): 214 | args_str = args_str + str(k) + ':' + str(v) + ',\t\n' 215 | txtfile.write(args_str) 216 | 217 | # create log 218 | log_path = os.path.join(output_directory, 'logs', 219 | datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname()) 220 | if os.path.isdir(log_path): 221 | shutil.rmtree(log_path) 222 | os.makedirs(log_path) 223 | logger = SummaryWriter(log_path) 224 | 225 | # train 226 | model.train() 227 | if args.freeze: 228 | model.module.freeze_backbone_bn() 229 | output_directory = utils.get_output_directory(args, check=True) 230 | 231 | average_meter = AverageMeter() 232 | 233 | for it in tqdm(range(start_iter, args.max_iter + 1), total=args.max_iter, leave=False, dynamic_ncols=True): 234 | # for it in range(1, args.max_iter + 1): 235 | # Clear gradients (ready to accumulate) 236 | optimizer.zero_grad() 237 | 238 | loss = 0 239 | 240 | data_time = 0 241 | gpu_time = 0 242 | 243 | for _ in range(args.iter_size): 244 | end = time.time() 245 | try: 246 | samples = next(loader_iter) 247 | except: 248 | loader_iter = iter(train_loader) 249 | samples = next(loader_iter) 250 | 251 | input = samples['image'].cuda() 252 | target = samples['label'].cuda() 253 | 254 | torch.cuda.synchronize() 255 | data_time_ = time.time() 256 | data_time += data_time_ - end 257 | 258 | with torch.autograd.detect_anomaly(): 259 | preds = model(input) # @wx 注意输出 260 | 261 | # print('#train preds size:', len(preds)) 262 | # print('#train preds[0] size:', preds[0].size()) 263 | iter_loss = 0 264 | if args.msc: 265 | for pred in preds: 266 | # Resize labels for {100%, 75%, 50%, Max} logits 267 | target_ = utils.resize_labels(target, shape=(pred.size()[-2], pred.size()[-1])) 268 | # print('#train pred size:', pred.size()) 269 | iter_loss += criterion(pred, target_) 270 | else: 271 | pred = preds 272 | target_ = utils.resize_labels(target, shape=(pred.size()[-2], pred.size()[-1])) 273 | # print('#train pred size:', pred.size()) 274 | # print('#train target size:', target.size()) 275 | iter_loss += criterion(pred, target_) 276 | 277 | # Backpropagate (just compute gradients wrt the loss) 278 | iter_loss /= args.iter_size 279 | iter_loss.backward() 280 | 281 | loss += float(iter_loss) 282 | 283 | gpu_time += time.time() - data_time_ 284 | 285 | torch.cuda.synchronize() 286 | 287 | # Update weights with accumulated gradients 288 | optimizer.step() 289 | 290 | # Update learning rate 291 | scheduler.step(epoch=it) 292 | 293 | # measure accuracy and record loss 294 | result = Result() 295 | pred = F.softmax(pred, dim=1) 296 | 297 | result.evaluate(pred.data.cpu().numpy(), target.data.cpu().numpy(), n_class=21) 298 | average_meter.update(result, gpu_time, data_time, input.size(0)) 299 | 300 | if it % args.print_freq == 0: 301 | print('=> output: {}'.format(output_directory)) 302 | print('Train Iter: [{0}/{1}]\t' 303 | 't_Data={data_time:.3f}({average.data_time:.3f}) ' 304 | 't_GPU={gpu_time:.3f}({average.gpu_time:.3f})\n\t' 305 | 'Loss={Loss:.5f} ' 306 | 'MeanAcc={result.mean_acc:.3f}({average.mean_acc:.3f}) ' 307 | 'MIOU={result.mean_iou:.3f}({average.mean_iou:.3f}) ' 308 | .format(it, args.max_iter, data_time=data_time, gpu_time=gpu_time, 309 | Loss=loss, result=result, average=average_meter.average())) 310 | logger.add_scalar('Train/Loss', loss, it) 311 | logger.add_scalar('Train/mean_acc', result.mean_iou, it) 312 | logger.add_scalar('Train/mean_iou', result.mean_acc, it) 313 | 314 | for i, param_group in enumerate(optimizer.param_groups): 315 | old_lr = float(param_group['lr']) 316 | logger.add_scalar('Lr/lr_' + str(i), old_lr, it) 317 | 318 | if it % args.iter_save == 0: 319 | resu1t, img_merge = validate(args, val_loader, model, epoch=it, logger=logger) 320 | 321 | # remember best rmse and save checkpoint 322 | is_best = result.mean_iou < best_result.mean_iou 323 | if is_best: 324 | best_result = result 325 | with open(best_txt, 'w') as txtfile: 326 | txtfile.write( 327 | "Iter={}, mean_iou={:.3f}, mean_acc={:.3f}" 328 | "t_gpu={:.4f}". 329 | format(it, result.mean_iou, result.mean_acc, result.gpu_time)) 330 | if img_merge is not None: 331 | img_filename = output_directory + '/comparison_best.png' 332 | utils.save_image(img_merge, img_filename) 333 | 334 | # save checkpoint for each epoch 335 | utils.save_checkpoint({ 336 | 'args': args, 337 | 'epoch': it, 338 | 'model': model, 339 | 'best_result': best_result, 340 | 'optimizer': optimizer, 341 | }, is_best, it, output_directory) 342 | 343 | # change to train mode 344 | model.train() 345 | if args.freeze: 346 | model.module.freeze_backbone_bn() 347 | 348 | logger.close() 349 | else: 350 | print('no mode named as ', args.mode) 351 | exit(-1) 352 | 353 | 354 | if __name__ == '__main__': 355 | main() 356 | -------------------------------------------------------------------------------- /network/Auto_Deeplab/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/2/17 22:15 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ -------------------------------------------------------------------------------- /network/Auto_Deeplab/auto_deeplab.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/2/17 22:15 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | from network.Auto_Deeplab import layers 11 | 12 | 13 | class AutoDeeplab(nn.Module): 14 | def __init__(self, in_channels, out_channels, layout, cell=layers.Cell, activation=nn.ReLU6, upsample_at_end=True): 15 | """ 16 | A general implementation of the network architecture presented in the Auto Deeplab paper 17 | :param layout: A list of integers representing the y coordinate of a cell in the diagram used in the paper (zero-indexed) 18 | :param cell: The cell class to use. 19 | """ 20 | super(AutoDeeplab, self).__init__() 21 | self.upsample_at_end = upsample_at_end 22 | self.cells = [] 23 | 24 | self.initial_stem = nn.Sequential( 25 | nn.Conv2d(in_channels, 64, 3, stride=2, padding=1), 26 | nn.BatchNorm2d(64), 27 | activation() 28 | ).cuda() 29 | 30 | self.cells.append(nn.Sequential( 31 | nn.Conv2d(64, 64, 3, padding=1), 32 | nn.BatchNorm2d(64), 33 | activation() 34 | ).cuda()) 35 | 36 | self.cells.append(nn.Sequential( 37 | nn.Conv2d(64, 128, 3, stride=2, padding=1), 38 | nn.BatchNorm2d(128), 39 | activation() 40 | ).cuda()) 41 | 42 | # self.stem = nn.Sequential( 43 | # nn.Conv2d(in_channels, 64, 3, stride=2, padding=1), 44 | # nn.Conv2d(64, 64, 3, padding=1), 45 | # nn.Conv2d(64, 128, 3, stride=2, padding=1), 46 | # ).cuda() 47 | 48 | prev_channels = 64 49 | channels = 128 50 | assert layout[0] == 2 51 | for i, depth in enumerate(layout): 52 | curr_cell = cell(channels, prev_channels, channels).cuda() 53 | prev_channels = channels 54 | layer = [] 55 | # todo dilation? 56 | 57 | if i != len(layout) - 1: 58 | next_depth = layout[i + 1] 59 | assert abs(depth - next_depth) <= 1 60 | if next_depth > depth: 61 | # Downsampling 62 | layer.append(nn.Conv2d(channels, channels * 2, 3, stride=2, padding=1)) 63 | channels = channels * 2 64 | elif next_depth < depth: 65 | # Upsampling 66 | layer.append(nn.Upsample(scale_factor=2, mode="bilinear")) 67 | layer.append(nn.Conv2d(channels, channels // 2, 1)) 68 | channels = channels // 2 69 | 70 | # The cell is held outside the Sequential as it needs two arguments, while Sequential only accepts one 71 | self.cells.append((curr_cell, nn.Sequential(*layer).cuda())) 72 | 73 | # Pool, then reduce channels to the desired value 74 | self.pool = nn.Sequential( 75 | layers.ASPP(channels, 256, (6, 12, 18), (6, 12, 18)), 76 | nn.Conv2d(256, out_channels, 3, padding=1) 77 | ).cuda() 78 | 79 | self.upsampler = nn.Upsample(scale_factor=2 ** layout[-1], mode="bilinear") 80 | 81 | def forward(self, x): 82 | x = self.initial_stem(x) 83 | 84 | # Run stem layers 85 | prev_hs = [self.cells[0](x)] 86 | prev_hs.append(self.cells[1](prev_hs[0])) 87 | 88 | for i, layer in enumerate(self.cells[2:], 2): 89 | curr = layer[0](prev_hs[-1], prev_hs[-2]) # Execute cell 90 | curr = layer[1](curr) # Execute rest of the layer 91 | prev_hs[-2] = prev_hs[-1] 92 | prev_hs[-1] = curr 93 | 94 | x = self.pool(prev_hs[-1]) 95 | if self.upsample_at_end: 96 | x = self.upsampler(x) 97 | 98 | return x 99 | 100 | 101 | if __name__ == '__main__': 102 | layout = [2, 2, 2, 2, 3, 4, 3, 4, 4, 5, 5, 4, 3] 103 | model = AutoDeeplab(3, 3, layout, layers.Cell) 104 | print(model) 105 | print(model.cells) 106 | x = torch.rand((2, 3, 128, 128)).cuda() 107 | model(x) 108 | -------------------------------------------------------------------------------- /network/Auto_Deeplab/layers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/2/17 22:15 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ 7 | 8 | import torch.nn as nn 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | 13 | def fixed_padding(inputs, kernel_size, dilation): 14 | """ 15 | https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/modeling/backbone/xception.py 16 | :param kernel_size: 17 | :param dilation: 18 | :return: 19 | """ 20 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 21 | pad_total = kernel_size_effective - 1 22 | pad_beg = pad_total // 2 23 | pad_end = pad_total - pad_beg 24 | padded_inputs = F.pad(inputs, [pad_beg, pad_end, pad_beg, pad_end]) 25 | return padded_inputs 26 | 27 | 28 | # from https://github.com/quark0/darts/blob/master/cnn/operations.py 29 | class DilConv(nn.Module): 30 | def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True): 31 | super(DilConv, self).__init__() 32 | self.op = nn.Sequential( 33 | nn.ReLU(inplace=False), 34 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, 35 | groups=C_in, bias=False), 36 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), 37 | nn.BatchNorm2d(C_out, affine=affine), 38 | ) 39 | 40 | def forward(self, x): 41 | return self.op(x) 42 | 43 | 44 | class SepConv(nn.Module): 45 | 46 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): 47 | super(SepConv, self).__init__() 48 | self.op = nn.Sequential( 49 | nn.ReLU(inplace=False), 50 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False), 51 | nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False), 52 | nn.BatchNorm2d(C_in, affine=affine), 53 | nn.ReLU(inplace=False), 54 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, groups=C_in, bias=False), 55 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), 56 | nn.BatchNorm2d(C_out, affine=affine), 57 | ) 58 | 59 | def forward(self, x): 60 | return self.op(x) 61 | 62 | 63 | class Cell(nn.Module): 64 | def __init__(self, in_channels_h1, in_channels_h2, out_channels, dilation=1, activation=nn.ReLU6, 65 | bn=nn.BatchNorm2d): 66 | """ 67 | Initialization of inverted residual block 68 | :param in_channels_h1: number of input channels in h-1 69 | :param in_channels_h2: number of input channels in h-2 70 | :param out_channels: number of output channels 71 | :param t: the expansion factor of block 72 | :param s: stride of the first convolution 73 | :param dilation: dilation rate of 3*3 depthwise conv ?? fixme 74 | """ 75 | super(Cell, self).__init__() 76 | self.in_ = in_channels_h1 77 | self.out_ = out_channels 78 | self.activation = activation 79 | 80 | if in_channels_h1 > in_channels_h2: 81 | self.preprocess = FactorizedReduce(in_channels_h2, in_channels_h1) 82 | elif in_channels_h1 < in_channels_h2: 83 | # todo check this 84 | self.preprocess = nn.ConvTranspose2d(in_channels_h2, in_channels_h1, 3, stride=2, padding=1, output_padding=1) 85 | else: 86 | self.preprocess = None 87 | 88 | #self.atr3x3 = DilConv(in_channels_h1, out_channels, 3, 1, 1, dilation) 89 | #self.atr5x5 = DilConv(in_channels_h1, out_channels, 5, 1, 2, dilation) 90 | 91 | #self.sep3x3 = SepConv(in_channels_h1, out_channels, 3, 1, 1) 92 | #self.sep5x5 = SepConv(in_channels_h1, out_channels, 5, 1, 2) 93 | 94 | # Top 1 95 | self.top1_atr5x5 = DilConv(in_channels_h1, in_channels_h1, 5, 1, 2, dilation) 96 | self.top1_sep3x3 = SepConv(in_channels_h1, in_channels_h1, 3, 1, 1) 97 | 98 | # Top 2 99 | self.top2_sep5x5_1 = SepConv(in_channels_h1, in_channels_h1, 5, 1, 2) 100 | self.top2_sep5x5_2 = SepConv(in_channels_h1, in_channels_h1, 5, 1, 2) 101 | 102 | # Middle 103 | self.middle_sep3x3_1 = SepConv(in_channels_h1, in_channels_h1, 3, 1, 1) 104 | self.middle_sep3x3_2 = SepConv(in_channels_h1, in_channels_h1, 3, 1, 1) 105 | 106 | # Bottom 1 107 | self.bottom1_atr3x3 = DilConv(in_channels_h1, in_channels_h1, 3, 1, 1, dilation) 108 | self.bottom1_sep3x3 = SepConv(in_channels_h1, in_channels_h1, 3, 1, 1) 109 | 110 | # Bottom 2 111 | self.bottom2_atr5x5 = DilConv(in_channels_h1, in_channels_h1, 5, 1, 2, dilation) 112 | self.bottom2_sep5x5 = SepConv(in_channels_h1, in_channels_h1, 5, 1, 2) 113 | 114 | self.concate_conv = nn.Conv2d(in_channels_h1*5, out_channels, 1) 115 | 116 | def forward(self, h_1, h_2): 117 | """ 118 | :param h_1: 119 | :param h_2: 120 | :return: 121 | """ 122 | 123 | if self.preprocess is not None: 124 | h_2 = self.preprocess(h_2) 125 | 126 | top1 = self.top1_atr5x5(h_2) + self.top1_sep3x3(h_1) 127 | bottom1 = self.bottom1_atr3x3(h_1) + self.bottom1_sep3x3(h_2) 128 | middle = self.middle_sep3x3_1(h_2) + self.middle_sep3x3_2(bottom1) 129 | 130 | top2 = self.top2_sep5x5_1(top1) + self.top2_sep5x5_2(middle) 131 | bottom2 = self.bottom2_atr5x5(top2) + self.bottom2_sep5x5(bottom1) 132 | 133 | concat = torch.cat([top1, top2, middle, bottom2, bottom1], dim=1) 134 | 135 | return self.concate_conv(concat) 136 | 137 | 138 | class ASPP(nn.Module): 139 | def __init__(self, in_channels, out_channels, paddings, dilations): 140 | # todo depthwise separable conv 141 | super(ASPP, self).__init__() 142 | self.conv11 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False, ), 143 | nn.BatchNorm2d(256)) 144 | self.conv33_1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, 145 | padding=paddings[0], dilation=dilations[0], bias=False, ), 146 | nn.BatchNorm2d(256)) 147 | self.conv33_2 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, 148 | padding=paddings[1], dilation=dilations[1], bias=False, ), 149 | nn.BatchNorm2d(256)) 150 | self.conv33_3 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, 151 | padding=paddings[2], dilation=dilations[2], bias=False, ), 152 | nn.BatchNorm2d(256)) 153 | self.concate_conv = nn.Sequential(nn.Conv2d(out_channels * 5, out_channels, 1, bias=False), 154 | nn.BatchNorm2d(256)) 155 | # self.upsample = nn.Upsample(mode='bilinear', align_corners=True) 156 | 157 | def forward(self, x): 158 | conv11 = self.conv11(x) 159 | conv33_1 = self.conv33_1(x) 160 | conv33_2 = self.conv33_2(x) 161 | conv33_3 = self.conv33_3(x) 162 | 163 | # image pool and upsample 164 | image_pool = nn.AvgPool2d(kernel_size=x.size()[2:]) 165 | image_pool = image_pool(x) 166 | image_pool = self.conv11(image_pool) 167 | upsample = nn.Upsample(size=x.size()[2:], mode='bilinear', align_corners=True) 168 | upsample = upsample(image_pool) 169 | 170 | # concate 171 | concate = torch.cat([conv11, conv33_1, conv33_2, conv33_3, upsample], dim=1) 172 | 173 | return self.concate_conv(concate) 174 | 175 | 176 | # Based on quark0/darts on github 177 | class FactorizedReduce(nn.Module): 178 | 179 | def __init__(self, C_in, C_out, affine=True): 180 | super(FactorizedReduce, self).__init__() 181 | assert C_out % 2 == 0 182 | self.relu = nn.ReLU(inplace=False) 183 | self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) 184 | self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) 185 | self.bn = nn.BatchNorm2d(C_out, affine=affine) 186 | 187 | def forward(self, x): 188 | x = self.relu(x) 189 | padded = F.pad(x, (0, 1, 0, 1), "constant", 0) 190 | path2 = self.conv_2(padded[:, :, 1:, 1:]) 191 | out = torch.cat([self.conv_1(x), path2], dim=1) 192 | out = self.bn(out) 193 | return out -------------------------------------------------------------------------------- /network/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/1/30 16:59 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ -------------------------------------------------------------------------------- /network/base/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/2/17 22:50 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ -------------------------------------------------------------------------------- /network/base/deform_conv/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/2/18 15:19 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ -------------------------------------------------------------------------------- /network/base/deform_conv/deform_conv.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/2/18 15:23 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ 7 | 8 | import numpy as np 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch.autograd import Variable 13 | 14 | 15 | """ 16 | https://github.com/ChunhuanLin/deform_conv_pytorch/blob/master/deform_conv.py 17 | """ 18 | 19 | 20 | class DeformConv2D(nn.Module): 21 | def __init__(self, inc, outc, kernel_size=3, padding=1, bias=None): 22 | super(DeformConv2D, self).__init__() 23 | self.kernel_size = kernel_size 24 | self.padding = padding 25 | self.zero_padding = nn.ZeroPad2d(padding) 26 | self.conv_kernel = nn.Conv2d(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias) 27 | 28 | def forward(self, x, offset): 29 | dtype = offset.data.type() 30 | ks = self.kernel_size 31 | N = offset.size(1) // 2 32 | 33 | # Change offset's order from [x1, x2, ..., y1, y2, ...] to [x1, y1, x2, y2, ...] 34 | # Codes below are written to make sure same results of MXNet implementation. 35 | # You can remove them, and it won't influence the module's performance. 36 | offsets_index = Variable(torch.cat([torch.arange(0, 2 * N, 2), torch.arange(1, 2 * N + 1, 2)]), 37 | requires_grad=False).type_as(x).long() 38 | offsets_index = offsets_index.unsqueeze(dim=0).unsqueeze(dim=-1).unsqueeze(dim=-1).expand(*offset.size()) 39 | offset = torch.gather(offset, dim=1, index=offsets_index) 40 | # ------------------------------------------------------------------------ 41 | 42 | if self.padding: 43 | x = self.zero_padding(x) 44 | 45 | # (b, 2N, h, w) 46 | p = self._get_p(offset, dtype) 47 | 48 | # (b, h, w, 2N) 49 | p = p.contiguous().permute(0, 2, 3, 1) 50 | q_lt = Variable(p.data, requires_grad=False).floor() 51 | q_rb = q_lt + 1 52 | 53 | q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2) - 1), torch.clamp(q_lt[..., N:], 0, x.size(3) - 1)], 54 | dim=-1).long() 55 | q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2) - 1), torch.clamp(q_rb[..., N:], 0, x.size(3) - 1)], 56 | dim=-1).long() 57 | q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], -1) 58 | q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], -1) 59 | 60 | # (b, h, w, N) 61 | mask = torch.cat([p[..., :N].lt(self.padding) + p[..., :N].gt(x.size(2) - 1 - self.padding), 62 | p[..., N:].lt(self.padding) + p[..., N:].gt(x.size(3) - 1 - self.padding)], dim=-1).type_as(p) 63 | mask = mask.detach() 64 | floor_p = p - (p - torch.floor(p)) 65 | p = p * (1 - mask) + floor_p * mask 66 | p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2) - 1), torch.clamp(p[..., N:], 0, x.size(3) - 1)], dim=-1) 67 | 68 | # bilinear kernel (b, h, w, N) 69 | g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:])) 70 | g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:])) 71 | g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:])) 72 | g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:])) 73 | 74 | # (b, c, h, w, N) 75 | x_q_lt = self._get_x_q(x, q_lt, N) 76 | x_q_rb = self._get_x_q(x, q_rb, N) 77 | x_q_lb = self._get_x_q(x, q_lb, N) 78 | x_q_rt = self._get_x_q(x, q_rt, N) 79 | 80 | # (b, c, h, w, N) 81 | x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \ 82 | g_rb.unsqueeze(dim=1) * x_q_rb + \ 83 | g_lb.unsqueeze(dim=1) * x_q_lb + \ 84 | g_rt.unsqueeze(dim=1) * x_q_rt 85 | 86 | x_offset = self._reshape_x_offset(x_offset, ks) 87 | out = self.conv_kernel(x_offset) 88 | 89 | return out 90 | 91 | def _get_p_n(self, N, dtype): 92 | p_n_x, p_n_y = np.meshgrid(range(-(self.kernel_size - 1) // 2, (self.kernel_size - 1) // 2 + 1), 93 | range(-(self.kernel_size - 1) // 2, (self.kernel_size - 1) // 2 + 1), indexing='ij') 94 | # (2N, 1) 95 | p_n = np.concatenate((p_n_x.flatten(), p_n_y.flatten())) 96 | p_n = np.reshape(p_n, (1, 2 * N, 1, 1)) 97 | p_n = Variable(torch.from_numpy(p_n).type(dtype), requires_grad=False) 98 | 99 | return p_n 100 | 101 | @staticmethod 102 | def _get_p_0(h, w, N, dtype): 103 | p_0_x, p_0_y = np.meshgrid(range(1, h + 1), range(1, w + 1), indexing='ij') 104 | p_0_x = p_0_x.flatten().reshape(1, 1, h, w).repeat(N, axis=1) 105 | p_0_y = p_0_y.flatten().reshape(1, 1, h, w).repeat(N, axis=1) 106 | p_0 = np.concatenate((p_0_x, p_0_y), axis=1) 107 | p_0 = Variable(torch.from_numpy(p_0).type(dtype), requires_grad=False) 108 | 109 | return p_0 110 | 111 | def _get_p(self, offset, dtype): 112 | N, h, w = offset.size(1) // 2, offset.size(2), offset.size(3) 113 | 114 | # (1, 2N, 1, 1) 115 | p_n = self._get_p_n(N, dtype) 116 | # (1, 2N, h, w) 117 | p_0 = self._get_p_0(h, w, N, dtype) 118 | p = p_0 + p_n + offset 119 | return p 120 | 121 | def _get_x_q(self, x, q, N): 122 | b, h, w, _ = q.size() 123 | padded_w = x.size(3) 124 | c = x.size(1) 125 | # (b, c, h*w) 126 | x = x.contiguous().view(b, c, -1) 127 | 128 | # (b, h, w, N) 129 | index = q[..., :N] * padded_w + q[..., N:] # offset_x*w + offset_y 130 | # (b, c, h*w*N) 131 | index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1) 132 | 133 | x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N) 134 | 135 | return x_offset 136 | 137 | @staticmethod 138 | def _reshape_x_offset(x_offset, ks): 139 | b, c, h, w, N = x_offset.size() 140 | x_offset = torch.cat([x_offset[..., s:s + ks].contiguous().view(b, c, h, w * ks) for s in range(0, N, ks)], 141 | dim=-1) 142 | x_offset = x_offset.contiguous().view(b, c, h * ks, w * ks) 143 | 144 | return x_offset 145 | 146 | 147 | if __name__ == '__main__': 148 | x = torch.randn(4, 3, 5, 5) 149 | 150 | p_conv = nn.Conv2d(3, 2 * 3 * 3, kernel_size=3, padding=1, stride=1) 151 | d_conv2 = DeformConv2D(3, 64) 152 | 153 | offset = p_conv(x) 154 | y = d_conv2(x, offset) 155 | 156 | print(y.size()) 157 | print('y = ', y) 158 | print("offset:") 159 | print(offset) 160 | -------------------------------------------------------------------------------- /network/base/deform_conv/deform_conv_v2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/2/19 15:32 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ 7 | 8 | 9 | import torch 10 | from torch import nn 11 | 12 | """ 13 | https://github.com/4uiiurz1/pytorch-deform-conv-v2/blob/master/deform_conv_v2.py 14 | """ 15 | 16 | 17 | class DeformConv2d(nn.Module): 18 | def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, bias=None, modulation=False): 19 | """ 20 | Args: 21 | modulation (bool, optional): If True, Modulated Defomable Convolution (Deformable ConvNets v2). 22 | """ 23 | super(DeformConv2d, self).__init__() 24 | self.kernel_size = kernel_size 25 | self.padding = padding 26 | self.stride = stride 27 | self.zero_padding = nn.ZeroPad2d(padding) 28 | self.conv = nn.Conv2d(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias) 29 | 30 | self.p_conv = nn.Conv2d(inc, 2 * kernel_size * kernel_size, kernel_size=3, padding=1, stride=stride) 31 | nn.init.constant_(self.p_conv.weight, 0) 32 | self.p_conv.register_backward_hook(self._set_lr) 33 | 34 | self.modulation = modulation 35 | if modulation: 36 | self.m_conv = nn.Conv2d(inc, kernel_size * kernel_size, kernel_size=3, padding=1, stride=stride) 37 | nn.init.constant_(self.m_conv.weight, 0.5) 38 | self.m_conv.register_backward_hook(self._set_lr) 39 | 40 | @staticmethod 41 | def _set_lr(module, grad_input, grad_output): 42 | grad_input = (grad_input[i] * 0.1 for i in range(len(grad_input))) 43 | grad_output = (grad_output[i] * 0.1 for i in range(len(grad_output))) 44 | 45 | def forward(self, x): 46 | offset = self.p_conv(x) 47 | if self.modulation: 48 | m = torch.sigmoid(self.m_conv(x)) 49 | 50 | dtype = offset.data.type() 51 | ks = self.kernel_size 52 | N = offset.size(1) // 2 53 | 54 | if self.padding: 55 | x = self.zero_padding(x) 56 | 57 | # (b, 2N, h, w) 58 | p = self._get_p(offset, dtype) 59 | 60 | # (b, h, w, 2N) 61 | p = p.contiguous().permute(0, 2, 3, 1) 62 | q_lt = p.detach().floor() 63 | q_rb = q_lt + 1 64 | 65 | q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2) - 1), torch.clamp(q_lt[..., N:], 0, x.size(3) - 1)], 66 | dim=-1).long() 67 | q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2) - 1), torch.clamp(q_rb[..., N:], 0, x.size(3) - 1)], 68 | dim=-1).long() 69 | q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1) 70 | q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1) 71 | 72 | # clip p 73 | p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2) - 1), torch.clamp(p[..., N:], 0, x.size(3) - 1)], dim=-1) 74 | 75 | # bilinear kernel (b, h, w, N) 76 | g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:])) 77 | g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:])) 78 | g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:])) 79 | g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:])) 80 | 81 | # (b, c, h, w, N) 82 | x_q_lt = self._get_x_q(x, q_lt, N) 83 | x_q_rb = self._get_x_q(x, q_rb, N) 84 | x_q_lb = self._get_x_q(x, q_lb, N) 85 | x_q_rt = self._get_x_q(x, q_rt, N) 86 | 87 | # (b, c, h, w, N) 88 | x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \ 89 | g_rb.unsqueeze(dim=1) * x_q_rb + \ 90 | g_lb.unsqueeze(dim=1) * x_q_lb + \ 91 | g_rt.unsqueeze(dim=1) * x_q_rt 92 | 93 | # modulation 94 | if self.modulation: 95 | m = m.contiguous().permute(0, 2, 3, 1) 96 | m = m.unsqueeze(dim=1) 97 | m = torch.cat([m for _ in range(x_offset.size(1))], dim=1) 98 | x_offset *= m 99 | 100 | x_offset = self._reshape_x_offset(x_offset, ks) 101 | out = self.conv(x_offset) 102 | 103 | return out 104 | 105 | def _get_p_n(self, N, dtype): 106 | p_n_x, p_n_y = torch.meshgrid( 107 | torch.arange(-(self.kernel_size - 1) // 2, (self.kernel_size - 1) // 2 + 1), 108 | torch.arange(-(self.kernel_size - 1) // 2, (self.kernel_size - 1) // 2 + 1)) 109 | # (2N, 1) 110 | p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0) 111 | p_n = p_n.view(1, 2 * N, 1, 1).type(dtype) 112 | 113 | return p_n 114 | 115 | def _get_p_0(self, h, w, N, dtype): 116 | p_0_x, p_0_y = torch.meshgrid( 117 | torch.arange(1, h * self.stride + 1, self.stride), 118 | torch.arange(1, w * self.stride + 1, self.stride)) 119 | p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1) 120 | p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1) 121 | p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype) 122 | 123 | return p_0 124 | 125 | def _get_p(self, offset, dtype): 126 | N, h, w = offset.size(1) // 2, offset.size(2), offset.size(3) 127 | 128 | # (1, 2N, 1, 1) 129 | p_n = self._get_p_n(N, dtype) 130 | # (1, 2N, h, w) 131 | p_0 = self._get_p_0(h, w, N, dtype) 132 | p = p_0 + p_n + offset 133 | return p 134 | 135 | def _get_x_q(self, x, q, N): 136 | b, h, w, _ = q.size() 137 | padded_w = x.size(3) 138 | c = x.size(1) 139 | # (b, c, h*w) 140 | x = x.contiguous().view(b, c, -1) 141 | 142 | # (b, h, w, N) 143 | index = q[..., :N] * padded_w + q[..., N:] # offset_x*w + offset_y 144 | # (b, c, h*w*N) 145 | index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1) 146 | 147 | x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N) 148 | 149 | return x_offset 150 | 151 | @staticmethod 152 | def _reshape_x_offset(x_offset, ks): 153 | b, c, h, w, N = x_offset.size() 154 | x_offset = torch.cat([x_offset[..., s:s + ks].contiguous().view(b, c, h, w * ks) for s in range(0, N, ks)], 155 | dim=-1) 156 | x_offset = x_offset.contiguous().view(b, c, h * ks, w * ks) 157 | 158 | return x_offset 159 | 160 | 161 | if __name__ == '__main__': 162 | x = torch.randn(4, 3, 5, 5) 163 | 164 | d_conv = DeformConv2d(inc=3, outc=64, modulation=True) 165 | 166 | y = d_conv(x) 167 | 168 | print(y.size()) 169 | -------------------------------------------------------------------------------- /network/base/msc.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/1/31 19:03 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | 13 | class MSC(nn.Module): 14 | """Multi-scale inputs""" 15 | 16 | def __init__(self, scale, pyramids=[0.5, 0.75]): 17 | super(MSC, self).__init__() 18 | self.scale = scale 19 | self.pyramids = pyramids 20 | 21 | def forward(self, x): 22 | # Original 23 | logits = self.scale(x) 24 | interp = lambda l: F.interpolate( 25 | l, size=logits.shape[2:], mode="bilinear", align_corners=True 26 | ) 27 | 28 | # Scaled 29 | logits_pyramid = [] 30 | for p in self.pyramids: 31 | size = [int(s * p) for s in x.shape[2:]] 32 | h = F.interpolate(x, size=size, mode="bilinear", align_corners=True) 33 | logits_pyramid.append(self.scale(h)) 34 | 35 | # Pixel-wise max 36 | logits_all = [logits] + [interp(l) for l in logits_pyramid] 37 | logits_max = torch.max(torch.stack(logits_all), dim=0)[0] 38 | 39 | if self.training: 40 | return [logits] + logits_pyramid + [logits_max] 41 | else: 42 | return logits_max 43 | 44 | def freeze_bn(self): 45 | self.scale.freeze_bn() 46 | 47 | def freeze_backbone_bn(self): 48 | self.scale.freeze_backbone_bn() 49 | 50 | def get_1x_lr_params(self): 51 | return self.scale.get_1x_lr_params() 52 | 53 | def get_10x_lr_params(self): 54 | return self.scale.get_10x_lr_params() 55 | -------------------------------------------------------------------------------- /network/base/oprations.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/1/30 21:20 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | 13 | def fixed_padding(inputs, kernel_size, dilation): 14 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 15 | pad_total = kernel_size_effective - 1 16 | pad_beg = pad_total // 2 17 | pad_end = pad_total - pad_beg 18 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) 19 | return padded_inputs 20 | 21 | 22 | class SeparableConv2d(nn.Module): 23 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, BatchNorm=None): 24 | super(SeparableConv2d, self).__init__() 25 | 26 | self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, 0, dilation, 27 | groups=inplanes, bias=bias) 28 | self.bn = BatchNorm(inplanes) 29 | self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias) 30 | 31 | def forward(self, x): 32 | x = fixed_padding(x, self.conv1.kernel_size[0], dilation=self.conv1.dilation[0]) 33 | x = self.conv1(x) 34 | x = self.bn(x) 35 | x = self.pointwise(x) 36 | return x 37 | 38 | 39 | class ASPP_module(nn.Module): 40 | def __init__(self, inplanes, planes, rate): 41 | super(ASPP_module, self).__init__() 42 | if rate == 1: 43 | kernel_size = 1 44 | padding = 0 45 | else: 46 | kernel_size = 3 47 | padding = rate 48 | self.atrous_convolution = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, 49 | stride=1, padding=padding, dilation=rate, bias=False) 50 | self.bn = nn.BatchNorm2d(planes) 51 | self.relu = nn.ReLU() 52 | 53 | self._init_weight() 54 | 55 | def forward(self, x): 56 | x = self.atrous_convolution(x) 57 | x = self.bn(x) 58 | 59 | return self.relu(x) 60 | 61 | def _init_weight(self): 62 | for m in self.modules(): 63 | if isinstance(m, nn.Conv2d): 64 | torch.nn.init.kaiming_normal_(m.weight) 65 | elif isinstance(m, nn.BatchNorm2d): 66 | m.weight.data.fill_(1) 67 | m.bias.data.zero_() 68 | -------------------------------------------------------------------------------- /network/base/resnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/1/31 16:39 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | model_urls = { 12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 17 | } 18 | 19 | 20 | class Bottleneck(nn.Module): 21 | expansion = 4 22 | 23 | def __init__(self, inplanes, planes, stride=1, rate=1, downsample=None): 24 | super(Bottleneck, self).__init__() 25 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 26 | self.bn1 = nn.BatchNorm2d(planes) 27 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 28 | dilation=rate, padding=rate, bias=False) 29 | self.bn2 = nn.BatchNorm2d(planes) 30 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 31 | self.bn3 = nn.BatchNorm2d(planes * 4) 32 | self.relu = nn.ReLU(inplace=True) 33 | self.downsample = downsample 34 | self.stride = stride 35 | self.rate = rate 36 | 37 | def forward(self, x): 38 | residual = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | out = self.relu(out) 47 | 48 | out = self.conv3(out) 49 | out = self.bn3(out) 50 | 51 | if self.downsample is not None: 52 | residual = self.downsample(x) 53 | 54 | out += residual 55 | out = self.relu(out) 56 | 57 | return out 58 | 59 | 60 | class ResNet(nn.Module): 61 | 62 | def __init__(self): 63 | self.inplanes = 64 64 | super(ResNet, self).__init__() 65 | 66 | # Modules 67 | self.conv1 = None 68 | self.bn1 = None 69 | self.relu = None 70 | self.maxpool = None 71 | 72 | self.layer1 = None 73 | self.layer2 = None 74 | self.layer3 = None 75 | self.layer4 = None 76 | 77 | self.init_weight() 78 | 79 | def _make_layer(self, block, planes, blocks, stride=1, rate=1): 80 | downsample = None 81 | if stride != 1 or self.inplanes != planes * block.expansion: 82 | downsample = nn.Sequential( 83 | nn.Conv2d(self.inplanes, planes * block.expansion, 84 | kernel_size=1, stride=stride, bias=False), 85 | nn.BatchNorm2d(planes * block.expansion), 86 | ) 87 | 88 | layers = [] 89 | layers.append(block(self.inplanes, planes, stride, rate, downsample)) 90 | self.inplanes = planes * block.expansion 91 | for i in range(1, blocks): 92 | layers.append(block(self.inplanes, planes)) 93 | 94 | return nn.Sequential(*layers) 95 | 96 | def _make_MG_unit(self, block, planes, blocks=[1, 2, 4], stride=1, rate=1): 97 | downsample = None 98 | if stride != 1 or self.inplanes != planes * block.expansion: 99 | downsample = nn.Sequential( 100 | nn.Conv2d(self.inplanes, planes * block.expansion, 101 | kernel_size=1, stride=stride, bias=False), 102 | nn.BatchNorm2d(planes * block.expansion), 103 | ) 104 | 105 | layers = [] 106 | layers.append(block(self.inplanes, planes, stride, rate=blocks[0] * rate, downsample=downsample)) 107 | self.inplanes = planes * block.expansion 108 | for i in range(1, len(blocks)): 109 | layers.append(block(self.inplanes, planes, stride=1, rate=blocks[i] * rate)) 110 | 111 | return nn.Sequential(*layers) 112 | 113 | def forward(self, input): 114 | return input 115 | 116 | def init_weight(self): 117 | for m in self.modules(): 118 | if isinstance(m, nn.Conv2d): 119 | torch.nn.init.kaiming_normal_(m.weight) 120 | elif isinstance(m, nn.BatchNorm2d): 121 | m.weight.data.fill_(1) 122 | m.bias.data.zero_() 123 | -------------------------------------------------------------------------------- /network/base/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/2/17 22:55 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ -------------------------------------------------------------------------------- /network/base/sync_batchnorm/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 18 | 19 | from .comm import SyncMaster 20 | 21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] 22 | 23 | 24 | def _sum_ft(tensor): 25 | """sum over the first and last dimention""" 26 | return tensor.sum(dim=0).sum(dim=-1) 27 | 28 | 29 | def _unsqueeze_ft(tensor): 30 | """add new dementions at the front and the tail""" 31 | return tensor.unsqueeze(0).unsqueeze(-1) 32 | 33 | 34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 36 | 37 | 38 | class _SynchronizedBatchNorm(_BatchNorm): 39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 41 | 42 | self._sync_master = SyncMaster(self._data_parallel_master) 43 | 44 | self._is_parallel = False 45 | self._parallel_id = None 46 | self._slave_pipe = None 47 | 48 | def forward(self, input): 49 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 50 | if not (self._is_parallel and self.training): 51 | return F.batch_norm( 52 | input, self.running_mean, self.running_var, self.weight, self.bias, 53 | self.training, self.momentum, self.eps) 54 | 55 | # Resize the input to (B, C, -1). 56 | input_shape = input.size() 57 | input = input.view(input.size(0), self.num_features, -1) 58 | 59 | # Compute the sum and square-sum. 60 | sum_size = input.size(0) * input.size(2) 61 | input_sum = _sum_ft(input) 62 | input_ssum = _sum_ft(input ** 2) 63 | 64 | # Reduce-and-broadcast the statistics. 65 | if self._parallel_id == 0: 66 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 67 | else: 68 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 69 | 70 | # Compute the output. 71 | if self.affine: 72 | # MJY:: Fuse the multiplication for speed. 73 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 74 | else: 75 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 76 | 77 | # Reshape it. 78 | return output.view(input_shape) 79 | 80 | def __data_parallel_replicate__(self, ctx, copy_id): 81 | self._is_parallel = True 82 | self._parallel_id = copy_id 83 | 84 | # parallel_id == 0 means master device. 85 | if self._parallel_id == 0: 86 | ctx.sync_master = self._sync_master 87 | else: 88 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 89 | 90 | def _data_parallel_master(self, intermediates): 91 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 92 | 93 | # Always using same "device order" makes the ReduceAdd operation faster. 94 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 95 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 96 | 97 | to_reduce = [i[1][:2] for i in intermediates] 98 | to_reduce = [j for i in to_reduce for j in i] # flatten 99 | target_gpus = [i[1].sum.get_device() for i in intermediates] 100 | 101 | sum_size = sum([i[1].sum_size for i in intermediates]) 102 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 103 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 104 | 105 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 106 | 107 | outputs = [] 108 | for i, rec in enumerate(intermediates): 109 | outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2]))) 110 | 111 | return outputs 112 | 113 | def _compute_mean_std(self, sum_, ssum, size): 114 | """Compute the mean and standard-deviation with sum and square-sum. This method 115 | also maintains the moving average on the master device.""" 116 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 117 | mean = sum_ / size 118 | sumvar = ssum - sum_ * mean 119 | unbias_var = sumvar / (size - 1) 120 | bias_var = sumvar / size 121 | 122 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 123 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 124 | 125 | return mean, bias_var.clamp(self.eps) ** -0.5 126 | 127 | 128 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 129 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 130 | mini-batch. 131 | .. math:: 132 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 133 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 134 | standard-deviation are reduced across all devices during training. 135 | For example, when one uses `nn.DataParallel` to wrap the network during 136 | training, PyTorch's implementation normalize the tensor on each device using 137 | the statistics only on that device, which accelerated the computation and 138 | is also easy to implement, but the statistics might be inaccurate. 139 | Instead, in this synchronized version, the statistics will be computed 140 | over all training samples distributed on multiple devices. 141 | 142 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 143 | as the built-in PyTorch implementation. 144 | The mean and standard-deviation are calculated per-dimension over 145 | the mini-batches and gamma and beta are learnable parameter vectors 146 | of size C (where C is the input size). 147 | During training, this layer keeps a running estimate of its computed mean 148 | and variance. The running sum is kept with a default momentum of 0.1. 149 | During evaluation, this running mean/variance is used for normalization. 150 | Because the BatchNorm is done over the `C` dimension, computing statistics 151 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 152 | Args: 153 | num_features: num_features from an expected input of size 154 | `batch_size x num_features [x width]` 155 | eps: a value added to the denominator for numerical stability. 156 | Default: 1e-5 157 | momentum: the value used for the running_mean and running_var 158 | computation. Default: 0.1 159 | affine: a boolean value that when set to ``True``, gives the layer learnable 160 | affine parameters. Default: ``True`` 161 | Shape: 162 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 163 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 164 | Examples: 165 | >>> # With Learnable Parameters 166 | >>> m = SynchronizedBatchNorm1d(100) 167 | >>> # Without Learnable Parameters 168 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 169 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 170 | >>> output = m(input) 171 | """ 172 | 173 | def _check_input_dim(self, input): 174 | if input.dim() != 2 and input.dim() != 3: 175 | raise ValueError('expected 2D or 3D input (got {}D input)' 176 | .format(input.dim())) 177 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 178 | 179 | 180 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 181 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 182 | of 3d inputs 183 | .. math:: 184 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 185 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 186 | standard-deviation are reduced across all devices during training. 187 | For example, when one uses `nn.DataParallel` to wrap the network during 188 | training, PyTorch's implementation normalize the tensor on each device using 189 | the statistics only on that device, which accelerated the computation and 190 | is also easy to implement, but the statistics might be inaccurate. 191 | Instead, in this synchronized version, the statistics will be computed 192 | over all training samples distributed on multiple devices. 193 | 194 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 195 | as the built-in PyTorch implementation. 196 | The mean and standard-deviation are calculated per-dimension over 197 | the mini-batches and gamma and beta are learnable parameter vectors 198 | of size C (where C is the input size). 199 | During training, this layer keeps a running estimate of its computed mean 200 | and variance. The running sum is kept with a default momentum of 0.1. 201 | During evaluation, this running mean/variance is used for normalization. 202 | Because the BatchNorm is done over the `C` dimension, computing statistics 203 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 204 | Args: 205 | num_features: num_features from an expected input of 206 | size batch_size x num_features x height x width 207 | eps: a value added to the denominator for numerical stability. 208 | Default: 1e-5 209 | momentum: the value used for the running_mean and running_var 210 | computation. Default: 0.1 211 | affine: a boolean value that when set to ``True``, gives the layer learnable 212 | affine parameters. Default: ``True`` 213 | Shape: 214 | - Input: :math:`(N, C, H, W)` 215 | - Output: :math:`(N, C, H, W)` (same shape as input) 216 | Examples: 217 | >>> # With Learnable Parameters 218 | >>> m = SynchronizedBatchNorm2d(100) 219 | >>> # Without Learnable Parameters 220 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 221 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 222 | >>> output = m(input) 223 | """ 224 | 225 | def _check_input_dim(self, input): 226 | if input.dim() != 4: 227 | raise ValueError('expected 4D input (got {}D input)' 228 | .format(input.dim())) 229 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 230 | 231 | 232 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 233 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 234 | of 4d inputs 235 | .. math:: 236 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 237 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 238 | standard-deviation are reduced across all devices during training. 239 | For example, when one uses `nn.DataParallel` to wrap the network during 240 | training, PyTorch's implementation normalize the tensor on each device using 241 | the statistics only on that device, which accelerated the computation and 242 | is also easy to implement, but the statistics might be inaccurate. 243 | Instead, in this synchronized version, the statistics will be computed 244 | over all training samples distributed on multiple devices. 245 | 246 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 247 | as the built-in PyTorch implementation. 248 | The mean and standard-deviation are calculated per-dimension over 249 | the mini-batches and gamma and beta are learnable parameter vectors 250 | of size C (where C is the input size). 251 | During training, this layer keeps a running estimate of its computed mean 252 | and variance. The running sum is kept with a default momentum of 0.1. 253 | During evaluation, this running mean/variance is used for normalization. 254 | Because the BatchNorm is done over the `C` dimension, computing statistics 255 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 256 | or Spatio-temporal BatchNorm 257 | Args: 258 | num_features: num_features from an expected input of 259 | size batch_size x num_features x depth x height x width 260 | eps: a value added to the denominator for numerical stability. 261 | Default: 1e-5 262 | momentum: the value used for the running_mean and running_var 263 | computation. Default: 0.1 264 | affine: a boolean value that when set to ``True``, gives the layer learnable 265 | affine parameters. Default: ``True`` 266 | Shape: 267 | - Input: :math:`(N, C, D, H, W)` 268 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 269 | Examples: 270 | >>> # With Learnable Parameters 271 | >>> m = SynchronizedBatchNorm3d(100) 272 | >>> # Without Learnable Parameters 273 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 274 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 275 | >>> output = m(input) 276 | """ 277 | 278 | def _check_input_dim(self, input): 279 | if input.dim() != 5: 280 | raise ValueError('expected 5D input (got {}D input)' 281 | .format(input.dim())) 282 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) -------------------------------------------------------------------------------- /network/base/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 59 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 60 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 61 | and passed to a registered callback. 62 | - After receiving the messages, the master device should gather the information and determine to message passed 63 | back to each slave devices. 64 | """ 65 | 66 | def __init__(self, master_callback): 67 | """ 68 | Args: 69 | master_callback: a callback to be invoked after having collected messages from slave devices. 70 | """ 71 | self._master_callback = master_callback 72 | self._queue = queue.Queue() 73 | self._registry = collections.OrderedDict() 74 | self._activated = False 75 | 76 | def __getstate__(self): 77 | return {'master_callback': self._master_callback} 78 | 79 | def __setstate__(self, state): 80 | self.__init__(state['master_callback']) 81 | 82 | def register_slave(self, identifier): 83 | """ 84 | Register an slave device. 85 | Args: 86 | identifier: an identifier, usually is the device id. 87 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 88 | """ 89 | if self._activated: 90 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 91 | self._activated = False 92 | self._registry.clear() 93 | future = FutureResult() 94 | self._registry[identifier] = _MasterRegistry(future) 95 | return SlavePipe(identifier, self._queue, future) 96 | 97 | def run_master(self, master_msg): 98 | """ 99 | Main entry for the master device in each forward pass. 100 | The messages were first collected from each devices (including the master device), and then 101 | an callback will be invoked to compute the message to be sent back to each devices 102 | (including the master device). 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | Returns: the message to be sent back to the master device. 107 | """ 108 | self._activated = True 109 | 110 | intermediates = [(0, master_msg)] 111 | for i in range(self.nr_slaves): 112 | intermediates.append(self._queue.get()) 113 | 114 | results = self._master_callback(intermediates) 115 | assert results[0][0] == 0, 'The first result should belongs to the master.' 116 | 117 | for i, res in results: 118 | if i == 0: 119 | continue 120 | self._registry[i].result.put(res) 121 | 122 | for i in range(self.nr_slaves): 123 | assert self._queue.get() is True 124 | 125 | return results[0][1] 126 | 127 | @property 128 | def nr_slaves(self): 129 | return len(self._registry) 130 | -------------------------------------------------------------------------------- /network/base/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 31 | Note that, as all modules are isomorphism, we assign each sub-module with a context 32 | (shared among multiple copies of this module on different devices). 33 | Through this context, different copies can share some information. 34 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 35 | of any slave copies. 36 | """ 37 | master_copy = modules[0] 38 | nr_modules = len(list(master_copy.modules())) 39 | ctxs = [CallbackContext() for _ in range(nr_modules)] 40 | 41 | for i, module in enumerate(modules): 42 | for j, m in enumerate(module.modules()): 43 | if hasattr(m, '__data_parallel_replicate__'): 44 | m.__data_parallel_replicate__(ctxs[j], i) 45 | 46 | 47 | class DataParallelWithCallback(DataParallel): 48 | """ 49 | Data Parallel with a replication callback. 50 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 51 | original `replicate` function. 52 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 53 | Examples: 54 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 55 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 56 | # sync_bn.__data_parallel_replicate__ will be invoked. 57 | """ 58 | 59 | def replicate(self, module, device_ids): 60 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 61 | execute_replication_callbacks(modules) 62 | return modules 63 | 64 | 65 | def patch_replication_callback(data_parallel): 66 | """ 67 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 68 | Useful when you have customized `DataParallel` implementation. 69 | Examples: 70 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 71 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 72 | > patch_replication_callback(sync_bn) 73 | # this is equivalent to 74 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 75 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 76 | """ 77 | 78 | assert isinstance(data_parallel, DataParallel) 79 | 80 | old_replicate = data_parallel.replicate 81 | 82 | @functools.wraps(old_replicate) 83 | def new_replicate(module, device_ids): 84 | modules = old_replicate(module, device_ids) 85 | execute_replication_callbacks(modules) 86 | return modules 87 | 88 | data_parallel.replicate = new_replicate -------------------------------------------------------------------------------- /network/base/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /network/base/xception.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/2/17 22:46 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ 7 | from network.base.oprations import SeparableConv2d 8 | from network.base.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 9 | 10 | """ 11 | https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/modeling/backbone/xception.py 12 | """ 13 | 14 | import math 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | import torch.utils.model_zoo as model_zoo 19 | 20 | 21 | class Block(nn.Module): 22 | def __init__(self, inplanes, planes, reps, stride=1, dilation=1, BatchNorm=None, 23 | start_with_relu=True, grow_first=True, is_last=False): 24 | super(Block, self).__init__() 25 | 26 | if planes != inplanes or stride != 1: 27 | self.skip = nn.Conv2d(inplanes, planes, 1, stride=stride, bias=False) 28 | self.skipbn = BatchNorm(planes) 29 | else: 30 | self.skip = None 31 | 32 | self.relu = nn.ReLU(inplace=True) 33 | rep = [] 34 | 35 | filters = inplanes 36 | if grow_first: 37 | rep.append(self.relu) 38 | rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm)) 39 | rep.append(BatchNorm(planes)) 40 | filters = planes 41 | 42 | for i in range(reps - 1): 43 | rep.append(self.relu) 44 | rep.append(SeparableConv2d(filters, filters, 3, 1, dilation, BatchNorm=BatchNorm)) 45 | rep.append(BatchNorm(filters)) 46 | 47 | if not grow_first: 48 | rep.append(self.relu) 49 | rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm)) 50 | rep.append(BatchNorm(planes)) 51 | 52 | if stride != 1: 53 | rep.append(self.relu) 54 | rep.append(SeparableConv2d(planes, planes, 3, 2, BatchNorm=BatchNorm)) 55 | rep.append(BatchNorm(planes)) 56 | 57 | if stride == 1 and is_last: 58 | rep.append(self.relu) 59 | rep.append(SeparableConv2d(planes, planes, 3, 1, BatchNorm=BatchNorm)) 60 | rep.append(BatchNorm(planes)) 61 | 62 | if not start_with_relu: 63 | rep = rep[1:] 64 | 65 | self.rep = nn.Sequential(*rep) 66 | 67 | def forward(self, inp): 68 | x = self.rep(inp) 69 | 70 | if self.skip is not None: 71 | skip = self.skip(inp) 72 | skip = self.skipbn(skip) 73 | else: 74 | skip = inp 75 | 76 | x = x + skip 77 | 78 | return x 79 | 80 | 81 | class AlignedXception(nn.Module): 82 | """ 83 | Modified Alighed Xception 84 | """ 85 | 86 | def __init__(self, output_stride, BatchNorm, 87 | pretrained=True): 88 | super(AlignedXception, self).__init__() 89 | 90 | if output_stride == 16: 91 | entry_block3_stride = 2 92 | middle_block_dilation = 1 93 | exit_block_dilations = (1, 2) 94 | elif output_stride == 8: 95 | entry_block3_stride = 1 96 | middle_block_dilation = 2 97 | exit_block_dilations = (2, 4) 98 | else: 99 | raise NotImplementedError 100 | 101 | # Entry flow 102 | self.conv1 = nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False) 103 | self.bn1 = BatchNorm(32) 104 | self.relu = nn.ReLU(inplace=True) 105 | 106 | self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False) 107 | self.bn2 = BatchNorm(64) 108 | 109 | self.block1 = Block(64, 128, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False) 110 | self.block2 = Block(128, 256, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False, 111 | grow_first=True) 112 | self.block3 = Block(256, 728, reps=2, stride=entry_block3_stride, BatchNorm=BatchNorm, 113 | start_with_relu=True, grow_first=True, is_last=True) 114 | 115 | # Middle flow 116 | self.block4 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 117 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 118 | self.block5 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 119 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 120 | self.block6 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 121 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 122 | self.block7 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 123 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 124 | self.block8 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 125 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 126 | self.block9 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 127 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 128 | self.block10 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 129 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 130 | self.block11 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 131 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 132 | self.block12 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 133 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 134 | self.block13 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 135 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 136 | self.block14 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 137 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 138 | self.block15 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 139 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 140 | self.block16 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 141 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 142 | self.block17 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 143 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 144 | self.block18 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 145 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 146 | self.block19 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 147 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 148 | 149 | # Exit flow 150 | self.block20 = Block(728, 1024, reps=2, stride=1, dilation=exit_block_dilations[0], 151 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=False, is_last=True) 152 | 153 | self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 154 | self.bn3 = BatchNorm(1536) 155 | 156 | self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 157 | self.bn4 = BatchNorm(1536) 158 | 159 | self.conv5 = SeparableConv2d(1536, 2048, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 160 | self.bn5 = BatchNorm(2048) 161 | 162 | # Init weights 163 | self._init_weight() 164 | 165 | # Load pretrained model 166 | if pretrained: 167 | self._load_pretrained_model() 168 | 169 | def forward(self, x): 170 | # Entry flow 171 | x = self.conv1(x) 172 | x = self.bn1(x) 173 | x = self.relu(x) 174 | 175 | x = self.conv2(x) 176 | x = self.bn2(x) 177 | x = self.relu(x) 178 | 179 | x = self.block1(x) 180 | # add relu here 181 | x = self.relu(x) 182 | low_level_feat = x 183 | x = self.block2(x) 184 | x = self.block3(x) 185 | 186 | # Middle flow 187 | x = self.block4(x) 188 | x = self.block5(x) 189 | x = self.block6(x) 190 | x = self.block7(x) 191 | x = self.block8(x) 192 | x = self.block9(x) 193 | x = self.block10(x) 194 | x = self.block11(x) 195 | x = self.block12(x) 196 | x = self.block13(x) 197 | x = self.block14(x) 198 | x = self.block15(x) 199 | x = self.block16(x) 200 | x = self.block17(x) 201 | x = self.block18(x) 202 | x = self.block19(x) 203 | 204 | # Exit flow 205 | x = self.block20(x) 206 | x = self.relu(x) 207 | x = self.conv3(x) 208 | x = self.bn3(x) 209 | x = self.relu(x) 210 | 211 | x = self.conv4(x) 212 | x = self.bn4(x) 213 | x = self.relu(x) 214 | 215 | x = self.conv5(x) 216 | x = self.bn5(x) 217 | x = self.relu(x) 218 | 219 | return x, low_level_feat 220 | 221 | def _init_weight(self): 222 | for m in self.modules(): 223 | if isinstance(m, nn.Conv2d): 224 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 225 | m.weight.data.normal_(0, math.sqrt(2. / n)) 226 | elif isinstance(m, SynchronizedBatchNorm2d): 227 | m.weight.data.fill_(1) 228 | m.bias.data.zero_() 229 | elif isinstance(m, nn.BatchNorm2d): 230 | m.weight.data.fill_(1) 231 | m.bias.data.zero_() 232 | 233 | def _load_pretrained_model(self): 234 | pretrain_dict = model_zoo.load_url('http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth') 235 | model_dict = {} 236 | state_dict = self.state_dict() 237 | 238 | for k, v in pretrain_dict.items(): 239 | if k in model_dict: 240 | if 'pointwise' in k: 241 | v = v.unsqueeze(-1).unsqueeze(-1) 242 | if k.startswith('block11'): 243 | model_dict[k] = v 244 | model_dict[k.replace('block11', 'block12')] = v 245 | model_dict[k.replace('block11', 'block13')] = v 246 | model_dict[k.replace('block11', 'block14')] = v 247 | model_dict[k.replace('block11', 'block15')] = v 248 | model_dict[k.replace('block11', 'block16')] = v 249 | model_dict[k.replace('block11', 'block17')] = v 250 | model_dict[k.replace('block11', 'block18')] = v 251 | model_dict[k.replace('block11', 'block19')] = v 252 | elif k.startswith('block12'): 253 | model_dict[k.replace('block12', 'block20')] = v 254 | elif k.startswith('bn3'): 255 | model_dict[k] = v 256 | model_dict[k.replace('bn3', 'bn4')] = v 257 | elif k.startswith('conv4'): 258 | model_dict[k.replace('conv4', 'conv5')] = v 259 | elif k.startswith('bn4'): 260 | model_dict[k.replace('bn4', 'bn5')] = v 261 | else: 262 | model_dict[k] = v 263 | state_dict.update(model_dict) 264 | self.load_state_dict(state_dict) 265 | 266 | 267 | if __name__ == "__main__": 268 | import torch 269 | 270 | model = AlignedXception(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=16) 271 | input = torch.rand(1, 3, 512, 512) 272 | output, low_level_feat = model(input) 273 | print(output.size()) 274 | print(low_level_feat.size()) 275 | -------------------------------------------------------------------------------- /network/deeplab.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/1/30 19:43 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ -------------------------------------------------------------------------------- /network/deeplab_deform_conv.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/2/18 21:05 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ -------------------------------------------------------------------------------- /network/deeplabv2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/1/30 16:59 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ 7 | 8 | from torch.utils import model_zoo 9 | 10 | from network.base.resnet import * 11 | from network.base.oprations import * 12 | 13 | 14 | class DeeplabV2(ResNet): 15 | def __init__(self, n_class, block, layers, pyramids): 16 | print("Constructing DeepLabv2 model...") 17 | print("Number of classes: {}".format(n_class)) 18 | super(DeeplabV2, self).__init__() 19 | 20 | self.inplanes = 64 21 | 22 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) 23 | self.bn1 = nn.BatchNorm2d(self.inplanes, affine=True) 24 | self.relu = nn.ReLU(inplace=True) 25 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True) # change 26 | 27 | self.layer1 = self._make_layer(block, 64, layers[0]) 28 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 29 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, rate=2) 30 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, rate=4) 31 | 32 | self.aspp1 = ASPP_module(2048, n_class, pyramids[0]) 33 | self.aspp2 = ASPP_module(2048, n_class, pyramids[1]) 34 | self.aspp3 = ASPP_module(2048, n_class, pyramids[2]) 35 | self.aspp4 = ASPP_module(2048, n_class, pyramids[3]) 36 | 37 | self.init_weight() 38 | 39 | def forward(self, input): 40 | x = self.conv1(input) 41 | x = self.bn1(x) 42 | x = self.relu(x) 43 | x = self.maxpool(x) 44 | 45 | x = self.layer1(x) 46 | x = self.layer2(x) 47 | x = self.layer3(x) 48 | x = self.layer4(x) 49 | 50 | x1 = self.aspp1(x) 51 | x2 = self.aspp2(x) 52 | x3 = self.aspp3(x) 53 | x4 = self.aspp4(x) 54 | 55 | x = x1 + x2 + x3 + x4 56 | 57 | x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True) 58 | return x 59 | 60 | def get_1x_lr_params(self): 61 | b = [self.conv1, self.bn1, self.layer1, self.layer2, self.layer3, self.layer4] 62 | for i in range(len(b)): 63 | for k in b[i].parameters(): 64 | if k.requires_grad: 65 | yield k 66 | 67 | def get_10x_lr_params(self): 68 | b = [self.aspp1, self.aspp2, self.aspp3, self.aspp4] 69 | for j in range(len(b)): 70 | for k in b[j].parameters(): 71 | if k.requires_grad: 72 | yield k 73 | 74 | def freeze_bn(self): 75 | for m in self.modules(): 76 | if isinstance(m, nn.BatchNorm2d): 77 | m.eval() 78 | 79 | def freeze_backbone_bn(self): 80 | self.bn1.eval() 81 | 82 | for m in self.layer1: 83 | if isinstance(m, nn.BatchNorm2d): 84 | m.eval() 85 | 86 | for m in self.layer2: 87 | if isinstance(m, nn.BatchNorm2d): 88 | m.eval() 89 | 90 | for m in self.layer3: 91 | if isinstance(m, nn.BatchNorm2d): 92 | m.eval() 93 | 94 | for m in self.layer4: 95 | if isinstance(m, nn.BatchNorm2d): 96 | m.eval() 97 | 98 | 99 | def resnet101(n_class, pretrained=True): 100 | 101 | model = DeeplabV2(n_class=n_class, block=Bottleneck, layers=[3, 4, 23, 3], pyramids=[6, 12, 18, 24]) 102 | 103 | if pretrained: 104 | pretrain_dict = model_zoo.load_url(model_urls['resnet101']) 105 | model_dict = {} 106 | state_dict = model.state_dict() 107 | for k, v in pretrain_dict.items(): 108 | if k in state_dict: 109 | model_dict[k] = v 110 | state_dict.update(model_dict) 111 | model.load_state_dict(state_dict) 112 | 113 | return model 114 | 115 | 116 | if __name__ == '__main__': 117 | model = resnet101(n_class=21, pretrained=True) 118 | 119 | img = torch.randn(4, 3, 513, 513) 120 | 121 | with torch.no_grad(): 122 | output = model.forward(img) 123 | 124 | print(output.size()) 125 | -------------------------------------------------------------------------------- /network/deeplabv3.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/1/31 18:10 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ 7 | 8 | import torch.nn.functional as F 9 | 10 | from torch.utils import model_zoo 11 | 12 | from network.base.oprations import ASPP_module 13 | from network.base.resnet import * 14 | 15 | 16 | class DeeplabV3(ResNet): 17 | 18 | def __init__(self, n_class, block, layers, pyramids, grids, output_stride=16): 19 | self.inplanes = 64 20 | super(DeeplabV3, self).__init__() 21 | if output_stride == 16: 22 | strides = [1, 2, 2, 1] 23 | rates = [1, 1, 1, 2] 24 | elif output_stride == 8: 25 | strides = [1, 2, 1, 1] 26 | rates = [1, 1, 2, 2] 27 | else: 28 | raise NotImplementedError 29 | 30 | # Backbone Modules 31 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 32 | bias=False) 33 | self.bn1 = nn.BatchNorm2d(64) 34 | self.relu = nn.ReLU(inplace=True) 35 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 36 | 37 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], rate=rates[0]) 38 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], rate=rates[1]) 39 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], rate=rates[2]) 40 | self.layer4 = self._make_MG_unit(block, 512, blocks=grids, stride=strides[3], rate=rates[3]) 41 | 42 | # Deeplab Modules 43 | self.aspp1 = ASPP_module(2048, 256, rate=pyramids[0]) 44 | self.aspp2 = ASPP_module(2048, 256, rate=pyramids[1]) 45 | self.aspp3 = ASPP_module(2048, 256, rate=pyramids[2]) 46 | self.aspp4 = ASPP_module(2048, 256, rate=pyramids[3]) 47 | 48 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 49 | nn.Conv2d(2048, 256, kernel_size=1, stride=1, bias=False), 50 | nn.BatchNorm2d(256), 51 | nn.ReLU()) 52 | 53 | # get result features from the concat 54 | self._conv1 = nn.Sequential(nn.Conv2d(1280, 256, kernel_size=1, stride=1, bias=False), 55 | nn.BatchNorm2d(256), 56 | nn.ReLU()) 57 | 58 | # generate the final logits 59 | self._conv2 = nn.Conv2d(256, n_class, kernel_size=1, bias=False) 60 | 61 | self.init_weight() 62 | 63 | def forward(self, input): 64 | x = self.conv1(input) 65 | x = self.bn1(x) 66 | x = self.relu(x) 67 | x = self.maxpool(x) 68 | 69 | x = self.layer1(x) 70 | x = self.layer2(x) 71 | x = self.layer3(x) 72 | x = self.layer4(x) 73 | 74 | x1 = self.aspp1(x) 75 | x2 = self.aspp2(x) 76 | x3 = self.aspp3(x) 77 | x4 = self.aspp4(x) 78 | 79 | # image-level features 80 | x5 = self.global_avg_pool(x) 81 | x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 82 | 83 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 84 | 85 | x = self._conv1(x) 86 | x = self._conv2(x) 87 | 88 | x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True) 89 | 90 | return x 91 | 92 | def get_1x_lr_params(self): 93 | b = [self.conv1, self.bn1, self.layer1, self.layer2, self.layer3, self.layer4] 94 | for i in range(len(b)): 95 | for k in b[i].parameters(): 96 | if k.requires_grad: 97 | yield k 98 | 99 | def get_10x_lr_params(self): 100 | b = [self.aspp1, self.aspp2, self.aspp3, self.aspp4, self.global_avg_pool, self._conv1, self._conv2] 101 | for j in range(len(b)): 102 | for k in b[j].parameters(): 103 | if k.requires_grad: 104 | yield k 105 | 106 | def freeze_bn(self): 107 | for m in self.modules(): 108 | if isinstance(m, nn.BatchNorm2d): 109 | m.eval() 110 | 111 | def freeze_backbone_bn(self): 112 | self.bn1.eval() 113 | 114 | for m in self.layer1: 115 | if isinstance(m, nn.BatchNorm2d): 116 | m.eval() 117 | 118 | for m in self.layer2: 119 | if isinstance(m, nn.BatchNorm2d): 120 | m.eval() 121 | 122 | for m in self.layer3: 123 | if isinstance(m, nn.BatchNorm2d): 124 | m.eval() 125 | 126 | for m in self.layer4: 127 | if isinstance(m, nn.BatchNorm2d): 128 | m.eval() 129 | 130 | 131 | def resnet101(n_class, output_stride=16, pretrained=True): 132 | if output_stride == 16: 133 | pyramids = [1, 6, 12, 18] 134 | grids = [1, 2, 4] 135 | elif output_stride == 8: 136 | pyramids = [1, 12, 24, 36] 137 | grids = [1, 2, 1] 138 | else: 139 | raise NotImplementedError 140 | 141 | model = DeeplabV3(n_class=n_class, block=Bottleneck, layers=[3, 4, 23, 3], 142 | pyramids=pyramids, grids=grids, output_stride=output_stride) 143 | 144 | if pretrained: 145 | pretrain_dict = model_zoo.load_url(model_urls['resnet101']) 146 | model_dict = {} 147 | state_dict = model.state_dict() 148 | for k, v in pretrain_dict.items(): 149 | if k in state_dict: 150 | model_dict[k] = v 151 | print(k) 152 | state_dict.update(model_dict) 153 | model.load_state_dict(state_dict) 154 | 155 | return model 156 | 157 | 158 | if __name__ == '__main__': 159 | model = resnet101(n_class=21, output_stride=16, pretrained=True) 160 | 161 | img = torch.randn(4, 3, 512, 512) 162 | 163 | with torch.no_grad(): 164 | output = model.forward(img) 165 | 166 | print(output.size()) 167 | -------------------------------------------------------------------------------- /network/deeplabv3plus_resnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/1/31 16:56 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ 7 | 8 | import math 9 | 10 | import torch.nn.functional as F 11 | from torch.utils import model_zoo 12 | 13 | from network.base.oprations import ASPP_module 14 | from network.base.resnet import * 15 | 16 | 17 | class DeeplabV3Plus(ResNet): 18 | 19 | def __init__(self, n_class, block, layers, pyramids, grids, output_stride=16): 20 | self.inplanes = 64 21 | super(DeeplabV3Plus, self).__init__() 22 | if output_stride == 16: 23 | strides = [1, 2, 2, 1] 24 | rates = [1, 1, 1, 2] 25 | elif output_stride == 8: 26 | strides = [1, 2, 1, 1] 27 | rates = [1, 1, 2, 2] 28 | else: 29 | raise NotImplementedError 30 | 31 | # Backbone Modules 32 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 33 | bias=False) 34 | self.bn1 = nn.BatchNorm2d(64) 35 | self.relu = nn.ReLU(inplace=True) 36 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 37 | 38 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], rate=rates[0]) 39 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], rate=rates[1]) 40 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], rate=rates[2]) 41 | self.layer4 = self._make_MG_unit(block, 512, blocks=grids, stride=strides[3], rate=rates[3]) 42 | 43 | # Deeplab Modules 44 | self.aspp1 = ASPP_module(2048, 256, rate=pyramids[0]) 45 | self.aspp2 = ASPP_module(2048, 256, rate=pyramids[1]) 46 | self.aspp3 = ASPP_module(2048, 256, rate=pyramids[2]) 47 | self.aspp4 = ASPP_module(2048, 256, rate=pyramids[3]) 48 | 49 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 50 | nn.Conv2d(2048, 256, 1, stride=1, bias=False), 51 | nn.BatchNorm2d(256), 52 | nn.ReLU()) 53 | 54 | # the resulting feature of the deeplabv3 encoder 55 | self._conv1 = nn.Sequential(nn.Conv2d(1280, 256, kernel_size=1, stride=1, bias=False), 56 | nn.BatchNorm2d(256), 57 | nn.ReLU()) 58 | 59 | # adopt [1x1, 48] for channel reduction. 60 | self._conv2 = nn.Sequential(nn.Conv2d(256, 48, 1, bias=False), 61 | nn.BatchNorm2d(48), 62 | nn.ReLU()) 63 | 64 | self._conv3 = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), 65 | nn.BatchNorm2d(256), 66 | nn.ReLU(), 67 | # 3x3 conv to refine the features 68 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 69 | nn.BatchNorm2d(256), 70 | nn.ReLU(), 71 | nn.Conv2d(256, n_class, kernel_size=1, stride=1)) 72 | 73 | self.init_weight() 74 | 75 | def forward(self, input): 76 | x = self.conv1(input) 77 | x = self.bn1(x) 78 | x = self.relu(x) 79 | x = self.maxpool(x) 80 | 81 | x = self.layer1(x) 82 | low_level_features = x 83 | x = self.layer2(x) 84 | x = self.layer3(x) 85 | x = self.layer4(x) 86 | 87 | x1 = self.aspp1(x) 88 | x2 = self.aspp2(x) 89 | x3 = self.aspp3(x) 90 | x4 = self.aspp4(x) 91 | x5 = self.global_avg_pool(x) 92 | x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 93 | 94 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 95 | 96 | x = self._conv1(x) 97 | x = F.upsample(x, size=(int(math.ceil(input.size()[-2] / 4)), 98 | int(math.ceil(input.size()[-1] / 4))), mode='bilinear', align_corners=True) 99 | """ 100 | above is the encoder 101 | """ 102 | 103 | low_level_features = self._conv2(low_level_features) 104 | 105 | x = torch.cat((x, low_level_features), dim=1) 106 | x = self._conv3(x) 107 | x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True) 108 | 109 | return x 110 | 111 | def get_1x_lr_params(self): 112 | b = [self.conv1, self.bn1, self.layer1, self.layer2, self.layer3, self.layer4] 113 | for i in range(len(b)): 114 | for k in b[i].parameters(): 115 | if k.requires_grad: 116 | yield k 117 | 118 | def get_10x_lr_params(self): 119 | b = [self.aspp1, self.aspp2, self.aspp3, self.aspp4, self.global_avg_pool, 120 | self._conv1, self._conv2, self._conv3] 121 | for i in range(len(b)): 122 | for k in b[i].parameters(): 123 | if k.requires_grad: 124 | yield k 125 | 126 | def freeze_bn(self): 127 | for m in self.modules(): 128 | if isinstance(m, nn.BatchNorm2d): 129 | m.eval() 130 | 131 | def freeze_backbone_bn(self): 132 | self.bn1.eval() 133 | 134 | for m in self.layer1: 135 | if isinstance(m, nn.BatchNorm2d): 136 | m.eval() 137 | 138 | for m in self.layer2: 139 | if isinstance(m, nn.BatchNorm2d): 140 | m.eval() 141 | 142 | for m in self.layer3: 143 | if isinstance(m, nn.BatchNorm2d): 144 | m.eval() 145 | 146 | for m in self.layer4: 147 | if isinstance(m, nn.BatchNorm2d): 148 | m.eval() 149 | 150 | 151 | def resnet101(n_class, output_stride=16, pretrained=True): 152 | if output_stride == 16: 153 | pyramids = [1, 6, 12, 18] 154 | grids = [1, 2, 4] 155 | elif output_stride == 8: 156 | pyramids = [1, 12, 24, 36] 157 | grids = [1, 2, 1] 158 | else: 159 | raise NotImplementedError 160 | 161 | model = DeeplabV3Plus(n_class=n_class, block=Bottleneck, layers=[3, 4, 23, 3], 162 | pyramids=pyramids, grids=grids, output_stride=output_stride) 163 | 164 | if pretrained: 165 | pretrain_dict = model_zoo.load_url(model_urls['resnet101']) 166 | model_dict = {} 167 | state_dict = model.state_dict() 168 | for k, v in pretrain_dict.items(): 169 | if k in state_dict: 170 | model_dict[k] = v 171 | # print(k) 172 | state_dict.update(model_dict) 173 | model.load_state_dict(state_dict) 174 | 175 | return model 176 | 177 | 178 | if __name__ == '__main__': 179 | model = resnet101(n_class=21, output_stride=16, pretrained=True) 180 | 181 | img = torch.randn(4, 3, 512, 512) 182 | 183 | with torch.no_grad(): 184 | output = model.forward(img) 185 | 186 | print(output.size()) 187 | -------------------------------------------------------------------------------- /network/deeplabv3plus_xception.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/1/31 16:57 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ -------------------------------------------------------------------------------- /network/get_models.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/1/30 23:19 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ 7 | 8 | from network import deeplabv2, deeplabv3, deeplabv3plus_resnet 9 | from network.base.msc import MSC 10 | 11 | 12 | def get_models(args): 13 | if args.model == 'deeplabv2': 14 | if args.msc: 15 | return MSC(scale= deeplabv2.resnet101(n_class=21, pretrained=True), pyramids=[0.5, 0.75]) 16 | return deeplabv2.resnet101(n_class=21, pretrained=True) 17 | elif args.model == 'deeplabv3': 18 | if args.msc: 19 | return MSC(scale= deeplabv3.resnet101(n_class=21, output_stride=16, pretrained=True), 20 | pyramids=[0.5, 0.75]) 21 | return deeplabv3.resnet101(n_class=21, output_stride=16, pretrained=True) 22 | elif args.model == 'deeplabv3plus': 23 | if args.msc: 24 | return MSC(scale= deeplabv3plus_resnet.resnet101(n_class=21, output_stride=16, pretrained=True), 25 | pyramids=[0.5, 0.75]) 26 | return deeplabv3plus_resnet.resnet101(n_class=21, output_stride=16, pretrained=True) 27 | else: 28 | print('Model {} not implemented.'.format(args.model)) 29 | raise NotImplementedError 30 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/2/19 11:08 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ 7 | 8 | for i in range(1, 10): 9 | print(i) -------------------------------------------------------------------------------- /validation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/1/30 23:34 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | import time 11 | 12 | import joblib 13 | 14 | import torch.nn.functional as F 15 | 16 | from libs.DenseCRF import DenseCRF 17 | from libs.metrics import AverageMeter, Result 18 | from libs import utils 19 | 20 | 21 | def validate(args, val_loader, model, epoch, logger): 22 | average_meter = AverageMeter() 23 | model.eval() # switch to train mode 24 | 25 | output_directory = utils.get_output_directory(args, check=True) 26 | skip = len(val_loader) // 9 # save images every skip iters 27 | 28 | if args.crf: 29 | ITER_MAX = 10 30 | POS_W = 3 31 | POS_XY_STD = 1 32 | BI_W = 4 33 | BI_XY_STD = 67 34 | BI_RGB_STD = 3 35 | 36 | postprocessor = DenseCRF( 37 | iter_max=ITER_MAX, 38 | pos_xy_std=POS_XY_STD, 39 | pos_w=POS_W, 40 | bi_xy_std=BI_XY_STD, 41 | bi_rgb_std=BI_RGB_STD, 42 | bi_w=BI_W, 43 | ) 44 | 45 | end = time.time() 46 | 47 | for i, samples in enumerate(val_loader): 48 | 49 | input = samples['image'] 50 | target = samples['label'] 51 | 52 | # itr_count += 1 53 | input, target = input.cuda(), target.cuda() 54 | # print('input size = ', input.size()) 55 | # print('target size = ', target.size()) 56 | torch.cuda.synchronize() 57 | data_time = time.time() - end 58 | 59 | # compute pred 60 | end = time.time() 61 | 62 | with torch.no_grad(): 63 | pred = model(input) # @wx 注意输出 64 | 65 | torch.cuda.synchronize() 66 | gpu_time = time.time() - end 67 | 68 | # measure accuracy and record loss 69 | result = Result() 70 | 71 | pred = F.softmax(pred, 1) 72 | 73 | if pred.size() != target.size(): 74 | pred = F.interpolate(pred, size=(target.size()[-2], target.size()[-1]), mode='bilinear', align_corners=True) 75 | 76 | pred = pred.data.cpu().numpy() 77 | target = target.data.cpu().numpy() 78 | 79 | # Post Processing 80 | if args.crf: 81 | images = input.data.cpu().numpy().astype(np.uint8).transpose(0, 2, 3, 1) 82 | pred = joblib.Parallel(n_jobs=-1)( 83 | [joblib.delayed(postprocessor)(*pair) for pair in zip(images, pred)] 84 | ) 85 | 86 | result.evaluate(pred, target, n_class=21) 87 | average_meter.update(result, gpu_time, data_time, input.size(0)) 88 | end = time.time() 89 | 90 | # save 8 images for visualization 91 | rgb = input.data.cpu().numpy()[0] 92 | target = target[0] 93 | pred = np.argmax(pred, axis=1) 94 | pred = pred[0] 95 | 96 | if i == 0: 97 | img_merge = utils.merge_into_row(rgb, target, pred) 98 | elif (i < 8 * skip) and (i % skip == 0): 99 | row = utils.merge_into_row(rgb, target, pred) 100 | img_merge = utils.add_row(img_merge, row) 101 | elif i == 8 * skip: 102 | filename = output_directory + '/comparison_' + str(epoch) + '.png' 103 | utils.save_image(img_merge, filename) 104 | 105 | if (i + 1) % args.print_freq == 0: 106 | print('Test: [{0}/{1}]\t' 107 | 't_GPU={gpu_time:.3f}({average.gpu_time:.3f})\n\t' 108 | 'mean_acc={result.mean_acc:.3f}({average.mean_acc:.3f}) ' 109 | 'mean_iou={result.mean_iou:.3f}({average.mean_iou:.3f})'.format( 110 | i + 1, len(val_loader), gpu_time=gpu_time, result=result, average=average_meter.average())) 111 | 112 | avg = average_meter.average() 113 | logger.add_scalar('Test/mean_acc', avg.mean_acc, epoch) 114 | logger.add_scalar('Test/mean_iou', avg.mean_iou, epoch) 115 | 116 | print('\n*\n' 117 | 'mean_acc={average.mean_acc:.3f}\n' 118 | 'mean_iou={average.mean_iou:.3f}\n' 119 | 't_GPU={time:.3f}\n'.format( 120 | average=avg, time=avg.gpu_time)) 121 | 122 | return avg, img_merge 123 | --------------------------------------------------------------------------------