├── LICENSE ├── README.md ├── pretrained ├── README.md ├── decoder │ ├── README.md │ ├── espnet_p_2_q_3.pth │ ├── espnet_p_2_q_5.pth │ ├── espnet_p_2_q_8.pth │ └── espnet_p_2_q_8_camvid.pth └── encoder │ ├── README.md │ ├── espnet_p_2_q_3.pth │ ├── espnet_p_2_q_5.pth │ └── espnet_p_2_q_8.pth ├── sample_video ├── ReadMe.md └── sample.png ├── test ├── Model.py ├── README.md ├── VisualizeResults.py └── data │ ├── README.md │ ├── frankfurt_000000_000294_leftImg8bit.png │ └── frankfurt_000000_000576_leftImg8bit.png └── train ├── Criteria.py ├── DataSet.py ├── IOUEval.py ├── Model.py ├── README.md ├── Transforms.py ├── VisualizeGraph.py ├── city ├── README.md ├── train.txt └── val.txt ├── loadData.py └── main.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Sachin Mehta 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ESPNet: Efficient Spatial Pyramid of Dilated Convolutions for Semantic Segmentation 2 | 3 | This repository contains the source code of our paper, [ESPNet](https://arxiv.org/abs/1803.06815) (accepted for publication in [ECCV'18](http://eccv2018.org/)). 4 | 5 | ## Sample results 6 | 7 | Check our [project page](https://sacmehta.github.io/ESPNet/) for more qualitative results (videos). 8 | 9 | Click on the below sample image to view the segmentation results on YouTube. 10 | 11 |

12 | 13 |

14 | 15 | 16 | ## Structure of this repository 17 | This repository is organized as: 18 | * [train](/train/) This directory contains the source code for trainig the ESPNet-C and ESPNet models. 19 | * [test](/test/) This directory contains the source code for evaluating our model on RGB Images. 20 | * [pretrained](/pretrained/) This directory contains the pre-trained models on the CityScape dataset 21 | * [encoder](/pretrained/encoder/) This directory contains the pretrained **ESPNet-C** models 22 | * [decoder](/pretrained/decoder/) This directory contains the pretrained **ESPNet** models 23 | 24 | 25 | ## Performance on the CityScape dataset 26 | 27 | Our model ESPNet achives an class-wise mIOU of **60.336** and category-wise mIOU of **82.178** on the CityScapes test dataset and runs at 28 | * 112 fps on the NVIDIA TitanX (30 fps faster than [ENet](https://arxiv.org/abs/1606.02147)) 29 | * 9 FPS on TX2 30 | * With the same number of parameters as [ENet](https://arxiv.org/abs/1606.02147), our model is **2%** more accurate 31 | 32 | ## Performance on the CamVid dataset 33 | 34 | Our model achieves an mIOU of 55.64 on the CamVid test set. We used the dataset splits (train/val/test) provided [here](https://github.com/alexgkendall/SegNet-Tutorial). We trained the models at a resolution of 480x360. For comparison with other models, see [SegNet paper](https://ieeexplore.ieee.org/document/7803544/). 35 | 36 | Note: We did not use the 3.5K dataset for training which was used in the SegNet paper. 37 | 38 | | Model | mIOU | Class avg. | 39 | | -- | -- | -- | 40 | | ENet | 51.3 | 68.3 | 41 | | SegNet | 55.6 | 65.2 | 42 | | ESPNet | 55.64 | 68.30 | 43 | 44 | ## Pre-requisite 45 | 46 | To run this code, you need to have following libraries: 47 | * [OpenCV](https://opencv.org/) - We tested our code with version > 3.0. 48 | * [PyTorch](http://pytorch.org/) - We tested with v0.3.0 49 | * Python - We tested our code with Pythonv3. If you are using Python v2, please feel free to make necessary changes to the code. 50 | 51 | We recommend to use [Anaconda](https://conda.io/docs/user-guide/install/linux.html). We have tested our code on Ubuntu 16.04. 52 | 53 | ## Citation 54 | If ESPNet is useful for your research, then please cite our paper. 55 | ``` 56 | @inproceedings{mehta2018espnet, 57 | title={ESPNet: Efficient Spatial Pyramid of Dilated Convolutions for Semantic Segmentation}, 58 | author={Sachin Mehta, Mohammad Rastegari, Anat Caspi, Linda Shapiro, and Hannaneh Hajishirzi}, 59 | booktitle={ECCV}, 60 | year={2018} 61 | } 62 | ``` 63 | 64 | 65 | ## FAQs 66 | 67 | ### Assertion error with class labels (t >= 0 && t < n_classes). 68 | 69 | If you are getting an assertion error with class labels, then please check the number of class labels defined in the label images. You can do this as: 70 | 71 | ``` 72 | import cv2 73 | import numpy as np 74 | labelImg = cv2.imread(, 0) 75 | unique_val_arr = np.unique(labelImg) 76 | print(unique_val_arr) 77 | ``` 78 | The values inside *unique_val_arr* should be between 0 and total number of classes in the dataset. If this is not the case, then pre-process your label images. For example, if the label iamge contains 255 as a value, then you can ignore these values by mapping it to an undefined or background class as: 79 | 80 | ``` 81 | labelImg[labelImg == 255] = 82 | ``` 83 | -------------------------------------------------------------------------------- /pretrained/README.md: -------------------------------------------------------------------------------- 1 | # ESPNet: Towards Fast and Efficient Semantic Segmentation on the Embedded Devices 2 | 3 | This directory contains the pretrained models for ESPNet-C and ESPNet under three different settings. 4 | 5 | * [encoder](/pretrained/encoder/) - Check this folder for ESPNet-C pretrained models. 6 | * [decoder](/pretrained/decoder/) - Check this folder for ESPNet pretrained models. 7 | -------------------------------------------------------------------------------- /pretrained/decoder/README.md: -------------------------------------------------------------------------------- 1 | # ESPNet: Towards Fast and Efficient Semantic Segmentation on the Embedded Devices 2 | 3 | This directory contains the pretrained models for ESPNet under three different settings: 4 | 5 | * p=2, q=3 6 | * p=2, q=5 7 | * p=2, q=8 8 | 9 | 10 | ## Models trained on the CamVid dataset 11 | * espnet_p_2_q_8_camvid.pth 12 | 13 | ## Models trained on the CityScapes dataset 14 | * espnet_p_2_q_3.pth 15 | * espnet_p_2_q_5.pth 16 | * espnet_p_2_q_8.pth 17 | -------------------------------------------------------------------------------- /pretrained/decoder/espnet_p_2_q_3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sacmehta/ESPNet/afe71c38edaee3514ca44e0adcafdf36109bf437/pretrained/decoder/espnet_p_2_q_3.pth -------------------------------------------------------------------------------- /pretrained/decoder/espnet_p_2_q_5.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sacmehta/ESPNet/afe71c38edaee3514ca44e0adcafdf36109bf437/pretrained/decoder/espnet_p_2_q_5.pth -------------------------------------------------------------------------------- /pretrained/decoder/espnet_p_2_q_8.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sacmehta/ESPNet/afe71c38edaee3514ca44e0adcafdf36109bf437/pretrained/decoder/espnet_p_2_q_8.pth -------------------------------------------------------------------------------- /pretrained/decoder/espnet_p_2_q_8_camvid.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sacmehta/ESPNet/afe71c38edaee3514ca44e0adcafdf36109bf437/pretrained/decoder/espnet_p_2_q_8_camvid.pth -------------------------------------------------------------------------------- /pretrained/encoder/README.md: -------------------------------------------------------------------------------- 1 | # ESPNet: Towards Fast and Efficient Semantic Segmentation on the Embedded Devices 2 | 3 | This directory contains the pretrained models for ESPNet-C under three different settings: 4 | 5 | * p=2, q=3 6 | * p=2, q=5 7 | * p=2, q=8 8 | -------------------------------------------------------------------------------- /pretrained/encoder/espnet_p_2_q_3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sacmehta/ESPNet/afe71c38edaee3514ca44e0adcafdf36109bf437/pretrained/encoder/espnet_p_2_q_3.pth -------------------------------------------------------------------------------- /pretrained/encoder/espnet_p_2_q_5.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sacmehta/ESPNet/afe71c38edaee3514ca44e0adcafdf36109bf437/pretrained/encoder/espnet_p_2_q_5.pth -------------------------------------------------------------------------------- /pretrained/encoder/espnet_p_2_q_8.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sacmehta/ESPNet/afe71c38edaee3514ca44e0adcafdf36109bf437/pretrained/encoder/espnet_p_2_q_8.pth -------------------------------------------------------------------------------- /sample_video/ReadMe.md: -------------------------------------------------------------------------------- 1 | This directory contains a sample video demonstrating the segmentation performance of ESPNet. 2 | -------------------------------------------------------------------------------- /sample_video/sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sacmehta/ESPNet/afe71c38edaee3514ca44e0adcafdf36109bf437/sample_video/sample.png -------------------------------------------------------------------------------- /test/Model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | __author__ = "Sachin Mehta" 5 | 6 | class CBR(nn.Module): 7 | ''' 8 | This class defines the convolution layer with batch normalization and PReLU activation 9 | ''' 10 | def __init__(self, nIn, nOut, kSize, stride=1): 11 | ''' 12 | 13 | :param nIn: number of input channels 14 | :param nOut: number of output channels 15 | :param kSize: kernel size 16 | :param stride: stride rate for down-sampling. Default is 1 17 | ''' 18 | super().__init__() 19 | padding = int((kSize - 1)/2) 20 | #self.conv = nn.Conv2d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False) 21 | self.conv = nn.Conv2d(nIn, nOut, (kSize, kSize), stride=stride, padding=(padding, padding), bias=False) 22 | #self.conv1 = nn.Conv2d(nOut, nOut, (1, kSize), stride=1, padding=(0, padding), bias=False) 23 | self.bn = nn.BatchNorm2d(nOut, eps=1e-03) 24 | self.act = nn.PReLU(nOut) 25 | 26 | def forward(self, input): 27 | ''' 28 | :param input: input feature map 29 | :return: transformed feature map 30 | ''' 31 | output = self.conv(input) 32 | #output = self.conv1(output) 33 | output = self.bn(output) 34 | output = self.act(output) 35 | return output 36 | 37 | 38 | class BR(nn.Module): 39 | ''' 40 | This class groups the batch normalization and PReLU activation 41 | ''' 42 | def __init__(self, nOut): 43 | ''' 44 | :param nOut: output feature maps 45 | ''' 46 | super().__init__() 47 | self.bn = nn.BatchNorm2d(nOut, eps=1e-03) 48 | self.act = nn.PReLU(nOut) 49 | 50 | def forward(self, input): 51 | ''' 52 | :param input: input feature map 53 | :return: normalized and thresholded feature map 54 | ''' 55 | output = self.bn(input) 56 | output = self.act(output) 57 | return output 58 | 59 | class CB(nn.Module): 60 | ''' 61 | This class groups the convolution and batch normalization 62 | ''' 63 | def __init__(self, nIn, nOut, kSize, stride=1): 64 | ''' 65 | :param nIn: number of input channels 66 | :param nOut: number of output channels 67 | :param kSize: kernel size 68 | :param stride: optinal stide for down-sampling 69 | ''' 70 | super().__init__() 71 | padding = int((kSize - 1)/2) 72 | self.conv = nn.Conv2d(nIn, nOut, (kSize, kSize), stride=stride, padding=(padding, padding), bias=False) 73 | self.bn = nn.BatchNorm2d(nOut, eps=1e-03) 74 | 75 | def forward(self, input): 76 | ''' 77 | 78 | :param input: input feature map 79 | :return: transformed feature map 80 | ''' 81 | output = self.conv(input) 82 | output = self.bn(output) 83 | return output 84 | 85 | class C(nn.Module): 86 | ''' 87 | This class is for a convolutional layer. 88 | ''' 89 | def __init__(self, nIn, nOut, kSize, stride=1): 90 | ''' 91 | 92 | :param nIn: number of input channels 93 | :param nOut: number of output channels 94 | :param kSize: kernel size 95 | :param stride: optional stride rate for down-sampling 96 | ''' 97 | super().__init__() 98 | padding = int((kSize - 1)/2) 99 | self.conv = nn.Conv2d(nIn, nOut, (kSize, kSize), stride=stride, padding=(padding, padding), bias=False) 100 | 101 | def forward(self, input): 102 | ''' 103 | :param input: input feature map 104 | :return: transformed feature map 105 | ''' 106 | output = self.conv(input) 107 | return output 108 | 109 | class CDilated(nn.Module): 110 | ''' 111 | This class defines the dilated convolution. 112 | ''' 113 | def __init__(self, nIn, nOut, kSize, stride=1, d=1): 114 | ''' 115 | :param nIn: number of input channels 116 | :param nOut: number of output channels 117 | :param kSize: kernel size 118 | :param stride: optional stride rate for down-sampling 119 | :param d: optional dilation rate 120 | ''' 121 | super().__init__() 122 | padding = int((kSize - 1)/2) * d 123 | self.conv = nn.Conv2d(nIn, nOut, (kSize, kSize), stride=stride, padding=(padding, padding), bias=False, dilation=d) 124 | 125 | def forward(self, input): 126 | ''' 127 | :param input: input feature map 128 | :return: transformed feature map 129 | ''' 130 | output = self.conv(input) 131 | return output 132 | 133 | class DownSamplerB(nn.Module): 134 | def __init__(self, nIn, nOut): 135 | super().__init__() 136 | n = int(nOut/5) 137 | n1 = nOut - 4*n 138 | self.c1 = C(nIn, n, 3, 2) 139 | self.d1 = CDilated(n, n1, 3, 1, 1) 140 | self.d2 = CDilated(n, n, 3, 1, 2) 141 | self.d4 = CDilated(n, n, 3, 1, 4) 142 | self.d8 = CDilated(n, n, 3, 1, 8) 143 | self.d16 = CDilated(n, n, 3, 1, 16) 144 | self.bn = nn.BatchNorm2d(nOut, eps=1e-3) 145 | self.act = nn.PReLU(nOut) 146 | 147 | def forward(self, input): 148 | output1 = self.c1(input) 149 | d1 = self.d1(output1) 150 | d2 = self.d2(output1) 151 | d4 = self.d4(output1) 152 | d8 = self.d8(output1) 153 | d16 = self.d16(output1) 154 | 155 | add1 = d2 156 | add2 = add1 + d4 157 | add3 = add2 + d8 158 | add4 = add3 + d16 159 | 160 | combine = torch.cat([d1, add1, add2, add3, add4],1) 161 | #combine_in_out = input + combine 162 | output = self.bn(combine) 163 | output = self.act(output) 164 | return output 165 | 166 | class DilatedParllelResidualBlockB(nn.Module): 167 | ''' 168 | This class defines the ESP block, which is based on the following principle 169 | Reduce ---> Split ---> Transform --> Merge 170 | ''' 171 | def __init__(self, nIn, nOut, add=True): 172 | ''' 173 | :param nIn: number of input channels 174 | :param nOut: number of output channels 175 | :param add: if true, add a residual connection through identity operation. You can use projection too as 176 | in ResNet paper, but we avoid to use it if the dimensions are not the same because we do not want to 177 | increase the module complexity 178 | ''' 179 | super().__init__() 180 | n = int(nOut/5) 181 | n1 = nOut - 4*n 182 | self.c1 = C(nIn, n, 1, 1) 183 | self.d1 = CDilated(n, n1, 3, 1, 1) # dilation rate of 2^0 184 | self.d2 = CDilated(n, n, 3, 1, 2) # dilation rate of 2^1 185 | self.d4 = CDilated(n, n, 3, 1, 4) # dilation rate of 2^2 186 | self.d8 = CDilated(n, n, 3, 1, 8) # dilation rate of 2^3 187 | self.d16 = CDilated(n, n, 3, 1, 16) # dilation rate of 2^4 188 | self.bn = BR(nOut) 189 | self.add = add 190 | 191 | def forward(self, input): 192 | ''' 193 | :param input: input feature map 194 | :return: transformed feature map 195 | ''' 196 | # reduce 197 | output1 = self.c1(input) 198 | # split and transform 199 | d1 = self.d1(output1) 200 | d2 = self.d2(output1) 201 | d4 = self.d4(output1) 202 | d8 = self.d8(output1) 203 | d16 = self.d16(output1) 204 | 205 | # heirarchical fusion for de-gridding 206 | add1 = d2 207 | add2 = add1 + d4 208 | add3 = add2 + d8 209 | add4 = add3 + d16 210 | 211 | #merge 212 | combine = torch.cat([d1, add1, add2, add3, add4], 1) 213 | 214 | # if residual version 215 | if self.add: 216 | combine = input + combine 217 | output = self.bn(combine) 218 | return output 219 | 220 | class InputProjectionA(nn.Module): 221 | ''' 222 | This class projects the input image to the same spatial dimensions as the feature map. 223 | For example, if the input image is 512 x512 x3 and spatial dimensions of feature map size are 56x56xF, then 224 | this class will generate an output of 56x56x3 225 | ''' 226 | def __init__(self, samplingTimes): 227 | ''' 228 | :param samplingTimes: The rate at which you want to down-sample the image 229 | ''' 230 | super().__init__() 231 | self.pool = nn.ModuleList() 232 | for i in range(0, samplingTimes): 233 | #pyramid-based approach for down-sampling 234 | self.pool.append(nn.AvgPool2d(3, stride=2, padding=1)) 235 | 236 | def forward(self, input): 237 | ''' 238 | :param input: Input RGB Image 239 | :return: down-sampled image (pyramid-based approach) 240 | ''' 241 | for pool in self.pool: 242 | input = pool(input) 243 | return input 244 | 245 | 246 | class ESPNet_Encoder(nn.Module): 247 | ''' 248 | This class defines the ESPNet-C network in the paper 249 | ''' 250 | def __init__(self, classes=20, p=5, q=3): 251 | ''' 252 | :param classes: number of classes in the dataset. Default is 20 for the cityscapes 253 | :param p: depth multiplier 254 | :param q: depth multiplier 255 | ''' 256 | super().__init__() 257 | self.level1 = CBR(3, 16, 3, 2) 258 | self.sample1 = InputProjectionA(1) 259 | self.sample2 = InputProjectionA(2) 260 | 261 | self.b1 = BR(16 + 3) 262 | self.level2_0 = DownSamplerB(16 +3, 64) 263 | 264 | self.level2 = nn.ModuleList() 265 | for i in range(0, p): 266 | self.level2.append(DilatedParllelResidualBlockB(64 , 64)) 267 | self.b2 = BR(128 + 3) 268 | 269 | self.level3_0 = DownSamplerB(128 + 3, 128) 270 | self.level3 = nn.ModuleList() 271 | for i in range(0, q): 272 | self.level3.append(DilatedParllelResidualBlockB(128 , 128)) 273 | self.b3 = BR(256) 274 | 275 | self.classifier = C(256, classes, 1, 1) 276 | 277 | def forward(self, input): 278 | ''' 279 | :param input: Receives the input RGB image 280 | :return: the transformed feature map with spatial dimensions 1/8th of the input image 281 | ''' 282 | output0 = self.level1(input) 283 | inp1 = self.sample1(input) 284 | inp2 = self.sample2(input) 285 | 286 | output0_cat = self.b1(torch.cat([output0, inp1], 1)) 287 | output1_0 = self.level2_0(output0_cat) # down-sampled 288 | 289 | for i, layer in enumerate(self.level2): 290 | if i==0: 291 | output1 = layer(output1_0) 292 | else: 293 | output1 = layer(output1) 294 | 295 | output1_cat = self.b2(torch.cat([output1, output1_0, inp2], 1)) 296 | 297 | output2_0 = self.level3_0(output1_cat) # down-sampled 298 | for i, layer in enumerate(self.level3): 299 | if i==0: 300 | output2 = layer(output2_0) 301 | else: 302 | output2 = layer(output2) 303 | 304 | output2_cat = self.b3(torch.cat([output2_0, output2], 1)) 305 | 306 | classifier = self.classifier(output2_cat) 307 | 308 | return classifier 309 | 310 | class ESPNet(nn.Module): 311 | ''' 312 | This class defines the ESPNet network 313 | ''' 314 | 315 | def __init__(self, classes=20, p=2, q=3, encoderFile=None): 316 | ''' 317 | :param classes: number of classes in the dataset. Default is 20 for the cityscapes 318 | :param p: depth multiplier 319 | :param q: depth multiplier 320 | :param encoderFile: pretrained encoder weights. Recall that we first trained the ESPNet-C and then attached the 321 | RUM-based light weight decoder. See paper for more details. 322 | ''' 323 | super().__init__() 324 | self.encoder = ESPNet_Encoder(classes, p, q) 325 | if encoderFile != None: 326 | self.encoder.load_state_dict(torch.load(encoderFile)) 327 | print('Encoder loaded!') 328 | # load the encoder modules 329 | self.modules = [] 330 | for i, m in enumerate(self.encoder.children()): 331 | self.modules.append(m) 332 | 333 | # light-weight decoder 334 | self.level3_C = C(128 + 3, classes, 1, 1) 335 | self.br = nn.BatchNorm2d(classes, eps=1e-03) 336 | self.conv = CBR(16 + classes, classes, 3, 1) 337 | 338 | self.up_l3 = nn.Sequential(nn.ConvTranspose2d(classes, classes, 2, stride=2, padding=0, output_padding=0, bias=False)) 339 | self.combine_l2_l3 = nn.Sequential(BR(2*classes), DilatedParllelResidualBlockB(2*classes , classes, add=False)) 340 | 341 | self.up_l2 = nn.Sequential(nn.ConvTranspose2d(classes, classes, 2, stride=2, padding=0, output_padding=0, bias=False), BR(classes)) 342 | 343 | self.classifier = nn.ConvTranspose2d(classes, classes, 2, stride=2, padding=0, output_padding=0, bias=False) 344 | 345 | def forward(self, input): 346 | ''' 347 | :param input: RGB image 348 | :return: transformed feature map 349 | ''' 350 | output0 = self.modules[0](input) 351 | inp1 = self.modules[1](input) 352 | inp2 = self.modules[2](input) 353 | 354 | output0_cat = self.modules[3](torch.cat([output0, inp1], 1)) 355 | output1_0 = self.modules[4](output0_cat) # down-sampled 356 | 357 | for i, layer in enumerate(self.modules[5]): 358 | if i == 0: 359 | output1 = layer(output1_0) 360 | else: 361 | output1 = layer(output1) 362 | 363 | output1_cat = self.modules[6](torch.cat([output1, output1_0, inp2], 1)) 364 | 365 | output2_0 = self.modules[7](output1_cat) # down-sampled 366 | for i, layer in enumerate(self.modules[8]): 367 | if i == 0: 368 | output2 = layer(output2_0) 369 | else: 370 | output2 = layer(output2) 371 | 372 | output2_cat = self.modules[9](torch.cat([output2_0, output2], 1)) # concatenate for feature map width expansion 373 | 374 | output2_c = self.up_l3(self.br(self.modules[10](output2_cat))) #RUM 375 | 376 | output1_C = self.level3_C(output1_cat) # project to C-dimensional space 377 | comb_l2_l3 = self.up_l2(self.combine_l2_l3(torch.cat([output1_C, output2_c], 1))) #RUM 378 | 379 | concat_features = self.conv(torch.cat([comb_l2_l3, output0], 1)) 380 | 381 | classifier = self.classifier(concat_features) 382 | return classifier 383 | -------------------------------------------------------------------------------- /test/README.md: -------------------------------------------------------------------------------- 1 | # ESPNet: Towards Fast and Efficient Semantic Segmentation on the Embedded Devices 2 | 3 | This folder contains the python scripts for running our pretrained models on the Cityscape dataset. 4 | 5 | ## Getting Started 6 | We provide the pretrained weights for ESPNet and ESPNet-C. Recall that ESPNet is the same as ESPNet-C, but with light weight decoder. 7 | 8 | Pre-requisites: 9 | * By default, we expect all images inside the ./data directory. If they are in different directory, please change the **data_dir** argument in the VisualizeResults.py file. 10 | 11 | * Also, if the image format is different (e.g. jpg), please change in the VisualizeResults.py file. 12 | 13 | This can be done using the below command: 14 | 15 | ``` 16 | python VisualizeResults.py --data_dir --img_extn 17 | ``` 18 | 19 | 20 | ### Running ESPNet-C models 21 | To run the ESPNet-C models, execute the following commands 22 | 23 | ``` 24 | python VisualizeResults.py --modelType 2 --p 2 --q 3 25 | ``` 26 | 27 | Here, p and q are the depth multipliers. Our models only support p=2 and q=3,5,8 28 | 29 | 30 | ### Running ESPNet models 31 | To run the ESPNet models, execute the following commands 32 | 33 | ``` 34 | python VisualizeResults.py --modelType 1 --p 2 --q 3 35 | ``` 36 | 37 | Here, p and q are the depth multipliers. Our models only support p=2 and q=3,5,8 38 | -------------------------------------------------------------------------------- /test/VisualizeResults.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.autograd import Variable 4 | import glob 5 | import cv2 6 | from PIL import Image as PILImage 7 | import Model as Net 8 | import os 9 | import time 10 | from argparse import ArgumentParser 11 | 12 | pallete = [[128, 64, 128], 13 | [244, 35, 232], 14 | [70, 70, 70], 15 | [102, 102, 156], 16 | [190, 153, 153], 17 | [153, 153, 153], 18 | [250, 170, 30], 19 | [220, 220, 0], 20 | [107, 142, 35], 21 | [152, 251, 152], 22 | [70, 130, 180], 23 | [220, 20, 60], 24 | [255, 0, 0], 25 | [0, 0, 142], 26 | [0, 0, 70], 27 | [0, 60, 100], 28 | [0, 80, 100], 29 | [0, 0, 230], 30 | [119, 11, 32], 31 | [0, 0, 0]] 32 | 33 | 34 | def relabel(img): 35 | ''' 36 | This function relabels the predicted labels so that cityscape dataset can process 37 | :param img: 38 | :return: 39 | ''' 40 | img[img == 19] = 255 41 | img[img == 18] = 33 42 | img[img == 17] = 32 43 | img[img == 16] = 31 44 | img[img == 15] = 28 45 | img[img == 14] = 27 46 | img[img == 13] = 26 47 | img[img == 12] = 25 48 | img[img == 11] = 24 49 | img[img == 10] = 23 50 | img[img == 9] = 22 51 | img[img == 8] = 21 52 | img[img == 7] = 20 53 | img[img == 6] = 19 54 | img[img == 5] = 17 55 | img[img == 4] = 13 56 | img[img == 3] = 12 57 | img[img == 2] = 11 58 | img[img == 1] = 8 59 | img[img == 0] = 7 60 | img[img == 255] = 0 61 | return img 62 | 63 | 64 | def evaluateModel(args, model, up, image_list): 65 | # gloabl mean and std values 66 | mean = [72.3923111, 82.90893555, 73.15840149] 67 | std = [45.3192215, 46.15289307, 44.91483307] 68 | 69 | for i, imgName in enumerate(image_list): 70 | img = cv2.imread(imgName) 71 | if args.overlay: 72 | img_orig = np.copy(img) 73 | 74 | img = img.astype(np.float32) 75 | for j in range(3): 76 | img[:, :, j] -= mean[j] 77 | for j in range(3): 78 | img[:, :, j] /= std[j] 79 | 80 | # resize the image to 1024x512x3 81 | img = cv2.resize(img, (1024, 512)) 82 | if args.overlay: 83 | img_orig = cv2.resize(img_orig, (1024, 512)) 84 | 85 | img /= 255 86 | img = img.transpose((2, 0, 1)) 87 | img_tensor = torch.from_numpy(img) 88 | img_tensor = torch.unsqueeze(img_tensor, 0) # add a batch dimension 89 | img_variable = Variable(img_tensor, volatile=True) 90 | if args.gpu: 91 | img_variable = img_variable.cuda() 92 | img_out = model(img_variable) 93 | 94 | if args.modelType == 2: 95 | img_out = up(img_out) 96 | 97 | classMap_numpy = img_out[0].max(0)[1].byte().cpu().data.numpy() 98 | 99 | if i % 100 == 0: 100 | print(i) 101 | 102 | name = imgName.split('/')[-1] 103 | 104 | if args.colored: 105 | classMap_numpy_color = np.zeros((img.shape[1], img.shape[2], img.shape[0]), dtype=np.uint8) 106 | for idx in range(len(pallete)): 107 | [r, g, b] = pallete[idx] 108 | classMap_numpy_color[classMap_numpy == idx] = [b, g, r] 109 | cv2.imwrite(args.savedir + os.sep + 'c_' + name.replace(args.img_extn, 'png'), classMap_numpy_color) 110 | if args.overlay: 111 | overlayed = cv2.addWeighted(img_orig, 0.5, classMap_numpy_color, 0.5, 0) 112 | cv2.imwrite(args.savedir + os.sep + 'over_' + name.replace(args.img_extn, 'jpg'), overlayed) 113 | 114 | if args.cityFormat: 115 | classMap_numpy = relabel(classMap_numpy.astype(np.uint8)) 116 | 117 | cv2.imwrite(args.savedir + os.sep + name.replace(args.img_extn, 'png'), classMap_numpy) 118 | 119 | 120 | def main(args): 121 | # read all the images in the folder 122 | image_list = glob.glob(args.data_dir + os.sep + '*.' + args.img_extn) 123 | 124 | up = None 125 | if args.modelType == 2: 126 | up = torch.nn.Upsample(scale_factor=8, mode='bilinear') 127 | if args.gpu: 128 | up = up.cuda() 129 | 130 | p = args.p 131 | q = args.q 132 | classes = args.classes 133 | if args.modelType == 2: 134 | modelA = Net.ESPNet_Encoder(classes, p, q) # Net.Mobile_SegNetDilatedIA_C_stage1(20) 135 | model_weight_file = args.weightsDir + os.sep + 'encoder' + os.sep + 'espnet_p_' + str(p) + '_q_' + str( 136 | q) + '.pth' 137 | if not os.path.isfile(model_weight_file): 138 | print('Pre-trained model file does not exist. Please check ../pretrained/encoder folder') 139 | exit(-1) 140 | modelA.load_state_dict(torch.load(model_weight_file)) 141 | elif args.modelType == 1: 142 | modelA = Net.ESPNet(classes, p, q) # Net.Mobile_SegNetDilatedIA_C_stage1(20) 143 | model_weight_file = args.weightsDir + os.sep + 'decoder' + os.sep + 'espnet_p_' + str(p) + '_q_' + str(q) + '.pth' 144 | if not os.path.isfile(model_weight_file): 145 | print('Pre-trained model file does not exist. Please check ../pretrained/decoder folder') 146 | exit(-1) 147 | modelA.load_state_dict(torch.load(model_weight_file)) 148 | else: 149 | print('Model not supported') 150 | # modelA = torch.nn.DataParallel(modelA) 151 | if args.gpu: 152 | modelA = modelA.cuda() 153 | 154 | # set to evaluation mode 155 | modelA.eval() 156 | 157 | if not os.path.isdir(args.savedir): 158 | os.mkdir(args.savedir) 159 | 160 | evaluateModel(args, modelA, up, image_list) 161 | 162 | 163 | if __name__ == '__main__': 164 | parser = ArgumentParser() 165 | parser.add_argument('--model', default="ESPNet", help='Model name') 166 | parser.add_argument('--data_dir', default="./data", help='Data directory') 167 | parser.add_argument('--img_extn', default="png", help='RGB Image format') 168 | parser.add_argument('--inWidth', type=int, default=1024, help='Width of RGB image') 169 | parser.add_argument('--inHeight', type=int, default=512, help='Height of RGB image') 170 | parser.add_argument('--scaleIn', type=int, default=1, help='For ESPNet-C, scaleIn=8. For ESPNet, scaleIn=1') 171 | parser.add_argument('--modelType', type=int, default=1, help='1=ESPNet, 2=ESPNet-C') 172 | parser.add_argument('--savedir', default='./results', help='directory to save the results') 173 | parser.add_argument('--gpu', default=True, type=bool, help='Run on CPU or GPU. If TRUE, then GPU.') 174 | parser.add_argument('--decoder', type=bool, default=True, 175 | help='True if ESPNet. False for ESPNet-C') # False for encoder 176 | parser.add_argument('--weightsDir', default='../pretrained/', help='Pretrained weights directory.') 177 | parser.add_argument('--p', default=2, type=int, help='depth multiplier. Supported only 2') 178 | parser.add_argument('--q', default=8, type=int, help='depth multiplier. Supported only 3, 5, 8') 179 | parser.add_argument('--cityFormat', default=True, type=bool, help='If you want to convert to cityscape ' 180 | 'original label ids') 181 | parser.add_argument('--colored', default=True, type=bool, help='If you want to visualize the ' 182 | 'segmentation masks in color') 183 | parser.add_argument('--overlay', default=True, type=bool, help='If you want to visualize the ' 184 | 'segmentation masks overlayed on top of RGB image') 185 | parser.add_argument('--classes', default=20, type=int, help='Number of classes in the dataset. 20 for Cityscapes') 186 | 187 | args = parser.parse_args() 188 | assert (args.modelType == 1) and args.decoder, 'Model type should be 2 for ESPNet-C and 1 for ESPNet' 189 | if args.overlay: 190 | args.colored = True # This has to be true if you want to overlay 191 | main(args) 192 | -------------------------------------------------------------------------------- /test/data/README.md: -------------------------------------------------------------------------------- 1 | This folder should contain all the images for which you want to generate the results 2 | -------------------------------------------------------------------------------- /test/data/frankfurt_000000_000294_leftImg8bit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sacmehta/ESPNet/afe71c38edaee3514ca44e0adcafdf36109bf437/test/data/frankfurt_000000_000294_leftImg8bit.png -------------------------------------------------------------------------------- /test/data/frankfurt_000000_000576_leftImg8bit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sacmehta/ESPNet/afe71c38edaee3514ca44e0adcafdf36109bf437/test/data/frankfurt_000000_000576_leftImg8bit.png -------------------------------------------------------------------------------- /train/Criteria.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | __author__ = "Sachin Mehta" 5 | 6 | 7 | class CrossEntropyLoss2d(nn.Module): 8 | ''' 9 | This file defines a cross entropy loss for 2D images 10 | ''' 11 | def __init__(self, weight=None): 12 | ''' 13 | :param weight: 1D weight vector to deal with the class-imbalance 14 | ''' 15 | super().__init__() 16 | 17 | self.loss = nn.NLLLoss2d(weight) 18 | 19 | def forward(self, outputs, targets): 20 | return self.loss(F.log_softmax(outputs, 1), targets) 21 | -------------------------------------------------------------------------------- /train/DataSet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | import torch.utils.data 4 | 5 | __author__ = "Sachin Mehta" 6 | 7 | 8 | class MyDataset(torch.utils.data.Dataset): 9 | ''' 10 | Class to load the dataset 11 | ''' 12 | def __init__(self, imList, labelList, transform=None): 13 | ''' 14 | :param imList: image list (Note that these lists have been processed and pickled using the loadData.py) 15 | :param labelList: label list (Note that these lists have been processed and pickled using the loadData.py) 16 | :param transform: Type of transformation. SEe Transforms.py for supported transformations 17 | ''' 18 | self.imList = imList 19 | self.labelList = labelList 20 | self.transform = transform 21 | 22 | def __len__(self): 23 | return len(self.imList) 24 | 25 | def __getitem__(self, idx): 26 | ''' 27 | 28 | :param idx: Index of the image file 29 | :return: returns the image and corresponding label file. 30 | ''' 31 | image_name = self.imList[idx] 32 | label_name = self.labelList[idx] 33 | image = cv2.imread(image_name) 34 | label = cv2.imread(label_name, 0) 35 | if self.transform: 36 | [image, label] = self.transform(image, label) 37 | return (image, label) 38 | -------------------------------------------------------------------------------- /train/IOUEval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | #adapted from https://github.com/shelhamer/fcn.berkeleyvision.org/blob/master/score.py 5 | 6 | class iouEval: 7 | def __init__(self, nClasses): 8 | self.nClasses = nClasses 9 | self.reset() 10 | 11 | def reset(self): 12 | self.overall_acc = 0 13 | self.per_class_acc = np.zeros(self.nClasses, dtype=np.float32) 14 | self.per_class_iu = np.zeros(self.nClasses, dtype=np.float32) 15 | self.mIOU = 0 16 | self.batchCount = 1 17 | 18 | def fast_hist(self, a, b): 19 | k = (a >= 0) & (a < self.nClasses) 20 | return np.bincount(self.nClasses * a[k].astype(int) + b[k], minlength=self.nClasses ** 2).reshape(self.nClasses, self.nClasses) 21 | 22 | def compute_hist(self, predict, gth): 23 | hist = self.fast_hist(gth, predict) 24 | return hist 25 | 26 | def addBatch(self, predict, gth): 27 | predict = predict.cpu().numpy().flatten() 28 | gth = gth.cpu().numpy().flatten() 29 | 30 | epsilon = 0.00000001 31 | hist = self.compute_hist(predict, gth) 32 | overall_acc = np.diag(hist).sum() / (hist.sum() + epsilon) 33 | per_class_acc = np.diag(hist) / (hist.sum(1) + epsilon) 34 | per_class_iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist) + epsilon) 35 | mIou = np.nanmean(per_class_iu) 36 | 37 | self.overall_acc +=overall_acc 38 | self.per_class_acc += per_class_acc 39 | self.per_class_iu += per_class_iu 40 | self.mIOU += mIou 41 | self.batchCount += 1 42 | 43 | def getMetric(self): 44 | overall_acc = self.overall_acc/self.batchCount 45 | per_class_acc = self.per_class_acc / self.batchCount 46 | per_class_iu = self.per_class_iu / self.batchCount 47 | mIOU = self.mIOU / self.batchCount 48 | 49 | return overall_acc, per_class_acc, per_class_iu, mIOU -------------------------------------------------------------------------------- /train/Model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | __author__ = "Sachin Mehta" 5 | 6 | class CBR(nn.Module): 7 | ''' 8 | This class defines the convolution layer with batch normalization and PReLU activation 9 | ''' 10 | def __init__(self, nIn, nOut, kSize, stride=1): 11 | ''' 12 | 13 | :param nIn: number of input channels 14 | :param nOut: number of output channels 15 | :param kSize: kernel size 16 | :param stride: stride rate for down-sampling. Default is 1 17 | ''' 18 | super().__init__() 19 | padding = int((kSize - 1)/2) 20 | #self.conv = nn.Conv2d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False) 21 | self.conv = nn.Conv2d(nIn, nOut, (kSize, kSize), stride=stride, padding=(padding, padding), bias=False) 22 | #self.conv1 = nn.Conv2d(nOut, nOut, (1, kSize), stride=1, padding=(0, padding), bias=False) 23 | self.bn = nn.BatchNorm2d(nOut, eps=1e-03) 24 | self.act = nn.PReLU(nOut) 25 | 26 | def forward(self, input): 27 | ''' 28 | :param input: input feature map 29 | :return: transformed feature map 30 | ''' 31 | output = self.conv(input) 32 | #output = self.conv1(output) 33 | output = self.bn(output) 34 | output = self.act(output) 35 | return output 36 | 37 | 38 | class BR(nn.Module): 39 | ''' 40 | This class groups the batch normalization and PReLU activation 41 | ''' 42 | def __init__(self, nOut): 43 | ''' 44 | :param nOut: output feature maps 45 | ''' 46 | super().__init__() 47 | self.bn = nn.BatchNorm2d(nOut, eps=1e-03) 48 | self.act = nn.PReLU(nOut) 49 | 50 | def forward(self, input): 51 | ''' 52 | :param input: input feature map 53 | :return: normalized and thresholded feature map 54 | ''' 55 | output = self.bn(input) 56 | output = self.act(output) 57 | return output 58 | 59 | class CB(nn.Module): 60 | ''' 61 | This class groups the convolution and batch normalization 62 | ''' 63 | def __init__(self, nIn, nOut, kSize, stride=1): 64 | ''' 65 | :param nIn: number of input channels 66 | :param nOut: number of output channels 67 | :param kSize: kernel size 68 | :param stride: optinal stide for down-sampling 69 | ''' 70 | super().__init__() 71 | padding = int((kSize - 1)/2) 72 | self.conv = nn.Conv2d(nIn, nOut, (kSize, kSize), stride=stride, padding=(padding, padding), bias=False) 73 | self.bn = nn.BatchNorm2d(nOut, eps=1e-03) 74 | 75 | def forward(self, input): 76 | ''' 77 | 78 | :param input: input feature map 79 | :return: transformed feature map 80 | ''' 81 | output = self.conv(input) 82 | output = self.bn(output) 83 | return output 84 | 85 | class C(nn.Module): 86 | ''' 87 | This class is for a convolutional layer. 88 | ''' 89 | def __init__(self, nIn, nOut, kSize, stride=1): 90 | ''' 91 | 92 | :param nIn: number of input channels 93 | :param nOut: number of output channels 94 | :param kSize: kernel size 95 | :param stride: optional stride rate for down-sampling 96 | ''' 97 | super().__init__() 98 | padding = int((kSize - 1)/2) 99 | self.conv = nn.Conv2d(nIn, nOut, (kSize, kSize), stride=stride, padding=(padding, padding), bias=False) 100 | 101 | def forward(self, input): 102 | ''' 103 | :param input: input feature map 104 | :return: transformed feature map 105 | ''' 106 | output = self.conv(input) 107 | return output 108 | 109 | class CDilated(nn.Module): 110 | ''' 111 | This class defines the dilated convolution. 112 | ''' 113 | def __init__(self, nIn, nOut, kSize, stride=1, d=1): 114 | ''' 115 | :param nIn: number of input channels 116 | :param nOut: number of output channels 117 | :param kSize: kernel size 118 | :param stride: optional stride rate for down-sampling 119 | :param d: optional dilation rate 120 | ''' 121 | super().__init__() 122 | padding = int((kSize - 1)/2) * d 123 | self.conv = nn.Conv2d(nIn, nOut, (kSize, kSize), stride=stride, padding=(padding, padding), bias=False, dilation=d) 124 | 125 | def forward(self, input): 126 | ''' 127 | :param input: input feature map 128 | :return: transformed feature map 129 | ''' 130 | output = self.conv(input) 131 | return output 132 | 133 | class DownSamplerB(nn.Module): 134 | def __init__(self, nIn, nOut): 135 | super().__init__() 136 | n = int(nOut/5) 137 | n1 = nOut - 4*n 138 | self.c1 = C(nIn, n, 3, 2) 139 | self.d1 = CDilated(n, n1, 3, 1, 1) 140 | self.d2 = CDilated(n, n, 3, 1, 2) 141 | self.d4 = CDilated(n, n, 3, 1, 4) 142 | self.d8 = CDilated(n, n, 3, 1, 8) 143 | self.d16 = CDilated(n, n, 3, 1, 16) 144 | self.bn = nn.BatchNorm2d(nOut, eps=1e-3) 145 | self.act = nn.PReLU(nOut) 146 | 147 | def forward(self, input): 148 | output1 = self.c1(input) 149 | d1 = self.d1(output1) 150 | d2 = self.d2(output1) 151 | d4 = self.d4(output1) 152 | d8 = self.d8(output1) 153 | d16 = self.d16(output1) 154 | 155 | add1 = d2 156 | add2 = add1 + d4 157 | add3 = add2 + d8 158 | add4 = add3 + d16 159 | 160 | combine = torch.cat([d1, add1, add2, add3, add4],1) 161 | #combine_in_out = input + combine 162 | output = self.bn(combine) 163 | output = self.act(output) 164 | return output 165 | 166 | class DilatedParllelResidualBlockB(nn.Module): 167 | ''' 168 | This class defines the ESP block, which is based on the following principle 169 | Reduce ---> Split ---> Transform --> Merge 170 | ''' 171 | def __init__(self, nIn, nOut, add=True): 172 | ''' 173 | :param nIn: number of input channels 174 | :param nOut: number of output channels 175 | :param add: if true, add a residual connection through identity operation. You can use projection too as 176 | in ResNet paper, but we avoid to use it if the dimensions are not the same because we do not want to 177 | increase the module complexity 178 | ''' 179 | super().__init__() 180 | n = int(nOut/5) 181 | n1 = nOut - 4*n 182 | self.c1 = C(nIn, n, 1, 1) 183 | self.d1 = CDilated(n, n1, 3, 1, 1) # dilation rate of 2^0 184 | self.d2 = CDilated(n, n, 3, 1, 2) # dilation rate of 2^1 185 | self.d4 = CDilated(n, n, 3, 1, 4) # dilation rate of 2^2 186 | self.d8 = CDilated(n, n, 3, 1, 8) # dilation rate of 2^3 187 | self.d16 = CDilated(n, n, 3, 1, 16) # dilation rate of 2^4 188 | self.bn = BR(nOut) 189 | self.add = add 190 | 191 | def forward(self, input): 192 | ''' 193 | :param input: input feature map 194 | :return: transformed feature map 195 | ''' 196 | # reduce 197 | output1 = self.c1(input) 198 | # split and transform 199 | d1 = self.d1(output1) 200 | d2 = self.d2(output1) 201 | d4 = self.d4(output1) 202 | d8 = self.d8(output1) 203 | d16 = self.d16(output1) 204 | 205 | # heirarchical fusion for de-gridding 206 | add1 = d2 207 | add2 = add1 + d4 208 | add3 = add2 + d8 209 | add4 = add3 + d16 210 | 211 | #merge 212 | combine = torch.cat([d1, add1, add2, add3, add4], 1) 213 | 214 | # if residual version 215 | if self.add: 216 | combine = input + combine 217 | output = self.bn(combine) 218 | return output 219 | 220 | class InputProjectionA(nn.Module): 221 | ''' 222 | This class projects the input image to the same spatial dimensions as the feature map. 223 | For example, if the input image is 512 x512 x3 and spatial dimensions of feature map size are 56x56xF, then 224 | this class will generate an output of 56x56x3 225 | ''' 226 | def __init__(self, samplingTimes): 227 | ''' 228 | :param samplingTimes: The rate at which you want to down-sample the image 229 | ''' 230 | super().__init__() 231 | self.pool = nn.ModuleList() 232 | for i in range(0, samplingTimes): 233 | #pyramid-based approach for down-sampling 234 | self.pool.append(nn.AvgPool2d(3, stride=2, padding=1)) 235 | 236 | def forward(self, input): 237 | ''' 238 | :param input: Input RGB Image 239 | :return: down-sampled image (pyramid-based approach) 240 | ''' 241 | for pool in self.pool: 242 | input = pool(input) 243 | return input 244 | 245 | 246 | class ESPNet_Encoder(nn.Module): 247 | ''' 248 | This class defines the ESPNet-C network in the paper 249 | ''' 250 | def __init__(self, classes=20, p=5, q=3): 251 | ''' 252 | :param classes: number of classes in the dataset. Default is 20 for the cityscapes 253 | :param p: depth multiplier 254 | :param q: depth multiplier 255 | ''' 256 | super().__init__() 257 | self.level1 = CBR(3, 16, 3, 2) 258 | self.sample1 = InputProjectionA(1) 259 | self.sample2 = InputProjectionA(2) 260 | 261 | self.b1 = BR(16 + 3) 262 | self.level2_0 = DownSamplerB(16 +3, 64) 263 | 264 | self.level2 = nn.ModuleList() 265 | for i in range(0, p): 266 | self.level2.append(DilatedParllelResidualBlockB(64 , 64)) 267 | self.b2 = BR(128 + 3) 268 | 269 | self.level3_0 = DownSamplerB(128 + 3, 128) 270 | self.level3 = nn.ModuleList() 271 | for i in range(0, q): 272 | self.level3.append(DilatedParllelResidualBlockB(128 , 128)) 273 | self.b3 = BR(256) 274 | 275 | self.classifier = C(256, classes, 1, 1) 276 | 277 | def forward(self, input): 278 | ''' 279 | :param input: Receives the input RGB image 280 | :return: the transformed feature map with spatial dimensions 1/8th of the input image 281 | ''' 282 | output0 = self.level1(input) 283 | inp1 = self.sample1(input) 284 | inp2 = self.sample2(input) 285 | 286 | output0_cat = self.b1(torch.cat([output0, inp1], 1)) 287 | output1_0 = self.level2_0(output0_cat) # down-sampled 288 | 289 | for i, layer in enumerate(self.level2): 290 | if i==0: 291 | output1 = layer(output1_0) 292 | else: 293 | output1 = layer(output1) 294 | 295 | output1_cat = self.b2(torch.cat([output1, output1_0, inp2], 1)) 296 | 297 | output2_0 = self.level3_0(output1_cat) # down-sampled 298 | for i, layer in enumerate(self.level3): 299 | if i==0: 300 | output2 = layer(output2_0) 301 | else: 302 | output2 = layer(output2) 303 | 304 | output2_cat = self.b3(torch.cat([output2_0, output2], 1)) 305 | 306 | classifier = self.classifier(output2_cat) 307 | 308 | return classifier 309 | 310 | class ESPNet(nn.Module): 311 | ''' 312 | This class defines the ESPNet network 313 | ''' 314 | 315 | def __init__(self, classes=20, p=2, q=3, encoderFile=None): 316 | ''' 317 | :param classes: number of classes in the dataset. Default is 20 for the cityscapes 318 | :param p: depth multiplier 319 | :param q: depth multiplier 320 | :param encoderFile: pretrained encoder weights. Recall that we first trained the ESPNet-C and then attached the 321 | RUM-based light weight decoder. See paper for more details. 322 | ''' 323 | super().__init__() 324 | self.encoder = ESPNet_Encoder(classes, p, q) 325 | if encoderFile != None: 326 | self.encoder.load_state_dict(torch.load(encoderFile)) 327 | print('Encoder loaded!') 328 | # load the encoder modules 329 | self.modules = [] 330 | for i, m in enumerate(self.encoder.children()): 331 | self.modules.append(m) 332 | 333 | # light-weight decoder 334 | self.level3_C = C(128 + 3, classes, 1, 1) 335 | self.br = nn.BatchNorm2d(classes, eps=1e-03) 336 | self.conv = CBR(19 + classes, classes, 3, 1) 337 | 338 | self.up_l3 = nn.Sequential(nn.ConvTranspose2d(classes, classes, 2, stride=2, padding=0, output_padding=0, bias=False)) 339 | self.combine_l2_l3 = nn.Sequential(BR(2*classes), DilatedParllelResidualBlockB(2*classes , classes, add=False)) 340 | 341 | self.up_l2 = nn.Sequential(nn.ConvTranspose2d(classes, classes, 2, stride=2, padding=0, output_padding=0, bias=False), BR(classes)) 342 | 343 | self.classifier = nn.ConvTranspose2d(classes, classes, 2, stride=2, padding=0, output_padding=0, bias=False) 344 | 345 | def forward(self, input): 346 | ''' 347 | :param input: RGB image 348 | :return: transformed feature map 349 | ''' 350 | output0 = self.modules[0](input) 351 | inp1 = self.modules[1](input) 352 | inp2 = self.modules[2](input) 353 | 354 | output0_cat = self.modules[3](torch.cat([output0, inp1], 1)) 355 | output1_0 = self.modules[4](output0_cat) # down-sampled 356 | 357 | for i, layer in enumerate(self.modules[5]): 358 | if i == 0: 359 | output1 = layer(output1_0) 360 | else: 361 | output1 = layer(output1) 362 | 363 | output1_cat = self.modules[6](torch.cat([output1, output1_0, inp2], 1)) 364 | 365 | output2_0 = self.modules[7](output1_cat) # down-sampled 366 | for i, layer in enumerate(self.modules[8]): 367 | if i == 0: 368 | output2 = layer(output2_0) 369 | else: 370 | output2 = layer(output2) 371 | 372 | output2_cat = self.modules[9](torch.cat([output2_0, output2], 1)) # concatenate for feature map width expansion 373 | 374 | output2_c = self.up_l3(self.br(self.modules[10](output2_cat))) #RUM 375 | 376 | output1_C = self.level3_C(output1_cat) # project to C-dimensional space 377 | comb_l2_l3 = self.up_l2(self.combine_l2_l3(torch.cat([output1_C, output2_c], 1))) #RUM 378 | 379 | concat_features = self.conv(torch.cat([comb_l2_l3, output0_cat], 1)) 380 | 381 | classifier = self.classifier(concat_features) 382 | return classifier 383 | -------------------------------------------------------------------------------- /train/README.md: -------------------------------------------------------------------------------- 1 | # ESPNet: Towards Fast and Efficient Semantic Segmentation on the Embedded Devices 2 | 3 | This folder contains the python scripts for training models on the Cityscape dataset. 4 | 5 | 6 | ## Getting Started 7 | 8 | ### Training ESPNet-C 9 | 10 | You can start training the model using below command: 11 | 12 | ``` 13 | python main.py 14 | ``` 15 | 16 | By default, **ESPNet-C** will be trained with p=2 and q=8. Since the spatial dimensions of the output of ESPNet-C are 1/8th of original image size, please set scaleIn parameter to 8. If you want to change the parameters, you can do so by using the below command: 17 | 18 | ``` 19 | python main.py --scaleIn 8 --p --q 20 | 21 | Example: 22 | 23 | python main.py --scaleIn 8 --p 2 --q 8 24 | ``` 25 | 26 | ### Training ESPNet 27 | Once you are done training the ESPNet-C, you can attach the light-weight decoder and train the ESPNet model 28 | 29 | ``` 30 | python main.py --scaleIn 1 --p --q --decoder True --pretrained 31 | 32 | Example: 33 | 34 | python main.py --scaleIn 1 --p 2 --q 8 --decoder True --pretrained ../pretrained/encoder/espnet_p_2_q_8.pth 35 | ``` 36 | 37 | **Note 1:** Currently, we support only single GPU training. If you want to train the model on multiple-GPUs, you can use **nn.DataParallel** api provided by PyTorch. 38 | 39 | **Note 2:** To train on a specific GPU (single), you can specify the GPU_ID using the CUDA_VISIBLE_DEVICES as: 40 | 41 | ``` 42 | CUDA_VISIBLE_DEVICES=2 python main.py 43 | ``` 44 | 45 | This will run the training program on GPU with ID 2. 46 | -------------------------------------------------------------------------------- /train/Transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | import cv2 5 | 6 | __author__ = "Sachin Mehta" 7 | 8 | 9 | class Scale(object): 10 | """ 11 | Randomly crop and resize the given PIL image with a probability of 0.5 12 | """ 13 | def __init__(self, wi, he): 14 | ''' 15 | 16 | :param wi: width after resizing 17 | :param he: height after reszing 18 | ''' 19 | self.w = wi 20 | self.h = he 21 | 22 | def __call__(self, img, label): 23 | ''' 24 | :param img: RGB image 25 | :param label: semantic label image 26 | :return: resized images 27 | ''' 28 | #bilinear interpolation for RGB image 29 | img = cv2.resize(img, (self.w, self.h)) 30 | # nearest neighbour interpolation for label image 31 | label = cv2.resize(label, (self.w, self.h), interpolation=cv2.INTER_NEAREST) 32 | 33 | return [img, label] 34 | 35 | 36 | 37 | class RandomCropResize(object): 38 | """ 39 | Randomly crop and resize the given PIL image with a probability of 0.5 40 | """ 41 | def __init__(self, crop_area): 42 | ''' 43 | :param crop_area: area to be cropped (this is the max value and we select between o and crop area 44 | ''' 45 | self.cw = crop_area 46 | self.ch = crop_area 47 | 48 | def __call__(self, img, label): 49 | if random.random() < 0.5: 50 | h, w = img.shape[:2] 51 | x1 = random.randint(0, self.ch) 52 | y1 = random.randint(0, self.cw) 53 | 54 | img_crop = img[y1:h-y1, x1:w-x1] 55 | label_crop = label[y1:h-y1, x1:w-x1] 56 | 57 | img_crop = cv2.resize(img_crop, (w, h)) 58 | label_crop = cv2.resize(label_crop, (w,h), interpolation=cv2.INTER_NEAREST) 59 | return img_crop, label_crop 60 | else: 61 | return [img, label] 62 | 63 | class RandomCrop(object): 64 | ''' 65 | This class if for random cropping 66 | ''' 67 | def __init__(self, cropArea): 68 | ''' 69 | :param cropArea: amount of cropping (in pixels) 70 | ''' 71 | self.crop = cropArea 72 | 73 | def __call__(self, img, label): 74 | 75 | if random.random() < 0.5: 76 | h, w = img.shape[:2] 77 | img_crop = img[self.crop:h-self.crop, self.crop:w-self.crop] 78 | label_crop = label[self.crop:h-self.crop, self.crop:w-self.crop] 79 | return img_crop, label_crop 80 | else: 81 | return [img, label] 82 | 83 | 84 | 85 | class RandomFlip(object): 86 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5 87 | """ 88 | 89 | def __call__(self, image, label): 90 | if random.random() < 0.5: 91 | x1 = 0#random.randint(0, 1) #if you want to do vertical flip, uncomment this line 92 | if x1 == 0: 93 | image = cv2.flip(image, 0) # horizontal flip 94 | label = cv2.flip(label, 0) # horizontal flip 95 | else: 96 | image = cv2.flip(image, 1) # veritcal flip 97 | label = cv2.flip(label, 1) # veritcal flip 98 | return [image, label] 99 | 100 | 101 | class Normalize(object): 102 | """Given mean: (R, G, B) and std: (R, G, B), 103 | will normalize each channel of the torch.*Tensor, i.e. 104 | channel = (channel - mean) / std 105 | """ 106 | 107 | def __init__(self, mean, std): 108 | ''' 109 | :param mean: global mean computed from dataset 110 | :param std: global std computed from dataset 111 | ''' 112 | self.mean = mean 113 | self.std = std 114 | 115 | def __call__(self, image, label): 116 | image = image.astype(np.float32) 117 | for i in range(3): 118 | image[:,:,i] -= self.mean[i] 119 | for i in range(3): 120 | image[:,:, i] /= self.std[i] 121 | 122 | return [image, label] 123 | 124 | class ToTensor(object): 125 | ''' 126 | This class converts the data to tensor so that it can be processed by PyTorch 127 | ''' 128 | def __init__(self, scale=1): 129 | ''' 130 | :param scale: ESPNet-C's output is 1/8th of original image size, so set this parameter accordingly 131 | ''' 132 | self.scale = scale # original images are 2048 x 1024 133 | 134 | def __call__(self, image, label): 135 | 136 | if self.scale != 1: 137 | h, w = label.shape[:2] 138 | image = cv2.resize(image, (int(w), int(h))) 139 | label = cv2.resize(label, (int(w/self.scale), int(h/self.scale)), interpolation=cv2.INTER_NEAREST) 140 | 141 | image = image.transpose((2,0,1)) 142 | 143 | image_tensor = torch.from_numpy(image).div(255) 144 | label_tensor = torch.LongTensor(np.array(label, dtype=np.int)) #torch.from_numpy(label) 145 | 146 | return [image_tensor, label_tensor] 147 | 148 | class Compose(object): 149 | """Composes several transforms together. 150 | """ 151 | 152 | def __init__(self, transforms): 153 | self.transforms = transforms 154 | 155 | def __call__(self, *args): 156 | for t in self.transforms: 157 | args = t(*args) 158 | return args 159 | -------------------------------------------------------------------------------- /train/VisualizeGraph.py: -------------------------------------------------------------------------------- 1 | from graphviz import Digraph 2 | import torch 3 | from torch.autograd import Variable 4 | 5 | ''' 6 | Not written by me 7 | Copied from here: https://github.com/szagoruyko/pytorchviz 8 | ''' 9 | 10 | def make_dot(var, params=None): 11 | """ Produces Graphviz representation of PyTorch autograd graph 12 | Blue nodes are the Variables that require grad, orange are Tensors 13 | saved for backward in torch.autograd.Function 14 | Args: 15 | var: output Variable 16 | params: dict of (name, Variable) to add names to node that 17 | require grad (TODO: make optional) 18 | """ 19 | if params is not None: 20 | assert isinstance(params.values()[0], Variable) 21 | param_map = {id(v): k for k, v in params.items()} 22 | 23 | node_attr = dict(style='filled', 24 | shape='box', 25 | align='left', 26 | fontsize='12', 27 | ranksep='0.1', 28 | height='0.2') 29 | dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12")) 30 | seen = set() 31 | 32 | def size_to_str(size): 33 | return '('+(', ').join(['%d' % v for v in size])+')' 34 | 35 | def add_nodes(var): 36 | if var not in seen: 37 | if torch.is_tensor(var): 38 | dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange') 39 | elif hasattr(var, 'variable'): 40 | u = var.variable 41 | name = param_map[id(u)] if params is not None else '' 42 | node_name = '%s\n %s' % (name, size_to_str(u.size())) 43 | dot.node(str(id(var)), node_name, fillcolor='lightblue') 44 | else: 45 | dot.node(str(id(var)), str(type(var).__name__)) 46 | seen.add(var) 47 | if hasattr(var, 'next_functions'): 48 | for u in var.next_functions: 49 | if u[0] is not None: 50 | dot.edge(str(id(u[0])), str(id(var))) 51 | add_nodes(u[0]) 52 | if hasattr(var, 'saved_tensors'): 53 | for t in var.saved_tensors: 54 | dot.edge(str(id(t)), str(id(var))) 55 | add_nodes(t) 56 | add_nodes(var.grad_fn) 57 | return dot -------------------------------------------------------------------------------- /train/city/README.md: -------------------------------------------------------------------------------- 1 | # ESPNet: Towards Fast and Efficient Semantic Segmentation on the Embedded Devices 2 | 3 | This folder contains the data. 4 | 5 | ## Change to custom data location 6 | If your data is saved in a different directory, no worries. You can pass the path of the directory and the files will load the data from the specified directory location. 7 | 8 | ``` 9 | python main.py --data_dir 10 | ``` 11 | 12 | Please make sure that your directory contains the **train.txt** and **val.txt** files. Our code expects the names of images in a particular format 13 | 14 | ``` 15 | ,