├── cfgs ├── __init__.py ├── __init__.pyc ├── DenseASPP161.pyc ├── __pycache__ │ ├── __init__.cpython-35.pyc │ └── DenseASPP161.cpython-35.pyc ├── MobileNetDenseASPP.py ├── DenseASPP201.py ├── DenseASPP121.py ├── DenseASPP161.py └── DenseASPP169.py ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-35.pyc │ └── denseASPP.cpython-35.pyc ├── MobileNetDenseASPP.py └── DenseASPP.py ├── demo.py ├── utils └── transfer.py ├── README.md └── inference.py /cfgs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cfgs/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepMotionAIResearch/DenseASPP/HEAD/cfgs/__init__.pyc -------------------------------------------------------------------------------- /cfgs/DenseASPP161.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepMotionAIResearch/DenseASPP/HEAD/cfgs/DenseASPP161.pyc -------------------------------------------------------------------------------- /cfgs/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepMotionAIResearch/DenseASPP/HEAD/cfgs/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepMotionAIResearch/DenseASPP/HEAD/models/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/denseASPP.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepMotionAIResearch/DenseASPP/HEAD/models/__pycache__/denseASPP.cpython-35.pyc -------------------------------------------------------------------------------- /cfgs/__pycache__/DenseASPP161.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepMotionAIResearch/DenseASPP/HEAD/cfgs/__pycache__/DenseASPP161.cpython-35.pyc -------------------------------------------------------------------------------- /cfgs/MobileNetDenseASPP.py: -------------------------------------------------------------------------------- 1 | Model_CFG = { 2 | 'dropout0': 0.1, 3 | 'dropout1': 0.1, 4 | 'd_feature0': 128, 5 | 'd_feature1': 64, 6 | 7 | 'pretrained_path': "./pretrained/mobilenetv2.pth.tar" 8 | } 9 | -------------------------------------------------------------------------------- /cfgs/DenseASPP201.py: -------------------------------------------------------------------------------- 1 | Model_CFG = { 2 | 'bn_size': 4, 3 | 'drop_rate': 0, 4 | 'growth_rate': 32, 5 | 'num_init_features': 64, 6 | 'block_config': (6, 12, 48, 32), 7 | 8 | 'dropout0': 0.1, 9 | 'dropout1': 0.1, 10 | 'd_feature0': 480, 11 | 'd_feature1': 240, 12 | } -------------------------------------------------------------------------------- /cfgs/DenseASPP121.py: -------------------------------------------------------------------------------- 1 | Model_CFG = { 2 | 'bn_size': 4, 3 | 'drop_rate': 0, 4 | 'growth_rate': 32, 5 | 'num_init_features': 64, 6 | 'block_config': (6, 12, 24, 16), 7 | 8 | 'dropout0': 0.1, 9 | 'dropout1': 0.1, 10 | 'd_feature0': 128, 11 | 'd_feature1': 64, 12 | 13 | 'pretrained_path': "./pretrained/densenet121.pth" 14 | } -------------------------------------------------------------------------------- /cfgs/DenseASPP161.py: -------------------------------------------------------------------------------- 1 | Model_CFG = { 2 | 'bn_size': 4, 3 | 'drop_rate': 0, 4 | 'growth_rate': 48, 5 | 'num_init_features': 96, 6 | 'block_config': (6, 12, 36, 24), 7 | 8 | 'dropout0': 0.1, 9 | 'dropout1': 0.1, 10 | 'd_feature0': 512, 11 | 'd_feature1': 128, 12 | 13 | 'pretrained_path': "./pretrained/densenet161.pth" 14 | } 15 | -------------------------------------------------------------------------------- /cfgs/DenseASPP169.py: -------------------------------------------------------------------------------- 1 | Model_CFG = { 2 | 'bn_size': 4, 3 | 'drop_rate': 0, 4 | 'growth_rate': 32, 5 | 'num_init_features': 64, 6 | 'block_config': (6, 12, 32, 32), 7 | 8 | 'dropout0': 0.1, 9 | 'dropout1': 0.1, 10 | 'd_feature0': 336, 11 | 'd_feature1': 168, 12 | 13 | 'pretrained_path': "./pretrained/densenet169.pth" 14 | } 15 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Author: {maokeyang, kunyu, kuiyuanyang}@deepmotion.ai 4 | # Inference code of DenseASPP based segmentation models. 5 | 6 | 7 | import argparse 8 | from inference import Inference 9 | 10 | 11 | if __name__ == "__main__": 12 | parser = argparse.ArgumentParser(description='DenseASPP inference code.') 13 | parser.add_argument('--model_name', default='DenseASPP161', help='segmentation model.') 14 | parser.add_argument('--model_path', default='./weights/denseASPP161.pkl', help='weight path.') 15 | parser.add_argument('--img_dir', default='./Cityscapes/leftImg8bit/val', help='image dir.') 16 | args = parser.parse_args() 17 | 18 | infer = Inference(args.model_name, args.model_path) 19 | infer.folder_inference(args.img_dir, is_multiscale=False) 20 | -------------------------------------------------------------------------------- /utils/transfer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy 4 | 5 | IMG_PATH = "../val/" 6 | SAVE_PATH = "../results/" 7 | 8 | color_list = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33] 9 | color_map = [(128, 64, 128), (244, 35, 232), (70, 70, 70), (102, 102, 156), (190, 153, 153), (153, 153, 153), 10 | (250, 170, 30), (220, 220, 0), (107, 142, 35), (152, 251, 152), (70, 130, 180), (220, 20, 60), 11 | (255, 0, 0), (0, 0, 142), (0, 0, 70), (0, 60, 100), (0, 80, 100), (0, 0, 230), (119, 11, 32)] 12 | 13 | 14 | def change(): 15 | if not os.path.exists(SAVE_PATH): 16 | os.mkdir(SAVE_PATH) 17 | 18 | folders = sorted(os.listdir(IMG_PATH)) 19 | for f in folders: 20 | folder_path = IMG_PATH + f + "/" 21 | save_path = SAVE_PATH + f + "/" 22 | if not os.path.exists(save_path): 23 | os.mkdir(save_path) 24 | names = sorted(os.listdir(folder_path)) 25 | for n in names: 26 | print(n) 27 | img = cv2.cvtColor(cv2.imread(folder_path + n, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) 28 | R, G, B = cv2.split(img) 29 | mask = numpy.zeros_like(R, dtype=numpy.uint8) 30 | 31 | for i in range(color_list.__len__()): 32 | tmp_mask = numpy.zeros_like(R, dtype=numpy.uint8) 33 | color = color_map[i] 34 | tmp_mask[R[:] == color[0]] += 1 35 | tmp_mask[G[:] == color[1]] += 1 36 | tmp_mask[B[:] == color[2]] += 1 37 | 38 | mask[tmp_mask[:] == 3] = color_list[i] 39 | cv2.imwrite(save_path + n, mask) 40 | cv2.waitKey(1) 41 | 42 | 43 | if __name__ == "__main__": 44 | change() 45 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DenseASPP for Semantic Segmentation in Street Scenes [pdf](http://openaccess.thecvf.com/content_cvpr_2018/papers/Yang_DenseASPP_for_Semantic_CVPR_2018_paper.pdf) 2 | 3 | ## Introduction 4 | 5 | Semantic image segmentation is a basic street scene understanding task in autonomous driving, where each pixel in a high resolution image is categorized into a set of semantic labels. Unlike other scenarios, objects in autonomous driving scene exhibit very large scale changes, which poses great challenges for high-level feature representation in a sense that multi-scale information must be correctly encoded. 6 | 7 | To remedy this problem, atrous convolution[2, 3] was introduced to generate features with larger receptive fields without sacrificing spatial resolution. Built upon atrous convolution, Atrous Spatial Pyramid Pooling (ASPP)[3] was proposed to concatenate multiple atrous-convolved features using different dilation rates into a final feature representation. Although ASPP is able to generate multi-scale features, we argue the feature resolution in the scale-axis is not dense enough for the autonomous driving scenario. To this end, we propose Densely connected Atrous Spatial Pyramid Pooling (DenseASPP), which connects a set of atrous convolutional layers in a dense way, such that it generates multi-scale features that not only cover a larger scale range, but also cover that scale range densely, without significantly increasing the model size. We evaluate DenseASPP on the street scene benchmark Cityscapes[4] and achieve state-of-the-art performance. 8 | 9 | ## Usage 10 | 11 | ### 1. **Clone the repository:**
12 | 13 | ``` 14 | git clone https://github.com/DeepMotionAIResearch/DenseASPP.git 15 | ``` 16 | 17 | ### 2. **Download pretrained model:**
18 | Put the model at the folder `weights`. We provide some checkpoints to run the code: 19 | 20 | **DenseNet161 based model**: [GoogleDrive](https://drive.google.com/open?id=1kMKyboVGWlBxgYRYYnOXiA1mj_ufAXNJ) 21 | 22 | **Mobilenet v2 based model**: Coming soon. 23 | 24 | Performance of these checkpoints: 25 | 26 | Checkpoint name | Multi-scale inference | Cityscapes mIOU (val) | Cityscapes mIOU (test) | File Size 27 | ------------------------------------------------------------------------- | :-------------------------: | :----------------------------: | :----------------------------: |:-------: | 28 | [DenseASPP161](https://drive.google.com/file/d/1sCr-OkMUayaHAijdQrzndKk2WW78MVZG/view?usp=sharing) | False
True | 79.9%
80.6 % | -
79.5% | 142.7 MB 29 | [MobileNetDenseASPP](*) | False
True | 74.5%
75.0 % | -
- | 10.2 MB 30 | 31 | Please note that the performance of these checkpoints can be further improved by fine-tuning. Besides, these models were trained with **Pytorch 0.3.1** 32 | 33 | ### 3. **Inference** 34 | 35 | First cd to your code root, then run: 36 | 37 | ``` 38 | python demo.py --model_name DenseASPP161 --model_path --img_dir 39 | ``` 40 | 41 | ### 4. **Evaluation the results** 42 | Please cd to `./utils`, then run: 43 | 44 | ``` 45 | python transfer.py 46 | ``` 47 | 48 | And eval the results with the official evaluation code of Cityscapes, which can be found at [there](https://github.com/mcordts/cityscapesScripts) 49 | 50 | ## References 51 | 52 | 1. **DenseASPP for Semantic Segmentation in Street Scenes**
53 | Maoke Yang, Kun Yu, Chi Zhang, Zhiwei Li, Kuiyuan Yang.
54 | [link](http://openaccess.thecvf.com/content_cvpr_2018/papers/Yang_DenseASPP_for_Semantic_CVPR_2018_paper.pdf). In CVPR, 2018. 55 | 56 | 2. **Semantic Image Segmentation with Deep Convolutional Nets and Fully Connected CRFs**
57 | Liang-Chieh Chen+, George Papandreou+, Iasonas Kokkinos, Kevin Murphy, Alan L. Yuille (+ equal 58 | contribution).
59 | [link](https://arxiv.org/abs/1412.7062). In ICLR, 2015. 60 | 61 | 3. **DeepLab: Semantic Image Segmentation with Deep Convolutional Nets,** 62 | **Atrous Convolution, and Fully Connected CRFs**
63 | Liang-Chieh Chen+, George Papandreou+, Iasonas Kokkinos, Kevin Murphy, and Alan L Yuille (+ equal 64 | contribution).
65 | [link](http://arxiv.org/abs/1606.00915). TPAMI 2017. 66 | 67 | 4. **The Cityscapes Dataset for Semantic Urban Scene Understanding**
68 | Cordts, Marius, Mohamed Omran, Sebastian Ramos, Timo Rehfeld, Markus Enzweiler, Rodrigo Benenson, Uwe Franke, Stefan Roth, Bernt Schiele.
69 | [link](https://www.cityscapes-dataset.com/). In CVPR, 2016. 70 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import numpy 5 | import torch.nn.functional as F 6 | 7 | from PIL import Image 8 | from torchvision import transforms 9 | from torch.autograd import Variable 10 | from collections import OrderedDict 11 | 12 | 13 | IS_MULTISCALE = True 14 | N_CLASS = 19 15 | COLOR_MAP = [(128, 64, 128), (244, 35, 232), (70, 70, 70), (102, 102, 156), (190, 153, 153), (153, 153, 153), 16 | (250, 170, 30), (220, 220, 0), (107, 142, 35), (152, 251, 152), (70, 130, 180), (220, 20, 60), 17 | (255, 0, 0), (0, 0, 142), (0, 0, 70), (0, 60, 100), (0, 80, 100), (0, 0, 230), (119, 11, 32)] 18 | 19 | inf_scales = [0.5, 0.75, 1.0, 1.25, 1.5, 1.8] 20 | data_transforms = transforms.Compose([transforms.ToTensor(), 21 | transforms.Normalize([0.290101, 0.328081, 0.286964], 22 | [0.182954, 0.186566, 0.184475])]) 23 | 24 | 25 | class Inference(object): 26 | 27 | def __init__(self, model_name, model_path): 28 | self.seg_model = self.__init_model(model_name, model_path, is_local=False) 29 | 30 | def __init_model(self, model_name, model_path, is_local=False): 31 | if model_name == 'MobileNetDenseASPP': 32 | from cfgs.MobileNetDenseASPP import Model_CFG 33 | from models.MobileNetDenseASPP import DenseASPP 34 | elif model_name == 'DenseASPP121': 35 | from cfgs.DenseASPP121 import Model_CFG 36 | from models.DenseASPP import DenseASPP 37 | elif model_name == 'DenseASPP169': 38 | from cfgs.DenseASPP169 import Model_CFG 39 | from models.DenseASPP import DenseASPP 40 | elif model_name == 'DenseASPP201': 41 | from cfgs.DenseASPP201 import Model_CFG 42 | from models.DenseASPP import DenseASPP 43 | elif model_name == 'DenseASPP161': 44 | from cfgs.DenseASPP161 import Model_CFG 45 | from models.DenseASPP import DenseASPP 46 | else: 47 | from cfgs.DenseASPP161 import Model_CFG 48 | from models.DenseASPP import DenseASPP 49 | 50 | seg_model = DenseASPP(Model_CFG, n_class=N_CLASS, output_stride=8) 51 | self.__load_weight(seg_model, model_path, is_local=is_local) 52 | seg_model.eval() 53 | seg_model = seg_model.cuda() 54 | 55 | return seg_model 56 | 57 | def folder_inference(self, img_dir, is_multiscale=True): 58 | folders = sorted(os.listdir(img_dir)) 59 | for f in folders: 60 | read_path = os.path.join(img_dir, f) 61 | names = sorted(os.listdir(read_path)) 62 | for n in names: 63 | if not n.endswith(".png"): 64 | continue 65 | print(n) 66 | read_name = os.path.join(read_path, n) 67 | img = Image.open(read_name) 68 | if is_multiscale: 69 | pre = self.multiscale_inference(img) 70 | else: 71 | pre = self.single_inference(img) 72 | mask = self.__pre_to_img(pre) 73 | cv2.imshow('DenseASPP', mask) 74 | cv2.waitKey(0) 75 | 76 | def multiscale_inference(self, test_img): 77 | h, w = test_img.size 78 | pre = [] 79 | for scale in inf_scales: 80 | img_scaled = test_img.resize((int(h * scale), int(w * scale)), Image.CUBIC) 81 | pre_scaled = self.single_inference(img_scaled, is_flip=False) 82 | pre.append(pre_scaled) 83 | 84 | img_scaled = img_scaled.transpose(Image.FLIP_LEFT_RIGHT) 85 | pre_scaled = self.single_inference(img_scaled, is_flip=True) 86 | pre.append(pre_scaled) 87 | 88 | pre_final = self.__fushion_avg(pre) 89 | 90 | return pre_final 91 | 92 | def single_inference(self, test_img, is_flip=False): 93 | image = Variable(data_transforms(test_img).unsqueeze(0).cuda(), volatile=True) 94 | pre = self.seg_model.forward(image) 95 | 96 | if pre.size()[0] < 1024: 97 | pre = F.upsample(pre, size=(1024, 2048), mode='bilinear') 98 | 99 | pre = F.log_softmax(pre, dim=1) 100 | pre = pre.data.cpu().numpy() 101 | 102 | if is_flip: 103 | tem = pre[0] 104 | tem = tem.transpose(1, 2, 0) 105 | tem = numpy.fliplr(tem) 106 | tem = tem.transpose(2, 0, 1) 107 | pre[0] = tem 108 | 109 | return pre 110 | 111 | @staticmethod 112 | def __fushion_avg(pre): 113 | pre_final = 0 114 | for pre_scaled in pre: 115 | pre_final = pre_final + pre_scaled 116 | pre_final = pre_final / len(pre) 117 | return pre_final 118 | 119 | @staticmethod 120 | def __load_weight(seg_model, model_path, is_local=True): 121 | print("loading pre-trained weight") 122 | weight = torch.load(model_path, map_location=lambda storage, loc: storage) 123 | 124 | if is_local: 125 | seg_model.load_state_dict(weight) 126 | else: 127 | new_state_dict = OrderedDict() 128 | for k, v in weight.items(): 129 | name = k[7:] # remove `module.` 130 | new_state_dict[name] = v 131 | seg_model.load_state_dict(new_state_dict) 132 | 133 | @staticmethod 134 | def __pre_to_img(pre): 135 | result = pre.argmax(axis=1)[0] 136 | row, col = result.shape 137 | dst = numpy.zeros((row, col, 3), dtype=numpy.uint8) 138 | for i in range(N_CLASS): 139 | dst[result == i] = COLOR_MAP[i] 140 | dst = numpy.array(dst, dtype=numpy.uint8) 141 | dst = cv2.cvtColor(dst, cv2.COLOR_RGB2BGR) 142 | return dst 143 | -------------------------------------------------------------------------------- /models/MobileNetDenseASPP.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from torch import nn 5 | from torch.nn import BatchNorm2d as bn 6 | 7 | 8 | class DenseASPP(nn.Module): 9 | """ 10 | * output_scale can only set as 8 or 16 11 | """ 12 | 13 | def __init__(self, model_cfg, n_class=19, output_stride=8): 14 | super(DenseASPP, self).__init__() 15 | dropout0 = model_cfg['dropout0'] 16 | dropout1 = model_cfg['dropout1'] 17 | d_feature0 = model_cfg['d_feature0'] 18 | d_feature1 = model_cfg['d_feature1'] 19 | 20 | feature_size = int(output_stride / 8) 21 | self.features = DilatedMobileNetV2(output_stride=output_stride) 22 | num_features = self.features.get_num_features() 23 | 24 | if feature_size > 1: 25 | self.features.add_module('upsample', nn.Upsample(scale_factor=2, mode='bilinear')) 26 | 27 | self.ASPP_3 = _DenseAsppBlock(input_num=num_features, num1=d_feature0, num2=d_feature1, 28 | dilation_rate=3, drop_out=dropout0, bn_start=False) 29 | 30 | self.ASPP_6 = _DenseAsppBlock(input_num=num_features + d_feature1 * 1, num1=d_feature0, num2=d_feature1, 31 | dilation_rate=6, drop_out=dropout0, bn_start=True) 32 | 33 | self.ASPP_12 = _DenseAsppBlock(input_num=num_features + d_feature1 * 2, num1=d_feature0, num2=d_feature1, 34 | dilation_rate=12, drop_out=dropout0, bn_start=True) 35 | 36 | self.ASPP_18 = _DenseAsppBlock(input_num=num_features + d_feature1 * 3, num1=d_feature0, num2=d_feature1, 37 | dilation_rate=18, drop_out=dropout0, bn_start=True) 38 | 39 | self.ASPP_24 = _DenseAsppBlock(input_num=num_features + d_feature1 * 4, num1=d_feature0, num2=d_feature1, 40 | dilation_rate=24, drop_out=dropout0, bn_start=True) 41 | num_features = num_features + 5 * d_feature1 42 | 43 | self.classification = nn.Sequential( 44 | nn.Dropout2d(p=dropout1), 45 | nn.Conv2d(in_channels=num_features, out_channels=n_class, kernel_size=1, padding=0), 46 | nn.Upsample(scale_factor=8, mode='bilinear'), 47 | ) 48 | 49 | for m in self.modules(): 50 | if isinstance(m, nn.Conv2d): 51 | nn.init.kaiming_uniform(m.weight.data) 52 | 53 | elif isinstance(m, bn): 54 | m.weight.data.fill_(1) 55 | m.bias.data.zero_() 56 | 57 | def forward(self, _input): 58 | feature = self.features(_input) 59 | 60 | aspp3 = self.ASPP_3(feature) 61 | feature = torch.cat((aspp3, feature), dim=1) 62 | 63 | aspp6 = self.ASPP_6(feature) 64 | feature = torch.cat((aspp6, feature), dim=1) 65 | 66 | aspp12 = self.ASPP_12(feature) 67 | feature = torch.cat((aspp12, feature), dim=1) 68 | 69 | aspp18 = self.ASPP_18(feature) 70 | feature = torch.cat((aspp18, feature), dim=1) 71 | 72 | aspp24 = self.ASPP_24(feature) 73 | feature = torch.cat((aspp24, feature), dim=1) 74 | 75 | cls = self.classification(feature) 76 | 77 | return cls 78 | 79 | 80 | def conv_bn(inp, oup, stride): 81 | return nn.Sequential( 82 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 83 | nn.BatchNorm2d(oup), 84 | nn.ReLU(inplace=True) 85 | ) 86 | 87 | 88 | def conv_1x1_bn(inp, oup): 89 | return nn.Sequential( 90 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 91 | nn.BatchNorm2d(oup), 92 | nn.ReLU(inplace=True) 93 | ) 94 | 95 | 96 | class InvertedResidual(nn.Module): 97 | def __init__(self, inp, oup, stride, expand_ratio, dilation=1): 98 | super(InvertedResidual, self).__init__() 99 | self.stride = stride 100 | assert stride in [1, 2] 101 | 102 | self.use_res_connect = self.stride == 1 and inp == oup 103 | 104 | self.conv = nn.Sequential( 105 | # pw 106 | nn.Conv2d(inp, inp * expand_ratio, 1, 1, 0, bias=False), 107 | nn.BatchNorm2d(inp * expand_ratio), 108 | nn.ReLU6(inplace=True), 109 | # dw 110 | nn.Conv2d(inp * expand_ratio, inp * expand_ratio, kernel_size=3, stride=stride, padding=dilation, 111 | dilation=dilation, groups=inp * expand_ratio, bias=False), 112 | nn.BatchNorm2d(inp * expand_ratio), 113 | nn.ReLU6(inplace=True), 114 | # pw-linear 115 | nn.Conv2d(inp * expand_ratio, oup, 1, 1, 0, bias=False), 116 | nn.BatchNorm2d(oup), 117 | ) 118 | 119 | def forward(self, x): 120 | if self.use_res_connect: 121 | return x + self.conv(x) 122 | else: 123 | return self.conv(x) 124 | 125 | 126 | class DilatedMobileNetV2(nn.Module): 127 | 128 | def __init__(self, width_mult=1., output_stride=8): 129 | super(DilatedMobileNetV2, self).__init__() 130 | self.num_features = 320 131 | self.scale_factor = int(output_stride / 8) 132 | scale = self.scale_factor 133 | # setting of inverted residual blocks 134 | self.interverted_residual_setting = [ 135 | # t, c, n, s 136 | [1, 16, 1, 1, 1], 137 | [6, 24, 2, 2, 1], 138 | [6, 32, 3, 2, 1], 139 | [6, 64, 4, int(scale), int(2 / scale)], 140 | [6, 96, 3, 1, int(2 / scale)], 141 | [6, 160, 3, 1, int(2 / scale)], 142 | [6, 320, 1, 1, int(2 / scale)], 143 | ] 144 | 145 | input_channel = int(32 * width_mult) 146 | self.features = [conv_bn(3, input_channel, 2)] 147 | # building inverted residual blocks 148 | for t, c, n, s, dilate in self.interverted_residual_setting: 149 | output_channel = int(c * width_mult) 150 | for i in range(n): 151 | if i == 0: 152 | self.features.append(InvertedResidual(input_channel, output_channel, s, t, dilation=dilate)) 153 | else: 154 | self.features.append(InvertedResidual(input_channel, output_channel, 1, t, dilation=dilate)) 155 | input_channel = output_channel 156 | self.features = nn.Sequential(*self.features) 157 | 158 | def forward(self, x): 159 | return self.features(x) 160 | 161 | def get_num_features(self): 162 | return self.num_features 163 | 164 | 165 | class _DenseAsppBlock(nn.Sequential): 166 | """ ConvNet block for building DenseASPP. """ 167 | 168 | def __init__(self, input_num, num1, num2, dilation_rate, drop_out, bn_start=True): 169 | super(_DenseAsppBlock, self).__init__() 170 | if bn_start: 171 | self.add_module('norm.1', bn(input_num, momentum=0.0003)), 172 | 173 | self.add_module('relu.1', nn.ReLU(inplace=True)), 174 | self.add_module('conv.1', nn.Conv2d(in_channels=input_num, out_channels=num1, kernel_size=1)), 175 | 176 | self.add_module('norm.2', bn(num1, momentum=0.0003)), 177 | self.add_module('relu.2', nn.ReLU(inplace=True)), 178 | self.add_module('conv.2', nn.Conv2d(in_channels=num1, out_channels=num2, kernel_size=3, 179 | dilation=dilation_rate, padding=dilation_rate)), 180 | 181 | self.drop_rate = drop_out 182 | 183 | def forward(self, _input): 184 | feature = super(_DenseAsppBlock, self).forward(_input) 185 | 186 | if self.drop_rate > 0: 187 | feature = F.dropout2d(feature, p=self.drop_rate, training=self.training) 188 | 189 | return feature 190 | 191 | 192 | if __name__ == "__main__": 193 | model = DenseASPP(2) 194 | print(model) 195 | -------------------------------------------------------------------------------- /models/DenseASPP.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from torch import nn 5 | from collections import OrderedDict 6 | from torch.nn import BatchNorm2d as bn 7 | 8 | 9 | class DenseASPP(nn.Module): 10 | """ 11 | * output_scale can only set as 8 or 16 12 | """ 13 | def __init__(self, model_cfg, n_class=19, output_stride=8): 14 | super(DenseASPP, self).__init__() 15 | bn_size = model_cfg['bn_size'] 16 | drop_rate = model_cfg['drop_rate'] 17 | growth_rate = model_cfg['growth_rate'] 18 | num_init_features = model_cfg['num_init_features'] 19 | block_config = model_cfg['block_config'] 20 | 21 | dropout0 = model_cfg['dropout0'] 22 | dropout1 = model_cfg['dropout1'] 23 | d_feature0 = model_cfg['d_feature0'] 24 | d_feature1 = model_cfg['d_feature1'] 25 | 26 | feature_size = int(output_stride / 8) 27 | 28 | # First convolution 29 | self.features = nn.Sequential(OrderedDict([ 30 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), 31 | ('norm0', bn(num_init_features)), 32 | ('relu0', nn.ReLU(inplace=True)), 33 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 34 | ])) 35 | 36 | # Each denseblock 37 | num_features = num_init_features 38 | # block1***************************************************************************************************** 39 | block = _DenseBlock(num_layers=block_config[0], num_input_features=num_features, 40 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate) 41 | self.features.add_module('denseblock%d' % 1, block) 42 | num_features = num_features + block_config[0] * growth_rate 43 | 44 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2) 45 | self.features.add_module('transition%d' % 1, trans) 46 | num_features = num_features // 2 47 | 48 | # block2***************************************************************************************************** 49 | block = _DenseBlock(num_layers=block_config[1], num_input_features=num_features, 50 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate) 51 | self.features.add_module('denseblock%d' % 2, block) 52 | num_features = num_features + block_config[1] * growth_rate 53 | 54 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2, stride=feature_size) 55 | self.features.add_module('transition%d' % 2, trans) 56 | num_features = num_features // 2 57 | 58 | # block3***************************************************************************************************** 59 | block = _DenseBlock(num_layers=block_config[2], num_input_features=num_features, bn_size=bn_size, 60 | growth_rate=growth_rate, drop_rate=drop_rate, dilation_rate=int(2 / feature_size)) 61 | self.features.add_module('denseblock%d' % 3, block) 62 | num_features = num_features + block_config[2] * growth_rate 63 | 64 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2, stride=1) 65 | self.features.add_module('transition%d' % 3, trans) 66 | num_features = num_features // 2 67 | 68 | # block4***************************************************************************************************** 69 | block = _DenseBlock(num_layers=block_config[3], num_input_features=num_features, bn_size=bn_size, 70 | growth_rate=growth_rate, drop_rate=drop_rate, dilation_rate=int(4 / feature_size)) 71 | self.features.add_module('denseblock%d' % 4, block) 72 | num_features = num_features + block_config[3] * growth_rate 73 | 74 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2, stride=1) 75 | self.features.add_module('transition%d' % 4, trans) 76 | num_features = num_features // 2 77 | 78 | # Final batch norm 79 | self.features.add_module('norm5', bn(num_features)) 80 | if feature_size > 1: 81 | self.features.add_module('upsample', nn.Upsample(scale_factor=2, mode='bilinear')) 82 | 83 | self.ASPP_3 = _DenseAsppBlock(input_num=num_features, num1=d_feature0, num2=d_feature1, 84 | dilation_rate=3, drop_out=dropout0, bn_start=False) 85 | 86 | self.ASPP_6 = _DenseAsppBlock(input_num=num_features + d_feature1 * 1, num1=d_feature0, num2=d_feature1, 87 | dilation_rate=6, drop_out=dropout0, bn_start=True) 88 | 89 | self.ASPP_12 = _DenseAsppBlock(input_num=num_features + d_feature1 * 2, num1=d_feature0, num2=d_feature1, 90 | dilation_rate=12, drop_out=dropout0, bn_start=True) 91 | 92 | self.ASPP_18 = _DenseAsppBlock(input_num=num_features + d_feature1 * 3, num1=d_feature0, num2=d_feature1, 93 | dilation_rate=18, drop_out=dropout0, bn_start=True) 94 | 95 | self.ASPP_24 = _DenseAsppBlock(input_num=num_features + d_feature1 * 4, num1=d_feature0, num2=d_feature1, 96 | dilation_rate=24, drop_out=dropout0, bn_start=True) 97 | num_features = num_features + 5 * d_feature1 98 | 99 | self.classification = nn.Sequential( 100 | nn.Dropout2d(p=dropout1), 101 | nn.Conv2d(in_channels=num_features, out_channels=n_class, kernel_size=1, padding=0), 102 | nn.Upsample(scale_factor=8, mode='bilinear'), 103 | ) 104 | 105 | for m in self.modules(): 106 | if isinstance(m, nn.Conv2d): 107 | nn.init.kaiming_uniform(m.weight.data) 108 | 109 | elif isinstance(m, nn.BatchNorm2d): 110 | m.weight.data.fill_(1) 111 | m.bias.data.zero_() 112 | 113 | def forward(self, _input): 114 | feature = self.features(_input) 115 | 116 | aspp3 = self.ASPP_3(feature) 117 | feature = torch.cat((aspp3, feature), dim=1) 118 | 119 | aspp6 = self.ASPP_6(feature) 120 | feature = torch.cat((aspp6, feature), dim=1) 121 | 122 | aspp12 = self.ASPP_12(feature) 123 | feature = torch.cat((aspp12, feature), dim=1) 124 | 125 | aspp18 = self.ASPP_18(feature) 126 | feature = torch.cat((aspp18, feature), dim=1) 127 | 128 | aspp24 = self.ASPP_24(feature) 129 | feature = torch.cat((aspp24, feature), dim=1) 130 | 131 | cls = self.classification(feature) 132 | 133 | return cls 134 | 135 | 136 | class _DenseAsppBlock(nn.Sequential): 137 | """ ConvNet block for building DenseASPP. """ 138 | 139 | def __init__(self, input_num, num1, num2, dilation_rate, drop_out, bn_start=True): 140 | super(_DenseAsppBlock, self).__init__() 141 | if bn_start: 142 | self.add_module('norm.1', bn(input_num, momentum=0.0003)), 143 | 144 | self.add_module('relu.1', nn.ReLU(inplace=True)), 145 | self.add_module('conv.1', nn.Conv2d(in_channels=input_num, out_channels=num1, kernel_size=1)), 146 | 147 | self.add_module('norm.2', bn(num1, momentum=0.0003)), 148 | self.add_module('relu.2', nn.ReLU(inplace=True)), 149 | self.add_module('conv.2', nn.Conv2d(in_channels=num1, out_channels=num2, kernel_size=3, 150 | dilation=dilation_rate, padding=dilation_rate)), 151 | 152 | self.drop_rate = drop_out 153 | 154 | def forward(self, _input): 155 | feature = super(_DenseAsppBlock, self).forward(_input) 156 | 157 | if self.drop_rate > 0: 158 | feature = F.dropout2d(feature, p=self.drop_rate, training=self.training) 159 | 160 | return feature 161 | 162 | 163 | class _DenseLayer(nn.Sequential): 164 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, dilation_rate=1): 165 | super(_DenseLayer, self).__init__() 166 | self.add_module('norm.1', bn(num_input_features)), 167 | self.add_module('relu.1', nn.ReLU(inplace=True)), 168 | self.add_module('conv.1', nn.Conv2d(num_input_features, bn_size * 169 | growth_rate, kernel_size=1, stride=1, bias=False)), 170 | self.add_module('norm.2', bn(bn_size * growth_rate)), 171 | self.add_module('relu.2', nn.ReLU(inplace=True)), 172 | self.add_module('conv.2', nn.Conv2d(bn_size * growth_rate, growth_rate, 173 | kernel_size=3, stride=1, dilation=dilation_rate, padding=dilation_rate, bias=False)), 174 | self.drop_rate = drop_rate 175 | 176 | def forward(self, x): 177 | new_features = super(_DenseLayer, self).forward(x) 178 | if self.drop_rate > 0: 179 | new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) 180 | return torch.cat([x, new_features], 1) 181 | 182 | 183 | class _DenseBlock(nn.Sequential): 184 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, dilation_rate=1): 185 | super(_DenseBlock, self).__init__() 186 | for i in range(num_layers): 187 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, 188 | bn_size, drop_rate, dilation_rate=dilation_rate) 189 | self.add_module('denselayer%d' % (i + 1), layer) 190 | 191 | 192 | class _Transition(nn.Sequential): 193 | def __init__(self, num_input_features, num_output_features, stride=2): 194 | super(_Transition, self).__init__() 195 | self.add_module('norm', bn(num_input_features)) 196 | self.add_module('relu', nn.ReLU(inplace=True)) 197 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False)) 198 | if stride == 2: 199 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=stride)) 200 | 201 | 202 | if __name__ == "__main__": 203 | model = DenseASPP(2) 204 | print(model) 205 | --------------------------------------------------------------------------------