├── LICENSE
├── README.md
├── pretrained
├── README.md
├── decoder
│ ├── README.md
│ ├── espnet_p_2_q_3.pth
│ ├── espnet_p_2_q_5.pth
│ ├── espnet_p_2_q_8.pth
│ └── espnet_p_2_q_8_camvid.pth
└── encoder
│ ├── README.md
│ ├── espnet_p_2_q_3.pth
│ ├── espnet_p_2_q_5.pth
│ └── espnet_p_2_q_8.pth
├── sample_video
├── ReadMe.md
└── sample.png
├── test
├── Model.py
├── README.md
├── VisualizeResults.py
└── data
│ ├── README.md
│ ├── frankfurt_000000_000294_leftImg8bit.png
│ └── frankfurt_000000_000576_leftImg8bit.png
└── train
├── Criteria.py
├── DataSet.py
├── IOUEval.py
├── Model.py
├── README.md
├── Transforms.py
├── VisualizeGraph.py
├── city
├── README.md
├── train.txt
└── val.txt
├── loadData.py
└── main.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Sachin Mehta
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ESPNet: Efficient Spatial Pyramid of Dilated Convolutions for Semantic Segmentation
2 |
3 | This repository contains the source code of our paper, [ESPNet](https://arxiv.org/abs/1803.06815) (accepted for publication in [ECCV'18](http://eccv2018.org/)).
4 |
5 | ## Sample results
6 |
7 | Check our [project page](https://sacmehta.github.io/ESPNet/) for more qualitative results (videos).
8 |
9 | Click on the below sample image to view the segmentation results on YouTube.
10 |
11 |
12 |
13 |
14 |
15 |
16 | ## Structure of this repository
17 | This repository is organized as:
18 | * [train](/train/) This directory contains the source code for trainig the ESPNet-C and ESPNet models.
19 | * [test](/test/) This directory contains the source code for evaluating our model on RGB Images.
20 | * [pretrained](/pretrained/) This directory contains the pre-trained models on the CityScape dataset
21 | * [encoder](/pretrained/encoder/) This directory contains the pretrained **ESPNet-C** models
22 | * [decoder](/pretrained/decoder/) This directory contains the pretrained **ESPNet** models
23 |
24 |
25 | ## Performance on the CityScape dataset
26 |
27 | Our model ESPNet achives an class-wise mIOU of **60.336** and category-wise mIOU of **82.178** on the CityScapes test dataset and runs at
28 | * 112 fps on the NVIDIA TitanX (30 fps faster than [ENet](https://arxiv.org/abs/1606.02147))
29 | * 9 FPS on TX2
30 | * With the same number of parameters as [ENet](https://arxiv.org/abs/1606.02147), our model is **2%** more accurate
31 |
32 | ## Performance on the CamVid dataset
33 |
34 | Our model achieves an mIOU of 55.64 on the CamVid test set. We used the dataset splits (train/val/test) provided [here](https://github.com/alexgkendall/SegNet-Tutorial). We trained the models at a resolution of 480x360. For comparison with other models, see [SegNet paper](https://ieeexplore.ieee.org/document/7803544/).
35 |
36 | Note: We did not use the 3.5K dataset for training which was used in the SegNet paper.
37 |
38 | | Model | mIOU | Class avg. |
39 | | -- | -- | -- |
40 | | ENet | 51.3 | 68.3 |
41 | | SegNet | 55.6 | 65.2 |
42 | | ESPNet | 55.64 | 68.30 |
43 |
44 | ## Pre-requisite
45 |
46 | To run this code, you need to have following libraries:
47 | * [OpenCV](https://opencv.org/) - We tested our code with version > 3.0.
48 | * [PyTorch](http://pytorch.org/) - We tested with v0.3.0
49 | * Python - We tested our code with Pythonv3. If you are using Python v2, please feel free to make necessary changes to the code.
50 |
51 | We recommend to use [Anaconda](https://conda.io/docs/user-guide/install/linux.html). We have tested our code on Ubuntu 16.04.
52 |
53 | ## Citation
54 | If ESPNet is useful for your research, then please cite our paper.
55 | ```
56 | @inproceedings{mehta2018espnet,
57 | title={ESPNet: Efficient Spatial Pyramid of Dilated Convolutions for Semantic Segmentation},
58 | author={Sachin Mehta, Mohammad Rastegari, Anat Caspi, Linda Shapiro, and Hannaneh Hajishirzi},
59 | booktitle={ECCV},
60 | year={2018}
61 | }
62 | ```
63 |
64 |
65 | ## FAQs
66 |
67 | ### Assertion error with class labels (t >= 0 && t < n_classes).
68 |
69 | If you are getting an assertion error with class labels, then please check the number of class labels defined in the label images. You can do this as:
70 |
71 | ```
72 | import cv2
73 | import numpy as np
74 | labelImg = cv2.imread(, 0)
75 | unique_val_arr = np.unique(labelImg)
76 | print(unique_val_arr)
77 | ```
78 | The values inside *unique_val_arr* should be between 0 and total number of classes in the dataset. If this is not the case, then pre-process your label images. For example, if the label iamge contains 255 as a value, then you can ignore these values by mapping it to an undefined or background class as:
79 |
80 | ```
81 | labelImg[labelImg == 255] =
82 | ```
83 |
--------------------------------------------------------------------------------
/pretrained/README.md:
--------------------------------------------------------------------------------
1 | # ESPNet: Towards Fast and Efficient Semantic Segmentation on the Embedded Devices
2 |
3 | This directory contains the pretrained models for ESPNet-C and ESPNet under three different settings.
4 |
5 | * [encoder](/pretrained/encoder/) - Check this folder for ESPNet-C pretrained models.
6 | * [decoder](/pretrained/decoder/) - Check this folder for ESPNet pretrained models.
7 |
--------------------------------------------------------------------------------
/pretrained/decoder/README.md:
--------------------------------------------------------------------------------
1 | # ESPNet: Towards Fast and Efficient Semantic Segmentation on the Embedded Devices
2 |
3 | This directory contains the pretrained models for ESPNet under three different settings:
4 |
5 | * p=2, q=3
6 | * p=2, q=5
7 | * p=2, q=8
8 |
9 |
10 | ## Models trained on the CamVid dataset
11 | * espnet_p_2_q_8_camvid.pth
12 |
13 | ## Models trained on the CityScapes dataset
14 | * espnet_p_2_q_3.pth
15 | * espnet_p_2_q_5.pth
16 | * espnet_p_2_q_8.pth
17 |
--------------------------------------------------------------------------------
/pretrained/decoder/espnet_p_2_q_3.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sacmehta/ESPNet/afe71c38edaee3514ca44e0adcafdf36109bf437/pretrained/decoder/espnet_p_2_q_3.pth
--------------------------------------------------------------------------------
/pretrained/decoder/espnet_p_2_q_5.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sacmehta/ESPNet/afe71c38edaee3514ca44e0adcafdf36109bf437/pretrained/decoder/espnet_p_2_q_5.pth
--------------------------------------------------------------------------------
/pretrained/decoder/espnet_p_2_q_8.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sacmehta/ESPNet/afe71c38edaee3514ca44e0adcafdf36109bf437/pretrained/decoder/espnet_p_2_q_8.pth
--------------------------------------------------------------------------------
/pretrained/decoder/espnet_p_2_q_8_camvid.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sacmehta/ESPNet/afe71c38edaee3514ca44e0adcafdf36109bf437/pretrained/decoder/espnet_p_2_q_8_camvid.pth
--------------------------------------------------------------------------------
/pretrained/encoder/README.md:
--------------------------------------------------------------------------------
1 | # ESPNet: Towards Fast and Efficient Semantic Segmentation on the Embedded Devices
2 |
3 | This directory contains the pretrained models for ESPNet-C under three different settings:
4 |
5 | * p=2, q=3
6 | * p=2, q=5
7 | * p=2, q=8
8 |
--------------------------------------------------------------------------------
/pretrained/encoder/espnet_p_2_q_3.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sacmehta/ESPNet/afe71c38edaee3514ca44e0adcafdf36109bf437/pretrained/encoder/espnet_p_2_q_3.pth
--------------------------------------------------------------------------------
/pretrained/encoder/espnet_p_2_q_5.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sacmehta/ESPNet/afe71c38edaee3514ca44e0adcafdf36109bf437/pretrained/encoder/espnet_p_2_q_5.pth
--------------------------------------------------------------------------------
/pretrained/encoder/espnet_p_2_q_8.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sacmehta/ESPNet/afe71c38edaee3514ca44e0adcafdf36109bf437/pretrained/encoder/espnet_p_2_q_8.pth
--------------------------------------------------------------------------------
/sample_video/ReadMe.md:
--------------------------------------------------------------------------------
1 | This directory contains a sample video demonstrating the segmentation performance of ESPNet.
2 |
--------------------------------------------------------------------------------
/sample_video/sample.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sacmehta/ESPNet/afe71c38edaee3514ca44e0adcafdf36109bf437/sample_video/sample.png
--------------------------------------------------------------------------------
/test/Model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | __author__ = "Sachin Mehta"
5 |
6 | class CBR(nn.Module):
7 | '''
8 | This class defines the convolution layer with batch normalization and PReLU activation
9 | '''
10 | def __init__(self, nIn, nOut, kSize, stride=1):
11 | '''
12 |
13 | :param nIn: number of input channels
14 | :param nOut: number of output channels
15 | :param kSize: kernel size
16 | :param stride: stride rate for down-sampling. Default is 1
17 | '''
18 | super().__init__()
19 | padding = int((kSize - 1)/2)
20 | #self.conv = nn.Conv2d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False)
21 | self.conv = nn.Conv2d(nIn, nOut, (kSize, kSize), stride=stride, padding=(padding, padding), bias=False)
22 | #self.conv1 = nn.Conv2d(nOut, nOut, (1, kSize), stride=1, padding=(0, padding), bias=False)
23 | self.bn = nn.BatchNorm2d(nOut, eps=1e-03)
24 | self.act = nn.PReLU(nOut)
25 |
26 | def forward(self, input):
27 | '''
28 | :param input: input feature map
29 | :return: transformed feature map
30 | '''
31 | output = self.conv(input)
32 | #output = self.conv1(output)
33 | output = self.bn(output)
34 | output = self.act(output)
35 | return output
36 |
37 |
38 | class BR(nn.Module):
39 | '''
40 | This class groups the batch normalization and PReLU activation
41 | '''
42 | def __init__(self, nOut):
43 | '''
44 | :param nOut: output feature maps
45 | '''
46 | super().__init__()
47 | self.bn = nn.BatchNorm2d(nOut, eps=1e-03)
48 | self.act = nn.PReLU(nOut)
49 |
50 | def forward(self, input):
51 | '''
52 | :param input: input feature map
53 | :return: normalized and thresholded feature map
54 | '''
55 | output = self.bn(input)
56 | output = self.act(output)
57 | return output
58 |
59 | class CB(nn.Module):
60 | '''
61 | This class groups the convolution and batch normalization
62 | '''
63 | def __init__(self, nIn, nOut, kSize, stride=1):
64 | '''
65 | :param nIn: number of input channels
66 | :param nOut: number of output channels
67 | :param kSize: kernel size
68 | :param stride: optinal stide for down-sampling
69 | '''
70 | super().__init__()
71 | padding = int((kSize - 1)/2)
72 | self.conv = nn.Conv2d(nIn, nOut, (kSize, kSize), stride=stride, padding=(padding, padding), bias=False)
73 | self.bn = nn.BatchNorm2d(nOut, eps=1e-03)
74 |
75 | def forward(self, input):
76 | '''
77 |
78 | :param input: input feature map
79 | :return: transformed feature map
80 | '''
81 | output = self.conv(input)
82 | output = self.bn(output)
83 | return output
84 |
85 | class C(nn.Module):
86 | '''
87 | This class is for a convolutional layer.
88 | '''
89 | def __init__(self, nIn, nOut, kSize, stride=1):
90 | '''
91 |
92 | :param nIn: number of input channels
93 | :param nOut: number of output channels
94 | :param kSize: kernel size
95 | :param stride: optional stride rate for down-sampling
96 | '''
97 | super().__init__()
98 | padding = int((kSize - 1)/2)
99 | self.conv = nn.Conv2d(nIn, nOut, (kSize, kSize), stride=stride, padding=(padding, padding), bias=False)
100 |
101 | def forward(self, input):
102 | '''
103 | :param input: input feature map
104 | :return: transformed feature map
105 | '''
106 | output = self.conv(input)
107 | return output
108 |
109 | class CDilated(nn.Module):
110 | '''
111 | This class defines the dilated convolution.
112 | '''
113 | def __init__(self, nIn, nOut, kSize, stride=1, d=1):
114 | '''
115 | :param nIn: number of input channels
116 | :param nOut: number of output channels
117 | :param kSize: kernel size
118 | :param stride: optional stride rate for down-sampling
119 | :param d: optional dilation rate
120 | '''
121 | super().__init__()
122 | padding = int((kSize - 1)/2) * d
123 | self.conv = nn.Conv2d(nIn, nOut, (kSize, kSize), stride=stride, padding=(padding, padding), bias=False, dilation=d)
124 |
125 | def forward(self, input):
126 | '''
127 | :param input: input feature map
128 | :return: transformed feature map
129 | '''
130 | output = self.conv(input)
131 | return output
132 |
133 | class DownSamplerB(nn.Module):
134 | def __init__(self, nIn, nOut):
135 | super().__init__()
136 | n = int(nOut/5)
137 | n1 = nOut - 4*n
138 | self.c1 = C(nIn, n, 3, 2)
139 | self.d1 = CDilated(n, n1, 3, 1, 1)
140 | self.d2 = CDilated(n, n, 3, 1, 2)
141 | self.d4 = CDilated(n, n, 3, 1, 4)
142 | self.d8 = CDilated(n, n, 3, 1, 8)
143 | self.d16 = CDilated(n, n, 3, 1, 16)
144 | self.bn = nn.BatchNorm2d(nOut, eps=1e-3)
145 | self.act = nn.PReLU(nOut)
146 |
147 | def forward(self, input):
148 | output1 = self.c1(input)
149 | d1 = self.d1(output1)
150 | d2 = self.d2(output1)
151 | d4 = self.d4(output1)
152 | d8 = self.d8(output1)
153 | d16 = self.d16(output1)
154 |
155 | add1 = d2
156 | add2 = add1 + d4
157 | add3 = add2 + d8
158 | add4 = add3 + d16
159 |
160 | combine = torch.cat([d1, add1, add2, add3, add4],1)
161 | #combine_in_out = input + combine
162 | output = self.bn(combine)
163 | output = self.act(output)
164 | return output
165 |
166 | class DilatedParllelResidualBlockB(nn.Module):
167 | '''
168 | This class defines the ESP block, which is based on the following principle
169 | Reduce ---> Split ---> Transform --> Merge
170 | '''
171 | def __init__(self, nIn, nOut, add=True):
172 | '''
173 | :param nIn: number of input channels
174 | :param nOut: number of output channels
175 | :param add: if true, add a residual connection through identity operation. You can use projection too as
176 | in ResNet paper, but we avoid to use it if the dimensions are not the same because we do not want to
177 | increase the module complexity
178 | '''
179 | super().__init__()
180 | n = int(nOut/5)
181 | n1 = nOut - 4*n
182 | self.c1 = C(nIn, n, 1, 1)
183 | self.d1 = CDilated(n, n1, 3, 1, 1) # dilation rate of 2^0
184 | self.d2 = CDilated(n, n, 3, 1, 2) # dilation rate of 2^1
185 | self.d4 = CDilated(n, n, 3, 1, 4) # dilation rate of 2^2
186 | self.d8 = CDilated(n, n, 3, 1, 8) # dilation rate of 2^3
187 | self.d16 = CDilated(n, n, 3, 1, 16) # dilation rate of 2^4
188 | self.bn = BR(nOut)
189 | self.add = add
190 |
191 | def forward(self, input):
192 | '''
193 | :param input: input feature map
194 | :return: transformed feature map
195 | '''
196 | # reduce
197 | output1 = self.c1(input)
198 | # split and transform
199 | d1 = self.d1(output1)
200 | d2 = self.d2(output1)
201 | d4 = self.d4(output1)
202 | d8 = self.d8(output1)
203 | d16 = self.d16(output1)
204 |
205 | # heirarchical fusion for de-gridding
206 | add1 = d2
207 | add2 = add1 + d4
208 | add3 = add2 + d8
209 | add4 = add3 + d16
210 |
211 | #merge
212 | combine = torch.cat([d1, add1, add2, add3, add4], 1)
213 |
214 | # if residual version
215 | if self.add:
216 | combine = input + combine
217 | output = self.bn(combine)
218 | return output
219 |
220 | class InputProjectionA(nn.Module):
221 | '''
222 | This class projects the input image to the same spatial dimensions as the feature map.
223 | For example, if the input image is 512 x512 x3 and spatial dimensions of feature map size are 56x56xF, then
224 | this class will generate an output of 56x56x3
225 | '''
226 | def __init__(self, samplingTimes):
227 | '''
228 | :param samplingTimes: The rate at which you want to down-sample the image
229 | '''
230 | super().__init__()
231 | self.pool = nn.ModuleList()
232 | for i in range(0, samplingTimes):
233 | #pyramid-based approach for down-sampling
234 | self.pool.append(nn.AvgPool2d(3, stride=2, padding=1))
235 |
236 | def forward(self, input):
237 | '''
238 | :param input: Input RGB Image
239 | :return: down-sampled image (pyramid-based approach)
240 | '''
241 | for pool in self.pool:
242 | input = pool(input)
243 | return input
244 |
245 |
246 | class ESPNet_Encoder(nn.Module):
247 | '''
248 | This class defines the ESPNet-C network in the paper
249 | '''
250 | def __init__(self, classes=20, p=5, q=3):
251 | '''
252 | :param classes: number of classes in the dataset. Default is 20 for the cityscapes
253 | :param p: depth multiplier
254 | :param q: depth multiplier
255 | '''
256 | super().__init__()
257 | self.level1 = CBR(3, 16, 3, 2)
258 | self.sample1 = InputProjectionA(1)
259 | self.sample2 = InputProjectionA(2)
260 |
261 | self.b1 = BR(16 + 3)
262 | self.level2_0 = DownSamplerB(16 +3, 64)
263 |
264 | self.level2 = nn.ModuleList()
265 | for i in range(0, p):
266 | self.level2.append(DilatedParllelResidualBlockB(64 , 64))
267 | self.b2 = BR(128 + 3)
268 |
269 | self.level3_0 = DownSamplerB(128 + 3, 128)
270 | self.level3 = nn.ModuleList()
271 | for i in range(0, q):
272 | self.level3.append(DilatedParllelResidualBlockB(128 , 128))
273 | self.b3 = BR(256)
274 |
275 | self.classifier = C(256, classes, 1, 1)
276 |
277 | def forward(self, input):
278 | '''
279 | :param input: Receives the input RGB image
280 | :return: the transformed feature map with spatial dimensions 1/8th of the input image
281 | '''
282 | output0 = self.level1(input)
283 | inp1 = self.sample1(input)
284 | inp2 = self.sample2(input)
285 |
286 | output0_cat = self.b1(torch.cat([output0, inp1], 1))
287 | output1_0 = self.level2_0(output0_cat) # down-sampled
288 |
289 | for i, layer in enumerate(self.level2):
290 | if i==0:
291 | output1 = layer(output1_0)
292 | else:
293 | output1 = layer(output1)
294 |
295 | output1_cat = self.b2(torch.cat([output1, output1_0, inp2], 1))
296 |
297 | output2_0 = self.level3_0(output1_cat) # down-sampled
298 | for i, layer in enumerate(self.level3):
299 | if i==0:
300 | output2 = layer(output2_0)
301 | else:
302 | output2 = layer(output2)
303 |
304 | output2_cat = self.b3(torch.cat([output2_0, output2], 1))
305 |
306 | classifier = self.classifier(output2_cat)
307 |
308 | return classifier
309 |
310 | class ESPNet(nn.Module):
311 | '''
312 | This class defines the ESPNet network
313 | '''
314 |
315 | def __init__(self, classes=20, p=2, q=3, encoderFile=None):
316 | '''
317 | :param classes: number of classes in the dataset. Default is 20 for the cityscapes
318 | :param p: depth multiplier
319 | :param q: depth multiplier
320 | :param encoderFile: pretrained encoder weights. Recall that we first trained the ESPNet-C and then attached the
321 | RUM-based light weight decoder. See paper for more details.
322 | '''
323 | super().__init__()
324 | self.encoder = ESPNet_Encoder(classes, p, q)
325 | if encoderFile != None:
326 | self.encoder.load_state_dict(torch.load(encoderFile))
327 | print('Encoder loaded!')
328 | # load the encoder modules
329 | self.modules = []
330 | for i, m in enumerate(self.encoder.children()):
331 | self.modules.append(m)
332 |
333 | # light-weight decoder
334 | self.level3_C = C(128 + 3, classes, 1, 1)
335 | self.br = nn.BatchNorm2d(classes, eps=1e-03)
336 | self.conv = CBR(16 + classes, classes, 3, 1)
337 |
338 | self.up_l3 = nn.Sequential(nn.ConvTranspose2d(classes, classes, 2, stride=2, padding=0, output_padding=0, bias=False))
339 | self.combine_l2_l3 = nn.Sequential(BR(2*classes), DilatedParllelResidualBlockB(2*classes , classes, add=False))
340 |
341 | self.up_l2 = nn.Sequential(nn.ConvTranspose2d(classes, classes, 2, stride=2, padding=0, output_padding=0, bias=False), BR(classes))
342 |
343 | self.classifier = nn.ConvTranspose2d(classes, classes, 2, stride=2, padding=0, output_padding=0, bias=False)
344 |
345 | def forward(self, input):
346 | '''
347 | :param input: RGB image
348 | :return: transformed feature map
349 | '''
350 | output0 = self.modules[0](input)
351 | inp1 = self.modules[1](input)
352 | inp2 = self.modules[2](input)
353 |
354 | output0_cat = self.modules[3](torch.cat([output0, inp1], 1))
355 | output1_0 = self.modules[4](output0_cat) # down-sampled
356 |
357 | for i, layer in enumerate(self.modules[5]):
358 | if i == 0:
359 | output1 = layer(output1_0)
360 | else:
361 | output1 = layer(output1)
362 |
363 | output1_cat = self.modules[6](torch.cat([output1, output1_0, inp2], 1))
364 |
365 | output2_0 = self.modules[7](output1_cat) # down-sampled
366 | for i, layer in enumerate(self.modules[8]):
367 | if i == 0:
368 | output2 = layer(output2_0)
369 | else:
370 | output2 = layer(output2)
371 |
372 | output2_cat = self.modules[9](torch.cat([output2_0, output2], 1)) # concatenate for feature map width expansion
373 |
374 | output2_c = self.up_l3(self.br(self.modules[10](output2_cat))) #RUM
375 |
376 | output1_C = self.level3_C(output1_cat) # project to C-dimensional space
377 | comb_l2_l3 = self.up_l2(self.combine_l2_l3(torch.cat([output1_C, output2_c], 1))) #RUM
378 |
379 | concat_features = self.conv(torch.cat([comb_l2_l3, output0], 1))
380 |
381 | classifier = self.classifier(concat_features)
382 | return classifier
383 |
--------------------------------------------------------------------------------
/test/README.md:
--------------------------------------------------------------------------------
1 | # ESPNet: Towards Fast and Efficient Semantic Segmentation on the Embedded Devices
2 |
3 | This folder contains the python scripts for running our pretrained models on the Cityscape dataset.
4 |
5 | ## Getting Started
6 | We provide the pretrained weights for ESPNet and ESPNet-C. Recall that ESPNet is the same as ESPNet-C, but with light weight decoder.
7 |
8 | Pre-requisites:
9 | * By default, we expect all images inside the ./data directory. If they are in different directory, please change the **data_dir** argument in the VisualizeResults.py file.
10 |
11 | * Also, if the image format is different (e.g. jpg), please change in the VisualizeResults.py file.
12 |
13 | This can be done using the below command:
14 |
15 | ```
16 | python VisualizeResults.py --data_dir --img_extn
17 | ```
18 |
19 |
20 | ### Running ESPNet-C models
21 | To run the ESPNet-C models, execute the following commands
22 |
23 | ```
24 | python VisualizeResults.py --modelType 2 --p 2 --q 3
25 | ```
26 |
27 | Here, p and q are the depth multipliers. Our models only support p=2 and q=3,5,8
28 |
29 |
30 | ### Running ESPNet models
31 | To run the ESPNet models, execute the following commands
32 |
33 | ```
34 | python VisualizeResults.py --modelType 1 --p 2 --q 3
35 | ```
36 |
37 | Here, p and q are the depth multipliers. Our models only support p=2 and q=3,5,8
38 |
--------------------------------------------------------------------------------
/test/VisualizeResults.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.autograd import Variable
4 | import glob
5 | import cv2
6 | from PIL import Image as PILImage
7 | import Model as Net
8 | import os
9 | import time
10 | from argparse import ArgumentParser
11 |
12 | pallete = [[128, 64, 128],
13 | [244, 35, 232],
14 | [70, 70, 70],
15 | [102, 102, 156],
16 | [190, 153, 153],
17 | [153, 153, 153],
18 | [250, 170, 30],
19 | [220, 220, 0],
20 | [107, 142, 35],
21 | [152, 251, 152],
22 | [70, 130, 180],
23 | [220, 20, 60],
24 | [255, 0, 0],
25 | [0, 0, 142],
26 | [0, 0, 70],
27 | [0, 60, 100],
28 | [0, 80, 100],
29 | [0, 0, 230],
30 | [119, 11, 32],
31 | [0, 0, 0]]
32 |
33 |
34 | def relabel(img):
35 | '''
36 | This function relabels the predicted labels so that cityscape dataset can process
37 | :param img:
38 | :return:
39 | '''
40 | img[img == 19] = 255
41 | img[img == 18] = 33
42 | img[img == 17] = 32
43 | img[img == 16] = 31
44 | img[img == 15] = 28
45 | img[img == 14] = 27
46 | img[img == 13] = 26
47 | img[img == 12] = 25
48 | img[img == 11] = 24
49 | img[img == 10] = 23
50 | img[img == 9] = 22
51 | img[img == 8] = 21
52 | img[img == 7] = 20
53 | img[img == 6] = 19
54 | img[img == 5] = 17
55 | img[img == 4] = 13
56 | img[img == 3] = 12
57 | img[img == 2] = 11
58 | img[img == 1] = 8
59 | img[img == 0] = 7
60 | img[img == 255] = 0
61 | return img
62 |
63 |
64 | def evaluateModel(args, model, up, image_list):
65 | # gloabl mean and std values
66 | mean = [72.3923111, 82.90893555, 73.15840149]
67 | std = [45.3192215, 46.15289307, 44.91483307]
68 |
69 | for i, imgName in enumerate(image_list):
70 | img = cv2.imread(imgName)
71 | if args.overlay:
72 | img_orig = np.copy(img)
73 |
74 | img = img.astype(np.float32)
75 | for j in range(3):
76 | img[:, :, j] -= mean[j]
77 | for j in range(3):
78 | img[:, :, j] /= std[j]
79 |
80 | # resize the image to 1024x512x3
81 | img = cv2.resize(img, (1024, 512))
82 | if args.overlay:
83 | img_orig = cv2.resize(img_orig, (1024, 512))
84 |
85 | img /= 255
86 | img = img.transpose((2, 0, 1))
87 | img_tensor = torch.from_numpy(img)
88 | img_tensor = torch.unsqueeze(img_tensor, 0) # add a batch dimension
89 | img_variable = Variable(img_tensor, volatile=True)
90 | if args.gpu:
91 | img_variable = img_variable.cuda()
92 | img_out = model(img_variable)
93 |
94 | if args.modelType == 2:
95 | img_out = up(img_out)
96 |
97 | classMap_numpy = img_out[0].max(0)[1].byte().cpu().data.numpy()
98 |
99 | if i % 100 == 0:
100 | print(i)
101 |
102 | name = imgName.split('/')[-1]
103 |
104 | if args.colored:
105 | classMap_numpy_color = np.zeros((img.shape[1], img.shape[2], img.shape[0]), dtype=np.uint8)
106 | for idx in range(len(pallete)):
107 | [r, g, b] = pallete[idx]
108 | classMap_numpy_color[classMap_numpy == idx] = [b, g, r]
109 | cv2.imwrite(args.savedir + os.sep + 'c_' + name.replace(args.img_extn, 'png'), classMap_numpy_color)
110 | if args.overlay:
111 | overlayed = cv2.addWeighted(img_orig, 0.5, classMap_numpy_color, 0.5, 0)
112 | cv2.imwrite(args.savedir + os.sep + 'over_' + name.replace(args.img_extn, 'jpg'), overlayed)
113 |
114 | if args.cityFormat:
115 | classMap_numpy = relabel(classMap_numpy.astype(np.uint8))
116 |
117 | cv2.imwrite(args.savedir + os.sep + name.replace(args.img_extn, 'png'), classMap_numpy)
118 |
119 |
120 | def main(args):
121 | # read all the images in the folder
122 | image_list = glob.glob(args.data_dir + os.sep + '*.' + args.img_extn)
123 |
124 | up = None
125 | if args.modelType == 2:
126 | up = torch.nn.Upsample(scale_factor=8, mode='bilinear')
127 | if args.gpu:
128 | up = up.cuda()
129 |
130 | p = args.p
131 | q = args.q
132 | classes = args.classes
133 | if args.modelType == 2:
134 | modelA = Net.ESPNet_Encoder(classes, p, q) # Net.Mobile_SegNetDilatedIA_C_stage1(20)
135 | model_weight_file = args.weightsDir + os.sep + 'encoder' + os.sep + 'espnet_p_' + str(p) + '_q_' + str(
136 | q) + '.pth'
137 | if not os.path.isfile(model_weight_file):
138 | print('Pre-trained model file does not exist. Please check ../pretrained/encoder folder')
139 | exit(-1)
140 | modelA.load_state_dict(torch.load(model_weight_file))
141 | elif args.modelType == 1:
142 | modelA = Net.ESPNet(classes, p, q) # Net.Mobile_SegNetDilatedIA_C_stage1(20)
143 | model_weight_file = args.weightsDir + os.sep + 'decoder' + os.sep + 'espnet_p_' + str(p) + '_q_' + str(q) + '.pth'
144 | if not os.path.isfile(model_weight_file):
145 | print('Pre-trained model file does not exist. Please check ../pretrained/decoder folder')
146 | exit(-1)
147 | modelA.load_state_dict(torch.load(model_weight_file))
148 | else:
149 | print('Model not supported')
150 | # modelA = torch.nn.DataParallel(modelA)
151 | if args.gpu:
152 | modelA = modelA.cuda()
153 |
154 | # set to evaluation mode
155 | modelA.eval()
156 |
157 | if not os.path.isdir(args.savedir):
158 | os.mkdir(args.savedir)
159 |
160 | evaluateModel(args, modelA, up, image_list)
161 |
162 |
163 | if __name__ == '__main__':
164 | parser = ArgumentParser()
165 | parser.add_argument('--model', default="ESPNet", help='Model name')
166 | parser.add_argument('--data_dir', default="./data", help='Data directory')
167 | parser.add_argument('--img_extn', default="png", help='RGB Image format')
168 | parser.add_argument('--inWidth', type=int, default=1024, help='Width of RGB image')
169 | parser.add_argument('--inHeight', type=int, default=512, help='Height of RGB image')
170 | parser.add_argument('--scaleIn', type=int, default=1, help='For ESPNet-C, scaleIn=8. For ESPNet, scaleIn=1')
171 | parser.add_argument('--modelType', type=int, default=1, help='1=ESPNet, 2=ESPNet-C')
172 | parser.add_argument('--savedir', default='./results', help='directory to save the results')
173 | parser.add_argument('--gpu', default=True, type=bool, help='Run on CPU or GPU. If TRUE, then GPU.')
174 | parser.add_argument('--decoder', type=bool, default=True,
175 | help='True if ESPNet. False for ESPNet-C') # False for encoder
176 | parser.add_argument('--weightsDir', default='../pretrained/', help='Pretrained weights directory.')
177 | parser.add_argument('--p', default=2, type=int, help='depth multiplier. Supported only 2')
178 | parser.add_argument('--q', default=8, type=int, help='depth multiplier. Supported only 3, 5, 8')
179 | parser.add_argument('--cityFormat', default=True, type=bool, help='If you want to convert to cityscape '
180 | 'original label ids')
181 | parser.add_argument('--colored', default=True, type=bool, help='If you want to visualize the '
182 | 'segmentation masks in color')
183 | parser.add_argument('--overlay', default=True, type=bool, help='If you want to visualize the '
184 | 'segmentation masks overlayed on top of RGB image')
185 | parser.add_argument('--classes', default=20, type=int, help='Number of classes in the dataset. 20 for Cityscapes')
186 |
187 | args = parser.parse_args()
188 | assert (args.modelType == 1) and args.decoder, 'Model type should be 2 for ESPNet-C and 1 for ESPNet'
189 | if args.overlay:
190 | args.colored = True # This has to be true if you want to overlay
191 | main(args)
192 |
--------------------------------------------------------------------------------
/test/data/README.md:
--------------------------------------------------------------------------------
1 | This folder should contain all the images for which you want to generate the results
2 |
--------------------------------------------------------------------------------
/test/data/frankfurt_000000_000294_leftImg8bit.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sacmehta/ESPNet/afe71c38edaee3514ca44e0adcafdf36109bf437/test/data/frankfurt_000000_000294_leftImg8bit.png
--------------------------------------------------------------------------------
/test/data/frankfurt_000000_000576_leftImg8bit.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sacmehta/ESPNet/afe71c38edaee3514ca44e0adcafdf36109bf437/test/data/frankfurt_000000_000576_leftImg8bit.png
--------------------------------------------------------------------------------
/train/Criteria.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 | __author__ = "Sachin Mehta"
5 |
6 |
7 | class CrossEntropyLoss2d(nn.Module):
8 | '''
9 | This file defines a cross entropy loss for 2D images
10 | '''
11 | def __init__(self, weight=None):
12 | '''
13 | :param weight: 1D weight vector to deal with the class-imbalance
14 | '''
15 | super().__init__()
16 |
17 | self.loss = nn.NLLLoss2d(weight)
18 |
19 | def forward(self, outputs, targets):
20 | return self.loss(F.log_softmax(outputs, 1), targets)
21 |
--------------------------------------------------------------------------------
/train/DataSet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import cv2
3 | import torch.utils.data
4 |
5 | __author__ = "Sachin Mehta"
6 |
7 |
8 | class MyDataset(torch.utils.data.Dataset):
9 | '''
10 | Class to load the dataset
11 | '''
12 | def __init__(self, imList, labelList, transform=None):
13 | '''
14 | :param imList: image list (Note that these lists have been processed and pickled using the loadData.py)
15 | :param labelList: label list (Note that these lists have been processed and pickled using the loadData.py)
16 | :param transform: Type of transformation. SEe Transforms.py for supported transformations
17 | '''
18 | self.imList = imList
19 | self.labelList = labelList
20 | self.transform = transform
21 |
22 | def __len__(self):
23 | return len(self.imList)
24 |
25 | def __getitem__(self, idx):
26 | '''
27 |
28 | :param idx: Index of the image file
29 | :return: returns the image and corresponding label file.
30 | '''
31 | image_name = self.imList[idx]
32 | label_name = self.labelList[idx]
33 | image = cv2.imread(image_name)
34 | label = cv2.imread(label_name, 0)
35 | if self.transform:
36 | [image, label] = self.transform(image, label)
37 | return (image, label)
38 |
--------------------------------------------------------------------------------
/train/IOUEval.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | #adapted from https://github.com/shelhamer/fcn.berkeleyvision.org/blob/master/score.py
5 |
6 | class iouEval:
7 | def __init__(self, nClasses):
8 | self.nClasses = nClasses
9 | self.reset()
10 |
11 | def reset(self):
12 | self.overall_acc = 0
13 | self.per_class_acc = np.zeros(self.nClasses, dtype=np.float32)
14 | self.per_class_iu = np.zeros(self.nClasses, dtype=np.float32)
15 | self.mIOU = 0
16 | self.batchCount = 1
17 |
18 | def fast_hist(self, a, b):
19 | k = (a >= 0) & (a < self.nClasses)
20 | return np.bincount(self.nClasses * a[k].astype(int) + b[k], minlength=self.nClasses ** 2).reshape(self.nClasses, self.nClasses)
21 |
22 | def compute_hist(self, predict, gth):
23 | hist = self.fast_hist(gth, predict)
24 | return hist
25 |
26 | def addBatch(self, predict, gth):
27 | predict = predict.cpu().numpy().flatten()
28 | gth = gth.cpu().numpy().flatten()
29 |
30 | epsilon = 0.00000001
31 | hist = self.compute_hist(predict, gth)
32 | overall_acc = np.diag(hist).sum() / (hist.sum() + epsilon)
33 | per_class_acc = np.diag(hist) / (hist.sum(1) + epsilon)
34 | per_class_iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist) + epsilon)
35 | mIou = np.nanmean(per_class_iu)
36 |
37 | self.overall_acc +=overall_acc
38 | self.per_class_acc += per_class_acc
39 | self.per_class_iu += per_class_iu
40 | self.mIOU += mIou
41 | self.batchCount += 1
42 |
43 | def getMetric(self):
44 | overall_acc = self.overall_acc/self.batchCount
45 | per_class_acc = self.per_class_acc / self.batchCount
46 | per_class_iu = self.per_class_iu / self.batchCount
47 | mIOU = self.mIOU / self.batchCount
48 |
49 | return overall_acc, per_class_acc, per_class_iu, mIOU
--------------------------------------------------------------------------------
/train/Model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | __author__ = "Sachin Mehta"
5 |
6 | class CBR(nn.Module):
7 | '''
8 | This class defines the convolution layer with batch normalization and PReLU activation
9 | '''
10 | def __init__(self, nIn, nOut, kSize, stride=1):
11 | '''
12 |
13 | :param nIn: number of input channels
14 | :param nOut: number of output channels
15 | :param kSize: kernel size
16 | :param stride: stride rate for down-sampling. Default is 1
17 | '''
18 | super().__init__()
19 | padding = int((kSize - 1)/2)
20 | #self.conv = nn.Conv2d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False)
21 | self.conv = nn.Conv2d(nIn, nOut, (kSize, kSize), stride=stride, padding=(padding, padding), bias=False)
22 | #self.conv1 = nn.Conv2d(nOut, nOut, (1, kSize), stride=1, padding=(0, padding), bias=False)
23 | self.bn = nn.BatchNorm2d(nOut, eps=1e-03)
24 | self.act = nn.PReLU(nOut)
25 |
26 | def forward(self, input):
27 | '''
28 | :param input: input feature map
29 | :return: transformed feature map
30 | '''
31 | output = self.conv(input)
32 | #output = self.conv1(output)
33 | output = self.bn(output)
34 | output = self.act(output)
35 | return output
36 |
37 |
38 | class BR(nn.Module):
39 | '''
40 | This class groups the batch normalization and PReLU activation
41 | '''
42 | def __init__(self, nOut):
43 | '''
44 | :param nOut: output feature maps
45 | '''
46 | super().__init__()
47 | self.bn = nn.BatchNorm2d(nOut, eps=1e-03)
48 | self.act = nn.PReLU(nOut)
49 |
50 | def forward(self, input):
51 | '''
52 | :param input: input feature map
53 | :return: normalized and thresholded feature map
54 | '''
55 | output = self.bn(input)
56 | output = self.act(output)
57 | return output
58 |
59 | class CB(nn.Module):
60 | '''
61 | This class groups the convolution and batch normalization
62 | '''
63 | def __init__(self, nIn, nOut, kSize, stride=1):
64 | '''
65 | :param nIn: number of input channels
66 | :param nOut: number of output channels
67 | :param kSize: kernel size
68 | :param stride: optinal stide for down-sampling
69 | '''
70 | super().__init__()
71 | padding = int((kSize - 1)/2)
72 | self.conv = nn.Conv2d(nIn, nOut, (kSize, kSize), stride=stride, padding=(padding, padding), bias=False)
73 | self.bn = nn.BatchNorm2d(nOut, eps=1e-03)
74 |
75 | def forward(self, input):
76 | '''
77 |
78 | :param input: input feature map
79 | :return: transformed feature map
80 | '''
81 | output = self.conv(input)
82 | output = self.bn(output)
83 | return output
84 |
85 | class C(nn.Module):
86 | '''
87 | This class is for a convolutional layer.
88 | '''
89 | def __init__(self, nIn, nOut, kSize, stride=1):
90 | '''
91 |
92 | :param nIn: number of input channels
93 | :param nOut: number of output channels
94 | :param kSize: kernel size
95 | :param stride: optional stride rate for down-sampling
96 | '''
97 | super().__init__()
98 | padding = int((kSize - 1)/2)
99 | self.conv = nn.Conv2d(nIn, nOut, (kSize, kSize), stride=stride, padding=(padding, padding), bias=False)
100 |
101 | def forward(self, input):
102 | '''
103 | :param input: input feature map
104 | :return: transformed feature map
105 | '''
106 | output = self.conv(input)
107 | return output
108 |
109 | class CDilated(nn.Module):
110 | '''
111 | This class defines the dilated convolution.
112 | '''
113 | def __init__(self, nIn, nOut, kSize, stride=1, d=1):
114 | '''
115 | :param nIn: number of input channels
116 | :param nOut: number of output channels
117 | :param kSize: kernel size
118 | :param stride: optional stride rate for down-sampling
119 | :param d: optional dilation rate
120 | '''
121 | super().__init__()
122 | padding = int((kSize - 1)/2) * d
123 | self.conv = nn.Conv2d(nIn, nOut, (kSize, kSize), stride=stride, padding=(padding, padding), bias=False, dilation=d)
124 |
125 | def forward(self, input):
126 | '''
127 | :param input: input feature map
128 | :return: transformed feature map
129 | '''
130 | output = self.conv(input)
131 | return output
132 |
133 | class DownSamplerB(nn.Module):
134 | def __init__(self, nIn, nOut):
135 | super().__init__()
136 | n = int(nOut/5)
137 | n1 = nOut - 4*n
138 | self.c1 = C(nIn, n, 3, 2)
139 | self.d1 = CDilated(n, n1, 3, 1, 1)
140 | self.d2 = CDilated(n, n, 3, 1, 2)
141 | self.d4 = CDilated(n, n, 3, 1, 4)
142 | self.d8 = CDilated(n, n, 3, 1, 8)
143 | self.d16 = CDilated(n, n, 3, 1, 16)
144 | self.bn = nn.BatchNorm2d(nOut, eps=1e-3)
145 | self.act = nn.PReLU(nOut)
146 |
147 | def forward(self, input):
148 | output1 = self.c1(input)
149 | d1 = self.d1(output1)
150 | d2 = self.d2(output1)
151 | d4 = self.d4(output1)
152 | d8 = self.d8(output1)
153 | d16 = self.d16(output1)
154 |
155 | add1 = d2
156 | add2 = add1 + d4
157 | add3 = add2 + d8
158 | add4 = add3 + d16
159 |
160 | combine = torch.cat([d1, add1, add2, add3, add4],1)
161 | #combine_in_out = input + combine
162 | output = self.bn(combine)
163 | output = self.act(output)
164 | return output
165 |
166 | class DilatedParllelResidualBlockB(nn.Module):
167 | '''
168 | This class defines the ESP block, which is based on the following principle
169 | Reduce ---> Split ---> Transform --> Merge
170 | '''
171 | def __init__(self, nIn, nOut, add=True):
172 | '''
173 | :param nIn: number of input channels
174 | :param nOut: number of output channels
175 | :param add: if true, add a residual connection through identity operation. You can use projection too as
176 | in ResNet paper, but we avoid to use it if the dimensions are not the same because we do not want to
177 | increase the module complexity
178 | '''
179 | super().__init__()
180 | n = int(nOut/5)
181 | n1 = nOut - 4*n
182 | self.c1 = C(nIn, n, 1, 1)
183 | self.d1 = CDilated(n, n1, 3, 1, 1) # dilation rate of 2^0
184 | self.d2 = CDilated(n, n, 3, 1, 2) # dilation rate of 2^1
185 | self.d4 = CDilated(n, n, 3, 1, 4) # dilation rate of 2^2
186 | self.d8 = CDilated(n, n, 3, 1, 8) # dilation rate of 2^3
187 | self.d16 = CDilated(n, n, 3, 1, 16) # dilation rate of 2^4
188 | self.bn = BR(nOut)
189 | self.add = add
190 |
191 | def forward(self, input):
192 | '''
193 | :param input: input feature map
194 | :return: transformed feature map
195 | '''
196 | # reduce
197 | output1 = self.c1(input)
198 | # split and transform
199 | d1 = self.d1(output1)
200 | d2 = self.d2(output1)
201 | d4 = self.d4(output1)
202 | d8 = self.d8(output1)
203 | d16 = self.d16(output1)
204 |
205 | # heirarchical fusion for de-gridding
206 | add1 = d2
207 | add2 = add1 + d4
208 | add3 = add2 + d8
209 | add4 = add3 + d16
210 |
211 | #merge
212 | combine = torch.cat([d1, add1, add2, add3, add4], 1)
213 |
214 | # if residual version
215 | if self.add:
216 | combine = input + combine
217 | output = self.bn(combine)
218 | return output
219 |
220 | class InputProjectionA(nn.Module):
221 | '''
222 | This class projects the input image to the same spatial dimensions as the feature map.
223 | For example, if the input image is 512 x512 x3 and spatial dimensions of feature map size are 56x56xF, then
224 | this class will generate an output of 56x56x3
225 | '''
226 | def __init__(self, samplingTimes):
227 | '''
228 | :param samplingTimes: The rate at which you want to down-sample the image
229 | '''
230 | super().__init__()
231 | self.pool = nn.ModuleList()
232 | for i in range(0, samplingTimes):
233 | #pyramid-based approach for down-sampling
234 | self.pool.append(nn.AvgPool2d(3, stride=2, padding=1))
235 |
236 | def forward(self, input):
237 | '''
238 | :param input: Input RGB Image
239 | :return: down-sampled image (pyramid-based approach)
240 | '''
241 | for pool in self.pool:
242 | input = pool(input)
243 | return input
244 |
245 |
246 | class ESPNet_Encoder(nn.Module):
247 | '''
248 | This class defines the ESPNet-C network in the paper
249 | '''
250 | def __init__(self, classes=20, p=5, q=3):
251 | '''
252 | :param classes: number of classes in the dataset. Default is 20 for the cityscapes
253 | :param p: depth multiplier
254 | :param q: depth multiplier
255 | '''
256 | super().__init__()
257 | self.level1 = CBR(3, 16, 3, 2)
258 | self.sample1 = InputProjectionA(1)
259 | self.sample2 = InputProjectionA(2)
260 |
261 | self.b1 = BR(16 + 3)
262 | self.level2_0 = DownSamplerB(16 +3, 64)
263 |
264 | self.level2 = nn.ModuleList()
265 | for i in range(0, p):
266 | self.level2.append(DilatedParllelResidualBlockB(64 , 64))
267 | self.b2 = BR(128 + 3)
268 |
269 | self.level3_0 = DownSamplerB(128 + 3, 128)
270 | self.level3 = nn.ModuleList()
271 | for i in range(0, q):
272 | self.level3.append(DilatedParllelResidualBlockB(128 , 128))
273 | self.b3 = BR(256)
274 |
275 | self.classifier = C(256, classes, 1, 1)
276 |
277 | def forward(self, input):
278 | '''
279 | :param input: Receives the input RGB image
280 | :return: the transformed feature map with spatial dimensions 1/8th of the input image
281 | '''
282 | output0 = self.level1(input)
283 | inp1 = self.sample1(input)
284 | inp2 = self.sample2(input)
285 |
286 | output0_cat = self.b1(torch.cat([output0, inp1], 1))
287 | output1_0 = self.level2_0(output0_cat) # down-sampled
288 |
289 | for i, layer in enumerate(self.level2):
290 | if i==0:
291 | output1 = layer(output1_0)
292 | else:
293 | output1 = layer(output1)
294 |
295 | output1_cat = self.b2(torch.cat([output1, output1_0, inp2], 1))
296 |
297 | output2_0 = self.level3_0(output1_cat) # down-sampled
298 | for i, layer in enumerate(self.level3):
299 | if i==0:
300 | output2 = layer(output2_0)
301 | else:
302 | output2 = layer(output2)
303 |
304 | output2_cat = self.b3(torch.cat([output2_0, output2], 1))
305 |
306 | classifier = self.classifier(output2_cat)
307 |
308 | return classifier
309 |
310 | class ESPNet(nn.Module):
311 | '''
312 | This class defines the ESPNet network
313 | '''
314 |
315 | def __init__(self, classes=20, p=2, q=3, encoderFile=None):
316 | '''
317 | :param classes: number of classes in the dataset. Default is 20 for the cityscapes
318 | :param p: depth multiplier
319 | :param q: depth multiplier
320 | :param encoderFile: pretrained encoder weights. Recall that we first trained the ESPNet-C and then attached the
321 | RUM-based light weight decoder. See paper for more details.
322 | '''
323 | super().__init__()
324 | self.encoder = ESPNet_Encoder(classes, p, q)
325 | if encoderFile != None:
326 | self.encoder.load_state_dict(torch.load(encoderFile))
327 | print('Encoder loaded!')
328 | # load the encoder modules
329 | self.modules = []
330 | for i, m in enumerate(self.encoder.children()):
331 | self.modules.append(m)
332 |
333 | # light-weight decoder
334 | self.level3_C = C(128 + 3, classes, 1, 1)
335 | self.br = nn.BatchNorm2d(classes, eps=1e-03)
336 | self.conv = CBR(19 + classes, classes, 3, 1)
337 |
338 | self.up_l3 = nn.Sequential(nn.ConvTranspose2d(classes, classes, 2, stride=2, padding=0, output_padding=0, bias=False))
339 | self.combine_l2_l3 = nn.Sequential(BR(2*classes), DilatedParllelResidualBlockB(2*classes , classes, add=False))
340 |
341 | self.up_l2 = nn.Sequential(nn.ConvTranspose2d(classes, classes, 2, stride=2, padding=0, output_padding=0, bias=False), BR(classes))
342 |
343 | self.classifier = nn.ConvTranspose2d(classes, classes, 2, stride=2, padding=0, output_padding=0, bias=False)
344 |
345 | def forward(self, input):
346 | '''
347 | :param input: RGB image
348 | :return: transformed feature map
349 | '''
350 | output0 = self.modules[0](input)
351 | inp1 = self.modules[1](input)
352 | inp2 = self.modules[2](input)
353 |
354 | output0_cat = self.modules[3](torch.cat([output0, inp1], 1))
355 | output1_0 = self.modules[4](output0_cat) # down-sampled
356 |
357 | for i, layer in enumerate(self.modules[5]):
358 | if i == 0:
359 | output1 = layer(output1_0)
360 | else:
361 | output1 = layer(output1)
362 |
363 | output1_cat = self.modules[6](torch.cat([output1, output1_0, inp2], 1))
364 |
365 | output2_0 = self.modules[7](output1_cat) # down-sampled
366 | for i, layer in enumerate(self.modules[8]):
367 | if i == 0:
368 | output2 = layer(output2_0)
369 | else:
370 | output2 = layer(output2)
371 |
372 | output2_cat = self.modules[9](torch.cat([output2_0, output2], 1)) # concatenate for feature map width expansion
373 |
374 | output2_c = self.up_l3(self.br(self.modules[10](output2_cat))) #RUM
375 |
376 | output1_C = self.level3_C(output1_cat) # project to C-dimensional space
377 | comb_l2_l3 = self.up_l2(self.combine_l2_l3(torch.cat([output1_C, output2_c], 1))) #RUM
378 |
379 | concat_features = self.conv(torch.cat([comb_l2_l3, output0_cat], 1))
380 |
381 | classifier = self.classifier(concat_features)
382 | return classifier
383 |
--------------------------------------------------------------------------------
/train/README.md:
--------------------------------------------------------------------------------
1 | # ESPNet: Towards Fast and Efficient Semantic Segmentation on the Embedded Devices
2 |
3 | This folder contains the python scripts for training models on the Cityscape dataset.
4 |
5 |
6 | ## Getting Started
7 |
8 | ### Training ESPNet-C
9 |
10 | You can start training the model using below command:
11 |
12 | ```
13 | python main.py
14 | ```
15 |
16 | By default, **ESPNet-C** will be trained with p=2 and q=8. Since the spatial dimensions of the output of ESPNet-C are 1/8th of original image size, please set scaleIn parameter to 8. If you want to change the parameters, you can do so by using the below command:
17 |
18 | ```
19 | python main.py --scaleIn 8 --p --q
20 |
21 | Example:
22 |
23 | python main.py --scaleIn 8 --p 2 --q 8
24 | ```
25 |
26 | ### Training ESPNet
27 | Once you are done training the ESPNet-C, you can attach the light-weight decoder and train the ESPNet model
28 |
29 | ```
30 | python main.py --scaleIn 1 --p --q --decoder True --pretrained
31 |
32 | Example:
33 |
34 | python main.py --scaleIn 1 --p 2 --q 8 --decoder True --pretrained ../pretrained/encoder/espnet_p_2_q_8.pth
35 | ```
36 |
37 | **Note 1:** Currently, we support only single GPU training. If you want to train the model on multiple-GPUs, you can use **nn.DataParallel** api provided by PyTorch.
38 |
39 | **Note 2:** To train on a specific GPU (single), you can specify the GPU_ID using the CUDA_VISIBLE_DEVICES as:
40 |
41 | ```
42 | CUDA_VISIBLE_DEVICES=2 python main.py
43 | ```
44 |
45 | This will run the training program on GPU with ID 2.
46 |
--------------------------------------------------------------------------------
/train/Transforms.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import random
4 | import cv2
5 |
6 | __author__ = "Sachin Mehta"
7 |
8 |
9 | class Scale(object):
10 | """
11 | Randomly crop and resize the given PIL image with a probability of 0.5
12 | """
13 | def __init__(self, wi, he):
14 | '''
15 |
16 | :param wi: width after resizing
17 | :param he: height after reszing
18 | '''
19 | self.w = wi
20 | self.h = he
21 |
22 | def __call__(self, img, label):
23 | '''
24 | :param img: RGB image
25 | :param label: semantic label image
26 | :return: resized images
27 | '''
28 | #bilinear interpolation for RGB image
29 | img = cv2.resize(img, (self.w, self.h))
30 | # nearest neighbour interpolation for label image
31 | label = cv2.resize(label, (self.w, self.h), interpolation=cv2.INTER_NEAREST)
32 |
33 | return [img, label]
34 |
35 |
36 |
37 | class RandomCropResize(object):
38 | """
39 | Randomly crop and resize the given PIL image with a probability of 0.5
40 | """
41 | def __init__(self, crop_area):
42 | '''
43 | :param crop_area: area to be cropped (this is the max value and we select between o and crop area
44 | '''
45 | self.cw = crop_area
46 | self.ch = crop_area
47 |
48 | def __call__(self, img, label):
49 | if random.random() < 0.5:
50 | h, w = img.shape[:2]
51 | x1 = random.randint(0, self.ch)
52 | y1 = random.randint(0, self.cw)
53 |
54 | img_crop = img[y1:h-y1, x1:w-x1]
55 | label_crop = label[y1:h-y1, x1:w-x1]
56 |
57 | img_crop = cv2.resize(img_crop, (w, h))
58 | label_crop = cv2.resize(label_crop, (w,h), interpolation=cv2.INTER_NEAREST)
59 | return img_crop, label_crop
60 | else:
61 | return [img, label]
62 |
63 | class RandomCrop(object):
64 | '''
65 | This class if for random cropping
66 | '''
67 | def __init__(self, cropArea):
68 | '''
69 | :param cropArea: amount of cropping (in pixels)
70 | '''
71 | self.crop = cropArea
72 |
73 | def __call__(self, img, label):
74 |
75 | if random.random() < 0.5:
76 | h, w = img.shape[:2]
77 | img_crop = img[self.crop:h-self.crop, self.crop:w-self.crop]
78 | label_crop = label[self.crop:h-self.crop, self.crop:w-self.crop]
79 | return img_crop, label_crop
80 | else:
81 | return [img, label]
82 |
83 |
84 |
85 | class RandomFlip(object):
86 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5
87 | """
88 |
89 | def __call__(self, image, label):
90 | if random.random() < 0.5:
91 | x1 = 0#random.randint(0, 1) #if you want to do vertical flip, uncomment this line
92 | if x1 == 0:
93 | image = cv2.flip(image, 0) # horizontal flip
94 | label = cv2.flip(label, 0) # horizontal flip
95 | else:
96 | image = cv2.flip(image, 1) # veritcal flip
97 | label = cv2.flip(label, 1) # veritcal flip
98 | return [image, label]
99 |
100 |
101 | class Normalize(object):
102 | """Given mean: (R, G, B) and std: (R, G, B),
103 | will normalize each channel of the torch.*Tensor, i.e.
104 | channel = (channel - mean) / std
105 | """
106 |
107 | def __init__(self, mean, std):
108 | '''
109 | :param mean: global mean computed from dataset
110 | :param std: global std computed from dataset
111 | '''
112 | self.mean = mean
113 | self.std = std
114 |
115 | def __call__(self, image, label):
116 | image = image.astype(np.float32)
117 | for i in range(3):
118 | image[:,:,i] -= self.mean[i]
119 | for i in range(3):
120 | image[:,:, i] /= self.std[i]
121 |
122 | return [image, label]
123 |
124 | class ToTensor(object):
125 | '''
126 | This class converts the data to tensor so that it can be processed by PyTorch
127 | '''
128 | def __init__(self, scale=1):
129 | '''
130 | :param scale: ESPNet-C's output is 1/8th of original image size, so set this parameter accordingly
131 | '''
132 | self.scale = scale # original images are 2048 x 1024
133 |
134 | def __call__(self, image, label):
135 |
136 | if self.scale != 1:
137 | h, w = label.shape[:2]
138 | image = cv2.resize(image, (int(w), int(h)))
139 | label = cv2.resize(label, (int(w/self.scale), int(h/self.scale)), interpolation=cv2.INTER_NEAREST)
140 |
141 | image = image.transpose((2,0,1))
142 |
143 | image_tensor = torch.from_numpy(image).div(255)
144 | label_tensor = torch.LongTensor(np.array(label, dtype=np.int)) #torch.from_numpy(label)
145 |
146 | return [image_tensor, label_tensor]
147 |
148 | class Compose(object):
149 | """Composes several transforms together.
150 | """
151 |
152 | def __init__(self, transforms):
153 | self.transforms = transforms
154 |
155 | def __call__(self, *args):
156 | for t in self.transforms:
157 | args = t(*args)
158 | return args
159 |
--------------------------------------------------------------------------------
/train/VisualizeGraph.py:
--------------------------------------------------------------------------------
1 | from graphviz import Digraph
2 | import torch
3 | from torch.autograd import Variable
4 |
5 | '''
6 | Not written by me
7 | Copied from here: https://github.com/szagoruyko/pytorchviz
8 | '''
9 |
10 | def make_dot(var, params=None):
11 | """ Produces Graphviz representation of PyTorch autograd graph
12 | Blue nodes are the Variables that require grad, orange are Tensors
13 | saved for backward in torch.autograd.Function
14 | Args:
15 | var: output Variable
16 | params: dict of (name, Variable) to add names to node that
17 | require grad (TODO: make optional)
18 | """
19 | if params is not None:
20 | assert isinstance(params.values()[0], Variable)
21 | param_map = {id(v): k for k, v in params.items()}
22 |
23 | node_attr = dict(style='filled',
24 | shape='box',
25 | align='left',
26 | fontsize='12',
27 | ranksep='0.1',
28 | height='0.2')
29 | dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
30 | seen = set()
31 |
32 | def size_to_str(size):
33 | return '('+(', ').join(['%d' % v for v in size])+')'
34 |
35 | def add_nodes(var):
36 | if var not in seen:
37 | if torch.is_tensor(var):
38 | dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
39 | elif hasattr(var, 'variable'):
40 | u = var.variable
41 | name = param_map[id(u)] if params is not None else ''
42 | node_name = '%s\n %s' % (name, size_to_str(u.size()))
43 | dot.node(str(id(var)), node_name, fillcolor='lightblue')
44 | else:
45 | dot.node(str(id(var)), str(type(var).__name__))
46 | seen.add(var)
47 | if hasattr(var, 'next_functions'):
48 | for u in var.next_functions:
49 | if u[0] is not None:
50 | dot.edge(str(id(u[0])), str(id(var)))
51 | add_nodes(u[0])
52 | if hasattr(var, 'saved_tensors'):
53 | for t in var.saved_tensors:
54 | dot.edge(str(id(t)), str(id(var)))
55 | add_nodes(t)
56 | add_nodes(var.grad_fn)
57 | return dot
--------------------------------------------------------------------------------
/train/city/README.md:
--------------------------------------------------------------------------------
1 | # ESPNet: Towards Fast and Efficient Semantic Segmentation on the Embedded Devices
2 |
3 | This folder contains the data.
4 |
5 | ## Change to custom data location
6 | If your data is saved in a different directory, no worries. You can pass the path of the directory and the files will load the data from the specified directory location.
7 |
8 | ```
9 | python main.py --data_dir
10 | ```
11 |
12 | Please make sure that your directory contains the **train.txt** and **val.txt** files. Our code expects the names of images in a particular format
13 |
14 | ```
15 | ,