├── .gitignore
├── ENet-Real_Time_Semantic_Segmentation.ipynb
├── LICENSE
├── README.md
├── init.py
├── models
├── ASNeck.py
├── ENet.py
├── InitialBlock.py
├── RDDNeck.py
├── UBNeck.py
└── __init__.py
├── test.py
├── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.swo
2 | *.png
3 | *.jpg
4 | *.zip
5 | *.swp
6 |
7 | # Byte-compiled / optimized / DLL files
8 | __pycache__/
9 | *.py[cod]
10 | *$py.class
11 |
12 | # C extensions
13 | *.so
14 |
15 | # Distribution / packaging
16 | .Python
17 | build/
18 | develop-eggs/
19 | dist/
20 | downloads/
21 | eggs/
22 | .eggs/
23 | lib/
24 | lib64/
25 | parts/
26 | sdist/
27 | var/
28 | wheels/
29 | *.egg-info/
30 | .installed.cfg
31 | *.egg
32 | MANIFEST
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | .hypothesis/
54 | .pytest_cache/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 |
65 | # Flask stuff:
66 | instance/
67 | .webassets-cache
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/_build/
74 |
75 | # PyBuilder
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # pyenv
82 | .python-version
83 |
84 | # celery beat schedule file
85 | celerybeat-schedule
86 |
87 | # SageMath parsed files
88 | *.sage.py
89 |
90 | # Environments
91 | .env
92 | .venv
93 | env/
94 | venv/
95 | ENV/
96 | env.bak/
97 | venv.bak/
98 |
99 | # Spyder project settings
100 | .spyderproject
101 | .spyproject
102 |
103 | # Rope project settings
104 | .ropeproject
105 |
106 | # mkdocs documentation
107 | /site
108 |
109 | # mypy
110 | .mypy_cache/
111 | .idea/
112 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2019, Arunava
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | * Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ENet - Real Time Semantic Segmentation
2 |
3 | A Neural Net Architecture for real time Semantic Segmentation.
4 | In this repository we have reproduced the ENet Paper - Which can be used on
5 | mobile devices for real time semantic segmentattion. The link to the paper can be found here: [ENet](https://arxiv.org/pdf/1606.02147.pdf)
6 |
7 | ## How to use?
8 |
9 | 0. This repository comes in with a handy notebook which you can use with Colab.
10 | You can find a link to the notebook here: [
11 | ENet - Real Time Semantic Segmentation](https://github.com/iArunava/ENet-Real-Time-Semantic-Segmentation/blob/master/ENet-Real%20Time%20Semantic%20Segmentation.ipynb)
12 | Open it in colab: [Open in Colab](https://colab.research.google.com/github/iArunava/ENet-Real-Time-Semantic-Segmentation/blob/master/ENet-Real%20Time%20Semantic%20Segmentation.ipynb)
13 |
14 | ---
15 |
16 |
17 | 0. Clone the repository and cd into it
18 | ```
19 | git clone https://github.com/iArunava/ENet-Real-Time-Semantic-Segmentation.git
20 | cd ENet-Real-Time-Semantic-Segmentation/
21 | ```
22 |
23 | 1. Use this command to train the model
24 | ```
25 | python3 init.py --mode train -iptr path/to/train/input/set/ -lptr /path/to/label/set/
26 | ```
27 |
28 | 2. Use this command to test the model
29 | ```
30 | python3 init.py --mode test -m /path/to/the/pretrained/model.pth -i /path/to/image/to/infer.png
31 | ```
32 |
33 | 3. Use `--help` to get more commands
34 | ```
35 | python3 init.py --help
36 | ```
37 |
38 | ## Some results
39 |
40 | 
41 | 
42 | 
43 | 
44 | 
45 |
46 | ## References
47 | 1. A. Paszke, A. Chaurasia, S. Kim, and E. Culurciello.
48 | Enet: A deep neural network architecture
49 | for real-time semantic segmentation. arXiv preprint
50 | arXiv:1606.02147, 2016.
51 |
52 | ## Citations
53 |
54 | ```
55 | @inproceedings{ BrostowSFC:ECCV08,
56 | author = {Gabriel J. Brostow and Jamie Shotton and Julien Fauqueur and Roberto Cipolla},
57 | title = {Segmentation and Recognition Using Structure from Motion Point Clouds},
58 | booktitle = {ECCV (1)},
59 | year = {2008},
60 | pages = {44-57}
61 | }
62 |
63 | @article{ BrostowFC:PRL2008,
64 | author = "Gabriel J. Brostow and Julien Fauqueur and Roberto Cipolla",
65 | title = "Semantic Object Classes in Video: A High-Definition Ground Truth Database",
66 | journal = "Pattern Recognition Letters",
67 | volume = "xx",
68 | number = "x",
69 | pages = "xx-xx",
70 | year = "2008"
71 | }
72 | ```
73 |
74 | ## License
75 |
76 | The code in this repository is distributed under the BSD v3 Licemse.
77 | Feel free to fork and enjoy :)
78 |
--------------------------------------------------------------------------------
/init.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import argparse
3 | from train import *
4 | from test import *
5 |
6 | color_map = {
7 | 'unlabeled' : ( 0, 0, 0),
8 | 'dynamic' : (111, 74, 0),
9 | 'ground' : ( 81, 0, 81),
10 | 'road' : (128, 64,128),
11 | 'sidewalk' : (244, 35,232),
12 | 'parking' : (250,170,160),
13 | 'rail track' : (230,150,140),
14 | 'building' : ( 70, 70, 70),
15 | 'wall' : (102,102,156),
16 | 'fence' : (190,153,153),
17 | 'guard rail' : (180,165,180),
18 | 'bridge' : (150,100,100),
19 | 'tunnel' : (150,120, 90),
20 | 'pole' : (153,153,153),
21 | 'traffic light' : (250,170, 30),
22 | 'traffic sign' : (220,220, 0),
23 | 'vegetation' : (107,142, 35),
24 | 'terrain' : (152,251,152),
25 | 'sky' : ( 70,130,180),
26 | 'person' : (220, 20, 60),
27 | 'rider' : (255, 0, 0),
28 | 'car' : ( 0, 0,142),
29 | 'truck' : ( 0, 0, 70),
30 | 'bus' : ( 0, 60,100),
31 | 'caravan' : ( 0, 0, 90),
32 | 'trailer' : ( 0, 0,110),
33 | 'train' : ( 0, 80,100),
34 | 'motorcycle' : ( 0, 0,230),
35 | 'bicycle' : (119, 11, 32)
36 | }
37 |
38 | if __name__ == '__main__':
39 | parser = argparse.ArgumentParser()
40 |
41 | parser.add_argument('-m',
42 | type=str,
43 | default='./datasets/CamVid/ckpt-camvid-enet.pth',
44 | help='The path to the pretrained enet model')
45 |
46 | parser.add_argument('-i', '--image-path',
47 | type=str,
48 | help='The path to the image to perform semantic segmentation')
49 |
50 | parser.add_argument('-rh', '--resize-height',
51 | type=int,
52 | default=512,
53 | help='The height for the resized image')
54 |
55 | parser.add_argument('-rw', '--resize-width',
56 | type=int,
57 | default=512,
58 | help='The width for the resized image')
59 |
60 | parser.add_argument('-lr', '--learning-rate',
61 | type=float,
62 | default=5e-4,
63 | help='The learning rate')
64 |
65 | parser.add_argument('-bs', '--batch-size',
66 | type=int,
67 | default=10,
68 | help='The batch size')
69 |
70 | parser.add_argument('-wd', '--weight-decay',
71 | type=float,
72 | default=2e-4,
73 | help='The weight decay')
74 |
75 | parser.add_argument('-c', '--constant',
76 | type=float,
77 | default=1.02,
78 | help='The constant used for calculating the class weights')
79 |
80 | parser.add_argument('-e', '--epochs',
81 | type=int,
82 | default=102,
83 | help='The number of epochs')
84 |
85 | parser.add_argument('-nc', '--num-classes',
86 | type=int,
87 | default=12,
88 | help='Number of unique classes')
89 |
90 | parser.add_argument('-se', '--save-every',
91 | type=int,
92 | default=10,
93 | help='The number of epochs after which to save a model')
94 |
95 | parser.add_argument('-iptr', '--input-path-train',
96 | type=str,
97 | default='./datasets/CamVid/train/',
98 | help='The path to the input dataset')
99 |
100 | parser.add_argument('-lptr', '--label-path-train',
101 | type=str,
102 | default='./datasets/CamVid/trainannot/',
103 | help='The path to the label dataset')
104 |
105 | parser.add_argument('-ipv', '--input-path-val',
106 | type=str,
107 | default='./datasets/CamVid/val/',
108 | help='The path to the input dataset')
109 |
110 | parser.add_argument('-lpv', '--label-path-val',
111 | type=str,
112 | default='./datasets/CamVid/valannot/',
113 | help='The path to the label dataset')
114 |
115 | parser.add_argument('-iptt', '--input-path-test',
116 | type=str,
117 | default='./datasets/CamVid/test/',
118 | help='The path to the input dataset')
119 |
120 | parser.add_argument('-lptt', '--label-path-test',
121 | type=str,
122 | default='./datasets/CamVid/testannot/',
123 | help='The path to the label dataset')
124 |
125 | parser.add_argument('-pe', '--print-every',
126 | type=int,
127 | default=1,
128 | help='The number of epochs after which to print the training loss')
129 |
130 | parser.add_argument('-ee', '--eval-every',
131 | type=int,
132 | default=10,
133 | help='The number of epochs after which to print the validation loss')
134 |
135 | parser.add_argument('--cuda',
136 | type=bool,
137 | default=False,
138 | help='Whether to use cuda or not')
139 |
140 | parser.add_argument('--mode',
141 | choices=['train', 'test'],
142 | default='train',
143 | help='Whether to train or test')
144 |
145 | FLAGS, unparsed = parser.parse_known_args()
146 |
147 | FLAGS.cuda = torch.device('cuda:0' if torch.cuda.is_available() and FLAGS.cuda \
148 | else 'cpu')
149 |
150 | if FLAGS.mode.lower() == 'train':
151 | train(FLAGS)
152 | elif FLAGS.mode.lower() == 'test':
153 | test(FLAGS)
154 | else:
155 | raise RuntimeError('Unknown mode passed. \n Mode passed should be either \
156 | of "train" or "test"')
157 |
--------------------------------------------------------------------------------
/models/ASNeck.py:
--------------------------------------------------------------------------------
1 | ###################################################
2 | # Copyright (c) 2019 #
3 | # Authors: @iArunava #
4 | # @AvivSham #
5 | # #
6 | # License: BSD License 3.0 #
7 | # #
8 | # The Code in this file is distributed for free #
9 | # usage and modification with proper linkage back #
10 | # to this repository. #
11 | ###################################################
12 |
13 | import torch
14 | import torch.nn as nn
15 |
16 | class ASNeck(nn.Module):
17 | def __init__(self, in_channels, out_channels, projection_ratio=4):
18 |
19 | super().__init__()
20 |
21 | # Define class variables
22 | self.in_channels = in_channels
23 | self.reduced_depth = int(in_channels / projection_ratio)
24 | self.out_channels = out_channels
25 |
26 | self.dropout = nn.Dropout2d(p=0.1)
27 |
28 | self.conv1 = nn.Conv2d(in_channels = self.in_channels,
29 | out_channels = self.reduced_depth,
30 | kernel_size = 1,
31 | stride = 1,
32 | padding = 0,
33 | bias = False)
34 |
35 | self.prelu1 = nn.PReLU()
36 |
37 | self.conv21 = nn.Conv2d(in_channels = self.reduced_depth,
38 | out_channels = self.reduced_depth,
39 | kernel_size = (1, 5),
40 | stride = 1,
41 | padding = (0, 2),
42 | bias = False)
43 |
44 | self.conv22 = nn.Conv2d(in_channels = self.reduced_depth,
45 | out_channels = self.reduced_depth,
46 | kernel_size = (5, 1),
47 | stride = 1,
48 | padding = (2, 0),
49 | bias = False)
50 |
51 | self.prelu2 = nn.PReLU()
52 |
53 | self.conv3 = nn.Conv2d(in_channels = self.reduced_depth,
54 | out_channels = self.out_channels,
55 | kernel_size = 1,
56 | stride = 1,
57 | padding = 0,
58 | bias = False)
59 |
60 | self.prelu3 = nn.PReLU()
61 |
62 | self.batchnorm1 = nn.BatchNorm2d(self.reduced_depth)
63 | self.batchnorm2 = nn.BatchNorm2d(self.reduced_depth)
64 | self.batchnorm3 = nn.BatchNorm2d(self.out_channels)
65 |
66 | def forward(self, x):
67 | bs = x.size()[0]
68 | x_copy = x
69 |
70 | # Side Branch
71 | x = self.conv1(x)
72 | x = self.batchnorm1(x)
73 | x = self.prelu1(x)
74 |
75 | x = self.conv21(x)
76 | x = self.conv22(x)
77 | x = self.batchnorm2(x)
78 | x = self.prelu2(x)
79 |
80 | x = self.conv3(x)
81 |
82 | x = self.dropout(x)
83 | x = self.batchnorm3(x)
84 |
85 | # Main Branch
86 |
87 | if self.in_channels != self.out_channels:
88 | out_shape = self.out_channels - self.in_channels
89 | extras = torch.zeros((bs, out_shape, x.shape[2], x.shape[3]))
90 | if torch.cuda.is_available():
91 | extras = extras.cuda()
92 | x_copy = torch.cat((x_copy, extras), dim = 1)
93 |
94 | # Sum of main and side branches
95 | x = x + x_copy
96 | x = self.prelu3(x)
97 |
98 | return x
99 |
--------------------------------------------------------------------------------
/models/ENet.py:
--------------------------------------------------------------------------------
1 | ##################################################################
2 | # Reproducing the paper #
3 | # ENet - Real Time Semantic Segmentation #
4 | # Paper: https://arxiv.org/pdf/1606.02147.pdf #
5 | # #
6 | # Copyright (c) 2019 #
7 | # Authors: @iArunava #
8 | # @AvivSham #
9 | # #
10 | # License: BSD License 3.0 #
11 | # #
12 | # The Code in this file is distributed for free #
13 | # usage and modification with proper credits #
14 | # directing back to this repository. #
15 | ##################################################################
16 |
17 | import torch
18 | import torch.nn as nn
19 | from .InitialBlock import InitialBlock
20 | from .RDDNeck import RDDNeck
21 | from .UBNeck import UBNeck
22 | from .ASNeck import ASNeck
23 |
24 | class ENet(nn.Module):
25 | def __init__(self, C):
26 | super().__init__()
27 |
28 | # Define class variables
29 | self.C = C
30 |
31 | # The initial block
32 | self.init = InitialBlock()
33 |
34 |
35 | # The first bottleneck
36 | self.b10 = RDDNeck(dilation=1,
37 | in_channels=16,
38 | out_channels=64,
39 | down_flag=True,
40 | p=0.01)
41 |
42 | self.b11 = RDDNeck(dilation=1,
43 | in_channels=64,
44 | out_channels=64,
45 | down_flag=False,
46 | p=0.01)
47 |
48 | self.b12 = RDDNeck(dilation=1,
49 | in_channels=64,
50 | out_channels=64,
51 | down_flag=False,
52 | p=0.01)
53 |
54 | self.b13 = RDDNeck(dilation=1,
55 | in_channels=64,
56 | out_channels=64,
57 | down_flag=False,
58 | p=0.01)
59 |
60 | self.b14 = RDDNeck(dilation=1,
61 | in_channels=64,
62 | out_channels=64,
63 | down_flag=False,
64 | p=0.01)
65 |
66 |
67 | # The second bottleneck
68 | self.b20 = RDDNeck(dilation=1,
69 | in_channels=64,
70 | out_channels=128,
71 | down_flag=True)
72 |
73 | self.b21 = RDDNeck(dilation=1,
74 | in_channels=128,
75 | out_channels=128,
76 | down_flag=False)
77 |
78 | self.b22 = RDDNeck(dilation=2,
79 | in_channels=128,
80 | out_channels=128,
81 | down_flag=False)
82 |
83 | self.b23 = ASNeck(in_channels=128,
84 | out_channels=128)
85 |
86 | self.b24 = RDDNeck(dilation=4,
87 | in_channels=128,
88 | out_channels=128,
89 | down_flag=False)
90 |
91 | self.b25 = RDDNeck(dilation=1,
92 | in_channels=128,
93 | out_channels=128,
94 | down_flag=False)
95 |
96 | self.b26 = RDDNeck(dilation=8,
97 | in_channels=128,
98 | out_channels=128,
99 | down_flag=False)
100 |
101 | self.b27 = ASNeck(in_channels=128,
102 | out_channels=128)
103 |
104 | self.b28 = RDDNeck(dilation=16,
105 | in_channels=128,
106 | out_channels=128,
107 | down_flag=False)
108 |
109 |
110 | # The third bottleneck
111 | self.b31 = RDDNeck(dilation=1,
112 | in_channels=128,
113 | out_channels=128,
114 | down_flag=False)
115 |
116 | self.b32 = RDDNeck(dilation=2,
117 | in_channels=128,
118 | out_channels=128,
119 | down_flag=False)
120 |
121 | self.b33 = ASNeck(in_channels=128,
122 | out_channels=128)
123 |
124 | self.b34 = RDDNeck(dilation=4,
125 | in_channels=128,
126 | out_channels=128,
127 | down_flag=False)
128 |
129 | self.b35 = RDDNeck(dilation=1,
130 | in_channels=128,
131 | out_channels=128,
132 | down_flag=False)
133 |
134 | self.b36 = RDDNeck(dilation=8,
135 | in_channels=128,
136 | out_channels=128,
137 | down_flag=False)
138 |
139 | self.b37 = ASNeck(in_channels=128,
140 | out_channels=128)
141 |
142 | self.b38 = RDDNeck(dilation=16,
143 | in_channels=128,
144 | out_channels=128,
145 | down_flag=False)
146 |
147 |
148 | # The fourth bottleneck
149 | self.b40 = UBNeck(in_channels=128,
150 | out_channels=64,
151 | relu=True)
152 |
153 | self.b41 = RDDNeck(dilation=1,
154 | in_channels=64,
155 | out_channels=64,
156 | down_flag=False,
157 | relu=True)
158 |
159 | self.b42 = RDDNeck(dilation=1,
160 | in_channels=64,
161 | out_channels=64,
162 | down_flag=False,
163 | relu=True)
164 |
165 |
166 | # The fifth bottleneck
167 | self.b50 = UBNeck(in_channels=64,
168 | out_channels=16,
169 | relu=True)
170 |
171 | self.b51 = RDDNeck(dilation=1,
172 | in_channels=16,
173 | out_channels=16,
174 | down_flag=False,
175 | relu=True)
176 |
177 |
178 | # Final ConvTranspose Layer
179 | self.fullconv = nn.ConvTranspose2d(in_channels=16,
180 | out_channels=self.C,
181 | kernel_size=3,
182 | stride=2,
183 | padding=1,
184 | output_padding=1,
185 | bias=False)
186 |
187 |
188 | def forward(self, x):
189 |
190 | # The initial block
191 | x = self.init(x)
192 |
193 | # The first bottleneck
194 | x, i1 = self.b10(x)
195 | x = self.b11(x)
196 | x = self.b12(x)
197 | x = self.b13(x)
198 | x = self.b14(x)
199 |
200 | # The second bottleneck
201 | x, i2 = self.b20(x)
202 | x = self.b21(x)
203 | x = self.b22(x)
204 | x = self.b23(x)
205 | x = self.b24(x)
206 | x = self.b25(x)
207 | x = self.b26(x)
208 | x = self.b27(x)
209 | x = self.b28(x)
210 |
211 | # The third bottleneck
212 | x = self.b31(x)
213 | x = self.b32(x)
214 | x = self.b33(x)
215 | x = self.b34(x)
216 | x = self.b35(x)
217 | x = self.b36(x)
218 | x = self.b37(x)
219 | x = self.b38(x)
220 |
221 | # The fourth bottleneck
222 | x = self.b40(x, i2)
223 | x = self.b41(x)
224 | x = self.b42(x)
225 |
226 | # The fifth bottleneck
227 | x = self.b50(x, i1)
228 | x = self.b51(x)
229 |
230 | # Final ConvTranspose Layer
231 | x = self.fullconv(x)
232 |
233 | return x
234 |
--------------------------------------------------------------------------------
/models/InitialBlock.py:
--------------------------------------------------------------------------------
1 | ###################################################
2 | # Copyright (c) 2019 #
3 | # Authors: @iArunava #
4 | # @AvivSham #
5 | # #
6 | # License: BSD License 3.0 #
7 | # #
8 | # The Code in this file is distributed for free #
9 | # usage and modification with proper linkage back #
10 | # to this repository. #
11 | ###################################################
12 |
13 | import torch
14 | import torch.nn as nn
15 |
16 | class InitialBlock(nn.Module):
17 | def __init__ (self,in_channels = 3,out_channels = 13):
18 | super().__init__()
19 |
20 |
21 | self.maxpool = nn.MaxPool2d(kernel_size=2,
22 | stride = 2,
23 | padding = 0)
24 |
25 | self.conv = nn.Conv2d(in_channels,
26 | out_channels,
27 | kernel_size = 3,
28 | stride = 2,
29 | padding = 1)
30 |
31 | self.prelu = nn.PReLU(16)
32 |
33 | self.batchnorm = nn.BatchNorm2d(out_channels)
34 |
35 | def forward(self, x):
36 |
37 | main = self.conv(x)
38 | main = self.batchnorm(main)
39 |
40 | side = self.maxpool(x)
41 |
42 | x = torch.cat((main, side), dim=1)
43 | x = self.prelu(x)
44 |
45 | return x
46 |
--------------------------------------------------------------------------------
/models/RDDNeck.py:
--------------------------------------------------------------------------------
1 | ###################################################
2 | # Copyright (c) 2019 #
3 | # Authors: @iArunava #
4 | # @AvivSham #
5 | # #
6 | # License: BSD License 3.0 #
7 | # #
8 | # The Code in this file is distributed for free #
9 | # usage and modification with proper linkage back #
10 | # to this repository. #
11 | ###################################################
12 |
13 | import torch
14 | import torch.nn as nn
15 |
16 |
17 | class RDDNeck(nn.Module):
18 | def __init__(self, dilation, in_channels, out_channels, down_flag, relu=False, projection_ratio=4, p=0.1):
19 |
20 | super().__init__()
21 |
22 | # Define class variables
23 | self.in_channels = in_channels
24 |
25 | self.out_channels = out_channels
26 | self.dilation = dilation
27 | self.down_flag = down_flag
28 |
29 | if down_flag:
30 | self.stride = 2
31 | self.reduced_depth = int(in_channels // projection_ratio)
32 | else:
33 | self.stride = 1
34 | self.reduced_depth = int(out_channels // projection_ratio)
35 |
36 | if relu:
37 | activation = nn.ReLU()
38 | else:
39 | activation = nn.PReLU()
40 |
41 | self.maxpool = nn.MaxPool2d(kernel_size = 2,
42 | stride = 2,
43 | padding = 0, return_indices=True)
44 |
45 |
46 |
47 | self.dropout = nn.Dropout2d(p=p)
48 |
49 | self.conv1 = nn.Conv2d(in_channels = self.in_channels,
50 | out_channels = self.reduced_depth,
51 | kernel_size = 1,
52 | stride = 1,
53 | padding = 0,
54 | bias = False,
55 | dilation = 1)
56 |
57 | self.prelu1 = activation
58 |
59 | self.conv2 = nn.Conv2d(in_channels = self.reduced_depth,
60 | out_channels = self.reduced_depth,
61 | kernel_size = 3,
62 | stride = self.stride,
63 | padding = self.dilation,
64 | bias = True,
65 | dilation = self.dilation)
66 |
67 | self.prelu2 = activation
68 |
69 | self.conv3 = nn.Conv2d(in_channels = self.reduced_depth,
70 | out_channels = self.out_channels,
71 | kernel_size = 1,
72 | stride = 1,
73 | padding = 0,
74 | bias = False,
75 | dilation = 1)
76 |
77 | self.prelu3 = activation
78 |
79 | self.batchnorm1 = nn.BatchNorm2d(self.reduced_depth)
80 | self.batchnorm2 = nn.BatchNorm2d(self.reduced_depth)
81 | self.batchnorm3 = nn.BatchNorm2d(self.out_channels)
82 |
83 |
84 | def forward(self, x):
85 |
86 | bs = x.size()[0]
87 | x_copy = x
88 |
89 | # Side Branch
90 | x = self.conv1(x)
91 | x = self.batchnorm1(x)
92 | x = self.prelu1(x)
93 |
94 | x = self.conv2(x)
95 | x = self.batchnorm2(x)
96 | x = self.prelu2(x)
97 |
98 | x = self.conv3(x)
99 | x = self.batchnorm3(x)
100 |
101 | x = self.dropout(x)
102 |
103 | # Main Branch
104 | if self.down_flag:
105 | x_copy, indices = self.maxpool(x_copy)
106 |
107 | if self.in_channels != self.out_channels:
108 | out_shape = self.out_channels - self.in_channels
109 | extras = torch.zeros((bs, out_shape, x.shape[2], x.shape[3]))
110 | if torch.cuda.is_available():
111 | extras = extras.cuda()
112 | x_copy = torch.cat((x_copy, extras), dim = 1)
113 |
114 | # Sum of main and side branches
115 | x = x + x_copy
116 | x = self.prelu3(x)
117 |
118 | if self.down_flag:
119 | return x, indices
120 | else:
121 | return x
122 |
--------------------------------------------------------------------------------
/models/UBNeck.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class UBNeck(nn.Module):
5 | def __init__(self, in_channels, out_channels, relu=False, projection_ratio=4):
6 |
7 | super().__init__()
8 |
9 | # Define class variables
10 | self.in_channels = in_channels
11 | self.reduced_depth = int(in_channels / projection_ratio)
12 | self.out_channels = out_channels
13 |
14 |
15 | if relu:
16 | activation = nn.ReLU()
17 | else:
18 | activation = nn.PReLU()
19 |
20 | self.unpool = nn.MaxUnpool2d(kernel_size = 2,
21 | stride = 2)
22 |
23 | self.main_conv = nn.Conv2d(in_channels = self.in_channels,
24 | out_channels = self.out_channels,
25 | kernel_size = 1)
26 |
27 | self.dropout = nn.Dropout2d(p=0.1)
28 |
29 | self.convt1 = nn.ConvTranspose2d(in_channels = self.in_channels,
30 | out_channels = self.reduced_depth,
31 | kernel_size = 1,
32 | padding = 0,
33 | bias = False)
34 |
35 |
36 | self.prelu1 = activation
37 |
38 | self.convt2 = nn.ConvTranspose2d(in_channels = self.reduced_depth,
39 | out_channels = self.reduced_depth,
40 | kernel_size = 3,
41 | stride = 2,
42 | padding = 1,
43 | output_padding = 1,
44 | bias = False)
45 |
46 | self.prelu2 = activation
47 |
48 | self.convt3 = nn.ConvTranspose2d(in_channels = self.reduced_depth,
49 | out_channels = self.out_channels,
50 | kernel_size = 1,
51 | padding = 0,
52 | bias = False)
53 |
54 | self.prelu3 = activation
55 |
56 | self.batchnorm1 = nn.BatchNorm2d(self.reduced_depth)
57 | self.batchnorm2 = nn.BatchNorm2d(self.reduced_depth)
58 | self.batchnorm3 = nn.BatchNorm2d(self.out_channels)
59 |
60 | def forward(self, x, indices):
61 | x_copy = x
62 |
63 | # Side Branch
64 | x = self.convt1(x)
65 | x = self.batchnorm1(x)
66 | x = self.prelu1(x)
67 |
68 | x = self.convt2(x)
69 | x = self.batchnorm2(x)
70 | x = self.prelu2(x)
71 |
72 | x = self.convt3(x)
73 | x = self.batchnorm3(x)
74 |
75 | x = self.dropout(x)
76 |
77 | # Main Branch
78 |
79 | x_copy = self.main_conv(x_copy)
80 | x_copy = self.unpool(x_copy, indices, output_size=x.size())
81 |
82 | # Concat
83 | x = x + x_copy
84 | x = self.prelu3(x)
85 |
86 | return x
87 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iArunava/ENet-Real-Time-Semantic-Segmentation/8e3e86c4c4eb8392d72962e393d992294d8fc8ae/models/__init__.py
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from utils import *
4 | from models.ENet import ENet
5 | import sys
6 | import os
7 | from tqdm import tqdm
8 | import matplotlib.pyplot as plt
9 |
10 |
11 | def test(FLAGS):
12 | # Check if the pretrained model is available
13 | if not FLAGS.m.endswith('.pth'):
14 | raise RuntimeError('Unknown file passed. Must end with .pth')
15 | if FLAGS.image_path is None or not os.path.exists(FLAGS.image_path):
16 | raise RuntimeError('An image file path must be passed')
17 |
18 | h = FLAGS.resize_height
19 | w = FLAGS.resize_width
20 |
21 | checkpoint = torch.load(FLAGS.m, map_location=FLAGS.cuda)
22 |
23 | # Assuming the dataset is camvid
24 | enet = ENet(FLAGS.num_classes)
25 | enet.load_state_dict(checkpoint['state_dict'])
26 |
27 | tmg_ = plt.imread(FLAGS.image_path)
28 | tmg_ = cv2.resize(tmg_, (h, w), cv2.INTER_NEAREST)
29 | tmg = torch.tensor(tmg_).unsqueeze(0).float()
30 | tmg = tmg.transpose(2, 3).transpose(1, 2)
31 |
32 | with torch.no_grad():
33 | out1 = enet(tmg.float()).squeeze(0)
34 |
35 | #smg_ = Image.open('/content/training/semantic/' + fname)
36 | #smg_ = cv2.resize(np.array(smg_), (512, 512), cv2.INTER_NEAREST)
37 |
38 | b_ = out1.data.max(0)[1].cpu().numpy()
39 |
40 | decoded_segmap = decode_segmap(b_)
41 |
42 | images = {
43 | 0 : ['Input Image', tmg_],
44 | 1 : ['Predicted Segmentation', b_],
45 | }
46 |
47 | show_images(images)
48 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from utils import *
4 | from models.ENet import ENet
5 | import sys
6 | from tqdm import tqdm
7 |
8 | def train(FLAGS):
9 |
10 | # Defining the hyperparameters
11 | device = FLAGS.cuda
12 | batch_size = FLAGS.batch_size
13 | epochs = FLAGS.epochs
14 | lr = FLAGS.learning_rate
15 | print_every = FLAGS.print_every
16 | eval_every = FLAGS.eval_every
17 | save_every = FLAGS.save_every
18 | nc = FLAGS.num_classes
19 | wd = FLAGS.weight_decay
20 | ip = FLAGS.input_path_train
21 | lp = FLAGS.label_path_train
22 | ipv = FLAGS.input_path_val
23 | lpv = FLAGS.label_path_val
24 | print ('[INFO]Defined all the hyperparameters successfully!')
25 |
26 | # Get the class weights
27 | print ('[INFO]Starting to define the class weights...')
28 | pipe = loader(ip, lp, batch_size='all')
29 | class_weights = get_class_weights(pipe, nc)
30 | print ('[INFO]Fetched all class weights successfully!')
31 |
32 | # Get an instance of the model
33 | enet = ENet(nc)
34 | print ('[INFO]Model Instantiated!')
35 |
36 | # Move the model to cuda if available
37 | enet = enet.to(device)
38 |
39 | # Define the criterion and the optimizer
40 | criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor(class_weights).to(device))
41 | optimizer = torch.optim.Adam(enet.parameters(),
42 | lr=lr,
43 | weight_decay=wd)
44 | print ('[INFO]Defined the loss function and the optimizer')
45 |
46 | # Training Loop starts
47 | print ('[INFO]Staring Training...')
48 | print ()
49 |
50 | train_losses = []
51 | eval_losses = []
52 |
53 | # Assuming we are using the CamVid Dataset
54 | bc_train = 367 // batch_size
55 | bc_eval = 101 // batch_size
56 |
57 | pipe = loader(ip, lp, batch_size)
58 | eval_pipe = loader(ipv, lpv, batch_size)
59 |
60 | epochs = epochs
61 |
62 | for e in range(1, epochs+1):
63 |
64 | train_loss = 0
65 | print ('-'*15,'Epoch %d' % e, '-'*15)
66 |
67 | enet.train()
68 |
69 | for _ in tqdm(range(bc_train)):
70 | X_batch, mask_batch = next(pipe)
71 |
72 | #assert (X_batch >= 0. and X_batch <= 1.0).all()
73 |
74 | X_batch, mask_batch = X_batch.to(device), mask_batch.to(device)
75 |
76 | optimizer.zero_grad()
77 |
78 | out = enet(X_batch.float())
79 |
80 | loss = criterion(out, mask_batch.long())
81 | loss.backward()
82 | optimizer.step()
83 |
84 | train_loss += loss.item()
85 |
86 |
87 | print ()
88 | train_losses.append(train_loss)
89 |
90 | if (e+1) % print_every == 0:
91 | print ('Epoch {}/{}...'.format(e, epochs),
92 | 'Loss {:6f}'.format(train_loss))
93 |
94 | if e % eval_every == 0:
95 | with torch.no_grad():
96 | enet.eval()
97 |
98 | eval_loss = 0
99 |
100 | for _ in tqdm(range(bc_eval)):
101 | inputs, labels = next(eval_pipe)
102 |
103 | inputs, labels = inputs.to(device), labels.to(device)
104 | out = enet(inputs)
105 |
106 | loss = criterion(out, labels.long())
107 |
108 | eval_loss += loss.item()
109 |
110 | print ()
111 | print ('Loss {:6f}'.format(eval_loss))
112 |
113 | eval_losses.append(eval_loss)
114 |
115 | if e % save_every == 0:
116 | checkpoint = {
117 | 'epochs' : e,
118 | 'state_dict' : enet.state_dict()
119 | }
120 | torch.save(checkpoint, './ckpt-enet-{}-{}.pth'.format(e, train_loss))
121 | print ('Model saved!')
122 |
123 | print ('Epoch {}/{}...'.format(e+1, epochs),
124 | 'Total Mean Loss: {:6f}'.format(sum(train_losses) / epochs))
125 |
126 | print ('[INFO]Training Process complete!')
127 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import matplotlib.pyplot as plt
4 | import os
5 | from PIL import Image
6 | import torch
7 |
8 | def create_class_mask(img, color_map, is_normalized_img=True, is_normalized_map=False, show_masks=False):
9 | """
10 | Function to create C matrices from the segmented image, where each of the C matrices is for one class
11 | with all ones at the pixel positions where that class is present
12 |
13 | img = The segmented image
14 |
15 | color_map = A list with tuples that contains all the RGB values for each color that represents
16 | some class in that image
17 |
18 | is_normalized_img = Boolean - Whether the image is normalized or not
19 | If normalized, then the image is multiplied with 255
20 |
21 | is_normalized_map = Boolean - Represents whether the color map is normalized or not, if so
22 | then the color map values are multiplied with 255
23 |
24 | show_masks = Wherether to show the created masks or not
25 | """
26 |
27 | if is_normalized_img and (not is_normalized_map):
28 | img *= 255
29 |
30 | if is_normalized_map and (not is_normalized_img):
31 | img = img / 255
32 |
33 | mask = []
34 | hw_tuple = img.shape[:-1]
35 | for color in color_map:
36 | color_img = []
37 | for idx in range(3):
38 | color_img.append(np.ones(hw_tuple) * color[idx])
39 |
40 | color_img = np.array(color_img, dtype=np.uint8).transpose(1, 2, 0)
41 |
42 | mask.append(np.uint8((color_img == img).sum(axis = -1) == 3))
43 |
44 | return np.array(mask)
45 |
46 |
47 | def loader(training_path, segmented_path, batch_size, h=512, w=512):
48 | """
49 | The Loader to generate inputs and labels from the Image and Segmented Directory
50 |
51 | Arguments:
52 |
53 | training_path - str - Path to the directory that contains the training images
54 |
55 | segmented_path - str - Path to the directory that contains the segmented images
56 |
57 | batch_size - int - the batch size
58 |
59 | yields inputs and labels of the batch size
60 | """
61 |
62 | filenames_t = os.listdir(training_path)
63 | total_files_t = len(filenames_t)
64 |
65 | filenames_s = os.listdir(segmented_path)
66 | total_files_s = len(filenames_s)
67 |
68 | assert(total_files_t == total_files_s)
69 |
70 | if str(batch_size).lower() == 'all':
71 | batch_size = total_files_s
72 |
73 | idx = 0
74 | while(1):
75 | batch_idxs = np.random.randint(0, total_files_s, batch_size)
76 |
77 |
78 | inputs = []
79 | labels = []
80 |
81 | for jj in batch_idxs:
82 | img = plt.imread(training_path + filenames_t[jj])
83 | img = cv2.resize(img, (h, w), cv2.INTER_NEAREST)
84 | inputs.append(img)
85 |
86 | img = Image.open(segmented_path + filenames_s[jj])
87 | img = np.array(img)
88 | img = cv2.resize(img, (h, w), cv2.INTER_NEAREST)
89 | labels.append(img)
90 |
91 | inputs = np.stack(inputs, axis=2)
92 | inputs = torch.tensor(inputs).transpose(0, 2).transpose(1, 3)
93 |
94 | labels = torch.tensor(labels)
95 |
96 | yield inputs, labels
97 |
98 |
99 | def decode_segmap(image):
100 | Sky = [128, 128, 128]
101 | Building = [128, 0, 0]
102 | Pole = [192, 192, 128]
103 | Road_marking = [255, 69, 0]
104 | Road = [128, 64, 128]
105 | Pavement = [60, 40, 222]
106 | Tree = [128, 128, 0]
107 | SignSymbol = [192, 128, 128]
108 | Fence = [64, 64, 128]
109 | Car = [64, 0, 128]
110 | Pedestrian = [64, 64, 0]
111 | Bicyclist = [0, 128, 192]
112 |
113 | label_colors = np.array([Sky, Building, Pole, Road_marking, Road,
114 | Pavement, Tree, SignSymbol, Fence, Car,
115 | Pedestrian, Bicyclist]).astype(np.uint8)
116 |
117 | r = np.zeros_like(image).astype(np.uint8)
118 | g = np.zeros_like(image).astype(np.uint8)
119 | b = np.zeros_like(image).astype(np.uint8)
120 |
121 | for label in range(len(label_colors)):
122 | r[image == label] = label_colors[label, 0]
123 | g[image == label] = label_colors[label, 1]
124 | b[image == label] = label_colors[label, 2]
125 |
126 | rgb = np.zeros((image.shape[0], image.shape[1], 3)).astype(np.uint8)
127 | rgb[:, :, 0] = r
128 | rgb[:, :, 1] = g
129 | rgb[:, :, 2] = b
130 |
131 | return rgb
132 |
133 | def show_images(images, in_row=True):
134 | '''
135 | Helper function to show 3 images
136 | '''
137 | total_images = len(images)
138 |
139 | rc_tuple = (1, total_images)
140 | if not in_row:
141 | rc_tuple = (total_images, 1)
142 |
143 | #figure = plt.figure(figsize=(20, 10))
144 | for ii in range(len(images)):
145 | plt.subplot(*rc_tuple, ii+1)
146 | plt.title(images[ii][0])
147 | plt.axis('off')
148 | plt.imshow(images[ii][1])
149 | plt.show()
150 |
151 | def get_class_weights(loader, num_classes, c=1.02):
152 | '''
153 | This class return the class weights for each class
154 |
155 | Arguments:
156 | - loader : The generator object which return all the labels at one iteration
157 | Do Note: That this class expects all the labels to be returned in
158 | one iteration
159 |
160 | - num_classes : The number of classes
161 |
162 | Return:
163 | - class_weights : An array equal in length to the number of classes
164 | containing the class weights for each class
165 | '''
166 |
167 | _, labels = next(loader)
168 | all_labels = labels.flatten()
169 | each_class = np.bincount(all_labels, minlength=num_classes)
170 | prospensity_score = each_class / len(all_labels)
171 | class_weights = 1 / (np.log(c + prospensity_score))
172 | return class_weights
173 |
--------------------------------------------------------------------------------