├── .idea
├── ESPNetv2.iml
├── misc.xml
├── modules.xml
├── vcs.xml
└── workspace.xml
├── LICENSE
├── README.md
├── imagenet
├── LRSchedule.py
├── Model.py
├── README.md
├── cnn_utils.py
├── evaluate.py
├── main.py
├── pretrained_weights
│ ├── espnetv2_s_0.5.pth
│ ├── espnetv2_s_1.0.pth
│ ├── espnetv2_s_1.25.pth
│ ├── espnetv2_s_1.5.pth
│ └── espnetv2_s_2.0.pth
└── utils.py
├── images
├── ReadMe.md
├── effCompare.png
└── powerTX2.png
└── segmentation
├── DataSet.py
├── IOUEval.py
├── README.md
├── Transforms.py
├── cnn
├── Model.py
├── SegmentationModel.py
└── cnn_utils.py
├── gen_cityscapes.py
├── loadData.py
├── main.py
└── train_utils.py
/.idea/ESPNetv2.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Sachin Mehta
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ESPNetv2: A Light-weight, Power Efficient, and General Purpose Convolutional Neural Network
2 |
3 | **IMPORTANT NOTE 1 (7 June, 2019)**: We have released new code base that supports several datasets and models, including ESPNetv2. Please see [here](https://github.com/sacmehta/EdgeNets) for more details.
4 |
5 | **IMPORTANT NOTE 2 (7 June, 2019)**: This repository is obsolete and we are not maintaining it anymore.
6 |
7 | This repository contains the source code of our paper, [ESPNetv2](https://arxiv.org/abs/1811.11431) which is accepted for publication at CVPR'19.
8 |
9 | ***Note:*** New segmentation models for the PASCAL VOC and the Cityscapes are coming soon. Our new models achieves mIOU of [68.0](http://host.robots.ox.ac.uk:8080/anonymous/DAMVRR.html) and [66.15](https://www.cityscapes-dataset.com/anonymous-results/?id=2267c613d55dd75d5301850c913b1507bf2f10586ca73eb8ebcf357cdcf3e036) on the PASCAL VOC and the Cityscapes test sets, respectively.
10 |
11 |
12 |
13 |
Real-time semantic segmentation using ESPNetv2 on iPhone7 (see EdgeNets for details)
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 | ## Comparison with SOTA methods
26 | Compared to state-of-the-art efficient networks, our network delivers competitive performance while being much more **power efficient**. Sample results are shown in below figure. For more details, please read our paper.
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
FLOPs vs. accuracy on the ImageNet dataset
40 |
41 |
42 |
Power consumption on TX2 device
43 |
44 |
45 |
46 |
47 |
48 |
49 | If you find our project useful in your research, please consider citing:
50 |
51 | ```
52 | @inproceedings{mehta2019espnetv2,
53 | title={ESPNetv2: A Light-weight, Power Efficient, and General Purpose Convolutional Neural Network},
54 | author={Sachin Mehta and Mohammad Rastegari and Linda Shapiro and Hannaneh Hajishirzi},
55 | booktitle={CVPR},
56 | year={2019}
57 | }
58 |
59 | @inproceedings{mehta2018espnet,
60 | title={ESPNet: Efficient Spatial Pyramid of Dilated Convolutions for Semantic Segmentation},
61 | author={Sachin Mehta and Mohammad Rastegari and Anat Caspi and Linda Shapiro and Hannaneh Hajishirzi},
62 | booktitle={ECCV},
63 | year={2018}
64 | }
65 | ```
66 |
67 | ## Structure
68 | This repository contains source code and pretrained for the following:
69 | * **Object classification:** We provide source code along with pre-trained models at different network complexities
70 | for the ImageNet dataset. Click [here](imagenet) for more details.
71 | * **Semantic segmentation:** We provide source code along with pre-trained models on the Cityscapes dataset. Check [here](segmentation) for more details.
72 |
73 | ## Requirements
74 |
75 | To run this repository, you should have following softwares installed:
76 | * PyTorch - We tested with v0.4.1
77 | * OpenCV - We tested with version 3.4.3
78 | * Python3 - Our code is written in Python3. We recommend to use [Anaconda](https://www.anaconda.com/) for the same.
79 |
80 | ## Instructions to install Pytorch and OpenCV with Anaconda
81 |
82 | Assuming that you have installed Anaconda successfully, you can follow the following instructions to install the packeges:
83 |
84 | ### PyTorch
85 | ```
86 | conda install pytorch torchvision -c pytorch
87 | ```
88 |
89 | Once installed, run the following commands in your terminal to verify the version:
90 | ```
91 | import torch
92 | torch.__version__
93 | ```
94 | This should print something like this `0.4.1.post2`.
95 |
96 | If your version is different, then follow PyTorch website [here](https://pytorch.org/) for more details.
97 |
98 | ### OpenCV
99 | ```
100 | conda install pip
101 | pip install --upgrade pip
102 | pip install opencv-python
103 | ```
104 |
105 | Once installed, run the following commands in your terminal to verify the version:
106 | ```
107 | import cv2
108 | cv2.__version__
109 | ```
110 | This should print something like this `3.4.3`.
111 |
112 |
113 | ## Implementation note
114 |
115 | You will see that `EESP` unit, the core building block of the ESPNetv2 architecture, has a `for` loop to process the input at different dilation rates.
116 | You can parallelize it using **Streams** in PyTorch. It improves the inference speed.
117 |
118 | A snippet to parallelize a `for` loop in pytorch is shown below:
119 | ```
120 | # Sequential version
121 | output = []
122 | a = torch.randn(1, 3, 10, 10)
123 | for i in range(4):
124 | output.append(a)
125 | torch.cat(output, 1)
126 | ```
127 |
128 | ```
129 | # Parallel version
130 | num_branches = 4
131 | streams = [(idx, torch.cuda.Stream()) for idx in range(num_branches)]
132 | output = []
133 | a = torch.randn(1, 3, 10, 10)
134 | for idx, s in streams:
135 | with torch.cuda.stream(s):
136 | output.append(a)
137 | torch.cuda.synchronize()
138 | torch.cat(output, 1)
139 | ```
140 |
141 | **Note:**
142 | * we have used above strategy to measure inference related statistics, including power consumption and run time on a single GPU.
143 | * We have not tested it (for training as well as inference) across multiple GPUs. If you want to use Streams and facing issues, please use PyTorch forums to resolve your queries.
144 |
--------------------------------------------------------------------------------
/imagenet/LRSchedule.py:
--------------------------------------------------------------------------------
1 |
2 | #============================================
3 | __author__ = "Sachin Mehta"
4 | __license__ = "MIT"
5 | __maintainer__ = "Sachin Mehta"
6 | #============================================
7 |
8 | class MyLRScheduler(object):
9 | '''
10 | CLass that defines cyclic learning rate that decays the learning rate linearly till the end of cycle and then restarts
11 | at the maximum value.
12 | '''
13 | def __init__(self, initial=0.1, cycle_len=5, steps=[51, 101, 131, 161, 191, 221, 251, 281], gamma=2):
14 | super(MyLRScheduler, self).__init__()
15 | assert len(steps) > 1, 'Please specify step intervals.'
16 | self.min_lr = initial # minimum learning rate
17 | self.m = cycle_len
18 | self.steps = steps
19 | self.warm_up_interval = 1 # we do not start from max value for the first epoch, because some time it diverges
20 | self.counter = 0
21 | self.decayFactor = gamma # factor by which we should decay learning rate
22 | self.count_cycles = 0
23 | self.step_counter = 0
24 | self.stepping = True
25 | print('Using Cyclic LR Scheduler with warm restarts')
26 |
27 | def get_lr(self, epoch):
28 | if epoch%self.steps[self.step_counter] == 0 and epoch > 1 and self.stepping:
29 | self.min_lr = self.min_lr / self.decayFactor
30 | self.count_cycles = 0
31 | if self.step_counter < len(self.steps) - 1:
32 | self.step_counter += 1
33 | else:
34 | self.stepping = False
35 | current_lr = self.min_lr
36 | # warm-up or cool-down phase
37 | if self.count_cycles < self.warm_up_interval:
38 | self.count_cycles += 1
39 | # We do not need warm up after first step.
40 | # so, we set warm up interval to 0 after first step
41 | if self.count_cycles == self.warm_up_interval:
42 | self.warm_up_interval = 0
43 | else:
44 | #Cyclic learning rate with warm restarts
45 | # max_lr (= min_lr * step_size) is decreased to min_lr using linear decay before
46 | # it is set to max value at the end of cycle.
47 | if self.counter >= self.m:
48 | self.counter = 0
49 | current_lr = round((self.min_lr * self.m) - (self.counter * self.min_lr), 5)
50 | self.counter += 1
51 | self.count_cycles += 1
52 | return current_lr
53 |
54 |
55 | if __name__ == '__main__':
56 | lrSched = MyLRScheduler(0.1)#MyLRScheduler(0.1, 5, [51, 101, 131, 161, 191, 221, 251, 281, 311, 341, 371])
57 | max_epochs = 300
58 | for i in range(max_epochs):
59 | print(i, lrSched.get_lr(i))
60 |
--------------------------------------------------------------------------------
/imagenet/Model.py:
--------------------------------------------------------------------------------
1 | from torch.nn import init
2 | import torch.nn.functional as F
3 | from cnn_utils import *
4 | import math
5 | import torch
6 |
7 | #============================================
8 | __author__ = "Sachin Mehta"
9 | __license__ = "MIT"
10 | __maintainer__ = "Sachin Mehta"
11 | #============================================
12 |
13 | class EESP(nn.Module):
14 | '''
15 | This class defines the EESP block, which is based on the following principle
16 | REDUCE ---> SPLIT ---> TRANSFORM --> MERGE
17 | '''
18 |
19 | def __init__(self, nIn, nOut, stride=1, k=4, r_lim=7, down_method='esp'): #down_method --> ['avg' or 'esp']
20 | '''
21 | :param nIn: number of input channels
22 | :param nOut: number of output channels
23 | :param stride: factor by which we should skip (useful for down-sampling). If 2, then down-samples the feature map by 2
24 | :param k: # of parallel branches
25 | :param r_lim: A maximum value of receptive field allowed for EESP block
26 | :param down_method: Downsample or not (equivalent to say stride is 2 or not)
27 | '''
28 | super().__init__()
29 | self.stride = stride
30 | n = int(nOut / k)
31 | n1 = nOut - (k - 1) * n
32 | assert down_method in ['avg', 'esp'], 'One of these is suppported (avg or esp)'
33 | assert n == n1, "n(={}) and n1(={}) should be equal for Depth-wise Convolution ".format(n, n1)
34 | self.proj_1x1 = CBR(nIn, n, 1, stride=1, groups=k)
35 |
36 | # (For convenience) Mapping between dilation rate and receptive field for a 3x3 kernel
37 | map_receptive_ksize = {3: 1, 5: 2, 7: 3, 9: 4, 11: 5, 13: 6, 15: 7, 17: 8}
38 | self.k_sizes = list()
39 | for i in range(k):
40 | ksize = int(3 + 2 * i)
41 | # After reaching the receptive field limit, fall back to the base kernel size of 3 with a dilation rate of 1
42 | ksize = ksize if ksize <= r_lim else 3
43 | self.k_sizes.append(ksize)
44 | # sort (in ascending order) these kernel sizes based on their receptive field
45 | # This enables us to ignore the kernels (3x3 in our case) with the same effective receptive field in hierarchical
46 | # feature fusion because kernels with 3x3 receptive fields does not have gridding artifact.
47 | self.k_sizes.sort()
48 | self.spp_dw = nn.ModuleList()
49 | for i in range(k):
50 | d_rate = map_receptive_ksize[self.k_sizes[i]]
51 | self.spp_dw.append(CDilated(n, n, kSize=3, stride=stride, groups=n, d=d_rate))
52 | # Performing a group convolution with K groups is the same as performing K point-wise convolutions
53 | self.conv_1x1_exp = CB(nOut, nOut, 1, 1, groups=k)
54 | self.br_after_cat = BR(nOut)
55 | self.module_act = nn.PReLU(nOut)
56 | self.downAvg = True if down_method == 'avg' else False
57 |
58 | def forward(self, input):
59 | '''
60 | :param input: input feature map
61 | :return: transformed feature map
62 | '''
63 |
64 | # Reduce --> project high-dimensional feature maps to low-dimensional space
65 | output1 = self.proj_1x1(input)
66 | output = [self.spp_dw[0](output1)]
67 | # compute the output for each branch and hierarchically fuse them
68 | # i.e. Split --> Transform --> HFF
69 | for k in range(1, len(self.spp_dw)):
70 | out_k = self.spp_dw[k](output1)
71 | # HFF
72 | out_k = out_k + output[k - 1]
73 | output.append(out_k)
74 | # Merge
75 | expanded = self.conv_1x1_exp( # learn linear combinations using group point-wise convolutions
76 | self.br_after_cat( # apply batch normalization followed by activation function (PRelu in this case)
77 | torch.cat(output, 1) # concatenate the output of different branches
78 | )
79 | )
80 | del output
81 | # if down-sampling, then return the concatenated vector
82 | # because Downsampling function will combine it with avg. pooled feature map and then threshold it
83 | if self.stride == 2 and self.downAvg:
84 | return expanded
85 |
86 | # if dimensions of input and concatenated vector are the same, add them (RESIDUAL LINK)
87 | if expanded.size() == input.size():
88 | expanded = expanded + input
89 |
90 | # Threshold the feature map using activation function (PReLU in this case)
91 | return self.module_act(expanded)
92 |
93 |
94 | class DownSampler(nn.Module):
95 | '''
96 | Down-sampling fucntion that has three parallel branches: (1) avg pooling,
97 | (2) EESP block with stride of 2 and (3) efficient long-range connection with the input.
98 | The output feature maps of branches from (1) and (2) are concatenated and then additively fused with (3) to produce
99 | the final output.
100 | '''
101 |
102 | def __init__(self, nin, nout, k=4, r_lim=9, reinf=True):
103 | '''
104 | :param nin: number of input channels
105 | :param nout: number of output channels
106 | :param k: # of parallel branches
107 | :param r_lim: A maximum value of receptive field allowed for EESP block
108 | :param reinf: Use long range shortcut connection with the input or not.
109 | '''
110 | super().__init__()
111 | nout_new = nout - nin
112 | self.eesp = EESP(nin, nout_new, stride=2, k=k, r_lim=r_lim, down_method='avg')
113 | self.avg = nn.AvgPool2d(kernel_size=3, padding=1, stride=2)
114 | if reinf:
115 | self.inp_reinf = nn.Sequential(
116 | CBR(config_inp_reinf, config_inp_reinf, 3, 1),
117 | CB(config_inp_reinf, nout, 1, 1)
118 | )
119 | self.act = nn.PReLU(nout)
120 |
121 | def forward(self, input, input2=None):
122 | '''
123 | :param input: input feature map
124 | :return: feature map down-sampled by a factor of 2
125 | '''
126 | avg_out = self.avg(input)
127 | eesp_out = self.eesp(input)
128 | output = torch.cat([avg_out, eesp_out], 1)
129 |
130 | if input2 is not None:
131 | #assuming the input is a square image
132 | # Shortcut connection with the input image
133 | w1 = avg_out.size(2)
134 | while True:
135 | input2 = F.avg_pool2d(input2, kernel_size=3, padding=1, stride=2)
136 | w2 = input2.size(2)
137 | if w2 == w1:
138 | break
139 | output = output + self.inp_reinf(input2)
140 |
141 | return self.act(output)
142 |
143 | class EESPNet(nn.Module):
144 | '''
145 | This class defines the ESPNetv2 architecture for the ImageNet classification
146 | '''
147 |
148 | def __init__(self, classes=1000, s=1):
149 | '''
150 | :param classes: number of classes in the dataset. Default is 1000 for the ImageNet dataset
151 | :param s: factor that scales the number of output feature maps
152 | '''
153 | super().__init__()
154 | reps = [0, 3, 7, 3] # how many times EESP blocks should be repeated at each spatial level.
155 | channels = 3
156 |
157 | r_lim = [13, 11, 9, 7, 5] # receptive field at each spatial level
158 | K = [4]*len(r_lim) # No. of parallel branches at different levels
159 |
160 | base = 32 #base configuration
161 | config_len = 5
162 | config = [base] * config_len
163 | base_s = 0
164 | for i in range(config_len):
165 | if i== 0:
166 | base_s = int(base * s)
167 | base_s = math.ceil(base_s / K[0]) * K[0]
168 | config[i] = base if base_s > base else base_s
169 | else:
170 | config[i] = base_s * pow(2, i)
171 | if s <= 1.5:
172 | config.append(1024)
173 | elif s <= 2.0:
174 | config.append(1280)
175 | else:
176 | ValueError('Configuration not supported')
177 |
178 |
179 | global config_inp_reinf
180 | config_inp_reinf = 3
181 | self.input_reinforcement = True # True for the shortcut connection with input
182 |
183 | assert len(K) == len(r_lim), 'Length of branching factor array and receptive field array should be the same.'
184 |
185 | self.level1 = CBR(channels, config[0], 3, 2) # 112 L1
186 |
187 | self.level2_0 = DownSampler(config[0], config[1], k=K[0], r_lim=r_lim[0], reinf=self.input_reinforcement) # out = 56
188 |
189 | self.level3_0 = DownSampler(config[1], config[2], k=K[1], r_lim=r_lim[1], reinf=self.input_reinforcement) # out = 28
190 | self.level3 = nn.ModuleList()
191 | for i in range(reps[1]):
192 | self.level3.append(EESP(config[2], config[2], stride=1, k=K[2], r_lim=r_lim[2]))
193 |
194 | self.level4_0 = DownSampler(config[2], config[3], k=K[2], r_lim=r_lim[2], reinf=self.input_reinforcement) #out = 14
195 | self.level4 = nn.ModuleList()
196 | for i in range(reps[2]):
197 | self.level4.append(EESP(config[3], config[3], stride=1, k=K[3], r_lim=r_lim[3]))
198 |
199 | self.level5_0 = DownSampler(config[3], config[4], k=K[3], r_lim=r_lim[3]) #7
200 | self.level5 = nn.ModuleList()
201 | for i in range(reps[3]):
202 | self.level5.append(EESP(config[4], config[4], stride=1, k=K[4], r_lim=r_lim[4]))
203 |
204 | # expand the feature maps using depth-wise convolution followed by group point-wise convolution
205 | self.level5.append(CBR(config[4], config[4], 3, 1, groups=config[4]))
206 | self.level5.append(CBR(config[4], config[5], 1, 1, groups=K[4]))
207 |
208 | self.classifier = nn.Linear(config[5], classes)
209 | self.init_params()
210 |
211 | def init_params(self):
212 | '''
213 | Function to initialze the parameters
214 | '''
215 | for m in self.modules():
216 | if isinstance(m, nn.Conv2d):
217 | init.kaiming_normal_(m.weight, mode='fan_out')
218 | if m.bias is not None:
219 | init.constant_(m.bias, 0)
220 | elif isinstance(m, nn.BatchNorm2d):
221 | init.constant_(m.weight, 1)
222 | init.constant_(m.bias, 0)
223 | elif isinstance(m, nn.Linear):
224 | init.normal_(m.weight, std=0.001)
225 | if m.bias is not None:
226 | init.constant_(m.bias, 0)
227 |
228 | def forward(self, input, p=0.2):
229 | '''
230 | :param input: Receives the input RGB image
231 | :return: a C-dimensional vector, C=# of classes
232 | '''
233 | out_l1 = self.level1(input) # 112
234 | if not self.input_reinforcement:
235 | del input
236 | input = None
237 |
238 | out_l2 = self.level2_0(out_l1, input) # 56
239 |
240 | out_l3_0 = self.level3_0(out_l2, input) # down-sample
241 | for i, layer in enumerate(self.level3):
242 | if i == 0:
243 | out_l3 = layer(out_l3_0)
244 | else:
245 | out_l3 = layer(out_l3)
246 |
247 | out_l4_0 = self.level4_0(out_l3, input) # down-sample
248 | for i, layer in enumerate(self.level4):
249 | if i == 0:
250 | out_l4 = layer(out_l4_0)
251 | else:
252 | out_l4 = layer(out_l4)
253 |
254 | out_l5_0 = self.level5_0(out_l4) # down-sample
255 | for i, layer in enumerate(self.level5):
256 | if i == 0:
257 | out_l5 = layer(out_l5_0)
258 | else:
259 | out_l5 = layer(out_l5)
260 |
261 | output_g = F.adaptive_avg_pool2d(out_l5, output_size=1)
262 | output_g = F.dropout(output_g, p=p, training=self.training)
263 | output_1x1 = output_g.view(output_g.size(0), -1)
264 |
265 | return self.classifier(output_1x1)
266 |
267 |
268 | if __name__ == '__main__':
269 | input = torch.Tensor(1, 3, 224, 224).cuda()
270 | model = EESPNet(classes=1000, s=1.0)
271 | out = model(input)
272 | print('Output size')
273 | print(out.size())
274 |
--------------------------------------------------------------------------------
/imagenet/README.md:
--------------------------------------------------------------------------------
1 | # [ESPNetv2: A Light-weight, Power Efficient, and General Purpose Convolutional Neural Network](https://arxiv.org/abs/1811.11431)
2 |
3 | This repository contains the source code for training on the ImageNet dataset along with the pre-trained models
4 |
5 | ## Training and Evaluation on the ImageNet dataset
6 |
7 | Below are the commands to train and test the network at scale `s=1.0`.
8 |
9 | ### Training
10 | To train the network, you can use the following command:
11 |
12 | ```
13 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --batch-size 512 --s 1.0 --data
14 | ```
15 |
16 | ### Evaluation
17 | To evaluate our pretrained models (or the ones trained by you), you can use `evaluate.py` file.
18 |
19 | Use below command to evaluate the performance of our model at scale `s=1.0` on the ImageNet dataset.
20 | ```
21 | CUDA_VISIBLE_DEVICES=0 python evaluate.py --batch-size 512 --s 1.0 --weightFile ./pretrained_weights/espnetv2_s_1.0.pth --data
22 | ```
23 |
24 | ## Results and pre-trained models
25 | We release the pre-trained models at different computational complexities. Following state-of-the-art methods, we measure top-1 accuracy on a
26 | cropped center view of size 224x224.
27 |
28 | Below table provide details about the performance of our model on the ImageNet validation set at different computational complexities along with links to download the pre-trained weights.
29 |
30 |
31 | | s | Params | FLOPs | top-1 (val) | Link |
32 | | -------- |--------|--------|-------| -------|
33 | | 0.5 | 1.24 | 28.37 | 57.7 | [here](pretrained_weights/espnetv2_s_0.5.pth) |
34 | | 1.0 | 1.67 | 85.72 | 66.1 | [here](pretrained_weights/espnetv2_s_1.0.pth) |
35 | | 1.25 | 1.96 | 123.39 | 67.9 | [here](pretrained_weights/espnetv2_s_1.25.pth) |
36 | | 1.5 | 2.31 | 168.6 | 69.2 | [here](pretrained_weights/espnetv2_s_1.5.pth) |
37 | | 2.0 | 3.49 | 284.8 | 72.1 | [here](pretrained_weights/espnetv2_s_2.0.pth) |
38 |
39 |
40 | ## ImageNet dataset preparation
41 | To prepare the dataset, follow instructions [here](https://github.com/facebook/fb.resnet.torch/blob/master/INSTALL.md#download-the-imagenet-dataset).
42 |
43 |
--------------------------------------------------------------------------------
/imagenet/cnn_utils.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | #============================================
4 | __author__ = "Sachin Mehta"
5 | __license__ = "MIT"
6 | __maintainer__ = "Sachin Mehta"
7 | #============================================
8 |
9 | class CBR(nn.Module):
10 | '''
11 | This class defines the convolution layer with batch normalization and PReLU activation
12 | '''
13 |
14 | def __init__(self, nIn, nOut, kSize, stride=1, groups=1):
15 | '''
16 |
17 | :param nIn: number of input channels
18 | :param nOut: number of output channels
19 | :param kSize: kernel size
20 | :param stride: stride rate for down-sampling. Default is 1
21 | '''
22 | super().__init__()
23 | padding = int((kSize - 1) / 2)
24 | self.conv = nn.Conv2d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False, groups=groups)
25 | self.bn = nn.BatchNorm2d(nOut)
26 | self.act = nn.PReLU(nOut)
27 |
28 | def forward(self, input):
29 | '''
30 | :param input: input feature map
31 | :return: transformed feature map
32 | '''
33 | output = self.conv(input)
34 | # output = self.conv1(output)
35 | output = self.bn(output)
36 | output = self.act(output)
37 | return output
38 |
39 |
40 | class BR(nn.Module):
41 | '''
42 | This class groups the batch normalization and PReLU activation
43 | '''
44 |
45 | def __init__(self, nOut):
46 | '''
47 | :param nOut: output feature maps
48 | '''
49 | super().__init__()
50 | self.bn = nn.BatchNorm2d(nOut)
51 | self.act = nn.PReLU(nOut)
52 |
53 | def forward(self, input):
54 | '''
55 | :param input: input feature map
56 | :return: normalized and thresholded feature map
57 | '''
58 | output = self.bn(input)
59 | output = self.act(output)
60 | return output
61 |
62 |
63 | class CB(nn.Module):
64 | '''
65 | This class groups the convolution and batch normalization
66 | '''
67 |
68 | def __init__(self, nIn, nOut, kSize, stride=1, groups=1):
69 | '''
70 | :param nIn: number of input channels
71 | :param nOut: number of output channels
72 | :param kSize: kernel size
73 | :param stride: optinal stide for down-sampling
74 | '''
75 | super().__init__()
76 | padding = int((kSize - 1) / 2)
77 | self.conv = nn.Conv2d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False,
78 | groups=groups)
79 | self.bn = nn.BatchNorm2d(nOut)
80 |
81 | def forward(self, input):
82 | '''
83 |
84 | :param input: input feature map
85 | :return: transformed feature map
86 | '''
87 | output = self.conv(input)
88 | output = self.bn(output)
89 | return output
90 |
91 |
92 | class C(nn.Module):
93 | '''
94 | This class is for a convolutional layer.
95 | '''
96 |
97 | def __init__(self, nIn, nOut, kSize, stride=1, groups=1):
98 | '''
99 |
100 | :param nIn: number of input channels
101 | :param nOut: number of output channels
102 | :param kSize: kernel size
103 | :param stride: optional stride rate for down-sampling
104 | '''
105 | super().__init__()
106 | padding = int((kSize - 1) / 2)
107 | self.conv = nn.Conv2d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False,
108 | groups=groups)
109 |
110 | def forward(self, input):
111 | '''
112 | :param input: input feature map
113 | :return: transformed feature map
114 | '''
115 | output = self.conv(input)
116 | return output
117 |
118 |
119 | class CDilated(nn.Module):
120 | '''
121 | This class defines the dilated convolution.
122 | '''
123 |
124 | def __init__(self, nIn, nOut, kSize, stride=1, d=1, groups=1):
125 | '''
126 | :param nIn: number of input channels
127 | :param nOut: number of output channels
128 | :param kSize: kernel size
129 | :param stride: optional stride rate for down-sampling
130 | :param d: optional dilation rate
131 | '''
132 | super().__init__()
133 | padding = int((kSize - 1) / 2) * d
134 | self.conv = nn.Conv2d(nIn, nOut,kSize, stride=stride, padding=padding, bias=False,
135 | dilation=d, groups=groups)
136 |
137 | def forward(self, input):
138 | '''
139 | :param input: input feature map
140 | :return: transformed feature map
141 | '''
142 | output = self.conv(input)
143 | return output
144 |
145 | class CDilatedB(nn.Module):
146 | '''
147 | This class defines the dilated convolution with batch normalization.
148 | '''
149 |
150 | def __init__(self, nIn, nOut, kSize, stride=1, d=1, groups=1):
151 | '''
152 | :param nIn: number of input channels
153 | :param nOut: number of output channels
154 | :param kSize: kernel size
155 | :param stride: optional stride rate for down-sampling
156 | :param d: optional dilation rate
157 | '''
158 | super().__init__()
159 | padding = int((kSize - 1) / 2) * d
160 | self.conv = nn.Conv2d(nIn, nOut,kSize, stride=stride, padding=padding, bias=False,
161 | dilation=d, groups=groups)
162 | self.bn = nn.BatchNorm2d(nOut)
163 |
164 | def forward(self, input):
165 | '''
166 | :param input: input feature map
167 | :return: transformed feature map
168 | '''
169 | return self.bn(self.conv(input))
170 |
--------------------------------------------------------------------------------
/imagenet/evaluate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | from torch.autograd import Variable
4 | import torch.nn.parallel
5 | import torch.backends.cudnn as cudnn
6 | import torch.optim
7 | import torch.utils.data
8 | import torch.utils.data.distributed
9 | import torchvision.transforms as transforms
10 | import torchvision.datasets as datasets
11 | import Model as Net
12 | import numpy as np
13 | from utils import *
14 | import os
15 | import torchvision.models as preModels
16 |
17 | cudnn.benchmark = True
18 |
19 | '''
20 | This file is mostly adapted from the PyTorch ImageNet example
21 | '''
22 |
23 | #============================================
24 | __author__ = "Sachin Mehta"
25 | __license__ = "MIT"
26 | __maintainer__ = "Sachin Mehta"
27 | #============================================
28 |
29 | def main(args):
30 |
31 | model = Net.EESPNet(classes=1000, s=args.s)
32 | model = torch.nn.DataParallel(model).cuda()
33 | if not os.path.isfile(args.weightFile):
34 | print('Weight file does not exist')
35 | exit(-1)
36 | dict_model = torch.load(args.weightFile)
37 | model.load_state_dict(dict_model)
38 |
39 |
40 | n_params = sum([np.prod(p.size()) for p in model.parameters()])
41 | print('Parameters: ' + str(n_params))
42 |
43 | # Data loading code
44 | valdir = os.path.join(args.data, 'val')
45 | traindir = os.path.join(args.data, 'train')
46 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
47 | std=[0.229, 0.224, 0.225])
48 |
49 | val_loader = torch.utils.data.DataLoader(
50 | datasets.ImageFolder(valdir, transforms.Compose([
51 | transforms.Resize(int(args.inpSize/0.875)),
52 | transforms.CenterCrop(args.inpSize),
53 | transforms.ToTensor(),
54 | normalize,
55 | ])),
56 | batch_size=args.batch_size, shuffle=False,
57 | num_workers=args.workers, pin_memory=True)
58 |
59 | validate(val_loader, model)
60 | return
61 |
62 |
63 | if __name__ == '__main__':
64 | parser = argparse.ArgumentParser(description='ESPNetv2 Training on the ImageNet')
65 | parser.add_argument('--data', default='/home/ubuntu/ILSVRC2015/Data/CLS-LOC/', help='path to dataset')
66 | parser.add_argument('--workers', default=12, type=int, help='number of data loading workers (default: 4)')
67 | parser.add_argument('-b', '--batch-size', default=512, type=int,
68 | metavar='N', help='mini-batch size (default: 256)')
69 | parser.add_argument('--print-freq', '-p', default=10, type=int,
70 | metavar='N', help='print frequency (default: 10)')
71 | parser.add_argument('--s', default=1, type=float,
72 | help='Factor by which output channels should be reduced (s > 1 for increasing the dims while < 1 for decreasing)')
73 | parser.add_argument('--weightFile', type=str, default='', help='weight file')
74 | parser.add_argument('--inpSize', default=224, type=int,
75 | help='Input size')
76 |
77 |
78 | args = parser.parse_args()
79 | args.parallel = True
80 |
81 | main(args)
82 |
--------------------------------------------------------------------------------
/imagenet/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch.nn.parallel
3 | import torch.backends.cudnn as cudnn
4 | import torch.optim
5 | import torch.utils.data
6 | import torch.utils.data.distributed
7 | import torchvision.transforms as transforms
8 | import torchvision.datasets as datasets
9 | import Model as Net
10 | import numpy as np
11 | from utils import *
12 | import random
13 | import os
14 | from LRSchedule import MyLRScheduler
15 |
16 | cudnn.benchmark = True
17 |
18 |
19 | #============================================
20 | __author__ = "Sachin Mehta"
21 | __license__ = "MIT"
22 | __maintainer__ = "Sachin Mehta"
23 | #============================================
24 |
25 | def compute_params(model):
26 | return sum([np.prod(p.size()) for p in model.parameters()])
27 |
28 | def main(args):
29 | best_prec1 = 0.0
30 |
31 | if not os.path.isdir(args.savedir):
32 | os.mkdir(args.savedir)
33 |
34 | # create model
35 | model = Net.EESPNet(classes=1000, s=args.s)
36 | print('Network Parameters: ' + str(compute_params(model)))
37 |
38 | #check if the cuda is available or not
39 | cuda_available = torch.cuda.is_available()
40 |
41 | num_gpus = torch.cuda.device_count()
42 | if num_gpus >= 1:
43 | model = torch.nn.DataParallel(model)
44 |
45 | if cuda_available:
46 | model = model.cuda()
47 |
48 |
49 | logFileLoc = args.savedir + 'logs.txt'
50 | if os.path.isfile(logFileLoc):
51 | logger = open(logFileLoc, 'a')
52 | else:
53 | logger = open(logFileLoc, 'w')
54 | logger.write("\n%s\t%s\t%s\t%s\t%s\t" % ('Epoch', 'Loss(Tr)', 'Loss(val)', 'top1 (tr)', 'top1 (val'))
55 |
56 | optimizer = torch.optim.SGD(model.parameters(), args.lr,
57 | momentum=args.momentum,
58 | weight_decay=args.weight_decay, nesterov=True)
59 |
60 | # optionally resume from a checkpoint
61 | if args.resume:
62 | if os.path.isfile(args.resume):
63 | print("=> loading checkpoint '{}'".format(args.resume))
64 | checkpoint = torch.load(args.resume)
65 | args.start_epoch = checkpoint['epoch']
66 | best_prec1 = checkpoint['best_prec1']
67 | model.load_state_dict(checkpoint['state_dict'])
68 | optimizer.load_state_dict(checkpoint['optimizer'])
69 | print("=> loaded checkpoint '{}' (epoch {})"
70 | .format(args.resume, checkpoint['epoch']))
71 | else:
72 | print("=> no checkpoint found at '{}'".format(args.resume))
73 |
74 | # Data loading code
75 | traindir = os.path.join(args.data, 'train')
76 | valdir = os.path.join(args.data, 'val')
77 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
78 | std=[0.229, 0.224, 0.225])
79 |
80 | train_loader1 = torch.utils.data.DataLoader(
81 | datasets.ImageFolder(
82 | traindir,
83 | transforms.Compose([
84 | transforms.RandomResizedCrop(args.inpSize),
85 | transforms.RandomHorizontalFlip(),
86 | transforms.ToTensor(),
87 | normalize,
88 | ])),
89 | batch_size=args.batch_size, shuffle=True,
90 | num_workers=args.workers, pin_memory=True)
91 |
92 |
93 | val_loader = torch.utils.data.DataLoader(
94 | datasets.ImageFolder(valdir, transforms.Compose([
95 | transforms.Resize(int(args.inpSize/0.875)),
96 | transforms.CenterCrop(args.inpSize),
97 | transforms.ToTensor(),
98 | normalize,
99 | ])),
100 | batch_size=args.batch_size, shuffle=False,
101 | num_workers=args.workers, pin_memory=True)
102 |
103 | # global customLR
104 | # steps at which we should decrease the learning rate
105 | step_sizes = [51, 101, 131, 161, 191, 221, 251, 281]
106 |
107 | #ImageNet experiments consume a lot of time
108 | # Just for safety, store the checkpoint before we decrease the learning rate
109 | # i.e. store the model at step -1
110 | step_store = list()
111 | for step in step_sizes:
112 | step_store.append(step-1)
113 |
114 |
115 | customLR = MyLRScheduler(args.lr, 5, step_sizes)
116 | #set up the variables in case of resuming
117 | if args.start_epoch != 0:
118 | for epoch in range(args.start_epoch):
119 | customLR.get_lr(epoch)
120 |
121 | for epoch in range(args.start_epoch, args.epochs):
122 | lr_log = customLR.get_lr(epoch)
123 | # set the optimizer with the learning rate
124 | # This can be done inside the MyLRScheduler
125 | for param_group in optimizer.param_groups:
126 | param_group['lr'] = lr_log
127 | print("LR for epoch {} = {:.5f}".format(epoch, lr_log))
128 | train_prec1, train_loss = train(train_loader1, model, optimizer, epoch)
129 | # evaluate on validation set
130 | val_prec1, val_loss = validate(val_loader, model)
131 |
132 | # remember best prec@1 and save checkpoint
133 | is_best = val_prec1.item() > best_prec1
134 | best_prec1 = max(val_prec1.item(), best_prec1)
135 | back_check = True if epoch in step_store else False #backup checkpoint or not
136 | save_checkpoint({
137 | 'epoch': epoch + 1,
138 | 'arch': 'ESPNet',
139 | 'state_dict': model.state_dict(),
140 | 'best_prec1': best_prec1,
141 | 'optimizer': optimizer.state_dict(),
142 | }, is_best, back_check, epoch, args.savedir)
143 |
144 | logger.write("\n%d\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.7f" % (epoch, train_loss, val_loss, train_prec1,
145 | val_prec1, lr_log))
146 | logger.flush()
147 |
148 |
149 | if __name__ == '__main__':
150 | parser = argparse.ArgumentParser(description='ESPNetv2 Training on the ImageNet')
151 | parser.add_argument('--data', default='/home/ubuntu/ILSVRC2015/Data/CLS-LOC/', help='path to dataset')
152 | parser.add_argument('--workers', default=12, type=int, help='number of data loading workers (default: 4)')
153 | parser.add_argument('--epochs', default=300, type=int, help='number of total epochs to run')
154 | parser.add_argument('--start-epoch', default=0, type=int, help='manual epoch number (useful on restarts)')
155 | parser.add_argument('--batch-size', default=512, type=int, help='mini-batch size (default: 512)')
156 | parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate')
157 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
158 | parser.add_argument('--weight-decay', default=4e-5, type=float, help='weight decay (default: 4e-5)')
159 | parser.add_argument('--resume', default='', type=str, help='path to latest checkpoint (default: none)')
160 | parser.add_argument('--savedir', type=str, default='./results', help='Location to save the results')
161 | parser.add_argument('--s', default=1, type=float, help='Factor by which output channels should be reduced (s > 1 '
162 | 'for increasing the dims while < 1 for decreasing)')
163 | parser.add_argument('--inpSize', default=224, type=int, help='Input image size (default: 224 x 224)')
164 |
165 |
166 | args = parser.parse_args()
167 | args.parallel = True
168 | cudnn.deterministic = True
169 | random.seed(1882)
170 | torch.manual_seed(1882)
171 |
172 | args.savedir = args.savedir + '_s_' + str(args.s) + '_inp_' + str(args.inpSize) + os.sep
173 |
174 | main(args)
175 |
--------------------------------------------------------------------------------
/imagenet/pretrained_weights/espnetv2_s_0.5.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sacmehta/ESPNetv2/b78e323039908f31347d8ca17f49d5502ef1a594/imagenet/pretrained_weights/espnetv2_s_0.5.pth
--------------------------------------------------------------------------------
/imagenet/pretrained_weights/espnetv2_s_1.0.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sacmehta/ESPNetv2/b78e323039908f31347d8ca17f49d5502ef1a594/imagenet/pretrained_weights/espnetv2_s_1.0.pth
--------------------------------------------------------------------------------
/imagenet/pretrained_weights/espnetv2_s_1.25.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sacmehta/ESPNetv2/b78e323039908f31347d8ca17f49d5502ef1a594/imagenet/pretrained_weights/espnetv2_s_1.25.pth
--------------------------------------------------------------------------------
/imagenet/pretrained_weights/espnetv2_s_1.5.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sacmehta/ESPNetv2/b78e323039908f31347d8ca17f49d5502ef1a594/imagenet/pretrained_weights/espnetv2_s_1.5.pth
--------------------------------------------------------------------------------
/imagenet/pretrained_weights/espnetv2_s_2.0.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sacmehta/ESPNetv2/b78e323039908f31347d8ca17f49d5502ef1a594/imagenet/pretrained_weights/espnetv2_s_2.0.pth
--------------------------------------------------------------------------------
/imagenet/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import shutil
4 | import torch.nn.functional as F
5 | import time
6 |
7 |
8 | #============================================
9 | __author__ = "Sachin Mehta"
10 | __license__ = "MIT"
11 | __maintainer__ = "Sachin Mehta"
12 | #============================================
13 |
14 | '''
15 | This file is mostly adapted from the PyTorch ImageNet example
16 | '''
17 |
18 | class AverageMeter(object):
19 | """Computes and stores the average and current value"""
20 | def __init__(self):
21 | self.reset()
22 |
23 | def reset(self):
24 | self.val = 0
25 | self.avg = 0
26 | self.sum = 0
27 | self.count = 0
28 |
29 | def update(self, val, n=1):
30 | self.val = val
31 | self.sum += val * n
32 | self.count += n
33 | self.avg = self.sum / self.count
34 |
35 |
36 | def accuracy(output, target, topk=(1,)):
37 | """Computes the precision@k for the specified values of k"""
38 | maxk = max(topk)
39 | batch_size = target.size(0)
40 |
41 | _, pred = output.topk(maxk, 1, True, True)
42 | pred = pred.t()
43 | correct = pred.eq(target.view(1, -1).expand_as(pred))
44 |
45 | res = []
46 | for k in topk:
47 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
48 | res.append(correct_k.mul_(100.0 / batch_size))
49 | return res
50 |
51 | '''
52 | Utility to save checkpoint or not
53 | '''
54 | def save_checkpoint(state, is_best, back_check, epoch, dir):
55 | check_pt_file = dir + os.sep + 'checkpoint.pth.tar'
56 | torch.save(state, check_pt_file)
57 | if is_best:
58 | #We only need best models weight and not check point states, etc.
59 | torch.save(state['state_dict'], dir + os.sep + 'model_best.pth')
60 | if back_check:
61 | shutil.copyfile(check_pt_file, dir + os.sep + 'checkpoint_back' + str(epoch) + '.pth.tar')
62 |
63 | '''
64 | Cross entropy loss function
65 | '''
66 | def loss_fn(outputs, labels):
67 |
68 | return F.cross_entropy(outputs, labels)
69 |
70 | '''
71 | Training loop
72 | '''
73 | def train(train_loader, model, optimizer, epoch):
74 |
75 | batch_time = AverageMeter()
76 | data_time = AverageMeter()
77 | losses = AverageMeter()
78 | top1 = AverageMeter()
79 | top5 = AverageMeter()
80 |
81 | # switch to train mode
82 | model.train()
83 |
84 | end = time.time()
85 | for i, (input, target) in enumerate(train_loader):
86 |
87 | # measure data loading time
88 | data_time.update(time.time() - end)
89 |
90 | input = input.cuda(non_blocking=True)
91 | target = target.cuda(non_blocking=True)
92 |
93 | # compute output
94 | output = model(input)
95 |
96 | # compute loss
97 | loss = loss_fn(output, target)
98 |
99 | # measure accuracy and record loss
100 | prec1, prec5 = accuracy(output, target, topk=(1, 5))
101 | #losses.update(loss.data[0], input.size(0))
102 | losses.update(loss.item(), input.size(0))
103 | top1.update(prec1[0], input.size(0))
104 | top5.update(prec5[0], input.size(0))
105 |
106 | # compute gradient and do SGD step
107 | optimizer.zero_grad()
108 | loss.backward()
109 | optimizer.step()
110 |
111 | # measure elapsed time
112 | batch_time.update(time.time() - end)
113 | end = time.time()
114 |
115 | if i % 100 == 0: #print after every 100 batches
116 | print("Epoch: %d[%d/%d]\t\tBatch Time:%.4f\t\tLoss:%.4f\t\ttop1:%.4f (%.4f)\t\ttop5:%.4f (%.4f)" %
117 | (epoch, i, len(train_loader), batch_time.avg, losses.avg, top1.val, top1.avg, top5.val, top5.avg))
118 |
119 |
120 | return top1.avg, losses.avg
121 |
122 | '''
123 | Validation loop
124 | '''
125 | def validate(val_loader, model):
126 | batch_time = AverageMeter()
127 | losses = AverageMeter()
128 | top1 = AverageMeter()
129 | top5 = AverageMeter()
130 |
131 | # switch to evaluate mode
132 | model.eval()
133 |
134 | # with torch.no_grad():
135 | end = time.time()
136 | with torch.no_grad():
137 | for i, (input, target) in enumerate(val_loader):
138 | input = input.cuda(non_blocking=True)
139 | target = target.cuda(non_blocking=True)
140 |
141 | # compute output
142 | output = model(input)
143 | loss = loss_fn(output, target)
144 |
145 | # measure accuracy and record loss
146 | prec1, prec5 = accuracy(output, target, topk=(1, 5))
147 |
148 | # replace if using pytorch version < 0.4
149 | #losses.update(loss.data[0], input.size(0))
150 | losses.update(loss.item(), input.size(0))
151 | top1.update(prec1[0], input.size(0))
152 | top5.update(prec5[0], input.size(0))
153 |
154 | # measure elapsed time
155 | batch_time.update(time.time() - end)
156 | end = time.time()
157 |
158 | if i % 100 == 0: # print after every 100 batches
159 | print("Batch:[%d/%d]\t\tBatchTime:%.3f\t\tLoss:%.3f\t\ttop1:%.3f (%.3f)\t\ttop5:%.3f(%.3f)" %
160 | (i, len(val_loader), batch_time.avg, losses.avg, top1.val, top1.avg, top5.val, top5.avg))
161 |
162 | print(' * Prec@1:%.3f Prec@5:%.3f' % (top1.avg, top5.avg))
163 |
164 | return top1.avg, losses.avg
--------------------------------------------------------------------------------
/images/ReadMe.md:
--------------------------------------------------------------------------------
1 | Directory for sample images
2 |
--------------------------------------------------------------------------------
/images/effCompare.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sacmehta/ESPNetv2/b78e323039908f31347d8ca17f49d5502ef1a594/images/effCompare.png
--------------------------------------------------------------------------------
/images/powerTX2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sacmehta/ESPNetv2/b78e323039908f31347d8ca17f49d5502ef1a594/images/powerTX2.png
--------------------------------------------------------------------------------
/segmentation/DataSet.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import torch.utils.data
3 | import numpy as np
4 |
5 |
6 | #============================================
7 | __author__ = "Sachin Mehta"
8 | __license__ = "MIT"
9 | __maintainer__ = "Sachin Mehta"
10 | #============================================
11 |
12 |
13 | class MyDataset(torch.utils.data.Dataset):
14 | '''
15 | Class to load the dataset
16 | '''
17 | def __init__(self, imList, labelList, transform=None):
18 | '''
19 | :param imList: image list (Note that these lists have been processed and pickled using the loadData.py)
20 | :param labelList: label list (Note that these lists have been processed and pickled using the loadData.py)
21 | :param transform: Type of transformation. SEe Transforms.py for supported transformations
22 | '''
23 | self.imList = imList
24 | self.labelList = labelList
25 | self.transform = transform
26 |
27 | def __len__(self):
28 | return len(self.imList)
29 |
30 | def __getitem__(self, idx):
31 | '''
32 |
33 | :param idx: Index of the image file
34 | :return: returns the image and corresponding label file.
35 | '''
36 | image_name = self.imList[idx]
37 | label_name = self.labelList[idx]
38 | image = cv2.imread(image_name)
39 | label = cv2.imread(label_name, 0)
40 | # if you have 255 label in your label files, map it to the background class (19) in the Cityscapes dataset
41 | if 255 in np.unique(label):
42 | label[label==255] = 19
43 |
44 | if self.transform:
45 | [image, label] = self.transform(image, label)
46 | return (image, label)
47 |
--------------------------------------------------------------------------------
/segmentation/IOUEval.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | #adapted from https://github.com/shelhamer/fcn.berkeleyvision.org/blob/master/score.py
4 |
5 | #============================================
6 | __author__ = "Sachin Mehta"
7 | __license__ = "MIT"
8 | __maintainer__ = "Sachin Mehta"
9 | #============================================
10 |
11 | class iouEval:
12 | def __init__(self, nClasses):
13 | self.nClasses = nClasses
14 | self.reset()
15 |
16 | def reset(self):
17 | self.overall_acc = 0
18 | self.per_class_acc = np.zeros(self.nClasses, dtype=np.float32)
19 | self.per_class_iu = np.zeros(self.nClasses, dtype=np.float32)
20 | self.mIOU = 0
21 | self.batchCount = 0
22 |
23 | def fast_hist(self, a, b):
24 | k = (a >= 0) & (a < self.nClasses)
25 | return np.bincount(self.nClasses * a[k].astype(int) + b[k], minlength=self.nClasses ** 2).reshape(self.nClasses, self.nClasses)
26 |
27 | def compute_hist(self, predict, gth):
28 | hist = self.fast_hist(gth, predict)
29 | return hist
30 |
31 | def addBatch(self, predict, gth):
32 | predict = predict.cpu().numpy().flatten()
33 | gth = gth.cpu().numpy().flatten()
34 |
35 | epsilon = 0.00000001
36 | hist = self.compute_hist(predict, gth)
37 | overall_acc = np.diag(hist).sum() / (hist.sum() + epsilon)
38 | per_class_acc = np.diag(hist) / (hist.sum(1) + epsilon)
39 | per_class_iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist) + epsilon)
40 | mIou = np.nanmean(per_class_iu)
41 |
42 | self.overall_acc +=overall_acc
43 | self.per_class_acc += per_class_acc
44 | self.per_class_iu += per_class_iu
45 | self.mIOU += mIou
46 | self.batchCount += 1
47 |
48 | def getMetric(self):
49 | overall_acc = self.overall_acc/self.batchCount
50 | per_class_acc = self.per_class_acc / self.batchCount
51 | per_class_iu = self.per_class_iu / self.batchCount
52 | mIOU = self.mIOU / self.batchCount
53 |
54 | return overall_acc, per_class_acc, per_class_iu, mIOU
55 |
--------------------------------------------------------------------------------
/segmentation/README.md:
--------------------------------------------------------------------------------
1 | # ESPNetv2: A Light-weight, Power Efficient, and General Purpose Convolutional Neural Network
2 |
3 | This repository contains the source code that we used for semantic segmentation in our paper of our paper, [ESPNetv2](https://arxiv.org/abs/1811.11431).
4 |
5 | ***Note:*** New segmentation models for the PASCAL VOC and the Cityscapes are coming soon. Our new models achieves mIOU of [68.0](http://host.robots.ox.ac.uk:8080/anonymous/DAMVRR.html) and [66.15](https://www.cityscapes-dataset.com/anonymous-results/?id=2267c613d55dd75d5301850c913b1507bf2f10586ca73eb8ebcf357cdcf3e036) on the PASCAL VOC and the Cityscapes test sets, respectively.
6 |
7 | **IMPORTANT NOTE** We released a new repository EdgeNets that contains our work on efficient network design. See [EdgeNets](https://github.com/sacmehta/EdgeNets/) for performance comparison and pretrained models.
8 |
9 | **THIS REPO IS OBSOLETE**
10 |
--------------------------------------------------------------------------------
/segmentation/Transforms.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import random
4 | import cv2
5 |
6 |
7 | #============================================
8 | __author__ = "Sachin Mehta"
9 | __license__ = "MIT"
10 | __maintainer__ = "Sachin Mehta"
11 | #============================================
12 |
13 |
14 | class Scale(object):
15 | """
16 | Randomly crop and resize the given PIL image with a probability of 0.5
17 | """
18 | def __init__(self, wi, he):
19 | '''
20 |
21 | :param wi: width after resizing
22 | :param he: height after reszing
23 | '''
24 | self.w = wi
25 | self.h = he
26 |
27 | def __call__(self, img, label):
28 | '''
29 | :param img: RGB image
30 | :param label: semantic label image
31 | :return: resized images
32 | '''
33 | #bilinear interpolation for RGB image
34 | img = cv2.resize(img, (self.w, self.h))
35 | # nearest neighbour interpolation for label image
36 | label = cv2.resize(label, (self.w, self.h), interpolation=cv2.INTER_NEAREST)
37 |
38 | return [img, label]
39 |
40 |
41 |
42 | class RandomCropResize(object):
43 | """
44 | Randomly crop and resize the given PIL image with a probability of 0.5
45 | """
46 | def __init__(self, size):
47 | '''
48 | :param crop_area: area to be cropped (this is the max value and we select between o and crop area
49 | '''
50 | self.size = size
51 |
52 | def __call__(self, img, label):
53 | h, w = img.shape[:2]
54 | x1 = random.randint(0, int(w*0.1)) # 25% to 10%
55 | y1 = random.randint(0, int(h*0.1))
56 |
57 | img_crop = img[y1:h-y1, x1:w-x1]
58 | label_crop = label[y1:h-y1, x1:w-x1]
59 |
60 | img_crop = cv2.resize(img_crop, self.size)
61 | label_crop = cv2.resize(label_crop, self.size, interpolation=cv2.INTER_NEAREST)
62 | return img_crop, label_crop
63 |
64 | class RandomCrop(object):
65 | '''
66 | This class if for random cropping
67 | '''
68 | def __init__(self, cropArea):
69 | '''
70 | :param cropArea: amount of cropping (in pixels)
71 | '''
72 | self.crop = cropArea
73 |
74 | def __call__(self, img, label):
75 |
76 | if random.random() < 0.5:
77 | h, w = img.shape[:2]
78 | img_crop = img[self.crop:h-self.crop, self.crop:w-self.crop]
79 | label_crop = label[self.crop:h-self.crop, self.crop:w-self.crop]
80 | return img_crop, label_crop
81 | else:
82 | return [img, label]
83 |
84 |
85 |
86 | class RandomFlip(object):
87 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5
88 | """
89 |
90 | def __call__(self, image, label):
91 | if random.random() < 0.5:
92 | x1 = 0#random.randint(0, 1) #if you want to do vertical flip, uncomment this line
93 | if x1 == 0:
94 | image = cv2.flip(image, 0) # horizontal flip
95 | label = cv2.flip(label, 0) # horizontal flip
96 | else:
97 | image = cv2.flip(image, 1) # veritcal flip
98 | label = cv2.flip(label, 1) # veritcal flip
99 | return [image, label]
100 |
101 |
102 | class Normalize(object):
103 | """Given mean: (R, G, B) and std: (R, G, B),
104 | will normalize each channel of the torch.*Tensor, i.e.
105 | channel = (channel - mean) / std
106 | """
107 |
108 | def __init__(self, mean, std):
109 | '''
110 | :param mean: global mean computed from dataset
111 | :param std: global std computed from dataset
112 | '''
113 | self.mean = mean
114 | self.std = std
115 |
116 | def __call__(self, image, label):
117 | image = image.astype(np.float32)
118 | for i in range(3):
119 | image[:,:,i] -= self.mean[i]
120 | for i in range(3):
121 | image[:,:, i] /= self.std[i]
122 |
123 | return [image, label]
124 |
125 | class ToTensor(object):
126 | '''
127 | This class converts the data to tensor so that it can be processed by PyTorch
128 | '''
129 | def __init__(self, scale=1):
130 | '''
131 | :param scale: ESPNet-C's output is 1/8th of original image size, so set this parameter accordingly
132 | '''
133 | self.scale = scale # original images are 2048 x 1024
134 |
135 | def __call__(self, image, label):
136 |
137 | if self.scale != 1:
138 | h, w = label.shape[:2]
139 | image = cv2.resize(image, (int(w), int(h)))
140 | label = cv2.resize(label, (int(w/self.scale), int(h/self.scale)), interpolation=cv2.INTER_NEAREST)
141 |
142 | image = image.transpose((2,0,1))
143 |
144 | image_tensor = torch.from_numpy(image).div(255)
145 | label_tensor = torch.LongTensor(np.array(label, dtype=np.int)) #torch.from_numpy(label)
146 |
147 | return [image_tensor, label_tensor]
148 |
149 | class Compose(object):
150 | """Composes several transforms together.
151 | """
152 |
153 | def __init__(self, transforms):
154 | self.transforms = transforms
155 |
156 | def __call__(self, *args):
157 | for t in self.transforms:
158 | args = t(*args)
159 | return args
160 |
--------------------------------------------------------------------------------
/segmentation/cnn/Model.py:
--------------------------------------------------------------------------------
1 | from torch.nn import init
2 | import torch.nn.functional as F
3 | from cnn.cnn_utils import *
4 | import math
5 | import torch
6 |
7 | __author__ = "Sachin Mehta"
8 | __version__ = "1.0.1"
9 | __maintainer__ = "Sachin Mehta"
10 |
11 | class EESP(nn.Module):
12 | '''
13 | This class defines the EESP block, which is based on the following principle
14 | REDUCE ---> SPLIT ---> TRANSFORM --> MERGE
15 | '''
16 |
17 | def __init__(self, nIn, nOut, stride=1, k=4, r_lim=7, down_method='esp'): #down_method --> ['avg' or 'esp']
18 | '''
19 | :param nIn: number of input channels
20 | :param nOut: number of output channels
21 | :param stride: factor by which we should skip (useful for down-sampling). If 2, then down-samples the feature map by 2
22 | :param k: # of parallel branches
23 | :param r_lim: A maximum value of receptive field allowed for EESP block
24 | :param g: number of groups to be used in the feature map reduction step.
25 | '''
26 | super().__init__()
27 | self.stride = stride
28 | n = int(nOut / k)
29 | n1 = nOut - (k - 1) * n
30 | assert down_method in ['avg', 'esp'], 'One of these is suppported (avg or esp)'
31 | assert n == n1, "n(={}) and n1(={}) should be equal for Depth-wise Convolution ".format(n, n1)
32 | #assert nIn%k == 0, "Number of input channels ({}) should be divisible by # of branches ({})".format(nIn, k)
33 | #assert n % k == 0, "Number of output channels ({}) should be divisible by # of branches ({})".format(n, k)
34 | self.proj_1x1 = CBR(nIn, n, 1, stride=1, groups=k)
35 |
36 | # (For convenience) Mapping between dilation rate and receptive field for a 3x3 kernel
37 | map_receptive_ksize = {3: 1, 5: 2, 7: 3, 9: 4, 11: 5, 13: 6, 15: 7, 17: 8}
38 | self.k_sizes = list()
39 | for i in range(k):
40 | ksize = int(3 + 2 * i)
41 | # After reaching the receptive field limit, fall back to the base kernel size of 3 with a dilation rate of 1
42 | ksize = ksize if ksize <= r_lim else 3
43 | self.k_sizes.append(ksize)
44 | # sort (in ascending order) these kernel sizes based on their receptive field
45 | # This enables us to ignore the kernels (3x3 in our case) with the same effective receptive field in hierarchical
46 | # feature fusion because kernels with 3x3 receptive fields does not have gridding artifact.
47 | self.k_sizes.sort()
48 | self.spp_dw = nn.ModuleList()
49 | #self.bn = nn.ModuleList()
50 | for i in range(k):
51 | d_rate = map_receptive_ksize[self.k_sizes[i]]
52 | self.spp_dw.append(CDilated(n, n, kSize=3, stride=stride, groups=n, d=d_rate))
53 | #self.bn.append(nn.BatchNorm2d(n))
54 | self.conv_1x1_exp = CB(nOut, nOut, 1, 1, groups=k)
55 | self.br_after_cat = BR(nOut)
56 | self.module_act = nn.PReLU(nOut)
57 | self.downAvg = True if down_method == 'avg' else False
58 |
59 | def forward(self, input):
60 | '''
61 | :param input: input feature map
62 | :return: transformed feature map
63 | '''
64 |
65 | # Reduce --> project high-dimensional feature maps to low-dimensional space
66 | output1 = self.proj_1x1(input)
67 | output = [self.spp_dw[0](output1)]
68 | # compute the output for each branch and hierarchically fuse them
69 | # i.e. Split --> Transform --> HFF
70 | for k in range(1, len(self.spp_dw)):
71 | out_k = self.spp_dw[k](output1)
72 | # HFF
73 | # We donot combine the branches that have the same effective receptive (3x3 in our case)
74 | # because there are no holes in those kernels.
75 | out_k = out_k + output[k - 1]
76 | #apply batch norm after fusion and then append to the list
77 | output.append(out_k)
78 | # Merge
79 | expanded = self.conv_1x1_exp( # Aggregate the feature maps using point-wise convolution
80 | self.br_after_cat( # apply batch normalization followed by activation function (PRelu in this case)
81 | torch.cat(output, 1) # concatenate the output of different branches
82 | )
83 | )
84 | del output
85 | # if down-sampling, then return the concatenated vector
86 | # as Downsampling function will combine it with avg. pooled feature map and then threshold it
87 | if self.stride == 2 and self.downAvg:
88 | return expanded
89 |
90 | # if dimensions of input and concatenated vector are the same, add them (RESIDUAL LINK)
91 | if expanded.size() == input.size():
92 | expanded = expanded + input
93 |
94 | # Threshold the feature map using activation function (PReLU in this case)
95 | return self.module_act(expanded)
96 |
97 |
98 | class DownSampler(nn.Module):
99 | '''
100 | Down-sampling fucntion that has two parallel branches: (1) avg pooling
101 | and (2) EESP block with stride of 2. The output feature maps of these branches
102 | are then concatenated and thresholded using an activation function (PReLU in our
103 | case) to produce the final output.
104 | '''
105 |
106 | def __init__(self, nin, nout, k=4, r_lim=9, reinf=True):
107 | '''
108 | :param nin: number of input channels
109 | :param nout: number of output channels
110 | :param k: # of parallel branches
111 | :param r_lim: A maximum value of receptive field allowed for EESP block
112 | :param g: number of groups to be used in the feature map reduction step.
113 | '''
114 | super().__init__()
115 | nout_new = nout - nin
116 | self.eesp = EESP(nin, nout_new, stride=2, k=k, r_lim=r_lim, down_method='avg')
117 | self.avg = nn.AvgPool2d(kernel_size=3, padding=1, stride=2)
118 | if reinf:
119 | self.inp_reinf = nn.Sequential(
120 | CBR(config_inp_reinf, config_inp_reinf, 3, 1),
121 | CB(config_inp_reinf, nout, 1, 1)
122 | )
123 | self.act = nn.PReLU(nout)
124 |
125 | def forward(self, input, input2=None):
126 | '''
127 | :param input: input feature map
128 | :return: feature map down-sampled by a factor of 2
129 | '''
130 | avg_out = self.avg(input)
131 | eesp_out = self.eesp(input)
132 | output = torch.cat([avg_out, eesp_out], 1)
133 | if input2 is not None:
134 | #assuming the input is a square image
135 | w1 = avg_out.size(2)
136 | while True:
137 | input2 = F.avg_pool2d(input2, kernel_size=3, padding=1, stride=2)
138 | w2 = input2.size(2)
139 | if w2 == w1:
140 | break
141 | output = output + self.inp_reinf(input2)
142 |
143 | return self.act(output) #self.act(output)
144 |
145 | class EESPNet(nn.Module):
146 | '''
147 | This class defines the ESPNetv2 architecture for the ImageNet classification
148 | '''
149 |
150 | def __init__(self, classes=20, s=1):
151 | '''
152 | :param classes: number of classes in the dataset. Default is 20 for the cityscapes
153 | :param s: factor that scales the number of output feature maps
154 | '''
155 | super().__init__()
156 | reps = [0, 3, 7, 3] # how many times EESP blocks should be repeated.
157 | channels = 3
158 |
159 | r_lim = [13, 11, 9, 7, 5] # receptive field at each spatial level
160 | K = [4]*len(r_lim) # No. of parallel branches at different levels
161 |
162 | base = 32 #base configuration
163 | config_len = 5
164 | config = [base] * config_len
165 | base_s = 0
166 | for i in range(config_len):
167 | if i== 0:
168 | base_s = int(base * s)
169 | base_s = math.ceil(base_s / K[0]) * K[0]
170 | config[i] = base if base_s > base else base_s
171 | else:
172 | config[i] = base_s * pow(2, i)
173 | if s <= 1.5:
174 | config.append(1024)
175 | elif s in [1.5, 2]:
176 | config.append(1280)
177 | else:
178 | ValueError('Configuration not supported')
179 |
180 | #print('Config: ', config)
181 |
182 | global config_inp_reinf
183 | config_inp_reinf = 3
184 | self.input_reinforcement = True
185 | assert len(K) == len(r_lim), 'Length of branching factor array and receptive field array should be the same.'
186 |
187 | self.level1 = CBR(channels, config[0], 3, 2) # 112 L1
188 |
189 | self.level2_0 = DownSampler(config[0], config[1], k=K[0], r_lim=r_lim[0], reinf=self.input_reinforcement) # out = 56
190 | self.level3_0 = DownSampler(config[1], config[2], k=K[1], r_lim=r_lim[1], reinf=self.input_reinforcement) # out = 28
191 | self.level3 = nn.ModuleList()
192 | for i in range(reps[1]):
193 | self.level3.append(EESP(config[2], config[2], stride=1, k=K[2], r_lim=r_lim[2]))
194 |
195 | self.level4_0 = DownSampler(config[2], config[3], k=K[2], r_lim=r_lim[2], reinf=self.input_reinforcement) #out = 14
196 | self.level4 = nn.ModuleList()
197 | for i in range(reps[2]):
198 | self.level4.append(EESP(config[3], config[3], stride=1, k=K[3], r_lim=r_lim[3]))
199 |
200 | self.level5_0 = DownSampler(config[3], config[4], k=K[3], r_lim=r_lim[3]) #7
201 | self.level5 = nn.ModuleList()
202 | for i in range(reps[3]):
203 | self.level5.append(EESP(config[4], config[4], stride=1, k=K[4], r_lim=r_lim[4]))
204 |
205 | # expand the feature maps using depth-wise separable convolution
206 | self.level5.append(CBR(config[4], config[4], 3, 1, groups=config[4]))
207 | self.level5.append(CBR(config[4], config[5], 1, 1, groups=K[4]))
208 |
209 |
210 |
211 | #self.level5_exp = nn.ModuleList()
212 | #assert config[5]%config[4] == 0, '{} should be divisible by {}'.format(config[5], config[4])
213 | #gr = int(config[5]/config[4])
214 | #for i in range(gr):
215 | # self.level5_exp.append(CBR(config[4], config[4], 1, 1, groups=pow(2, i)))
216 |
217 | self.classifier = nn.Linear(config[5], classes)
218 | self.init_params()
219 |
220 | def init_params(self):
221 | '''
222 | Function to initialze the parameters
223 | '''
224 | for m in self.modules():
225 | if isinstance(m, nn.Conv2d):
226 | init.kaiming_normal_(m.weight, mode='fan_out')
227 | if m.bias is not None:
228 | init.constant_(m.bias, 0)
229 | elif isinstance(m, nn.BatchNorm2d):
230 | init.constant_(m.weight, 1)
231 | init.constant_(m.bias, 0)
232 | elif isinstance(m, nn.Linear):
233 | init.normal_(m.weight, std=0.001)
234 | if m.bias is not None:
235 | init.constant_(m.bias, 0)
236 |
237 | def forward(self, input, p=0.2, seg=True):
238 | '''
239 | :param input: Receives the input RGB image
240 | :return: a C-dimensional vector, C=# of classes
241 | '''
242 | out_l1 = self.level1(input) # 112
243 | if not self.input_reinforcement:
244 | del input
245 | input = None
246 |
247 | out_l2 = self.level2_0(out_l1, input) # 56
248 |
249 | out_l3_0 = self.level3_0(out_l2, input) # out_l2_inp_rein
250 | for i, layer in enumerate(self.level3):
251 | if i == 0:
252 | out_l3 = layer(out_l3_0)
253 | else:
254 | out_l3 = layer(out_l3)
255 |
256 | out_l4_0 = self.level4_0(out_l3, input) # down-sampled
257 | for i, layer in enumerate(self.level4):
258 | if i == 0:
259 | out_l4 = layer(out_l4_0)
260 | else:
261 | out_l4 = layer(out_l4)
262 |
263 | if not seg:
264 | out_l5_0 = self.level5_0(out_l4) # down-sampled
265 | for i, layer in enumerate(self.level5):
266 | if i == 0:
267 | out_l5 = layer(out_l5_0)
268 | else:
269 | out_l5 = layer(out_l5)
270 |
271 | #out_e = []
272 | #for layer in self.level5_exp:
273 | # out_e.append(layer(out_l5))
274 | #out_exp = torch.cat(out_e, dim=1)
275 |
276 |
277 |
278 | output_g = F.adaptive_avg_pool2d(out_l5, output_size=1)
279 | output_g = F.dropout(output_g, p=p, training=self.training)
280 | output_1x1 = output_g.view(output_g.size(0), -1)
281 |
282 | return self.classifier(output_1x1)
283 | return out_l1, out_l2, out_l3, out_l4
284 |
285 |
286 |
--------------------------------------------------------------------------------
/segmentation/cnn/SegmentationModel.py:
--------------------------------------------------------------------------------
1 | #============================================
2 | __author__ = "Sachin Mehta"
3 | __license__ = "MIT"
4 | __maintainer__ = "Sachin Mehta"
5 | #============================================
6 | import torch
7 | from torch import nn
8 |
9 | from cnn.Model import EESPNet, EESP
10 | import os
11 | import torch.nn.functional as F
12 | from cnn.cnn_utils import *
13 |
14 | class EESPNet_Seg(nn.Module):
15 | def __init__(self, classes=20, s=1, pretrained=None, gpus=1):
16 | super().__init__()
17 | classificationNet = EESPNet(classes=1000, s=s)
18 | if gpus >=1:
19 | classificationNet = nn.DataParallel(classificationNet)
20 | # load the pretrained weights
21 | if pretrained:
22 | if not os.path.isfile(pretrained):
23 | print('Weight file does not exist. Training without pre-trained weights')
24 | print('Model initialized with pretrained weights')
25 | classificationNet.load_state_dict(torch.load(pretrained))
26 |
27 | self.net = classificationNet.module
28 |
29 | del classificationNet
30 | # delete last few layers
31 | del self.net.classifier
32 | del self.net.level5
33 | del self.net.level5_0
34 | if s <=0.5:
35 | p = 0.1
36 | else:
37 | p=0.2
38 |
39 | self.proj_L4_C = CBR(self.net.level4[-1].module_act.num_parameters, self.net.level3[-1].module_act.num_parameters, 1, 1)
40 | pspSize = 2*self.net.level3[-1].module_act.num_parameters
41 | self.pspMod = nn.Sequential(EESP(pspSize, pspSize //2, stride=1, k=4, r_lim=7),
42 | PSPModule(pspSize // 2, pspSize //2))
43 | self.project_l3 = nn.Sequential(nn.Dropout2d(p=p), C(pspSize // 2, classes, 1, 1))
44 | self.act_l3 = BR(classes)
45 | self.project_l2 = CBR(self.net.level2_0.act.num_parameters + classes, classes, 1, 1)
46 | self.project_l1 = nn.Sequential(nn.Dropout2d(p=p), C(self.net.level1.act.num_parameters + classes, classes, 1, 1))
47 |
48 | def hierarchicalUpsample(self, x, factor=3):
49 | for i in range(factor):
50 | x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
51 | return x
52 |
53 |
54 | def forward(self, input):
55 | out_l1, out_l2, out_l3, out_l4 = self.net(input, seg=True)
56 | out_l4_proj = self.proj_L4_C(out_l4)
57 | up_l4_to_l3 = F.interpolate(out_l4_proj, scale_factor=2, mode='bilinear', align_corners=True)
58 | merged_l3_upl4 = self.pspMod(torch.cat([out_l3, up_l4_to_l3], 1))
59 | proj_merge_l3_bef_act = self.project_l3(merged_l3_upl4)
60 | proj_merge_l3 = self.act_l3(proj_merge_l3_bef_act)
61 | out_up_l3 = F.interpolate(proj_merge_l3, scale_factor=2, mode='bilinear', align_corners=True)
62 | merge_l2 = self.project_l2(torch.cat([out_l2, out_up_l3], 1))
63 | out_up_l2 = F.interpolate(merge_l2, scale_factor=2, mode='bilinear', align_corners=True)
64 | merge_l1 = self.project_l1(torch.cat([out_l1, out_up_l2], 1))
65 | if self.training:
66 | return F.interpolate(merge_l1, scale_factor=2, mode='bilinear', align_corners=True), self.hierarchicalUpsample(proj_merge_l3_bef_act)
67 | else:
68 | return F.interpolate(merge_l1, scale_factor=2, mode='bilinear', align_corners=True)
69 |
70 |
71 | if __name__ == '__main__':
72 | input = torch.Tensor(1, 3, 512, 1024).cuda()
73 | net = EESPNet_Seg(classes=20, s=2).cuda()
74 | out_x_8 = net(input)
75 | print(out_x_8.size())
76 |
77 |
--------------------------------------------------------------------------------
/segmentation/cnn/cnn_utils.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torch.nn.functional as F
4 |
5 |
6 | __author__ = "Sachin Mehta"
7 | __version__ = "1.0.1"
8 | __maintainer__ = "Sachin Mehta"
9 |
10 |
11 | class PSPModule(nn.Module):
12 | def __init__(self, features, out_features=1024, sizes=(1, 2, 4, 8)):
13 | super().__init__()
14 | self.stages = []
15 | self.stages = nn.ModuleList([C(features, features, 3, 1, groups=features) for size in sizes])
16 | self.project = CBR(features * (len(sizes) + 1), out_features, 1, 1)
17 |
18 | def forward(self, feats):
19 | h, w = feats.size(2), feats.size(3)
20 | out = [feats]
21 | for stage in self.stages:
22 | feats = F.avg_pool2d(feats, kernel_size=3, stride=2, padding=1)
23 | upsampled = F.interpolate(input=stage(feats), size=(h, w), mode='bilinear', align_corners=True)
24 | out.append(upsampled)
25 | return self.project(torch.cat(out, dim=1))
26 |
27 | class CBR(nn.Module):
28 | '''
29 | This class defines the convolution layer with batch normalization and PReLU activation
30 | '''
31 |
32 | def __init__(self, nIn, nOut, kSize, stride=1, groups=1):
33 | '''
34 |
35 | :param nIn: number of input channels
36 | :param nOut: number of output channels
37 | :param kSize: kernel size
38 | :param stride: stride rate for down-sampling. Default is 1
39 | '''
40 | super().__init__()
41 | padding = int((kSize - 1) / 2)
42 | self.conv = nn.Conv2d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False, groups=groups)
43 | self.bn = nn.BatchNorm2d(nOut)
44 | self.act = nn.PReLU(nOut)
45 |
46 | def forward(self, input):
47 | '''
48 | :param input: input feature map
49 | :return: transformed feature map
50 | '''
51 | output = self.conv(input)
52 | # output = self.conv1(output)
53 | output = self.bn(output)
54 | output = self.act(output)
55 | return output
56 |
57 |
58 | class BR(nn.Module):
59 | '''
60 | This class groups the batch normalization and PReLU activation
61 | '''
62 |
63 | def __init__(self, nOut):
64 | '''
65 | :param nOut: output feature maps
66 | '''
67 | super().__init__()
68 | self.bn = nn.BatchNorm2d(nOut)
69 | self.act = nn.PReLU(nOut)
70 |
71 | def forward(self, input):
72 | '''
73 | :param input: input feature map
74 | :return: normalized and thresholded feature map
75 | '''
76 | output = self.bn(input)
77 | output = self.act(output)
78 | return output
79 |
80 |
81 | class CB(nn.Module):
82 | '''
83 | This class groups the convolution and batch normalization
84 | '''
85 |
86 | def __init__(self, nIn, nOut, kSize, stride=1, groups=1):
87 | '''
88 | :param nIn: number of input channels
89 | :param nOut: number of output channels
90 | :param kSize: kernel size
91 | :param stride: optinal stide for down-sampling
92 | '''
93 | super().__init__()
94 | padding = int((kSize - 1) / 2)
95 | self.conv = nn.Conv2d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False,
96 | groups=groups)
97 | self.bn = nn.BatchNorm2d(nOut)
98 |
99 | def forward(self, input):
100 | '''
101 |
102 | :param input: input feature map
103 | :return: transformed feature map
104 | '''
105 | output = self.conv(input)
106 | output = self.bn(output)
107 | return output
108 |
109 |
110 | class C(nn.Module):
111 | '''
112 | This class is for a convolutional layer.
113 | '''
114 |
115 | def __init__(self, nIn, nOut, kSize, stride=1, groups=1):
116 | '''
117 |
118 | :param nIn: number of input channels
119 | :param nOut: number of output channels
120 | :param kSize: kernel size
121 | :param stride: optional stride rate for down-sampling
122 | '''
123 | super().__init__()
124 | padding = int((kSize - 1) / 2)
125 | self.conv = nn.Conv2d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False,
126 | groups=groups)
127 |
128 | def forward(self, input):
129 | '''
130 | :param input: input feature map
131 | :return: transformed feature map
132 | '''
133 | output = self.conv(input)
134 | return output
135 |
136 |
137 | class CDilated(nn.Module):
138 | '''
139 | This class defines the dilated convolution.
140 | '''
141 |
142 | def __init__(self, nIn, nOut, kSize, stride=1, d=1, groups=1):
143 | '''
144 | :param nIn: number of input channels
145 | :param nOut: number of output channels
146 | :param kSize: kernel size
147 | :param stride: optional stride rate for down-sampling
148 | :param d: optional dilation rate
149 | '''
150 | super().__init__()
151 | padding = int((kSize - 1) / 2) * d
152 | self.conv = nn.Conv2d(nIn, nOut,kSize, stride=stride, padding=padding, bias=False,
153 | dilation=d, groups=groups)
154 |
155 | def forward(self, input):
156 | '''
157 | :param input: input feature map
158 | :return: transformed feature map
159 | '''
160 | output = self.conv(input)
161 | return output
162 |
163 | class CDilatedB(nn.Module):
164 | '''
165 | This class defines the dilated convolution with batch normalization.
166 | '''
167 |
168 | def __init__(self, nIn, nOut, kSize, stride=1, d=1, groups=1):
169 | '''
170 | :param nIn: number of input channels
171 | :param nOut: number of output channels
172 | :param kSize: kernel size
173 | :param stride: optional stride rate for down-sampling
174 | :param d: optional dilation rate
175 | '''
176 | super().__init__()
177 | padding = int((kSize - 1) / 2) * d
178 | self.conv = nn.Conv2d(nIn, nOut,kSize, stride=stride, padding=padding, bias=False,
179 | dilation=d, groups=groups)
180 | self.bn = nn.BatchNorm2d(nOut)
181 |
182 | def forward(self, input):
183 | '''
184 | :param input: input feature map
185 | :return: transformed feature map
186 | '''
187 | return self.bn(self.conv(input))
188 |
--------------------------------------------------------------------------------
/segmentation/gen_cityscapes.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import glob
4 |
5 | import cv2
6 | import os
7 | from argparse import ArgumentParser
8 | from cnn import SegmentationModel as net
9 | from torch import nn
10 |
11 |
12 | #============================================
13 | __author__ = "Sachin Mehta"
14 | __license__ = "MIT"
15 | __maintainer__ = "Sachin Mehta"
16 | #============================================
17 |
18 | pallete = [[128, 64, 128],
19 | [244, 35, 232],
20 | [70, 70, 70],
21 | [102, 102, 156],
22 | [190, 153, 153],
23 | [153, 153, 153],
24 | [250, 170, 30],
25 | [220, 220, 0],
26 | [107, 142, 35],
27 | [152, 251, 152],
28 | [70, 130, 180],
29 | [220, 20, 60],
30 | [255, 0, 0],
31 | [0, 0, 142],
32 | [0, 0, 70],
33 | [0, 60, 100],
34 | [0, 80, 100],
35 | [0, 0, 230],
36 | [119, 11, 32],
37 | [0, 0, 0]]
38 |
39 |
40 | def relabel(img):
41 | '''
42 | This function relabels the predicted labels so that cityscape dataset can process
43 | :param img:
44 | :return:
45 | '''
46 | img[img == 19] = 255
47 | img[img == 18] = 33
48 | img[img == 17] = 32
49 | img[img == 16] = 31
50 | img[img == 15] = 28
51 | img[img == 14] = 27
52 | img[img == 13] = 26
53 | img[img == 12] = 25
54 | img[img == 11] = 24
55 | img[img == 10] = 23
56 | img[img == 9] = 22
57 | img[img == 8] = 21
58 | img[img == 7] = 20
59 | img[img == 6] = 19
60 | img[img == 5] = 17
61 | img[img == 4] = 13
62 | img[img == 3] = 12
63 | img[img == 2] = 11
64 | img[img == 1] = 8
65 | img[img == 0] = 7
66 | img[img == 255] = 0
67 | return img
68 |
69 |
70 | def evaluateModel(args, model, image_list):
71 | # gloabl mean and std values
72 | mean = [72.3923111, 82.90893555, 73.15840149]
73 | std = [45.3192215, 46.15289307, 44.91483307]
74 |
75 | model.eval()
76 | for i, imgName in enumerate(image_list):
77 | img = cv2.imread(imgName)
78 | if args.overlay:
79 | img_orig = np.copy(img)
80 |
81 | img = img.astype(np.float32)
82 | for j in range(3):
83 | img[:, :, j] -= mean[j]
84 | for j in range(3):
85 | img[:, :, j] /= std[j]
86 |
87 | # resize the image to 1024x512x3
88 | img = cv2.resize(img, (args.inWidth, args.inHeight))
89 | if args.overlay:
90 | img_orig = cv2.resize(img_orig, (args.inWidth, args.inHeight))
91 |
92 | img /= 255
93 | img = img.transpose((2, 0, 1))
94 | img_tensor = torch.from_numpy(img)
95 | img_tensor = torch.unsqueeze(img_tensor, 0) # add a batch dimension
96 | if args.gpu:
97 | img_tensor = img_tensor.cuda()
98 | img_out = model(img_tensor)
99 |
100 | classMap_numpy = img_out[0].max(0)[1].byte().cpu().data.numpy()
101 | # upsample the feature maps to the same size as the input image using Nearest neighbour interpolation
102 | # upsample the feature map from 1024x512 to 2048x1024
103 | classMap_numpy = cv2.resize(classMap_numpy, (args.inWidth*2, args.inHeight*2), interpolation=cv2.INTER_NEAREST)
104 | if i % 100 == 0 and i > 0:
105 | print('Processed [{}/{}]'.format(i, len(image_list)))
106 |
107 | name = imgName.split('/')[-1]
108 |
109 | if args.colored:
110 | classMap_numpy_color = np.zeros((img.shape[1], img.shape[2], img.shape[0]), dtype=np.uint8)
111 | for idx in range(len(pallete)):
112 | [r, g, b] = pallete[idx]
113 | classMap_numpy_color[classMap_numpy == idx] = [b, g, r]
114 | cv2.imwrite(args.savedir + os.sep + 'c_' + name.replace(args.img_extn, 'png'), classMap_numpy_color)
115 | if args.overlay:
116 | overlayed = cv2.addWeighted(img_orig, 0.5, classMap_numpy_color, 0.5, 0)
117 | cv2.imwrite(args.savedir + os.sep + 'over_' + name.replace(args.img_extn, 'jpg'), overlayed)
118 |
119 | if args.cityFormat:
120 | classMap_numpy = relabel(classMap_numpy.astype(np.uint8))
121 |
122 |
123 | cv2.imwrite(args.savedir + os.sep + name.replace(args.img_extn, 'png'), classMap_numpy)
124 |
125 |
126 | def main(args):
127 | # read all the images in the folder
128 | image_list = glob.glob(args.data_dir + os.sep + '*.' + args.img_extn)
129 |
130 | modelA = net.EESPNet_Seg(args.classes, s=args.s)
131 | if not os.path.isfile(args.pretrained):
132 | print('Pre-trained model file does not exist. Please check ./pretrained_models folder')
133 | exit(-1)
134 | modelA = nn.DataParallel(modelA)
135 | modelA.load_state_dict(torch.load(args.pretrained))
136 | if args.gpu:
137 | modelA = modelA.cuda()
138 |
139 | # set to evaluation mode
140 | modelA.eval()
141 |
142 | if not os.path.isdir(args.savedir):
143 | os.mkdir(args.savedir)
144 |
145 | evaluateModel(args, modelA, image_list)
146 |
147 |
148 | if __name__ == '__main__':
149 | parser = ArgumentParser()
150 | parser.add_argument('--model', default="ESPNetv2", help='Model name')
151 | parser.add_argument('--data_dir', default="./data", help='Data directory')
152 | parser.add_argument('--img_extn', default="png", help='RGB Image format')
153 | parser.add_argument('--inWidth', type=int, default=1024, help='Width of RGB image')
154 | parser.add_argument('--inHeight', type=int, default=512, help='Height of RGB image')
155 | parser.add_argument('--savedir', default='./results', help='directory to save the results')
156 | parser.add_argument('--gpu', default=True, type=bool, help='Run on CPU or GPU. If TRUE, then GPU.')
157 | parser.add_argument('--pretrained', default='', help='Pretrained weights directory.')
158 | parser.add_argument('--s', default=0.5, type=float, help='scale')
159 | parser.add_argument('--cityFormat', default=True, type=bool, help='If you want to convert to cityscape '
160 | 'original label ids')
161 | parser.add_argument('--colored', default=False, type=bool, help='If you want to visualize the '
162 | 'segmentation masks in color')
163 | parser.add_argument('--overlay', default=False, type=bool, help='If you want to visualize the '
164 | 'segmentation masks overlayed on top of RGB image')
165 | parser.add_argument('--classes', default=20, type=int, help='Number of classes in the dataset. 20 for Cityscapes')
166 |
167 | args = parser.parse_args()
168 | if args.overlay:
169 | args.colored = True # This has to be true if you want to overlay
170 | main(args)
171 |
--------------------------------------------------------------------------------
/segmentation/loadData.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import pickle
4 |
5 |
6 | #============================================
7 | __author__ = "Sachin Mehta"
8 | __license__ = "MIT"
9 | __maintainer__ = "Sachin Mehta"
10 | #============================================
11 |
12 | class LoadData:
13 | '''
14 | Class to laod the data
15 | '''
16 | def __init__(self, data_dir, classes, cached_data_file, normVal=1.10):
17 | '''
18 | :param data_dir: directory where the dataset is kept
19 | :param classes: number of classes in the dataset
20 | :param cached_data_file: location where cached file has to be stored
21 | :param normVal: normalization value, as defined in ERFNet paper
22 | '''
23 | self.data_dir = data_dir
24 | self.classes = classes
25 | self.classWeights = np.ones(self.classes, dtype=np.float32)
26 | self.normVal = normVal
27 | self.mean = np.zeros(3, dtype=np.float32)
28 | self.std = np.zeros(3, dtype=np.float32)
29 | self.trainImList = list()
30 | self.valImList = list()
31 | self.trainAnnotList = list()
32 | self.valAnnotList = list()
33 | self.cached_data_file = cached_data_file
34 |
35 | def compute_class_weights(self, histogram):
36 | '''
37 | Helper function to compute the class weights
38 | :param histogram: distribution of class samples
39 | :return: None, but updates the classWeights variable
40 | '''
41 | normHist = histogram / np.sum(histogram)
42 | for i in range(self.classes):
43 | self.classWeights[i] = 1 / (np.log(self.normVal + normHist[i]))
44 |
45 | def readFile(self, fileName, trainStg=False):
46 | '''
47 | Function to read the data
48 | :param fileName: file that stores the image locations
49 | :param trainStg: if processing training or validation data
50 | :return: 0 if successful
51 | '''
52 | if trainStg == True:
53 | global_hist = np.zeros(self.classes, dtype=np.float32)
54 |
55 | no_files = 0
56 | min_val_al = 0
57 | max_val_al = 0
58 | with open(self.data_dir + '/' + fileName, 'r') as textFile:
59 | for line in textFile:
60 | # we expect the text file to contain the data in following format
61 | # ,