├── .gitignore ├── ENet-Real_Time_Semantic_Segmentation.ipynb ├── LICENSE ├── README.md ├── init.py ├── models ├── ASNeck.py ├── ENet.py ├── InitialBlock.py ├── RDDNeck.py ├── UBNeck.py └── __init__.py ├── test.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.swo 2 | *.png 3 | *.jpg 4 | *.zip 5 | *.swp 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # pyenv 82 | .python-version 83 | 84 | # celery beat schedule file 85 | celerybeat-schedule 86 | 87 | # SageMath parsed files 88 | *.sage.py 89 | 90 | # Environments 91 | .env 92 | .venv 93 | env/ 94 | venv/ 95 | ENV/ 96 | env.bak/ 97 | venv.bak/ 98 | 99 | # Spyder project settings 100 | .spyderproject 101 | .spyproject 102 | 103 | # Rope project settings 104 | .ropeproject 105 | 106 | # mkdocs documentation 107 | /site 108 | 109 | # mypy 110 | .mypy_cache/ 111 | .idea/ 112 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019, Arunava 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ENet - Real Time Semantic Segmentation 2 | 3 | A Neural Net Architecture for real time Semantic Segmentation.
4 | In this repository we have reproduced the ENet Paper - Which can be used on 5 | mobile devices for real time semantic segmentattion. The link to the paper can be found here: [ENet](https://arxiv.org/pdf/1606.02147.pdf) 6 | 7 | ## How to use? 8 | 9 | 0. This repository comes in with a handy notebook which you can use with Colab.
10 | You can find a link to the notebook here: [ 11 | ENet - Real Time Semantic Segmentation](https://github.com/iArunava/ENet-Real-Time-Semantic-Segmentation/blob/master/ENet-Real%20Time%20Semantic%20Segmentation.ipynb)
12 | Open it in colab: [Open in Colab](https://colab.research.google.com/github/iArunava/ENet-Real-Time-Semantic-Segmentation/blob/master/ENet-Real%20Time%20Semantic%20Segmentation.ipynb) 13 | 14 | --- 15 | 16 | 17 | 0. Clone the repository and cd into it 18 | ``` 19 | git clone https://github.com/iArunava/ENet-Real-Time-Semantic-Segmentation.git 20 | cd ENet-Real-Time-Semantic-Segmentation/ 21 | ``` 22 | 23 | 1. Use this command to train the model 24 | ``` 25 | python3 init.py --mode train -iptr path/to/train/input/set/ -lptr /path/to/label/set/ 26 | ``` 27 | 28 | 2. Use this command to test the model 29 | ``` 30 | python3 init.py --mode test -m /path/to/the/pretrained/model.pth -i /path/to/image/to/infer.png 31 | ``` 32 | 33 | 3. Use `--help` to get more commands 34 | ``` 35 | python3 init.py --help 36 | ``` 37 | 38 | ## Some results 39 | 40 | ![enet infer 1](https://user-images.githubusercontent.com/26242097/51782315-4b88d300-214c-11e9-9c92-3444c6582a80.png) 41 | ![enet infer 4](https://user-images.githubusercontent.com/26242097/51782341-a02c4e00-214c-11e9-8566-f2092ddad086.png) 42 | ![enet infer 6](https://user-images.githubusercontent.com/26242097/51782371-01542180-214d-11e9-80b8-55807f83f776.png) 43 | ![enet infer 5](https://user-images.githubusercontent.com/26242097/51782353-c3ef9400-214c-11e9-8c66-276795c83f08.png) 44 | ![enet infer 2](https://user-images.githubusercontent.com/26242097/51782324-6b1ffb80-214c-11e9-9f92-741954699f4d.png) 45 | 46 | ## References 47 | 1. A. Paszke, A. Chaurasia, S. Kim, and E. Culurciello. 48 | Enet: A deep neural network architecture 49 | for real-time semantic segmentation. arXiv preprint 50 | arXiv:1606.02147, 2016. 51 | 52 | ## Citations 53 | 54 | ``` 55 | @inproceedings{ BrostowSFC:ECCV08, 56 | author = {Gabriel J. Brostow and Jamie Shotton and Julien Fauqueur and Roberto Cipolla}, 57 | title = {Segmentation and Recognition Using Structure from Motion Point Clouds}, 58 | booktitle = {ECCV (1)}, 59 | year = {2008}, 60 | pages = {44-57} 61 | } 62 | 63 | @article{ BrostowFC:PRL2008, 64 | author = "Gabriel J. Brostow and Julien Fauqueur and Roberto Cipolla", 65 | title = "Semantic Object Classes in Video: A High-Definition Ground Truth Database", 66 | journal = "Pattern Recognition Letters", 67 | volume = "xx", 68 | number = "x", 69 | pages = "xx-xx", 70 | year = "2008" 71 | } 72 | ``` 73 | 74 | ## License 75 | 76 | The code in this repository is distributed under the BSD v3 Licemse.
77 | Feel free to fork and enjoy :) 78 | -------------------------------------------------------------------------------- /init.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | from train import * 4 | from test import * 5 | 6 | color_map = { 7 | 'unlabeled' : ( 0, 0, 0), 8 | 'dynamic' : (111, 74, 0), 9 | 'ground' : ( 81, 0, 81), 10 | 'road' : (128, 64,128), 11 | 'sidewalk' : (244, 35,232), 12 | 'parking' : (250,170,160), 13 | 'rail track' : (230,150,140), 14 | 'building' : ( 70, 70, 70), 15 | 'wall' : (102,102,156), 16 | 'fence' : (190,153,153), 17 | 'guard rail' : (180,165,180), 18 | 'bridge' : (150,100,100), 19 | 'tunnel' : (150,120, 90), 20 | 'pole' : (153,153,153), 21 | 'traffic light' : (250,170, 30), 22 | 'traffic sign' : (220,220, 0), 23 | 'vegetation' : (107,142, 35), 24 | 'terrain' : (152,251,152), 25 | 'sky' : ( 70,130,180), 26 | 'person' : (220, 20, 60), 27 | 'rider' : (255, 0, 0), 28 | 'car' : ( 0, 0,142), 29 | 'truck' : ( 0, 0, 70), 30 | 'bus' : ( 0, 60,100), 31 | 'caravan' : ( 0, 0, 90), 32 | 'trailer' : ( 0, 0,110), 33 | 'train' : ( 0, 80,100), 34 | 'motorcycle' : ( 0, 0,230), 35 | 'bicycle' : (119, 11, 32) 36 | } 37 | 38 | if __name__ == '__main__': 39 | parser = argparse.ArgumentParser() 40 | 41 | parser.add_argument('-m', 42 | type=str, 43 | default='./datasets/CamVid/ckpt-camvid-enet.pth', 44 | help='The path to the pretrained enet model') 45 | 46 | parser.add_argument('-i', '--image-path', 47 | type=str, 48 | help='The path to the image to perform semantic segmentation') 49 | 50 | parser.add_argument('-rh', '--resize-height', 51 | type=int, 52 | default=512, 53 | help='The height for the resized image') 54 | 55 | parser.add_argument('-rw', '--resize-width', 56 | type=int, 57 | default=512, 58 | help='The width for the resized image') 59 | 60 | parser.add_argument('-lr', '--learning-rate', 61 | type=float, 62 | default=5e-4, 63 | help='The learning rate') 64 | 65 | parser.add_argument('-bs', '--batch-size', 66 | type=int, 67 | default=10, 68 | help='The batch size') 69 | 70 | parser.add_argument('-wd', '--weight-decay', 71 | type=float, 72 | default=2e-4, 73 | help='The weight decay') 74 | 75 | parser.add_argument('-c', '--constant', 76 | type=float, 77 | default=1.02, 78 | help='The constant used for calculating the class weights') 79 | 80 | parser.add_argument('-e', '--epochs', 81 | type=int, 82 | default=102, 83 | help='The number of epochs') 84 | 85 | parser.add_argument('-nc', '--num-classes', 86 | type=int, 87 | default=12, 88 | help='Number of unique classes') 89 | 90 | parser.add_argument('-se', '--save-every', 91 | type=int, 92 | default=10, 93 | help='The number of epochs after which to save a model') 94 | 95 | parser.add_argument('-iptr', '--input-path-train', 96 | type=str, 97 | default='./datasets/CamVid/train/', 98 | help='The path to the input dataset') 99 | 100 | parser.add_argument('-lptr', '--label-path-train', 101 | type=str, 102 | default='./datasets/CamVid/trainannot/', 103 | help='The path to the label dataset') 104 | 105 | parser.add_argument('-ipv', '--input-path-val', 106 | type=str, 107 | default='./datasets/CamVid/val/', 108 | help='The path to the input dataset') 109 | 110 | parser.add_argument('-lpv', '--label-path-val', 111 | type=str, 112 | default='./datasets/CamVid/valannot/', 113 | help='The path to the label dataset') 114 | 115 | parser.add_argument('-iptt', '--input-path-test', 116 | type=str, 117 | default='./datasets/CamVid/test/', 118 | help='The path to the input dataset') 119 | 120 | parser.add_argument('-lptt', '--label-path-test', 121 | type=str, 122 | default='./datasets/CamVid/testannot/', 123 | help='The path to the label dataset') 124 | 125 | parser.add_argument('-pe', '--print-every', 126 | type=int, 127 | default=1, 128 | help='The number of epochs after which to print the training loss') 129 | 130 | parser.add_argument('-ee', '--eval-every', 131 | type=int, 132 | default=10, 133 | help='The number of epochs after which to print the validation loss') 134 | 135 | parser.add_argument('--cuda', 136 | type=bool, 137 | default=False, 138 | help='Whether to use cuda or not') 139 | 140 | parser.add_argument('--mode', 141 | choices=['train', 'test'], 142 | default='train', 143 | help='Whether to train or test') 144 | 145 | FLAGS, unparsed = parser.parse_known_args() 146 | 147 | FLAGS.cuda = torch.device('cuda:0' if torch.cuda.is_available() and FLAGS.cuda \ 148 | else 'cpu') 149 | 150 | if FLAGS.mode.lower() == 'train': 151 | train(FLAGS) 152 | elif FLAGS.mode.lower() == 'test': 153 | test(FLAGS) 154 | else: 155 | raise RuntimeError('Unknown mode passed. \n Mode passed should be either \ 156 | of "train" or "test"') 157 | -------------------------------------------------------------------------------- /models/ASNeck.py: -------------------------------------------------------------------------------- 1 | ################################################### 2 | # Copyright (c) 2019 # 3 | # Authors: @iArunava # 4 | # @AvivSham # 5 | # # 6 | # License: BSD License 3.0 # 7 | # # 8 | # The Code in this file is distributed for free # 9 | # usage and modification with proper linkage back # 10 | # to this repository. # 11 | ################################################### 12 | 13 | import torch 14 | import torch.nn as nn 15 | 16 | class ASNeck(nn.Module): 17 | def __init__(self, in_channels, out_channels, projection_ratio=4): 18 | 19 | super().__init__() 20 | 21 | # Define class variables 22 | self.in_channels = in_channels 23 | self.reduced_depth = int(in_channels / projection_ratio) 24 | self.out_channels = out_channels 25 | 26 | self.dropout = nn.Dropout2d(p=0.1) 27 | 28 | self.conv1 = nn.Conv2d(in_channels = self.in_channels, 29 | out_channels = self.reduced_depth, 30 | kernel_size = 1, 31 | stride = 1, 32 | padding = 0, 33 | bias = False) 34 | 35 | self.prelu1 = nn.PReLU() 36 | 37 | self.conv21 = nn.Conv2d(in_channels = self.reduced_depth, 38 | out_channels = self.reduced_depth, 39 | kernel_size = (1, 5), 40 | stride = 1, 41 | padding = (0, 2), 42 | bias = False) 43 | 44 | self.conv22 = nn.Conv2d(in_channels = self.reduced_depth, 45 | out_channels = self.reduced_depth, 46 | kernel_size = (5, 1), 47 | stride = 1, 48 | padding = (2, 0), 49 | bias = False) 50 | 51 | self.prelu2 = nn.PReLU() 52 | 53 | self.conv3 = nn.Conv2d(in_channels = self.reduced_depth, 54 | out_channels = self.out_channels, 55 | kernel_size = 1, 56 | stride = 1, 57 | padding = 0, 58 | bias = False) 59 | 60 | self.prelu3 = nn.PReLU() 61 | 62 | self.batchnorm1 = nn.BatchNorm2d(self.reduced_depth) 63 | self.batchnorm2 = nn.BatchNorm2d(self.reduced_depth) 64 | self.batchnorm3 = nn.BatchNorm2d(self.out_channels) 65 | 66 | def forward(self, x): 67 | bs = x.size()[0] 68 | x_copy = x 69 | 70 | # Side Branch 71 | x = self.conv1(x) 72 | x = self.batchnorm1(x) 73 | x = self.prelu1(x) 74 | 75 | x = self.conv21(x) 76 | x = self.conv22(x) 77 | x = self.batchnorm2(x) 78 | x = self.prelu2(x) 79 | 80 | x = self.conv3(x) 81 | 82 | x = self.dropout(x) 83 | x = self.batchnorm3(x) 84 | 85 | # Main Branch 86 | 87 | if self.in_channels != self.out_channels: 88 | out_shape = self.out_channels - self.in_channels 89 | extras = torch.zeros((bs, out_shape, x.shape[2], x.shape[3])) 90 | if torch.cuda.is_available(): 91 | extras = extras.cuda() 92 | x_copy = torch.cat((x_copy, extras), dim = 1) 93 | 94 | # Sum of main and side branches 95 | x = x + x_copy 96 | x = self.prelu3(x) 97 | 98 | return x 99 | -------------------------------------------------------------------------------- /models/ENet.py: -------------------------------------------------------------------------------- 1 | ################################################################## 2 | # Reproducing the paper # 3 | # ENet - Real Time Semantic Segmentation # 4 | # Paper: https://arxiv.org/pdf/1606.02147.pdf # 5 | # # 6 | # Copyright (c) 2019 # 7 | # Authors: @iArunava # 8 | # @AvivSham # 9 | # # 10 | # License: BSD License 3.0 # 11 | # # 12 | # The Code in this file is distributed for free # 13 | # usage and modification with proper credits # 14 | # directing back to this repository. # 15 | ################################################################## 16 | 17 | import torch 18 | import torch.nn as nn 19 | from .InitialBlock import InitialBlock 20 | from .RDDNeck import RDDNeck 21 | from .UBNeck import UBNeck 22 | from .ASNeck import ASNeck 23 | 24 | class ENet(nn.Module): 25 | def __init__(self, C): 26 | super().__init__() 27 | 28 | # Define class variables 29 | self.C = C 30 | 31 | # The initial block 32 | self.init = InitialBlock() 33 | 34 | 35 | # The first bottleneck 36 | self.b10 = RDDNeck(dilation=1, 37 | in_channels=16, 38 | out_channels=64, 39 | down_flag=True, 40 | p=0.01) 41 | 42 | self.b11 = RDDNeck(dilation=1, 43 | in_channels=64, 44 | out_channels=64, 45 | down_flag=False, 46 | p=0.01) 47 | 48 | self.b12 = RDDNeck(dilation=1, 49 | in_channels=64, 50 | out_channels=64, 51 | down_flag=False, 52 | p=0.01) 53 | 54 | self.b13 = RDDNeck(dilation=1, 55 | in_channels=64, 56 | out_channels=64, 57 | down_flag=False, 58 | p=0.01) 59 | 60 | self.b14 = RDDNeck(dilation=1, 61 | in_channels=64, 62 | out_channels=64, 63 | down_flag=False, 64 | p=0.01) 65 | 66 | 67 | # The second bottleneck 68 | self.b20 = RDDNeck(dilation=1, 69 | in_channels=64, 70 | out_channels=128, 71 | down_flag=True) 72 | 73 | self.b21 = RDDNeck(dilation=1, 74 | in_channels=128, 75 | out_channels=128, 76 | down_flag=False) 77 | 78 | self.b22 = RDDNeck(dilation=2, 79 | in_channels=128, 80 | out_channels=128, 81 | down_flag=False) 82 | 83 | self.b23 = ASNeck(in_channels=128, 84 | out_channels=128) 85 | 86 | self.b24 = RDDNeck(dilation=4, 87 | in_channels=128, 88 | out_channels=128, 89 | down_flag=False) 90 | 91 | self.b25 = RDDNeck(dilation=1, 92 | in_channels=128, 93 | out_channels=128, 94 | down_flag=False) 95 | 96 | self.b26 = RDDNeck(dilation=8, 97 | in_channels=128, 98 | out_channels=128, 99 | down_flag=False) 100 | 101 | self.b27 = ASNeck(in_channels=128, 102 | out_channels=128) 103 | 104 | self.b28 = RDDNeck(dilation=16, 105 | in_channels=128, 106 | out_channels=128, 107 | down_flag=False) 108 | 109 | 110 | # The third bottleneck 111 | self.b31 = RDDNeck(dilation=1, 112 | in_channels=128, 113 | out_channels=128, 114 | down_flag=False) 115 | 116 | self.b32 = RDDNeck(dilation=2, 117 | in_channels=128, 118 | out_channels=128, 119 | down_flag=False) 120 | 121 | self.b33 = ASNeck(in_channels=128, 122 | out_channels=128) 123 | 124 | self.b34 = RDDNeck(dilation=4, 125 | in_channels=128, 126 | out_channels=128, 127 | down_flag=False) 128 | 129 | self.b35 = RDDNeck(dilation=1, 130 | in_channels=128, 131 | out_channels=128, 132 | down_flag=False) 133 | 134 | self.b36 = RDDNeck(dilation=8, 135 | in_channels=128, 136 | out_channels=128, 137 | down_flag=False) 138 | 139 | self.b37 = ASNeck(in_channels=128, 140 | out_channels=128) 141 | 142 | self.b38 = RDDNeck(dilation=16, 143 | in_channels=128, 144 | out_channels=128, 145 | down_flag=False) 146 | 147 | 148 | # The fourth bottleneck 149 | self.b40 = UBNeck(in_channels=128, 150 | out_channels=64, 151 | relu=True) 152 | 153 | self.b41 = RDDNeck(dilation=1, 154 | in_channels=64, 155 | out_channels=64, 156 | down_flag=False, 157 | relu=True) 158 | 159 | self.b42 = RDDNeck(dilation=1, 160 | in_channels=64, 161 | out_channels=64, 162 | down_flag=False, 163 | relu=True) 164 | 165 | 166 | # The fifth bottleneck 167 | self.b50 = UBNeck(in_channels=64, 168 | out_channels=16, 169 | relu=True) 170 | 171 | self.b51 = RDDNeck(dilation=1, 172 | in_channels=16, 173 | out_channels=16, 174 | down_flag=False, 175 | relu=True) 176 | 177 | 178 | # Final ConvTranspose Layer 179 | self.fullconv = nn.ConvTranspose2d(in_channels=16, 180 | out_channels=self.C, 181 | kernel_size=3, 182 | stride=2, 183 | padding=1, 184 | output_padding=1, 185 | bias=False) 186 | 187 | 188 | def forward(self, x): 189 | 190 | # The initial block 191 | x = self.init(x) 192 | 193 | # The first bottleneck 194 | x, i1 = self.b10(x) 195 | x = self.b11(x) 196 | x = self.b12(x) 197 | x = self.b13(x) 198 | x = self.b14(x) 199 | 200 | # The second bottleneck 201 | x, i2 = self.b20(x) 202 | x = self.b21(x) 203 | x = self.b22(x) 204 | x = self.b23(x) 205 | x = self.b24(x) 206 | x = self.b25(x) 207 | x = self.b26(x) 208 | x = self.b27(x) 209 | x = self.b28(x) 210 | 211 | # The third bottleneck 212 | x = self.b31(x) 213 | x = self.b32(x) 214 | x = self.b33(x) 215 | x = self.b34(x) 216 | x = self.b35(x) 217 | x = self.b36(x) 218 | x = self.b37(x) 219 | x = self.b38(x) 220 | 221 | # The fourth bottleneck 222 | x = self.b40(x, i2) 223 | x = self.b41(x) 224 | x = self.b42(x) 225 | 226 | # The fifth bottleneck 227 | x = self.b50(x, i1) 228 | x = self.b51(x) 229 | 230 | # Final ConvTranspose Layer 231 | x = self.fullconv(x) 232 | 233 | return x 234 | -------------------------------------------------------------------------------- /models/InitialBlock.py: -------------------------------------------------------------------------------- 1 | ################################################### 2 | # Copyright (c) 2019 # 3 | # Authors: @iArunava # 4 | # @AvivSham # 5 | # # 6 | # License: BSD License 3.0 # 7 | # # 8 | # The Code in this file is distributed for free # 9 | # usage and modification with proper linkage back # 10 | # to this repository. # 11 | ################################################### 12 | 13 | import torch 14 | import torch.nn as nn 15 | 16 | class InitialBlock(nn.Module): 17 | def __init__ (self,in_channels = 3,out_channels = 13): 18 | super().__init__() 19 | 20 | 21 | self.maxpool = nn.MaxPool2d(kernel_size=2, 22 | stride = 2, 23 | padding = 0) 24 | 25 | self.conv = nn.Conv2d(in_channels, 26 | out_channels, 27 | kernel_size = 3, 28 | stride = 2, 29 | padding = 1) 30 | 31 | self.prelu = nn.PReLU(16) 32 | 33 | self.batchnorm = nn.BatchNorm2d(out_channels) 34 | 35 | def forward(self, x): 36 | 37 | main = self.conv(x) 38 | main = self.batchnorm(main) 39 | 40 | side = self.maxpool(x) 41 | 42 | x = torch.cat((main, side), dim=1) 43 | x = self.prelu(x) 44 | 45 | return x 46 | -------------------------------------------------------------------------------- /models/RDDNeck.py: -------------------------------------------------------------------------------- 1 | ################################################### 2 | # Copyright (c) 2019 # 3 | # Authors: @iArunava # 4 | # @AvivSham # 5 | # # 6 | # License: BSD License 3.0 # 7 | # # 8 | # The Code in this file is distributed for free # 9 | # usage and modification with proper linkage back # 10 | # to this repository. # 11 | ################################################### 12 | 13 | import torch 14 | import torch.nn as nn 15 | 16 | 17 | class RDDNeck(nn.Module): 18 | def __init__(self, dilation, in_channels, out_channels, down_flag, relu=False, projection_ratio=4, p=0.1): 19 | 20 | super().__init__() 21 | 22 | # Define class variables 23 | self.in_channels = in_channels 24 | 25 | self.out_channels = out_channels 26 | self.dilation = dilation 27 | self.down_flag = down_flag 28 | 29 | if down_flag: 30 | self.stride = 2 31 | self.reduced_depth = int(in_channels // projection_ratio) 32 | else: 33 | self.stride = 1 34 | self.reduced_depth = int(out_channels // projection_ratio) 35 | 36 | if relu: 37 | activation = nn.ReLU() 38 | else: 39 | activation = nn.PReLU() 40 | 41 | self.maxpool = nn.MaxPool2d(kernel_size = 2, 42 | stride = 2, 43 | padding = 0, return_indices=True) 44 | 45 | 46 | 47 | self.dropout = nn.Dropout2d(p=p) 48 | 49 | self.conv1 = nn.Conv2d(in_channels = self.in_channels, 50 | out_channels = self.reduced_depth, 51 | kernel_size = 1, 52 | stride = 1, 53 | padding = 0, 54 | bias = False, 55 | dilation = 1) 56 | 57 | self.prelu1 = activation 58 | 59 | self.conv2 = nn.Conv2d(in_channels = self.reduced_depth, 60 | out_channels = self.reduced_depth, 61 | kernel_size = 3, 62 | stride = self.stride, 63 | padding = self.dilation, 64 | bias = True, 65 | dilation = self.dilation) 66 | 67 | self.prelu2 = activation 68 | 69 | self.conv3 = nn.Conv2d(in_channels = self.reduced_depth, 70 | out_channels = self.out_channels, 71 | kernel_size = 1, 72 | stride = 1, 73 | padding = 0, 74 | bias = False, 75 | dilation = 1) 76 | 77 | self.prelu3 = activation 78 | 79 | self.batchnorm1 = nn.BatchNorm2d(self.reduced_depth) 80 | self.batchnorm2 = nn.BatchNorm2d(self.reduced_depth) 81 | self.batchnorm3 = nn.BatchNorm2d(self.out_channels) 82 | 83 | 84 | def forward(self, x): 85 | 86 | bs = x.size()[0] 87 | x_copy = x 88 | 89 | # Side Branch 90 | x = self.conv1(x) 91 | x = self.batchnorm1(x) 92 | x = self.prelu1(x) 93 | 94 | x = self.conv2(x) 95 | x = self.batchnorm2(x) 96 | x = self.prelu2(x) 97 | 98 | x = self.conv3(x) 99 | x = self.batchnorm3(x) 100 | 101 | x = self.dropout(x) 102 | 103 | # Main Branch 104 | if self.down_flag: 105 | x_copy, indices = self.maxpool(x_copy) 106 | 107 | if self.in_channels != self.out_channels: 108 | out_shape = self.out_channels - self.in_channels 109 | extras = torch.zeros((bs, out_shape, x.shape[2], x.shape[3])) 110 | if torch.cuda.is_available(): 111 | extras = extras.cuda() 112 | x_copy = torch.cat((x_copy, extras), dim = 1) 113 | 114 | # Sum of main and side branches 115 | x = x + x_copy 116 | x = self.prelu3(x) 117 | 118 | if self.down_flag: 119 | return x, indices 120 | else: 121 | return x 122 | -------------------------------------------------------------------------------- /models/UBNeck.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class UBNeck(nn.Module): 5 | def __init__(self, in_channels, out_channels, relu=False, projection_ratio=4): 6 | 7 | super().__init__() 8 | 9 | # Define class variables 10 | self.in_channels = in_channels 11 | self.reduced_depth = int(in_channels / projection_ratio) 12 | self.out_channels = out_channels 13 | 14 | 15 | if relu: 16 | activation = nn.ReLU() 17 | else: 18 | activation = nn.PReLU() 19 | 20 | self.unpool = nn.MaxUnpool2d(kernel_size = 2, 21 | stride = 2) 22 | 23 | self.main_conv = nn.Conv2d(in_channels = self.in_channels, 24 | out_channels = self.out_channels, 25 | kernel_size = 1) 26 | 27 | self.dropout = nn.Dropout2d(p=0.1) 28 | 29 | self.convt1 = nn.ConvTranspose2d(in_channels = self.in_channels, 30 | out_channels = self.reduced_depth, 31 | kernel_size = 1, 32 | padding = 0, 33 | bias = False) 34 | 35 | 36 | self.prelu1 = activation 37 | 38 | self.convt2 = nn.ConvTranspose2d(in_channels = self.reduced_depth, 39 | out_channels = self.reduced_depth, 40 | kernel_size = 3, 41 | stride = 2, 42 | padding = 1, 43 | output_padding = 1, 44 | bias = False) 45 | 46 | self.prelu2 = activation 47 | 48 | self.convt3 = nn.ConvTranspose2d(in_channels = self.reduced_depth, 49 | out_channels = self.out_channels, 50 | kernel_size = 1, 51 | padding = 0, 52 | bias = False) 53 | 54 | self.prelu3 = activation 55 | 56 | self.batchnorm1 = nn.BatchNorm2d(self.reduced_depth) 57 | self.batchnorm2 = nn.BatchNorm2d(self.reduced_depth) 58 | self.batchnorm3 = nn.BatchNorm2d(self.out_channels) 59 | 60 | def forward(self, x, indices): 61 | x_copy = x 62 | 63 | # Side Branch 64 | x = self.convt1(x) 65 | x = self.batchnorm1(x) 66 | x = self.prelu1(x) 67 | 68 | x = self.convt2(x) 69 | x = self.batchnorm2(x) 70 | x = self.prelu2(x) 71 | 72 | x = self.convt3(x) 73 | x = self.batchnorm3(x) 74 | 75 | x = self.dropout(x) 76 | 77 | # Main Branch 78 | 79 | x_copy = self.main_conv(x_copy) 80 | x_copy = self.unpool(x_copy, indices, output_size=x.size()) 81 | 82 | # Concat 83 | x = x + x_copy 84 | x = self.prelu3(x) 85 | 86 | return x 87 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iArunava/ENet-Real-Time-Semantic-Segmentation/8e3e86c4c4eb8392d72962e393d992294d8fc8ae/models/__init__.py -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from utils import * 4 | from models.ENet import ENet 5 | import sys 6 | import os 7 | from tqdm import tqdm 8 | import matplotlib.pyplot as plt 9 | 10 | 11 | def test(FLAGS): 12 | # Check if the pretrained model is available 13 | if not FLAGS.m.endswith('.pth'): 14 | raise RuntimeError('Unknown file passed. Must end with .pth') 15 | if FLAGS.image_path is None or not os.path.exists(FLAGS.image_path): 16 | raise RuntimeError('An image file path must be passed') 17 | 18 | h = FLAGS.resize_height 19 | w = FLAGS.resize_width 20 | 21 | checkpoint = torch.load(FLAGS.m, map_location=FLAGS.cuda) 22 | 23 | # Assuming the dataset is camvid 24 | enet = ENet(FLAGS.num_classes) 25 | enet.load_state_dict(checkpoint['state_dict']) 26 | 27 | tmg_ = plt.imread(FLAGS.image_path) 28 | tmg_ = cv2.resize(tmg_, (h, w), cv2.INTER_NEAREST) 29 | tmg = torch.tensor(tmg_).unsqueeze(0).float() 30 | tmg = tmg.transpose(2, 3).transpose(1, 2) 31 | 32 | with torch.no_grad(): 33 | out1 = enet(tmg.float()).squeeze(0) 34 | 35 | #smg_ = Image.open('/content/training/semantic/' + fname) 36 | #smg_ = cv2.resize(np.array(smg_), (512, 512), cv2.INTER_NEAREST) 37 | 38 | b_ = out1.data.max(0)[1].cpu().numpy() 39 | 40 | decoded_segmap = decode_segmap(b_) 41 | 42 | images = { 43 | 0 : ['Input Image', tmg_], 44 | 1 : ['Predicted Segmentation', b_], 45 | } 46 | 47 | show_images(images) 48 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from utils import * 4 | from models.ENet import ENet 5 | import sys 6 | from tqdm import tqdm 7 | 8 | def train(FLAGS): 9 | 10 | # Defining the hyperparameters 11 | device = FLAGS.cuda 12 | batch_size = FLAGS.batch_size 13 | epochs = FLAGS.epochs 14 | lr = FLAGS.learning_rate 15 | print_every = FLAGS.print_every 16 | eval_every = FLAGS.eval_every 17 | save_every = FLAGS.save_every 18 | nc = FLAGS.num_classes 19 | wd = FLAGS.weight_decay 20 | ip = FLAGS.input_path_train 21 | lp = FLAGS.label_path_train 22 | ipv = FLAGS.input_path_val 23 | lpv = FLAGS.label_path_val 24 | print ('[INFO]Defined all the hyperparameters successfully!') 25 | 26 | # Get the class weights 27 | print ('[INFO]Starting to define the class weights...') 28 | pipe = loader(ip, lp, batch_size='all') 29 | class_weights = get_class_weights(pipe, nc) 30 | print ('[INFO]Fetched all class weights successfully!') 31 | 32 | # Get an instance of the model 33 | enet = ENet(nc) 34 | print ('[INFO]Model Instantiated!') 35 | 36 | # Move the model to cuda if available 37 | enet = enet.to(device) 38 | 39 | # Define the criterion and the optimizer 40 | criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor(class_weights).to(device)) 41 | optimizer = torch.optim.Adam(enet.parameters(), 42 | lr=lr, 43 | weight_decay=wd) 44 | print ('[INFO]Defined the loss function and the optimizer') 45 | 46 | # Training Loop starts 47 | print ('[INFO]Staring Training...') 48 | print () 49 | 50 | train_losses = [] 51 | eval_losses = [] 52 | 53 | # Assuming we are using the CamVid Dataset 54 | bc_train = 367 // batch_size 55 | bc_eval = 101 // batch_size 56 | 57 | pipe = loader(ip, lp, batch_size) 58 | eval_pipe = loader(ipv, lpv, batch_size) 59 | 60 | epochs = epochs 61 | 62 | for e in range(1, epochs+1): 63 | 64 | train_loss = 0 65 | print ('-'*15,'Epoch %d' % e, '-'*15) 66 | 67 | enet.train() 68 | 69 | for _ in tqdm(range(bc_train)): 70 | X_batch, mask_batch = next(pipe) 71 | 72 | #assert (X_batch >= 0. and X_batch <= 1.0).all() 73 | 74 | X_batch, mask_batch = X_batch.to(device), mask_batch.to(device) 75 | 76 | optimizer.zero_grad() 77 | 78 | out = enet(X_batch.float()) 79 | 80 | loss = criterion(out, mask_batch.long()) 81 | loss.backward() 82 | optimizer.step() 83 | 84 | train_loss += loss.item() 85 | 86 | 87 | print () 88 | train_losses.append(train_loss) 89 | 90 | if (e+1) % print_every == 0: 91 | print ('Epoch {}/{}...'.format(e, epochs), 92 | 'Loss {:6f}'.format(train_loss)) 93 | 94 | if e % eval_every == 0: 95 | with torch.no_grad(): 96 | enet.eval() 97 | 98 | eval_loss = 0 99 | 100 | for _ in tqdm(range(bc_eval)): 101 | inputs, labels = next(eval_pipe) 102 | 103 | inputs, labels = inputs.to(device), labels.to(device) 104 | out = enet(inputs) 105 | 106 | loss = criterion(out, labels.long()) 107 | 108 | eval_loss += loss.item() 109 | 110 | print () 111 | print ('Loss {:6f}'.format(eval_loss)) 112 | 113 | eval_losses.append(eval_loss) 114 | 115 | if e % save_every == 0: 116 | checkpoint = { 117 | 'epochs' : e, 118 | 'state_dict' : enet.state_dict() 119 | } 120 | torch.save(checkpoint, './ckpt-enet-{}-{}.pth'.format(e, train_loss)) 121 | print ('Model saved!') 122 | 123 | print ('Epoch {}/{}...'.format(e+1, epochs), 124 | 'Total Mean Loss: {:6f}'.format(sum(train_losses) / epochs)) 125 | 126 | print ('[INFO]Training Process complete!') 127 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import matplotlib.pyplot as plt 4 | import os 5 | from PIL import Image 6 | import torch 7 | 8 | def create_class_mask(img, color_map, is_normalized_img=True, is_normalized_map=False, show_masks=False): 9 | """ 10 | Function to create C matrices from the segmented image, where each of the C matrices is for one class 11 | with all ones at the pixel positions where that class is present 12 | 13 | img = The segmented image 14 | 15 | color_map = A list with tuples that contains all the RGB values for each color that represents 16 | some class in that image 17 | 18 | is_normalized_img = Boolean - Whether the image is normalized or not 19 | If normalized, then the image is multiplied with 255 20 | 21 | is_normalized_map = Boolean - Represents whether the color map is normalized or not, if so 22 | then the color map values are multiplied with 255 23 | 24 | show_masks = Wherether to show the created masks or not 25 | """ 26 | 27 | if is_normalized_img and (not is_normalized_map): 28 | img *= 255 29 | 30 | if is_normalized_map and (not is_normalized_img): 31 | img = img / 255 32 | 33 | mask = [] 34 | hw_tuple = img.shape[:-1] 35 | for color in color_map: 36 | color_img = [] 37 | for idx in range(3): 38 | color_img.append(np.ones(hw_tuple) * color[idx]) 39 | 40 | color_img = np.array(color_img, dtype=np.uint8).transpose(1, 2, 0) 41 | 42 | mask.append(np.uint8((color_img == img).sum(axis = -1) == 3)) 43 | 44 | return np.array(mask) 45 | 46 | 47 | def loader(training_path, segmented_path, batch_size, h=512, w=512): 48 | """ 49 | The Loader to generate inputs and labels from the Image and Segmented Directory 50 | 51 | Arguments: 52 | 53 | training_path - str - Path to the directory that contains the training images 54 | 55 | segmented_path - str - Path to the directory that contains the segmented images 56 | 57 | batch_size - int - the batch size 58 | 59 | yields inputs and labels of the batch size 60 | """ 61 | 62 | filenames_t = os.listdir(training_path) 63 | total_files_t = len(filenames_t) 64 | 65 | filenames_s = os.listdir(segmented_path) 66 | total_files_s = len(filenames_s) 67 | 68 | assert(total_files_t == total_files_s) 69 | 70 | if str(batch_size).lower() == 'all': 71 | batch_size = total_files_s 72 | 73 | idx = 0 74 | while(1): 75 | batch_idxs = np.random.randint(0, total_files_s, batch_size) 76 | 77 | 78 | inputs = [] 79 | labels = [] 80 | 81 | for jj in batch_idxs: 82 | img = plt.imread(training_path + filenames_t[jj]) 83 | img = cv2.resize(img, (h, w), cv2.INTER_NEAREST) 84 | inputs.append(img) 85 | 86 | img = Image.open(segmented_path + filenames_s[jj]) 87 | img = np.array(img) 88 | img = cv2.resize(img, (h, w), cv2.INTER_NEAREST) 89 | labels.append(img) 90 | 91 | inputs = np.stack(inputs, axis=2) 92 | inputs = torch.tensor(inputs).transpose(0, 2).transpose(1, 3) 93 | 94 | labels = torch.tensor(labels) 95 | 96 | yield inputs, labels 97 | 98 | 99 | def decode_segmap(image): 100 | Sky = [128, 128, 128] 101 | Building = [128, 0, 0] 102 | Pole = [192, 192, 128] 103 | Road_marking = [255, 69, 0] 104 | Road = [128, 64, 128] 105 | Pavement = [60, 40, 222] 106 | Tree = [128, 128, 0] 107 | SignSymbol = [192, 128, 128] 108 | Fence = [64, 64, 128] 109 | Car = [64, 0, 128] 110 | Pedestrian = [64, 64, 0] 111 | Bicyclist = [0, 128, 192] 112 | 113 | label_colors = np.array([Sky, Building, Pole, Road_marking, Road, 114 | Pavement, Tree, SignSymbol, Fence, Car, 115 | Pedestrian, Bicyclist]).astype(np.uint8) 116 | 117 | r = np.zeros_like(image).astype(np.uint8) 118 | g = np.zeros_like(image).astype(np.uint8) 119 | b = np.zeros_like(image).astype(np.uint8) 120 | 121 | for label in range(len(label_colors)): 122 | r[image == label] = label_colors[label, 0] 123 | g[image == label] = label_colors[label, 1] 124 | b[image == label] = label_colors[label, 2] 125 | 126 | rgb = np.zeros((image.shape[0], image.shape[1], 3)).astype(np.uint8) 127 | rgb[:, :, 0] = r 128 | rgb[:, :, 1] = g 129 | rgb[:, :, 2] = b 130 | 131 | return rgb 132 | 133 | def show_images(images, in_row=True): 134 | ''' 135 | Helper function to show 3 images 136 | ''' 137 | total_images = len(images) 138 | 139 | rc_tuple = (1, total_images) 140 | if not in_row: 141 | rc_tuple = (total_images, 1) 142 | 143 | #figure = plt.figure(figsize=(20, 10)) 144 | for ii in range(len(images)): 145 | plt.subplot(*rc_tuple, ii+1) 146 | plt.title(images[ii][0]) 147 | plt.axis('off') 148 | plt.imshow(images[ii][1]) 149 | plt.show() 150 | 151 | def get_class_weights(loader, num_classes, c=1.02): 152 | ''' 153 | This class return the class weights for each class 154 | 155 | Arguments: 156 | - loader : The generator object which return all the labels at one iteration 157 | Do Note: That this class expects all the labels to be returned in 158 | one iteration 159 | 160 | - num_classes : The number of classes 161 | 162 | Return: 163 | - class_weights : An array equal in length to the number of classes 164 | containing the class weights for each class 165 | ''' 166 | 167 | _, labels = next(loader) 168 | all_labels = labels.flatten() 169 | each_class = np.bincount(all_labels, minlength=num_classes) 170 | prospensity_score = each_class / len(all_labels) 171 | class_weights = 1 / (np.log(c + prospensity_score)) 172 | return class_weights 173 | --------------------------------------------------------------------------------