├── cfgs
├── __init__.py
├── __init__.pyc
├── DenseASPP161.pyc
├── __pycache__
│ ├── __init__.cpython-35.pyc
│ └── DenseASPP161.cpython-35.pyc
├── MobileNetDenseASPP.py
├── DenseASPP201.py
├── DenseASPP121.py
├── DenseASPP161.py
└── DenseASPP169.py
├── models
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-35.pyc
│ └── denseASPP.cpython-35.pyc
├── MobileNetDenseASPP.py
└── DenseASPP.py
├── demo.py
├── utils
└── transfer.py
├── README.md
└── inference.py
/cfgs/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/cfgs/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepMotionAIResearch/DenseASPP/HEAD/cfgs/__init__.pyc
--------------------------------------------------------------------------------
/cfgs/DenseASPP161.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepMotionAIResearch/DenseASPP/HEAD/cfgs/DenseASPP161.pyc
--------------------------------------------------------------------------------
/cfgs/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepMotionAIResearch/DenseASPP/HEAD/cfgs/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/models/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepMotionAIResearch/DenseASPP/HEAD/models/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/models/__pycache__/denseASPP.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepMotionAIResearch/DenseASPP/HEAD/models/__pycache__/denseASPP.cpython-35.pyc
--------------------------------------------------------------------------------
/cfgs/__pycache__/DenseASPP161.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepMotionAIResearch/DenseASPP/HEAD/cfgs/__pycache__/DenseASPP161.cpython-35.pyc
--------------------------------------------------------------------------------
/cfgs/MobileNetDenseASPP.py:
--------------------------------------------------------------------------------
1 | Model_CFG = {
2 | 'dropout0': 0.1,
3 | 'dropout1': 0.1,
4 | 'd_feature0': 128,
5 | 'd_feature1': 64,
6 |
7 | 'pretrained_path': "./pretrained/mobilenetv2.pth.tar"
8 | }
9 |
--------------------------------------------------------------------------------
/cfgs/DenseASPP201.py:
--------------------------------------------------------------------------------
1 | Model_CFG = {
2 | 'bn_size': 4,
3 | 'drop_rate': 0,
4 | 'growth_rate': 32,
5 | 'num_init_features': 64,
6 | 'block_config': (6, 12, 48, 32),
7 |
8 | 'dropout0': 0.1,
9 | 'dropout1': 0.1,
10 | 'd_feature0': 480,
11 | 'd_feature1': 240,
12 | }
--------------------------------------------------------------------------------
/cfgs/DenseASPP121.py:
--------------------------------------------------------------------------------
1 | Model_CFG = {
2 | 'bn_size': 4,
3 | 'drop_rate': 0,
4 | 'growth_rate': 32,
5 | 'num_init_features': 64,
6 | 'block_config': (6, 12, 24, 16),
7 |
8 | 'dropout0': 0.1,
9 | 'dropout1': 0.1,
10 | 'd_feature0': 128,
11 | 'd_feature1': 64,
12 |
13 | 'pretrained_path': "./pretrained/densenet121.pth"
14 | }
--------------------------------------------------------------------------------
/cfgs/DenseASPP161.py:
--------------------------------------------------------------------------------
1 | Model_CFG = {
2 | 'bn_size': 4,
3 | 'drop_rate': 0,
4 | 'growth_rate': 48,
5 | 'num_init_features': 96,
6 | 'block_config': (6, 12, 36, 24),
7 |
8 | 'dropout0': 0.1,
9 | 'dropout1': 0.1,
10 | 'd_feature0': 512,
11 | 'd_feature1': 128,
12 |
13 | 'pretrained_path': "./pretrained/densenet161.pth"
14 | }
15 |
--------------------------------------------------------------------------------
/cfgs/DenseASPP169.py:
--------------------------------------------------------------------------------
1 | Model_CFG = {
2 | 'bn_size': 4,
3 | 'drop_rate': 0,
4 | 'growth_rate': 32,
5 | 'num_init_features': 64,
6 | 'block_config': (6, 12, 32, 32),
7 |
8 | 'dropout0': 0.1,
9 | 'dropout1': 0.1,
10 | 'd_feature0': 336,
11 | 'd_feature1': 168,
12 |
13 | 'pretrained_path': "./pretrained/densenet169.pth"
14 | }
15 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding:utf-8 -*-
3 | # Author: {maokeyang, kunyu, kuiyuanyang}@deepmotion.ai
4 | # Inference code of DenseASPP based segmentation models.
5 |
6 |
7 | import argparse
8 | from inference import Inference
9 |
10 |
11 | if __name__ == "__main__":
12 | parser = argparse.ArgumentParser(description='DenseASPP inference code.')
13 | parser.add_argument('--model_name', default='DenseASPP161', help='segmentation model.')
14 | parser.add_argument('--model_path', default='./weights/denseASPP161.pkl', help='weight path.')
15 | parser.add_argument('--img_dir', default='./Cityscapes/leftImg8bit/val', help='image dir.')
16 | args = parser.parse_args()
17 |
18 | infer = Inference(args.model_name, args.model_path)
19 | infer.folder_inference(args.img_dir, is_multiscale=False)
20 |
--------------------------------------------------------------------------------
/utils/transfer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy
4 |
5 | IMG_PATH = "../val/"
6 | SAVE_PATH = "../results/"
7 |
8 | color_list = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33]
9 | color_map = [(128, 64, 128), (244, 35, 232), (70, 70, 70), (102, 102, 156), (190, 153, 153), (153, 153, 153),
10 | (250, 170, 30), (220, 220, 0), (107, 142, 35), (152, 251, 152), (70, 130, 180), (220, 20, 60),
11 | (255, 0, 0), (0, 0, 142), (0, 0, 70), (0, 60, 100), (0, 80, 100), (0, 0, 230), (119, 11, 32)]
12 |
13 |
14 | def change():
15 | if not os.path.exists(SAVE_PATH):
16 | os.mkdir(SAVE_PATH)
17 |
18 | folders = sorted(os.listdir(IMG_PATH))
19 | for f in folders:
20 | folder_path = IMG_PATH + f + "/"
21 | save_path = SAVE_PATH + f + "/"
22 | if not os.path.exists(save_path):
23 | os.mkdir(save_path)
24 | names = sorted(os.listdir(folder_path))
25 | for n in names:
26 | print(n)
27 | img = cv2.cvtColor(cv2.imread(folder_path + n, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
28 | R, G, B = cv2.split(img)
29 | mask = numpy.zeros_like(R, dtype=numpy.uint8)
30 |
31 | for i in range(color_list.__len__()):
32 | tmp_mask = numpy.zeros_like(R, dtype=numpy.uint8)
33 | color = color_map[i]
34 | tmp_mask[R[:] == color[0]] += 1
35 | tmp_mask[G[:] == color[1]] += 1
36 | tmp_mask[B[:] == color[2]] += 1
37 |
38 | mask[tmp_mask[:] == 3] = color_list[i]
39 | cv2.imwrite(save_path + n, mask)
40 | cv2.waitKey(1)
41 |
42 |
43 | if __name__ == "__main__":
44 | change()
45 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DenseASPP for Semantic Segmentation in Street Scenes [pdf](http://openaccess.thecvf.com/content_cvpr_2018/papers/Yang_DenseASPP_for_Semantic_CVPR_2018_paper.pdf)
2 |
3 | ## Introduction
4 |
5 | Semantic image segmentation is a basic street scene understanding task in autonomous driving, where each pixel in a high resolution image is categorized into a set of semantic labels. Unlike other scenarios, objects in autonomous driving scene exhibit very large scale changes, which poses great challenges for high-level feature representation in a sense that multi-scale information must be correctly encoded.
6 |
7 | To remedy this problem, atrous convolution[2, 3] was introduced to generate features with larger receptive fields without sacrificing spatial resolution. Built upon atrous convolution, Atrous Spatial Pyramid Pooling (ASPP)[3] was proposed to concatenate multiple atrous-convolved features using different dilation rates into a final feature representation. Although ASPP is able to generate multi-scale features, we argue the feature resolution in the scale-axis is not dense enough for the autonomous driving scenario. To this end, we propose Densely connected Atrous Spatial Pyramid Pooling (DenseASPP), which connects a set of atrous convolutional layers in a dense way, such that it generates multi-scale features that not only cover a larger scale range, but also cover that scale range densely, without significantly increasing the model size. We evaluate DenseASPP on the street scene benchmark Cityscapes[4] and achieve state-of-the-art performance.
8 |
9 | ## Usage
10 |
11 | ### 1. **Clone the repository:**
12 |
13 | ```
14 | git clone https://github.com/DeepMotionAIResearch/DenseASPP.git
15 | ```
16 |
17 | ### 2. **Download pretrained model:**
18 | Put the model at the folder `weights`. We provide some checkpoints to run the code:
19 |
20 | **DenseNet161 based model**: [GoogleDrive](https://drive.google.com/open?id=1kMKyboVGWlBxgYRYYnOXiA1mj_ufAXNJ)
21 |
22 | **Mobilenet v2 based model**: Coming soon.
23 |
24 | Performance of these checkpoints:
25 |
26 | Checkpoint name | Multi-scale inference | Cityscapes mIOU (val) | Cityscapes mIOU (test) | File Size
27 | ------------------------------------------------------------------------- | :-------------------------: | :----------------------------: | :----------------------------: |:-------: |
28 | [DenseASPP161](https://drive.google.com/file/d/1sCr-OkMUayaHAijdQrzndKk2WW78MVZG/view?usp=sharing) | False
True | 79.9%
80.6 % | -
79.5% | 142.7 MB
29 | [MobileNetDenseASPP](*) | False
True | 74.5%
75.0 % | -
- | 10.2 MB
30 |
31 | Please note that the performance of these checkpoints can be further improved by fine-tuning. Besides, these models were trained with **Pytorch 0.3.1**
32 |
33 | ### 3. **Inference**
34 |
35 | First cd to your code root, then run:
36 |
37 | ```
38 | python demo.py --model_name DenseASPP161 --model_path --img_dir
39 | ```
40 |
41 | ### 4. **Evaluation the results**
42 | Please cd to `./utils`, then run:
43 |
44 | ```
45 | python transfer.py
46 | ```
47 |
48 | And eval the results with the official evaluation code of Cityscapes, which can be found at [there](https://github.com/mcordts/cityscapesScripts)
49 |
50 | ## References
51 |
52 | 1. **DenseASPP for Semantic Segmentation in Street Scenes**
53 | Maoke Yang, Kun Yu, Chi Zhang, Zhiwei Li, Kuiyuan Yang.
54 | [link](http://openaccess.thecvf.com/content_cvpr_2018/papers/Yang_DenseASPP_for_Semantic_CVPR_2018_paper.pdf). In CVPR, 2018.
55 |
56 | 2. **Semantic Image Segmentation with Deep Convolutional Nets and Fully Connected CRFs**
57 | Liang-Chieh Chen+, George Papandreou+, Iasonas Kokkinos, Kevin Murphy, Alan L. Yuille (+ equal
58 | contribution).
59 | [link](https://arxiv.org/abs/1412.7062). In ICLR, 2015.
60 |
61 | 3. **DeepLab: Semantic Image Segmentation with Deep Convolutional Nets,**
62 | **Atrous Convolution, and Fully Connected CRFs**
63 | Liang-Chieh Chen+, George Papandreou+, Iasonas Kokkinos, Kevin Murphy, and Alan L Yuille (+ equal
64 | contribution).
65 | [link](http://arxiv.org/abs/1606.00915). TPAMI 2017.
66 |
67 | 4. **The Cityscapes Dataset for Semantic Urban Scene Understanding**
68 | Cordts, Marius, Mohamed Omran, Sebastian Ramos, Timo Rehfeld, Markus Enzweiler, Rodrigo Benenson, Uwe Franke, Stefan Roth, Bernt Schiele.
69 | [link](https://www.cityscapes-dataset.com/). In CVPR, 2016.
70 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import torch
4 | import numpy
5 | import torch.nn.functional as F
6 |
7 | from PIL import Image
8 | from torchvision import transforms
9 | from torch.autograd import Variable
10 | from collections import OrderedDict
11 |
12 |
13 | IS_MULTISCALE = True
14 | N_CLASS = 19
15 | COLOR_MAP = [(128, 64, 128), (244, 35, 232), (70, 70, 70), (102, 102, 156), (190, 153, 153), (153, 153, 153),
16 | (250, 170, 30), (220, 220, 0), (107, 142, 35), (152, 251, 152), (70, 130, 180), (220, 20, 60),
17 | (255, 0, 0), (0, 0, 142), (0, 0, 70), (0, 60, 100), (0, 80, 100), (0, 0, 230), (119, 11, 32)]
18 |
19 | inf_scales = [0.5, 0.75, 1.0, 1.25, 1.5, 1.8]
20 | data_transforms = transforms.Compose([transforms.ToTensor(),
21 | transforms.Normalize([0.290101, 0.328081, 0.286964],
22 | [0.182954, 0.186566, 0.184475])])
23 |
24 |
25 | class Inference(object):
26 |
27 | def __init__(self, model_name, model_path):
28 | self.seg_model = self.__init_model(model_name, model_path, is_local=False)
29 |
30 | def __init_model(self, model_name, model_path, is_local=False):
31 | if model_name == 'MobileNetDenseASPP':
32 | from cfgs.MobileNetDenseASPP import Model_CFG
33 | from models.MobileNetDenseASPP import DenseASPP
34 | elif model_name == 'DenseASPP121':
35 | from cfgs.DenseASPP121 import Model_CFG
36 | from models.DenseASPP import DenseASPP
37 | elif model_name == 'DenseASPP169':
38 | from cfgs.DenseASPP169 import Model_CFG
39 | from models.DenseASPP import DenseASPP
40 | elif model_name == 'DenseASPP201':
41 | from cfgs.DenseASPP201 import Model_CFG
42 | from models.DenseASPP import DenseASPP
43 | elif model_name == 'DenseASPP161':
44 | from cfgs.DenseASPP161 import Model_CFG
45 | from models.DenseASPP import DenseASPP
46 | else:
47 | from cfgs.DenseASPP161 import Model_CFG
48 | from models.DenseASPP import DenseASPP
49 |
50 | seg_model = DenseASPP(Model_CFG, n_class=N_CLASS, output_stride=8)
51 | self.__load_weight(seg_model, model_path, is_local=is_local)
52 | seg_model.eval()
53 | seg_model = seg_model.cuda()
54 |
55 | return seg_model
56 |
57 | def folder_inference(self, img_dir, is_multiscale=True):
58 | folders = sorted(os.listdir(img_dir))
59 | for f in folders:
60 | read_path = os.path.join(img_dir, f)
61 | names = sorted(os.listdir(read_path))
62 | for n in names:
63 | if not n.endswith(".png"):
64 | continue
65 | print(n)
66 | read_name = os.path.join(read_path, n)
67 | img = Image.open(read_name)
68 | if is_multiscale:
69 | pre = self.multiscale_inference(img)
70 | else:
71 | pre = self.single_inference(img)
72 | mask = self.__pre_to_img(pre)
73 | cv2.imshow('DenseASPP', mask)
74 | cv2.waitKey(0)
75 |
76 | def multiscale_inference(self, test_img):
77 | h, w = test_img.size
78 | pre = []
79 | for scale in inf_scales:
80 | img_scaled = test_img.resize((int(h * scale), int(w * scale)), Image.CUBIC)
81 | pre_scaled = self.single_inference(img_scaled, is_flip=False)
82 | pre.append(pre_scaled)
83 |
84 | img_scaled = img_scaled.transpose(Image.FLIP_LEFT_RIGHT)
85 | pre_scaled = self.single_inference(img_scaled, is_flip=True)
86 | pre.append(pre_scaled)
87 |
88 | pre_final = self.__fushion_avg(pre)
89 |
90 | return pre_final
91 |
92 | def single_inference(self, test_img, is_flip=False):
93 | image = Variable(data_transforms(test_img).unsqueeze(0).cuda(), volatile=True)
94 | pre = self.seg_model.forward(image)
95 |
96 | if pre.size()[0] < 1024:
97 | pre = F.upsample(pre, size=(1024, 2048), mode='bilinear')
98 |
99 | pre = F.log_softmax(pre, dim=1)
100 | pre = pre.data.cpu().numpy()
101 |
102 | if is_flip:
103 | tem = pre[0]
104 | tem = tem.transpose(1, 2, 0)
105 | tem = numpy.fliplr(tem)
106 | tem = tem.transpose(2, 0, 1)
107 | pre[0] = tem
108 |
109 | return pre
110 |
111 | @staticmethod
112 | def __fushion_avg(pre):
113 | pre_final = 0
114 | for pre_scaled in pre:
115 | pre_final = pre_final + pre_scaled
116 | pre_final = pre_final / len(pre)
117 | return pre_final
118 |
119 | @staticmethod
120 | def __load_weight(seg_model, model_path, is_local=True):
121 | print("loading pre-trained weight")
122 | weight = torch.load(model_path, map_location=lambda storage, loc: storage)
123 |
124 | if is_local:
125 | seg_model.load_state_dict(weight)
126 | else:
127 | new_state_dict = OrderedDict()
128 | for k, v in weight.items():
129 | name = k[7:] # remove `module.`
130 | new_state_dict[name] = v
131 | seg_model.load_state_dict(new_state_dict)
132 |
133 | @staticmethod
134 | def __pre_to_img(pre):
135 | result = pre.argmax(axis=1)[0]
136 | row, col = result.shape
137 | dst = numpy.zeros((row, col, 3), dtype=numpy.uint8)
138 | for i in range(N_CLASS):
139 | dst[result == i] = COLOR_MAP[i]
140 | dst = numpy.array(dst, dtype=numpy.uint8)
141 | dst = cv2.cvtColor(dst, cv2.COLOR_RGB2BGR)
142 | return dst
143 |
--------------------------------------------------------------------------------
/models/MobileNetDenseASPP.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from torch import nn
5 | from torch.nn import BatchNorm2d as bn
6 |
7 |
8 | class DenseASPP(nn.Module):
9 | """
10 | * output_scale can only set as 8 or 16
11 | """
12 |
13 | def __init__(self, model_cfg, n_class=19, output_stride=8):
14 | super(DenseASPP, self).__init__()
15 | dropout0 = model_cfg['dropout0']
16 | dropout1 = model_cfg['dropout1']
17 | d_feature0 = model_cfg['d_feature0']
18 | d_feature1 = model_cfg['d_feature1']
19 |
20 | feature_size = int(output_stride / 8)
21 | self.features = DilatedMobileNetV2(output_stride=output_stride)
22 | num_features = self.features.get_num_features()
23 |
24 | if feature_size > 1:
25 | self.features.add_module('upsample', nn.Upsample(scale_factor=2, mode='bilinear'))
26 |
27 | self.ASPP_3 = _DenseAsppBlock(input_num=num_features, num1=d_feature0, num2=d_feature1,
28 | dilation_rate=3, drop_out=dropout0, bn_start=False)
29 |
30 | self.ASPP_6 = _DenseAsppBlock(input_num=num_features + d_feature1 * 1, num1=d_feature0, num2=d_feature1,
31 | dilation_rate=6, drop_out=dropout0, bn_start=True)
32 |
33 | self.ASPP_12 = _DenseAsppBlock(input_num=num_features + d_feature1 * 2, num1=d_feature0, num2=d_feature1,
34 | dilation_rate=12, drop_out=dropout0, bn_start=True)
35 |
36 | self.ASPP_18 = _DenseAsppBlock(input_num=num_features + d_feature1 * 3, num1=d_feature0, num2=d_feature1,
37 | dilation_rate=18, drop_out=dropout0, bn_start=True)
38 |
39 | self.ASPP_24 = _DenseAsppBlock(input_num=num_features + d_feature1 * 4, num1=d_feature0, num2=d_feature1,
40 | dilation_rate=24, drop_out=dropout0, bn_start=True)
41 | num_features = num_features + 5 * d_feature1
42 |
43 | self.classification = nn.Sequential(
44 | nn.Dropout2d(p=dropout1),
45 | nn.Conv2d(in_channels=num_features, out_channels=n_class, kernel_size=1, padding=0),
46 | nn.Upsample(scale_factor=8, mode='bilinear'),
47 | )
48 |
49 | for m in self.modules():
50 | if isinstance(m, nn.Conv2d):
51 | nn.init.kaiming_uniform(m.weight.data)
52 |
53 | elif isinstance(m, bn):
54 | m.weight.data.fill_(1)
55 | m.bias.data.zero_()
56 |
57 | def forward(self, _input):
58 | feature = self.features(_input)
59 |
60 | aspp3 = self.ASPP_3(feature)
61 | feature = torch.cat((aspp3, feature), dim=1)
62 |
63 | aspp6 = self.ASPP_6(feature)
64 | feature = torch.cat((aspp6, feature), dim=1)
65 |
66 | aspp12 = self.ASPP_12(feature)
67 | feature = torch.cat((aspp12, feature), dim=1)
68 |
69 | aspp18 = self.ASPP_18(feature)
70 | feature = torch.cat((aspp18, feature), dim=1)
71 |
72 | aspp24 = self.ASPP_24(feature)
73 | feature = torch.cat((aspp24, feature), dim=1)
74 |
75 | cls = self.classification(feature)
76 |
77 | return cls
78 |
79 |
80 | def conv_bn(inp, oup, stride):
81 | return nn.Sequential(
82 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
83 | nn.BatchNorm2d(oup),
84 | nn.ReLU(inplace=True)
85 | )
86 |
87 |
88 | def conv_1x1_bn(inp, oup):
89 | return nn.Sequential(
90 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
91 | nn.BatchNorm2d(oup),
92 | nn.ReLU(inplace=True)
93 | )
94 |
95 |
96 | class InvertedResidual(nn.Module):
97 | def __init__(self, inp, oup, stride, expand_ratio, dilation=1):
98 | super(InvertedResidual, self).__init__()
99 | self.stride = stride
100 | assert stride in [1, 2]
101 |
102 | self.use_res_connect = self.stride == 1 and inp == oup
103 |
104 | self.conv = nn.Sequential(
105 | # pw
106 | nn.Conv2d(inp, inp * expand_ratio, 1, 1, 0, bias=False),
107 | nn.BatchNorm2d(inp * expand_ratio),
108 | nn.ReLU6(inplace=True),
109 | # dw
110 | nn.Conv2d(inp * expand_ratio, inp * expand_ratio, kernel_size=3, stride=stride, padding=dilation,
111 | dilation=dilation, groups=inp * expand_ratio, bias=False),
112 | nn.BatchNorm2d(inp * expand_ratio),
113 | nn.ReLU6(inplace=True),
114 | # pw-linear
115 | nn.Conv2d(inp * expand_ratio, oup, 1, 1, 0, bias=False),
116 | nn.BatchNorm2d(oup),
117 | )
118 |
119 | def forward(self, x):
120 | if self.use_res_connect:
121 | return x + self.conv(x)
122 | else:
123 | return self.conv(x)
124 |
125 |
126 | class DilatedMobileNetV2(nn.Module):
127 |
128 | def __init__(self, width_mult=1., output_stride=8):
129 | super(DilatedMobileNetV2, self).__init__()
130 | self.num_features = 320
131 | self.scale_factor = int(output_stride / 8)
132 | scale = self.scale_factor
133 | # setting of inverted residual blocks
134 | self.interverted_residual_setting = [
135 | # t, c, n, s
136 | [1, 16, 1, 1, 1],
137 | [6, 24, 2, 2, 1],
138 | [6, 32, 3, 2, 1],
139 | [6, 64, 4, int(scale), int(2 / scale)],
140 | [6, 96, 3, 1, int(2 / scale)],
141 | [6, 160, 3, 1, int(2 / scale)],
142 | [6, 320, 1, 1, int(2 / scale)],
143 | ]
144 |
145 | input_channel = int(32 * width_mult)
146 | self.features = [conv_bn(3, input_channel, 2)]
147 | # building inverted residual blocks
148 | for t, c, n, s, dilate in self.interverted_residual_setting:
149 | output_channel = int(c * width_mult)
150 | for i in range(n):
151 | if i == 0:
152 | self.features.append(InvertedResidual(input_channel, output_channel, s, t, dilation=dilate))
153 | else:
154 | self.features.append(InvertedResidual(input_channel, output_channel, 1, t, dilation=dilate))
155 | input_channel = output_channel
156 | self.features = nn.Sequential(*self.features)
157 |
158 | def forward(self, x):
159 | return self.features(x)
160 |
161 | def get_num_features(self):
162 | return self.num_features
163 |
164 |
165 | class _DenseAsppBlock(nn.Sequential):
166 | """ ConvNet block for building DenseASPP. """
167 |
168 | def __init__(self, input_num, num1, num2, dilation_rate, drop_out, bn_start=True):
169 | super(_DenseAsppBlock, self).__init__()
170 | if bn_start:
171 | self.add_module('norm.1', bn(input_num, momentum=0.0003)),
172 |
173 | self.add_module('relu.1', nn.ReLU(inplace=True)),
174 | self.add_module('conv.1', nn.Conv2d(in_channels=input_num, out_channels=num1, kernel_size=1)),
175 |
176 | self.add_module('norm.2', bn(num1, momentum=0.0003)),
177 | self.add_module('relu.2', nn.ReLU(inplace=True)),
178 | self.add_module('conv.2', nn.Conv2d(in_channels=num1, out_channels=num2, kernel_size=3,
179 | dilation=dilation_rate, padding=dilation_rate)),
180 |
181 | self.drop_rate = drop_out
182 |
183 | def forward(self, _input):
184 | feature = super(_DenseAsppBlock, self).forward(_input)
185 |
186 | if self.drop_rate > 0:
187 | feature = F.dropout2d(feature, p=self.drop_rate, training=self.training)
188 |
189 | return feature
190 |
191 |
192 | if __name__ == "__main__":
193 | model = DenseASPP(2)
194 | print(model)
195 |
--------------------------------------------------------------------------------
/models/DenseASPP.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from torch import nn
5 | from collections import OrderedDict
6 | from torch.nn import BatchNorm2d as bn
7 |
8 |
9 | class DenseASPP(nn.Module):
10 | """
11 | * output_scale can only set as 8 or 16
12 | """
13 | def __init__(self, model_cfg, n_class=19, output_stride=8):
14 | super(DenseASPP, self).__init__()
15 | bn_size = model_cfg['bn_size']
16 | drop_rate = model_cfg['drop_rate']
17 | growth_rate = model_cfg['growth_rate']
18 | num_init_features = model_cfg['num_init_features']
19 | block_config = model_cfg['block_config']
20 |
21 | dropout0 = model_cfg['dropout0']
22 | dropout1 = model_cfg['dropout1']
23 | d_feature0 = model_cfg['d_feature0']
24 | d_feature1 = model_cfg['d_feature1']
25 |
26 | feature_size = int(output_stride / 8)
27 |
28 | # First convolution
29 | self.features = nn.Sequential(OrderedDict([
30 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
31 | ('norm0', bn(num_init_features)),
32 | ('relu0', nn.ReLU(inplace=True)),
33 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
34 | ]))
35 |
36 | # Each denseblock
37 | num_features = num_init_features
38 | # block1*****************************************************************************************************
39 | block = _DenseBlock(num_layers=block_config[0], num_input_features=num_features,
40 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate)
41 | self.features.add_module('denseblock%d' % 1, block)
42 | num_features = num_features + block_config[0] * growth_rate
43 |
44 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2)
45 | self.features.add_module('transition%d' % 1, trans)
46 | num_features = num_features // 2
47 |
48 | # block2*****************************************************************************************************
49 | block = _DenseBlock(num_layers=block_config[1], num_input_features=num_features,
50 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate)
51 | self.features.add_module('denseblock%d' % 2, block)
52 | num_features = num_features + block_config[1] * growth_rate
53 |
54 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2, stride=feature_size)
55 | self.features.add_module('transition%d' % 2, trans)
56 | num_features = num_features // 2
57 |
58 | # block3*****************************************************************************************************
59 | block = _DenseBlock(num_layers=block_config[2], num_input_features=num_features, bn_size=bn_size,
60 | growth_rate=growth_rate, drop_rate=drop_rate, dilation_rate=int(2 / feature_size))
61 | self.features.add_module('denseblock%d' % 3, block)
62 | num_features = num_features + block_config[2] * growth_rate
63 |
64 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2, stride=1)
65 | self.features.add_module('transition%d' % 3, trans)
66 | num_features = num_features // 2
67 |
68 | # block4*****************************************************************************************************
69 | block = _DenseBlock(num_layers=block_config[3], num_input_features=num_features, bn_size=bn_size,
70 | growth_rate=growth_rate, drop_rate=drop_rate, dilation_rate=int(4 / feature_size))
71 | self.features.add_module('denseblock%d' % 4, block)
72 | num_features = num_features + block_config[3] * growth_rate
73 |
74 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2, stride=1)
75 | self.features.add_module('transition%d' % 4, trans)
76 | num_features = num_features // 2
77 |
78 | # Final batch norm
79 | self.features.add_module('norm5', bn(num_features))
80 | if feature_size > 1:
81 | self.features.add_module('upsample', nn.Upsample(scale_factor=2, mode='bilinear'))
82 |
83 | self.ASPP_3 = _DenseAsppBlock(input_num=num_features, num1=d_feature0, num2=d_feature1,
84 | dilation_rate=3, drop_out=dropout0, bn_start=False)
85 |
86 | self.ASPP_6 = _DenseAsppBlock(input_num=num_features + d_feature1 * 1, num1=d_feature0, num2=d_feature1,
87 | dilation_rate=6, drop_out=dropout0, bn_start=True)
88 |
89 | self.ASPP_12 = _DenseAsppBlock(input_num=num_features + d_feature1 * 2, num1=d_feature0, num2=d_feature1,
90 | dilation_rate=12, drop_out=dropout0, bn_start=True)
91 |
92 | self.ASPP_18 = _DenseAsppBlock(input_num=num_features + d_feature1 * 3, num1=d_feature0, num2=d_feature1,
93 | dilation_rate=18, drop_out=dropout0, bn_start=True)
94 |
95 | self.ASPP_24 = _DenseAsppBlock(input_num=num_features + d_feature1 * 4, num1=d_feature0, num2=d_feature1,
96 | dilation_rate=24, drop_out=dropout0, bn_start=True)
97 | num_features = num_features + 5 * d_feature1
98 |
99 | self.classification = nn.Sequential(
100 | nn.Dropout2d(p=dropout1),
101 | nn.Conv2d(in_channels=num_features, out_channels=n_class, kernel_size=1, padding=0),
102 | nn.Upsample(scale_factor=8, mode='bilinear'),
103 | )
104 |
105 | for m in self.modules():
106 | if isinstance(m, nn.Conv2d):
107 | nn.init.kaiming_uniform(m.weight.data)
108 |
109 | elif isinstance(m, nn.BatchNorm2d):
110 | m.weight.data.fill_(1)
111 | m.bias.data.zero_()
112 |
113 | def forward(self, _input):
114 | feature = self.features(_input)
115 |
116 | aspp3 = self.ASPP_3(feature)
117 | feature = torch.cat((aspp3, feature), dim=1)
118 |
119 | aspp6 = self.ASPP_6(feature)
120 | feature = torch.cat((aspp6, feature), dim=1)
121 |
122 | aspp12 = self.ASPP_12(feature)
123 | feature = torch.cat((aspp12, feature), dim=1)
124 |
125 | aspp18 = self.ASPP_18(feature)
126 | feature = torch.cat((aspp18, feature), dim=1)
127 |
128 | aspp24 = self.ASPP_24(feature)
129 | feature = torch.cat((aspp24, feature), dim=1)
130 |
131 | cls = self.classification(feature)
132 |
133 | return cls
134 |
135 |
136 | class _DenseAsppBlock(nn.Sequential):
137 | """ ConvNet block for building DenseASPP. """
138 |
139 | def __init__(self, input_num, num1, num2, dilation_rate, drop_out, bn_start=True):
140 | super(_DenseAsppBlock, self).__init__()
141 | if bn_start:
142 | self.add_module('norm.1', bn(input_num, momentum=0.0003)),
143 |
144 | self.add_module('relu.1', nn.ReLU(inplace=True)),
145 | self.add_module('conv.1', nn.Conv2d(in_channels=input_num, out_channels=num1, kernel_size=1)),
146 |
147 | self.add_module('norm.2', bn(num1, momentum=0.0003)),
148 | self.add_module('relu.2', nn.ReLU(inplace=True)),
149 | self.add_module('conv.2', nn.Conv2d(in_channels=num1, out_channels=num2, kernel_size=3,
150 | dilation=dilation_rate, padding=dilation_rate)),
151 |
152 | self.drop_rate = drop_out
153 |
154 | def forward(self, _input):
155 | feature = super(_DenseAsppBlock, self).forward(_input)
156 |
157 | if self.drop_rate > 0:
158 | feature = F.dropout2d(feature, p=self.drop_rate, training=self.training)
159 |
160 | return feature
161 |
162 |
163 | class _DenseLayer(nn.Sequential):
164 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, dilation_rate=1):
165 | super(_DenseLayer, self).__init__()
166 | self.add_module('norm.1', bn(num_input_features)),
167 | self.add_module('relu.1', nn.ReLU(inplace=True)),
168 | self.add_module('conv.1', nn.Conv2d(num_input_features, bn_size *
169 | growth_rate, kernel_size=1, stride=1, bias=False)),
170 | self.add_module('norm.2', bn(bn_size * growth_rate)),
171 | self.add_module('relu.2', nn.ReLU(inplace=True)),
172 | self.add_module('conv.2', nn.Conv2d(bn_size * growth_rate, growth_rate,
173 | kernel_size=3, stride=1, dilation=dilation_rate, padding=dilation_rate, bias=False)),
174 | self.drop_rate = drop_rate
175 |
176 | def forward(self, x):
177 | new_features = super(_DenseLayer, self).forward(x)
178 | if self.drop_rate > 0:
179 | new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
180 | return torch.cat([x, new_features], 1)
181 |
182 |
183 | class _DenseBlock(nn.Sequential):
184 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, dilation_rate=1):
185 | super(_DenseBlock, self).__init__()
186 | for i in range(num_layers):
187 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate,
188 | bn_size, drop_rate, dilation_rate=dilation_rate)
189 | self.add_module('denselayer%d' % (i + 1), layer)
190 |
191 |
192 | class _Transition(nn.Sequential):
193 | def __init__(self, num_input_features, num_output_features, stride=2):
194 | super(_Transition, self).__init__()
195 | self.add_module('norm', bn(num_input_features))
196 | self.add_module('relu', nn.ReLU(inplace=True))
197 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False))
198 | if stride == 2:
199 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=stride))
200 |
201 |
202 | if __name__ == "__main__":
203 | model = DenseASPP(2)
204 | print(model)
205 |
--------------------------------------------------------------------------------