├── .idea ├── ESPNetv2.iml ├── misc.xml ├── modules.xml ├── vcs.xml └── workspace.xml ├── LICENSE ├── README.md ├── imagenet ├── LRSchedule.py ├── Model.py ├── README.md ├── cnn_utils.py ├── evaluate.py ├── main.py ├── pretrained_weights │ ├── espnetv2_s_0.5.pth │ ├── espnetv2_s_1.0.pth │ ├── espnetv2_s_1.25.pth │ ├── espnetv2_s_1.5.pth │ └── espnetv2_s_2.0.pth └── utils.py ├── images ├── ReadMe.md ├── effCompare.png └── powerTX2.png └── segmentation ├── DataSet.py ├── IOUEval.py ├── README.md ├── Transforms.py ├── cnn ├── Model.py ├── SegmentationModel.py └── cnn_utils.py ├── gen_cityscapes.py ├── loadData.py ├── main.py └── train_utils.py /.idea/ESPNetv2.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Sachin Mehta 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ESPNetv2: A Light-weight, Power Efficient, and General Purpose Convolutional Neural Network 2 | 3 | **IMPORTANT NOTE 1 (7 June, 2019)**: We have released new code base that supports several datasets and models, including ESPNetv2. Please see [here](https://github.com/sacmehta/EdgeNets) for more details. 4 | 5 | **IMPORTANT NOTE 2 (7 June, 2019)**: This repository is obsolete and we are not maintaining it anymore. 6 | 7 | This repository contains the source code of our paper, [ESPNetv2](https://arxiv.org/abs/1811.11431) which is accepted for publication at CVPR'19. 8 | 9 | ***Note:*** New segmentation models for the PASCAL VOC and the Cityscapes are coming soon. Our new models achieves mIOU of [68.0](http://host.robots.ox.ac.uk:8080/anonymous/DAMVRR.html) and [66.15](https://www.cityscapes-dataset.com/anonymous-results/?id=2267c613d55dd75d5301850c913b1507bf2f10586ca73eb8ebcf357cdcf3e036) on the PASCAL VOC and the Cityscapes test sets, respectively. 10 | 11 | 12 | 13 | 14 | 15 | 16 | 19 | 22 | 23 |
Real-time semantic segmentation using ESPNetv2 on iPhone7 (see EdgeNets for details)
17 | Seg demo on iPhone7 18 | 20 | Seg demo on iPhone7 21 |
24 | 25 | ## Comparison with SOTA methods 26 | Compared to state-of-the-art efficient networks, our network delivers competitive performance while being much more **power efficient**. Sample results are shown in below figure. For more details, please read our paper. 27 | 28 | 29 | 30 | 33 | 36 | 37 | 38 | 41 | 44 | 45 |
31 | 32 | 34 | 35 |
39 |

FLOPs vs. accuracy on the ImageNet dataset

40 |
42 |

Power consumption on TX2 device 43 |

46 | 47 | 48 | 49 | If you find our project useful in your research, please consider citing: 50 | 51 | ``` 52 | @inproceedings{mehta2019espnetv2, 53 | title={ESPNetv2: A Light-weight, Power Efficient, and General Purpose Convolutional Neural Network}, 54 | author={Sachin Mehta and Mohammad Rastegari and Linda Shapiro and Hannaneh Hajishirzi}, 55 | booktitle={CVPR}, 56 | year={2019} 57 | } 58 | 59 | @inproceedings{mehta2018espnet, 60 | title={ESPNet: Efficient Spatial Pyramid of Dilated Convolutions for Semantic Segmentation}, 61 | author={Sachin Mehta and Mohammad Rastegari and Anat Caspi and Linda Shapiro and Hannaneh Hajishirzi}, 62 | booktitle={ECCV}, 63 | year={2018} 64 | } 65 | ``` 66 | 67 | ## Structure 68 | This repository contains source code and pretrained for the following: 69 | * **Object classification:** We provide source code along with pre-trained models at different network complexities 70 | for the ImageNet dataset. Click [here](imagenet) for more details. 71 | * **Semantic segmentation:** We provide source code along with pre-trained models on the Cityscapes dataset. Check [here](segmentation) for more details. 72 | 73 | ## Requirements 74 | 75 | To run this repository, you should have following softwares installed: 76 | * PyTorch - We tested with v0.4.1 77 | * OpenCV - We tested with version 3.4.3 78 | * Python3 - Our code is written in Python3. We recommend to use [Anaconda](https://www.anaconda.com/) for the same. 79 | 80 | ## Instructions to install Pytorch and OpenCV with Anaconda 81 | 82 | Assuming that you have installed Anaconda successfully, you can follow the following instructions to install the packeges: 83 | 84 | ### PyTorch 85 | ``` 86 | conda install pytorch torchvision -c pytorch 87 | ``` 88 | 89 | Once installed, run the following commands in your terminal to verify the version: 90 | ``` 91 | import torch 92 | torch.__version__ 93 | ``` 94 | This should print something like this `0.4.1.post2`. 95 | 96 | If your version is different, then follow PyTorch website [here](https://pytorch.org/) for more details. 97 | 98 | ### OpenCV 99 | ``` 100 | conda install pip 101 | pip install --upgrade pip 102 | pip install opencv-python 103 | ``` 104 | 105 | Once installed, run the following commands in your terminal to verify the version: 106 | ``` 107 | import cv2 108 | cv2.__version__ 109 | ``` 110 | This should print something like this `3.4.3`. 111 | 112 | 113 | ## Implementation note 114 | 115 | You will see that `EESP` unit, the core building block of the ESPNetv2 architecture, has a `for` loop to process the input at different dilation rates. 116 | You can parallelize it using **Streams** in PyTorch. It improves the inference speed. 117 | 118 | A snippet to parallelize a `for` loop in pytorch is shown below: 119 | ``` 120 | # Sequential version 121 | output = [] 122 | a = torch.randn(1, 3, 10, 10) 123 | for i in range(4): 124 | output.append(a) 125 | torch.cat(output, 1) 126 | ``` 127 | 128 | ``` 129 | # Parallel version 130 | num_branches = 4 131 | streams = [(idx, torch.cuda.Stream()) for idx in range(num_branches)] 132 | output = [] 133 | a = torch.randn(1, 3, 10, 10) 134 | for idx, s in streams: 135 | with torch.cuda.stream(s): 136 | output.append(a) 137 | torch.cuda.synchronize() 138 | torch.cat(output, 1) 139 | ``` 140 | 141 | **Note:** 142 | * we have used above strategy to measure inference related statistics, including power consumption and run time on a single GPU. 143 | * We have not tested it (for training as well as inference) across multiple GPUs. If you want to use Streams and facing issues, please use PyTorch forums to resolve your queries. 144 | -------------------------------------------------------------------------------- /imagenet/LRSchedule.py: -------------------------------------------------------------------------------- 1 | 2 | #============================================ 3 | __author__ = "Sachin Mehta" 4 | __license__ = "MIT" 5 | __maintainer__ = "Sachin Mehta" 6 | #============================================ 7 | 8 | class MyLRScheduler(object): 9 | ''' 10 | CLass that defines cyclic learning rate that decays the learning rate linearly till the end of cycle and then restarts 11 | at the maximum value. 12 | ''' 13 | def __init__(self, initial=0.1, cycle_len=5, steps=[51, 101, 131, 161, 191, 221, 251, 281], gamma=2): 14 | super(MyLRScheduler, self).__init__() 15 | assert len(steps) > 1, 'Please specify step intervals.' 16 | self.min_lr = initial # minimum learning rate 17 | self.m = cycle_len 18 | self.steps = steps 19 | self.warm_up_interval = 1 # we do not start from max value for the first epoch, because some time it diverges 20 | self.counter = 0 21 | self.decayFactor = gamma # factor by which we should decay learning rate 22 | self.count_cycles = 0 23 | self.step_counter = 0 24 | self.stepping = True 25 | print('Using Cyclic LR Scheduler with warm restarts') 26 | 27 | def get_lr(self, epoch): 28 | if epoch%self.steps[self.step_counter] == 0 and epoch > 1 and self.stepping: 29 | self.min_lr = self.min_lr / self.decayFactor 30 | self.count_cycles = 0 31 | if self.step_counter < len(self.steps) - 1: 32 | self.step_counter += 1 33 | else: 34 | self.stepping = False 35 | current_lr = self.min_lr 36 | # warm-up or cool-down phase 37 | if self.count_cycles < self.warm_up_interval: 38 | self.count_cycles += 1 39 | # We do not need warm up after first step. 40 | # so, we set warm up interval to 0 after first step 41 | if self.count_cycles == self.warm_up_interval: 42 | self.warm_up_interval = 0 43 | else: 44 | #Cyclic learning rate with warm restarts 45 | # max_lr (= min_lr * step_size) is decreased to min_lr using linear decay before 46 | # it is set to max value at the end of cycle. 47 | if self.counter >= self.m: 48 | self.counter = 0 49 | current_lr = round((self.min_lr * self.m) - (self.counter * self.min_lr), 5) 50 | self.counter += 1 51 | self.count_cycles += 1 52 | return current_lr 53 | 54 | 55 | if __name__ == '__main__': 56 | lrSched = MyLRScheduler(0.1)#MyLRScheduler(0.1, 5, [51, 101, 131, 161, 191, 221, 251, 281, 311, 341, 371]) 57 | max_epochs = 300 58 | for i in range(max_epochs): 59 | print(i, lrSched.get_lr(i)) 60 | -------------------------------------------------------------------------------- /imagenet/Model.py: -------------------------------------------------------------------------------- 1 | from torch.nn import init 2 | import torch.nn.functional as F 3 | from cnn_utils import * 4 | import math 5 | import torch 6 | 7 | #============================================ 8 | __author__ = "Sachin Mehta" 9 | __license__ = "MIT" 10 | __maintainer__ = "Sachin Mehta" 11 | #============================================ 12 | 13 | class EESP(nn.Module): 14 | ''' 15 | This class defines the EESP block, which is based on the following principle 16 | REDUCE ---> SPLIT ---> TRANSFORM --> MERGE 17 | ''' 18 | 19 | def __init__(self, nIn, nOut, stride=1, k=4, r_lim=7, down_method='esp'): #down_method --> ['avg' or 'esp'] 20 | ''' 21 | :param nIn: number of input channels 22 | :param nOut: number of output channels 23 | :param stride: factor by which we should skip (useful for down-sampling). If 2, then down-samples the feature map by 2 24 | :param k: # of parallel branches 25 | :param r_lim: A maximum value of receptive field allowed for EESP block 26 | :param down_method: Downsample or not (equivalent to say stride is 2 or not) 27 | ''' 28 | super().__init__() 29 | self.stride = stride 30 | n = int(nOut / k) 31 | n1 = nOut - (k - 1) * n 32 | assert down_method in ['avg', 'esp'], 'One of these is suppported (avg or esp)' 33 | assert n == n1, "n(={}) and n1(={}) should be equal for Depth-wise Convolution ".format(n, n1) 34 | self.proj_1x1 = CBR(nIn, n, 1, stride=1, groups=k) 35 | 36 | # (For convenience) Mapping between dilation rate and receptive field for a 3x3 kernel 37 | map_receptive_ksize = {3: 1, 5: 2, 7: 3, 9: 4, 11: 5, 13: 6, 15: 7, 17: 8} 38 | self.k_sizes = list() 39 | for i in range(k): 40 | ksize = int(3 + 2 * i) 41 | # After reaching the receptive field limit, fall back to the base kernel size of 3 with a dilation rate of 1 42 | ksize = ksize if ksize <= r_lim else 3 43 | self.k_sizes.append(ksize) 44 | # sort (in ascending order) these kernel sizes based on their receptive field 45 | # This enables us to ignore the kernels (3x3 in our case) with the same effective receptive field in hierarchical 46 | # feature fusion because kernels with 3x3 receptive fields does not have gridding artifact. 47 | self.k_sizes.sort() 48 | self.spp_dw = nn.ModuleList() 49 | for i in range(k): 50 | d_rate = map_receptive_ksize[self.k_sizes[i]] 51 | self.spp_dw.append(CDilated(n, n, kSize=3, stride=stride, groups=n, d=d_rate)) 52 | # Performing a group convolution with K groups is the same as performing K point-wise convolutions 53 | self.conv_1x1_exp = CB(nOut, nOut, 1, 1, groups=k) 54 | self.br_after_cat = BR(nOut) 55 | self.module_act = nn.PReLU(nOut) 56 | self.downAvg = True if down_method == 'avg' else False 57 | 58 | def forward(self, input): 59 | ''' 60 | :param input: input feature map 61 | :return: transformed feature map 62 | ''' 63 | 64 | # Reduce --> project high-dimensional feature maps to low-dimensional space 65 | output1 = self.proj_1x1(input) 66 | output = [self.spp_dw[0](output1)] 67 | # compute the output for each branch and hierarchically fuse them 68 | # i.e. Split --> Transform --> HFF 69 | for k in range(1, len(self.spp_dw)): 70 | out_k = self.spp_dw[k](output1) 71 | # HFF 72 | out_k = out_k + output[k - 1] 73 | output.append(out_k) 74 | # Merge 75 | expanded = self.conv_1x1_exp( # learn linear combinations using group point-wise convolutions 76 | self.br_after_cat( # apply batch normalization followed by activation function (PRelu in this case) 77 | torch.cat(output, 1) # concatenate the output of different branches 78 | ) 79 | ) 80 | del output 81 | # if down-sampling, then return the concatenated vector 82 | # because Downsampling function will combine it with avg. pooled feature map and then threshold it 83 | if self.stride == 2 and self.downAvg: 84 | return expanded 85 | 86 | # if dimensions of input and concatenated vector are the same, add them (RESIDUAL LINK) 87 | if expanded.size() == input.size(): 88 | expanded = expanded + input 89 | 90 | # Threshold the feature map using activation function (PReLU in this case) 91 | return self.module_act(expanded) 92 | 93 | 94 | class DownSampler(nn.Module): 95 | ''' 96 | Down-sampling fucntion that has three parallel branches: (1) avg pooling, 97 | (2) EESP block with stride of 2 and (3) efficient long-range connection with the input. 98 | The output feature maps of branches from (1) and (2) are concatenated and then additively fused with (3) to produce 99 | the final output. 100 | ''' 101 | 102 | def __init__(self, nin, nout, k=4, r_lim=9, reinf=True): 103 | ''' 104 | :param nin: number of input channels 105 | :param nout: number of output channels 106 | :param k: # of parallel branches 107 | :param r_lim: A maximum value of receptive field allowed for EESP block 108 | :param reinf: Use long range shortcut connection with the input or not. 109 | ''' 110 | super().__init__() 111 | nout_new = nout - nin 112 | self.eesp = EESP(nin, nout_new, stride=2, k=k, r_lim=r_lim, down_method='avg') 113 | self.avg = nn.AvgPool2d(kernel_size=3, padding=1, stride=2) 114 | if reinf: 115 | self.inp_reinf = nn.Sequential( 116 | CBR(config_inp_reinf, config_inp_reinf, 3, 1), 117 | CB(config_inp_reinf, nout, 1, 1) 118 | ) 119 | self.act = nn.PReLU(nout) 120 | 121 | def forward(self, input, input2=None): 122 | ''' 123 | :param input: input feature map 124 | :return: feature map down-sampled by a factor of 2 125 | ''' 126 | avg_out = self.avg(input) 127 | eesp_out = self.eesp(input) 128 | output = torch.cat([avg_out, eesp_out], 1) 129 | 130 | if input2 is not None: 131 | #assuming the input is a square image 132 | # Shortcut connection with the input image 133 | w1 = avg_out.size(2) 134 | while True: 135 | input2 = F.avg_pool2d(input2, kernel_size=3, padding=1, stride=2) 136 | w2 = input2.size(2) 137 | if w2 == w1: 138 | break 139 | output = output + self.inp_reinf(input2) 140 | 141 | return self.act(output) 142 | 143 | class EESPNet(nn.Module): 144 | ''' 145 | This class defines the ESPNetv2 architecture for the ImageNet classification 146 | ''' 147 | 148 | def __init__(self, classes=1000, s=1): 149 | ''' 150 | :param classes: number of classes in the dataset. Default is 1000 for the ImageNet dataset 151 | :param s: factor that scales the number of output feature maps 152 | ''' 153 | super().__init__() 154 | reps = [0, 3, 7, 3] # how many times EESP blocks should be repeated at each spatial level. 155 | channels = 3 156 | 157 | r_lim = [13, 11, 9, 7, 5] # receptive field at each spatial level 158 | K = [4]*len(r_lim) # No. of parallel branches at different levels 159 | 160 | base = 32 #base configuration 161 | config_len = 5 162 | config = [base] * config_len 163 | base_s = 0 164 | for i in range(config_len): 165 | if i== 0: 166 | base_s = int(base * s) 167 | base_s = math.ceil(base_s / K[0]) * K[0] 168 | config[i] = base if base_s > base else base_s 169 | else: 170 | config[i] = base_s * pow(2, i) 171 | if s <= 1.5: 172 | config.append(1024) 173 | elif s <= 2.0: 174 | config.append(1280) 175 | else: 176 | ValueError('Configuration not supported') 177 | 178 | 179 | global config_inp_reinf 180 | config_inp_reinf = 3 181 | self.input_reinforcement = True # True for the shortcut connection with input 182 | 183 | assert len(K) == len(r_lim), 'Length of branching factor array and receptive field array should be the same.' 184 | 185 | self.level1 = CBR(channels, config[0], 3, 2) # 112 L1 186 | 187 | self.level2_0 = DownSampler(config[0], config[1], k=K[0], r_lim=r_lim[0], reinf=self.input_reinforcement) # out = 56 188 | 189 | self.level3_0 = DownSampler(config[1], config[2], k=K[1], r_lim=r_lim[1], reinf=self.input_reinforcement) # out = 28 190 | self.level3 = nn.ModuleList() 191 | for i in range(reps[1]): 192 | self.level3.append(EESP(config[2], config[2], stride=1, k=K[2], r_lim=r_lim[2])) 193 | 194 | self.level4_0 = DownSampler(config[2], config[3], k=K[2], r_lim=r_lim[2], reinf=self.input_reinforcement) #out = 14 195 | self.level4 = nn.ModuleList() 196 | for i in range(reps[2]): 197 | self.level4.append(EESP(config[3], config[3], stride=1, k=K[3], r_lim=r_lim[3])) 198 | 199 | self.level5_0 = DownSampler(config[3], config[4], k=K[3], r_lim=r_lim[3]) #7 200 | self.level5 = nn.ModuleList() 201 | for i in range(reps[3]): 202 | self.level5.append(EESP(config[4], config[4], stride=1, k=K[4], r_lim=r_lim[4])) 203 | 204 | # expand the feature maps using depth-wise convolution followed by group point-wise convolution 205 | self.level5.append(CBR(config[4], config[4], 3, 1, groups=config[4])) 206 | self.level5.append(CBR(config[4], config[5], 1, 1, groups=K[4])) 207 | 208 | self.classifier = nn.Linear(config[5], classes) 209 | self.init_params() 210 | 211 | def init_params(self): 212 | ''' 213 | Function to initialze the parameters 214 | ''' 215 | for m in self.modules(): 216 | if isinstance(m, nn.Conv2d): 217 | init.kaiming_normal_(m.weight, mode='fan_out') 218 | if m.bias is not None: 219 | init.constant_(m.bias, 0) 220 | elif isinstance(m, nn.BatchNorm2d): 221 | init.constant_(m.weight, 1) 222 | init.constant_(m.bias, 0) 223 | elif isinstance(m, nn.Linear): 224 | init.normal_(m.weight, std=0.001) 225 | if m.bias is not None: 226 | init.constant_(m.bias, 0) 227 | 228 | def forward(self, input, p=0.2): 229 | ''' 230 | :param input: Receives the input RGB image 231 | :return: a C-dimensional vector, C=# of classes 232 | ''' 233 | out_l1 = self.level1(input) # 112 234 | if not self.input_reinforcement: 235 | del input 236 | input = None 237 | 238 | out_l2 = self.level2_0(out_l1, input) # 56 239 | 240 | out_l3_0 = self.level3_0(out_l2, input) # down-sample 241 | for i, layer in enumerate(self.level3): 242 | if i == 0: 243 | out_l3 = layer(out_l3_0) 244 | else: 245 | out_l3 = layer(out_l3) 246 | 247 | out_l4_0 = self.level4_0(out_l3, input) # down-sample 248 | for i, layer in enumerate(self.level4): 249 | if i == 0: 250 | out_l4 = layer(out_l4_0) 251 | else: 252 | out_l4 = layer(out_l4) 253 | 254 | out_l5_0 = self.level5_0(out_l4) # down-sample 255 | for i, layer in enumerate(self.level5): 256 | if i == 0: 257 | out_l5 = layer(out_l5_0) 258 | else: 259 | out_l5 = layer(out_l5) 260 | 261 | output_g = F.adaptive_avg_pool2d(out_l5, output_size=1) 262 | output_g = F.dropout(output_g, p=p, training=self.training) 263 | output_1x1 = output_g.view(output_g.size(0), -1) 264 | 265 | return self.classifier(output_1x1) 266 | 267 | 268 | if __name__ == '__main__': 269 | input = torch.Tensor(1, 3, 224, 224).cuda() 270 | model = EESPNet(classes=1000, s=1.0) 271 | out = model(input) 272 | print('Output size') 273 | print(out.size()) 274 | -------------------------------------------------------------------------------- /imagenet/README.md: -------------------------------------------------------------------------------- 1 | # [ESPNetv2: A Light-weight, Power Efficient, and General Purpose Convolutional Neural Network](https://arxiv.org/abs/1811.11431) 2 | 3 | This repository contains the source code for training on the ImageNet dataset along with the pre-trained models 4 | 5 | ## Training and Evaluation on the ImageNet dataset 6 | 7 | Below are the commands to train and test the network at scale `s=1.0`. 8 | 9 | ### Training 10 | To train the network, you can use the following command: 11 | 12 | ``` 13 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --batch-size 512 --s 1.0 --data 14 | ``` 15 | 16 | ### Evaluation 17 | To evaluate our pretrained models (or the ones trained by you), you can use `evaluate.py` file. 18 | 19 | Use below command to evaluate the performance of our model at scale `s=1.0` on the ImageNet dataset. 20 | ``` 21 | CUDA_VISIBLE_DEVICES=0 python evaluate.py --batch-size 512 --s 1.0 --weightFile ./pretrained_weights/espnetv2_s_1.0.pth --data 22 | ``` 23 | 24 | ## Results and pre-trained models 25 | We release the pre-trained models at different computational complexities. Following state-of-the-art methods, we measure top-1 accuracy on a 26 | cropped center view of size 224x224. 27 | 28 | Below table provide details about the performance of our model on the ImageNet validation set at different computational complexities along with links to download the pre-trained weights. 29 | 30 | 31 | | s | Params | FLOPs | top-1 (val) | Link | 32 | | -------- |--------|--------|-------| -------| 33 | | 0.5 | 1.24 | 28.37 | 57.7 | [here](pretrained_weights/espnetv2_s_0.5.pth) | 34 | | 1.0 | 1.67 | 85.72 | 66.1 | [here](pretrained_weights/espnetv2_s_1.0.pth) | 35 | | 1.25 | 1.96 | 123.39 | 67.9 | [here](pretrained_weights/espnetv2_s_1.25.pth) | 36 | | 1.5 | 2.31 | 168.6 | 69.2 | [here](pretrained_weights/espnetv2_s_1.5.pth) | 37 | | 2.0 | 3.49 | 284.8 | 72.1 | [here](pretrained_weights/espnetv2_s_2.0.pth) | 38 | 39 | 40 | ## ImageNet dataset preparation 41 | To prepare the dataset, follow instructions [here](https://github.com/facebook/fb.resnet.torch/blob/master/INSTALL.md#download-the-imagenet-dataset). 42 | 43 | -------------------------------------------------------------------------------- /imagenet/cnn_utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | #============================================ 4 | __author__ = "Sachin Mehta" 5 | __license__ = "MIT" 6 | __maintainer__ = "Sachin Mehta" 7 | #============================================ 8 | 9 | class CBR(nn.Module): 10 | ''' 11 | This class defines the convolution layer with batch normalization and PReLU activation 12 | ''' 13 | 14 | def __init__(self, nIn, nOut, kSize, stride=1, groups=1): 15 | ''' 16 | 17 | :param nIn: number of input channels 18 | :param nOut: number of output channels 19 | :param kSize: kernel size 20 | :param stride: stride rate for down-sampling. Default is 1 21 | ''' 22 | super().__init__() 23 | padding = int((kSize - 1) / 2) 24 | self.conv = nn.Conv2d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False, groups=groups) 25 | self.bn = nn.BatchNorm2d(nOut) 26 | self.act = nn.PReLU(nOut) 27 | 28 | def forward(self, input): 29 | ''' 30 | :param input: input feature map 31 | :return: transformed feature map 32 | ''' 33 | output = self.conv(input) 34 | # output = self.conv1(output) 35 | output = self.bn(output) 36 | output = self.act(output) 37 | return output 38 | 39 | 40 | class BR(nn.Module): 41 | ''' 42 | This class groups the batch normalization and PReLU activation 43 | ''' 44 | 45 | def __init__(self, nOut): 46 | ''' 47 | :param nOut: output feature maps 48 | ''' 49 | super().__init__() 50 | self.bn = nn.BatchNorm2d(nOut) 51 | self.act = nn.PReLU(nOut) 52 | 53 | def forward(self, input): 54 | ''' 55 | :param input: input feature map 56 | :return: normalized and thresholded feature map 57 | ''' 58 | output = self.bn(input) 59 | output = self.act(output) 60 | return output 61 | 62 | 63 | class CB(nn.Module): 64 | ''' 65 | This class groups the convolution and batch normalization 66 | ''' 67 | 68 | def __init__(self, nIn, nOut, kSize, stride=1, groups=1): 69 | ''' 70 | :param nIn: number of input channels 71 | :param nOut: number of output channels 72 | :param kSize: kernel size 73 | :param stride: optinal stide for down-sampling 74 | ''' 75 | super().__init__() 76 | padding = int((kSize - 1) / 2) 77 | self.conv = nn.Conv2d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False, 78 | groups=groups) 79 | self.bn = nn.BatchNorm2d(nOut) 80 | 81 | def forward(self, input): 82 | ''' 83 | 84 | :param input: input feature map 85 | :return: transformed feature map 86 | ''' 87 | output = self.conv(input) 88 | output = self.bn(output) 89 | return output 90 | 91 | 92 | class C(nn.Module): 93 | ''' 94 | This class is for a convolutional layer. 95 | ''' 96 | 97 | def __init__(self, nIn, nOut, kSize, stride=1, groups=1): 98 | ''' 99 | 100 | :param nIn: number of input channels 101 | :param nOut: number of output channels 102 | :param kSize: kernel size 103 | :param stride: optional stride rate for down-sampling 104 | ''' 105 | super().__init__() 106 | padding = int((kSize - 1) / 2) 107 | self.conv = nn.Conv2d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False, 108 | groups=groups) 109 | 110 | def forward(self, input): 111 | ''' 112 | :param input: input feature map 113 | :return: transformed feature map 114 | ''' 115 | output = self.conv(input) 116 | return output 117 | 118 | 119 | class CDilated(nn.Module): 120 | ''' 121 | This class defines the dilated convolution. 122 | ''' 123 | 124 | def __init__(self, nIn, nOut, kSize, stride=1, d=1, groups=1): 125 | ''' 126 | :param nIn: number of input channels 127 | :param nOut: number of output channels 128 | :param kSize: kernel size 129 | :param stride: optional stride rate for down-sampling 130 | :param d: optional dilation rate 131 | ''' 132 | super().__init__() 133 | padding = int((kSize - 1) / 2) * d 134 | self.conv = nn.Conv2d(nIn, nOut,kSize, stride=stride, padding=padding, bias=False, 135 | dilation=d, groups=groups) 136 | 137 | def forward(self, input): 138 | ''' 139 | :param input: input feature map 140 | :return: transformed feature map 141 | ''' 142 | output = self.conv(input) 143 | return output 144 | 145 | class CDilatedB(nn.Module): 146 | ''' 147 | This class defines the dilated convolution with batch normalization. 148 | ''' 149 | 150 | def __init__(self, nIn, nOut, kSize, stride=1, d=1, groups=1): 151 | ''' 152 | :param nIn: number of input channels 153 | :param nOut: number of output channels 154 | :param kSize: kernel size 155 | :param stride: optional stride rate for down-sampling 156 | :param d: optional dilation rate 157 | ''' 158 | super().__init__() 159 | padding = int((kSize - 1) / 2) * d 160 | self.conv = nn.Conv2d(nIn, nOut,kSize, stride=stride, padding=padding, bias=False, 161 | dilation=d, groups=groups) 162 | self.bn = nn.BatchNorm2d(nOut) 163 | 164 | def forward(self, input): 165 | ''' 166 | :param input: input feature map 167 | :return: transformed feature map 168 | ''' 169 | return self.bn(self.conv(input)) 170 | -------------------------------------------------------------------------------- /imagenet/evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | from torch.autograd import Variable 4 | import torch.nn.parallel 5 | import torch.backends.cudnn as cudnn 6 | import torch.optim 7 | import torch.utils.data 8 | import torch.utils.data.distributed 9 | import torchvision.transforms as transforms 10 | import torchvision.datasets as datasets 11 | import Model as Net 12 | import numpy as np 13 | from utils import * 14 | import os 15 | import torchvision.models as preModels 16 | 17 | cudnn.benchmark = True 18 | 19 | ''' 20 | This file is mostly adapted from the PyTorch ImageNet example 21 | ''' 22 | 23 | #============================================ 24 | __author__ = "Sachin Mehta" 25 | __license__ = "MIT" 26 | __maintainer__ = "Sachin Mehta" 27 | #============================================ 28 | 29 | def main(args): 30 | 31 | model = Net.EESPNet(classes=1000, s=args.s) 32 | model = torch.nn.DataParallel(model).cuda() 33 | if not os.path.isfile(args.weightFile): 34 | print('Weight file does not exist') 35 | exit(-1) 36 | dict_model = torch.load(args.weightFile) 37 | model.load_state_dict(dict_model) 38 | 39 | 40 | n_params = sum([np.prod(p.size()) for p in model.parameters()]) 41 | print('Parameters: ' + str(n_params)) 42 | 43 | # Data loading code 44 | valdir = os.path.join(args.data, 'val') 45 | traindir = os.path.join(args.data, 'train') 46 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 47 | std=[0.229, 0.224, 0.225]) 48 | 49 | val_loader = torch.utils.data.DataLoader( 50 | datasets.ImageFolder(valdir, transforms.Compose([ 51 | transforms.Resize(int(args.inpSize/0.875)), 52 | transforms.CenterCrop(args.inpSize), 53 | transforms.ToTensor(), 54 | normalize, 55 | ])), 56 | batch_size=args.batch_size, shuffle=False, 57 | num_workers=args.workers, pin_memory=True) 58 | 59 | validate(val_loader, model) 60 | return 61 | 62 | 63 | if __name__ == '__main__': 64 | parser = argparse.ArgumentParser(description='ESPNetv2 Training on the ImageNet') 65 | parser.add_argument('--data', default='/home/ubuntu/ILSVRC2015/Data/CLS-LOC/', help='path to dataset') 66 | parser.add_argument('--workers', default=12, type=int, help='number of data loading workers (default: 4)') 67 | parser.add_argument('-b', '--batch-size', default=512, type=int, 68 | metavar='N', help='mini-batch size (default: 256)') 69 | parser.add_argument('--print-freq', '-p', default=10, type=int, 70 | metavar='N', help='print frequency (default: 10)') 71 | parser.add_argument('--s', default=1, type=float, 72 | help='Factor by which output channels should be reduced (s > 1 for increasing the dims while < 1 for decreasing)') 73 | parser.add_argument('--weightFile', type=str, default='', help='weight file') 74 | parser.add_argument('--inpSize', default=224, type=int, 75 | help='Input size') 76 | 77 | 78 | args = parser.parse_args() 79 | args.parallel = True 80 | 81 | main(args) 82 | -------------------------------------------------------------------------------- /imagenet/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch.nn.parallel 3 | import torch.backends.cudnn as cudnn 4 | import torch.optim 5 | import torch.utils.data 6 | import torch.utils.data.distributed 7 | import torchvision.transforms as transforms 8 | import torchvision.datasets as datasets 9 | import Model as Net 10 | import numpy as np 11 | from utils import * 12 | import random 13 | import os 14 | from LRSchedule import MyLRScheduler 15 | 16 | cudnn.benchmark = True 17 | 18 | 19 | #============================================ 20 | __author__ = "Sachin Mehta" 21 | __license__ = "MIT" 22 | __maintainer__ = "Sachin Mehta" 23 | #============================================ 24 | 25 | def compute_params(model): 26 | return sum([np.prod(p.size()) for p in model.parameters()]) 27 | 28 | def main(args): 29 | best_prec1 = 0.0 30 | 31 | if not os.path.isdir(args.savedir): 32 | os.mkdir(args.savedir) 33 | 34 | # create model 35 | model = Net.EESPNet(classes=1000, s=args.s) 36 | print('Network Parameters: ' + str(compute_params(model))) 37 | 38 | #check if the cuda is available or not 39 | cuda_available = torch.cuda.is_available() 40 | 41 | num_gpus = torch.cuda.device_count() 42 | if num_gpus >= 1: 43 | model = torch.nn.DataParallel(model) 44 | 45 | if cuda_available: 46 | model = model.cuda() 47 | 48 | 49 | logFileLoc = args.savedir + 'logs.txt' 50 | if os.path.isfile(logFileLoc): 51 | logger = open(logFileLoc, 'a') 52 | else: 53 | logger = open(logFileLoc, 'w') 54 | logger.write("\n%s\t%s\t%s\t%s\t%s\t" % ('Epoch', 'Loss(Tr)', 'Loss(val)', 'top1 (tr)', 'top1 (val')) 55 | 56 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 57 | momentum=args.momentum, 58 | weight_decay=args.weight_decay, nesterov=True) 59 | 60 | # optionally resume from a checkpoint 61 | if args.resume: 62 | if os.path.isfile(args.resume): 63 | print("=> loading checkpoint '{}'".format(args.resume)) 64 | checkpoint = torch.load(args.resume) 65 | args.start_epoch = checkpoint['epoch'] 66 | best_prec1 = checkpoint['best_prec1'] 67 | model.load_state_dict(checkpoint['state_dict']) 68 | optimizer.load_state_dict(checkpoint['optimizer']) 69 | print("=> loaded checkpoint '{}' (epoch {})" 70 | .format(args.resume, checkpoint['epoch'])) 71 | else: 72 | print("=> no checkpoint found at '{}'".format(args.resume)) 73 | 74 | # Data loading code 75 | traindir = os.path.join(args.data, 'train') 76 | valdir = os.path.join(args.data, 'val') 77 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 78 | std=[0.229, 0.224, 0.225]) 79 | 80 | train_loader1 = torch.utils.data.DataLoader( 81 | datasets.ImageFolder( 82 | traindir, 83 | transforms.Compose([ 84 | transforms.RandomResizedCrop(args.inpSize), 85 | transforms.RandomHorizontalFlip(), 86 | transforms.ToTensor(), 87 | normalize, 88 | ])), 89 | batch_size=args.batch_size, shuffle=True, 90 | num_workers=args.workers, pin_memory=True) 91 | 92 | 93 | val_loader = torch.utils.data.DataLoader( 94 | datasets.ImageFolder(valdir, transforms.Compose([ 95 | transforms.Resize(int(args.inpSize/0.875)), 96 | transforms.CenterCrop(args.inpSize), 97 | transforms.ToTensor(), 98 | normalize, 99 | ])), 100 | batch_size=args.batch_size, shuffle=False, 101 | num_workers=args.workers, pin_memory=True) 102 | 103 | # global customLR 104 | # steps at which we should decrease the learning rate 105 | step_sizes = [51, 101, 131, 161, 191, 221, 251, 281] 106 | 107 | #ImageNet experiments consume a lot of time 108 | # Just for safety, store the checkpoint before we decrease the learning rate 109 | # i.e. store the model at step -1 110 | step_store = list() 111 | for step in step_sizes: 112 | step_store.append(step-1) 113 | 114 | 115 | customLR = MyLRScheduler(args.lr, 5, step_sizes) 116 | #set up the variables in case of resuming 117 | if args.start_epoch != 0: 118 | for epoch in range(args.start_epoch): 119 | customLR.get_lr(epoch) 120 | 121 | for epoch in range(args.start_epoch, args.epochs): 122 | lr_log = customLR.get_lr(epoch) 123 | # set the optimizer with the learning rate 124 | # This can be done inside the MyLRScheduler 125 | for param_group in optimizer.param_groups: 126 | param_group['lr'] = lr_log 127 | print("LR for epoch {} = {:.5f}".format(epoch, lr_log)) 128 | train_prec1, train_loss = train(train_loader1, model, optimizer, epoch) 129 | # evaluate on validation set 130 | val_prec1, val_loss = validate(val_loader, model) 131 | 132 | # remember best prec@1 and save checkpoint 133 | is_best = val_prec1.item() > best_prec1 134 | best_prec1 = max(val_prec1.item(), best_prec1) 135 | back_check = True if epoch in step_store else False #backup checkpoint or not 136 | save_checkpoint({ 137 | 'epoch': epoch + 1, 138 | 'arch': 'ESPNet', 139 | 'state_dict': model.state_dict(), 140 | 'best_prec1': best_prec1, 141 | 'optimizer': optimizer.state_dict(), 142 | }, is_best, back_check, epoch, args.savedir) 143 | 144 | logger.write("\n%d\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.7f" % (epoch, train_loss, val_loss, train_prec1, 145 | val_prec1, lr_log)) 146 | logger.flush() 147 | 148 | 149 | if __name__ == '__main__': 150 | parser = argparse.ArgumentParser(description='ESPNetv2 Training on the ImageNet') 151 | parser.add_argument('--data', default='/home/ubuntu/ILSVRC2015/Data/CLS-LOC/', help='path to dataset') 152 | parser.add_argument('--workers', default=12, type=int, help='number of data loading workers (default: 4)') 153 | parser.add_argument('--epochs', default=300, type=int, help='number of total epochs to run') 154 | parser.add_argument('--start-epoch', default=0, type=int, help='manual epoch number (useful on restarts)') 155 | parser.add_argument('--batch-size', default=512, type=int, help='mini-batch size (default: 512)') 156 | parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate') 157 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum') 158 | parser.add_argument('--weight-decay', default=4e-5, type=float, help='weight decay (default: 4e-5)') 159 | parser.add_argument('--resume', default='', type=str, help='path to latest checkpoint (default: none)') 160 | parser.add_argument('--savedir', type=str, default='./results', help='Location to save the results') 161 | parser.add_argument('--s', default=1, type=float, help='Factor by which output channels should be reduced (s > 1 ' 162 | 'for increasing the dims while < 1 for decreasing)') 163 | parser.add_argument('--inpSize', default=224, type=int, help='Input image size (default: 224 x 224)') 164 | 165 | 166 | args = parser.parse_args() 167 | args.parallel = True 168 | cudnn.deterministic = True 169 | random.seed(1882) 170 | torch.manual_seed(1882) 171 | 172 | args.savedir = args.savedir + '_s_' + str(args.s) + '_inp_' + str(args.inpSize) + os.sep 173 | 174 | main(args) 175 | -------------------------------------------------------------------------------- /imagenet/pretrained_weights/espnetv2_s_0.5.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sacmehta/ESPNetv2/b78e323039908f31347d8ca17f49d5502ef1a594/imagenet/pretrained_weights/espnetv2_s_0.5.pth -------------------------------------------------------------------------------- /imagenet/pretrained_weights/espnetv2_s_1.0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sacmehta/ESPNetv2/b78e323039908f31347d8ca17f49d5502ef1a594/imagenet/pretrained_weights/espnetv2_s_1.0.pth -------------------------------------------------------------------------------- /imagenet/pretrained_weights/espnetv2_s_1.25.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sacmehta/ESPNetv2/b78e323039908f31347d8ca17f49d5502ef1a594/imagenet/pretrained_weights/espnetv2_s_1.25.pth -------------------------------------------------------------------------------- /imagenet/pretrained_weights/espnetv2_s_1.5.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sacmehta/ESPNetv2/b78e323039908f31347d8ca17f49d5502ef1a594/imagenet/pretrained_weights/espnetv2_s_1.5.pth -------------------------------------------------------------------------------- /imagenet/pretrained_weights/espnetv2_s_2.0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sacmehta/ESPNetv2/b78e323039908f31347d8ca17f49d5502ef1a594/imagenet/pretrained_weights/espnetv2_s_2.0.pth -------------------------------------------------------------------------------- /imagenet/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import shutil 4 | import torch.nn.functional as F 5 | import time 6 | 7 | 8 | #============================================ 9 | __author__ = "Sachin Mehta" 10 | __license__ = "MIT" 11 | __maintainer__ = "Sachin Mehta" 12 | #============================================ 13 | 14 | ''' 15 | This file is mostly adapted from the PyTorch ImageNet example 16 | ''' 17 | 18 | class AverageMeter(object): 19 | """Computes and stores the average and current value""" 20 | def __init__(self): 21 | self.reset() 22 | 23 | def reset(self): 24 | self.val = 0 25 | self.avg = 0 26 | self.sum = 0 27 | self.count = 0 28 | 29 | def update(self, val, n=1): 30 | self.val = val 31 | self.sum += val * n 32 | self.count += n 33 | self.avg = self.sum / self.count 34 | 35 | 36 | def accuracy(output, target, topk=(1,)): 37 | """Computes the precision@k for the specified values of k""" 38 | maxk = max(topk) 39 | batch_size = target.size(0) 40 | 41 | _, pred = output.topk(maxk, 1, True, True) 42 | pred = pred.t() 43 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 44 | 45 | res = [] 46 | for k in topk: 47 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 48 | res.append(correct_k.mul_(100.0 / batch_size)) 49 | return res 50 | 51 | ''' 52 | Utility to save checkpoint or not 53 | ''' 54 | def save_checkpoint(state, is_best, back_check, epoch, dir): 55 | check_pt_file = dir + os.sep + 'checkpoint.pth.tar' 56 | torch.save(state, check_pt_file) 57 | if is_best: 58 | #We only need best models weight and not check point states, etc. 59 | torch.save(state['state_dict'], dir + os.sep + 'model_best.pth') 60 | if back_check: 61 | shutil.copyfile(check_pt_file, dir + os.sep + 'checkpoint_back' + str(epoch) + '.pth.tar') 62 | 63 | ''' 64 | Cross entropy loss function 65 | ''' 66 | def loss_fn(outputs, labels): 67 | 68 | return F.cross_entropy(outputs, labels) 69 | 70 | ''' 71 | Training loop 72 | ''' 73 | def train(train_loader, model, optimizer, epoch): 74 | 75 | batch_time = AverageMeter() 76 | data_time = AverageMeter() 77 | losses = AverageMeter() 78 | top1 = AverageMeter() 79 | top5 = AverageMeter() 80 | 81 | # switch to train mode 82 | model.train() 83 | 84 | end = time.time() 85 | for i, (input, target) in enumerate(train_loader): 86 | 87 | # measure data loading time 88 | data_time.update(time.time() - end) 89 | 90 | input = input.cuda(non_blocking=True) 91 | target = target.cuda(non_blocking=True) 92 | 93 | # compute output 94 | output = model(input) 95 | 96 | # compute loss 97 | loss = loss_fn(output, target) 98 | 99 | # measure accuracy and record loss 100 | prec1, prec5 = accuracy(output, target, topk=(1, 5)) 101 | #losses.update(loss.data[0], input.size(0)) 102 | losses.update(loss.item(), input.size(0)) 103 | top1.update(prec1[0], input.size(0)) 104 | top5.update(prec5[0], input.size(0)) 105 | 106 | # compute gradient and do SGD step 107 | optimizer.zero_grad() 108 | loss.backward() 109 | optimizer.step() 110 | 111 | # measure elapsed time 112 | batch_time.update(time.time() - end) 113 | end = time.time() 114 | 115 | if i % 100 == 0: #print after every 100 batches 116 | print("Epoch: %d[%d/%d]\t\tBatch Time:%.4f\t\tLoss:%.4f\t\ttop1:%.4f (%.4f)\t\ttop5:%.4f (%.4f)" % 117 | (epoch, i, len(train_loader), batch_time.avg, losses.avg, top1.val, top1.avg, top5.val, top5.avg)) 118 | 119 | 120 | return top1.avg, losses.avg 121 | 122 | ''' 123 | Validation loop 124 | ''' 125 | def validate(val_loader, model): 126 | batch_time = AverageMeter() 127 | losses = AverageMeter() 128 | top1 = AverageMeter() 129 | top5 = AverageMeter() 130 | 131 | # switch to evaluate mode 132 | model.eval() 133 | 134 | # with torch.no_grad(): 135 | end = time.time() 136 | with torch.no_grad(): 137 | for i, (input, target) in enumerate(val_loader): 138 | input = input.cuda(non_blocking=True) 139 | target = target.cuda(non_blocking=True) 140 | 141 | # compute output 142 | output = model(input) 143 | loss = loss_fn(output, target) 144 | 145 | # measure accuracy and record loss 146 | prec1, prec5 = accuracy(output, target, topk=(1, 5)) 147 | 148 | # replace if using pytorch version < 0.4 149 | #losses.update(loss.data[0], input.size(0)) 150 | losses.update(loss.item(), input.size(0)) 151 | top1.update(prec1[0], input.size(0)) 152 | top5.update(prec5[0], input.size(0)) 153 | 154 | # measure elapsed time 155 | batch_time.update(time.time() - end) 156 | end = time.time() 157 | 158 | if i % 100 == 0: # print after every 100 batches 159 | print("Batch:[%d/%d]\t\tBatchTime:%.3f\t\tLoss:%.3f\t\ttop1:%.3f (%.3f)\t\ttop5:%.3f(%.3f)" % 160 | (i, len(val_loader), batch_time.avg, losses.avg, top1.val, top1.avg, top5.val, top5.avg)) 161 | 162 | print(' * Prec@1:%.3f Prec@5:%.3f' % (top1.avg, top5.avg)) 163 | 164 | return top1.avg, losses.avg -------------------------------------------------------------------------------- /images/ReadMe.md: -------------------------------------------------------------------------------- 1 | Directory for sample images 2 | -------------------------------------------------------------------------------- /images/effCompare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sacmehta/ESPNetv2/b78e323039908f31347d8ca17f49d5502ef1a594/images/effCompare.png -------------------------------------------------------------------------------- /images/powerTX2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sacmehta/ESPNetv2/b78e323039908f31347d8ca17f49d5502ef1a594/images/powerTX2.png -------------------------------------------------------------------------------- /segmentation/DataSet.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch.utils.data 3 | import numpy as np 4 | 5 | 6 | #============================================ 7 | __author__ = "Sachin Mehta" 8 | __license__ = "MIT" 9 | __maintainer__ = "Sachin Mehta" 10 | #============================================ 11 | 12 | 13 | class MyDataset(torch.utils.data.Dataset): 14 | ''' 15 | Class to load the dataset 16 | ''' 17 | def __init__(self, imList, labelList, transform=None): 18 | ''' 19 | :param imList: image list (Note that these lists have been processed and pickled using the loadData.py) 20 | :param labelList: label list (Note that these lists have been processed and pickled using the loadData.py) 21 | :param transform: Type of transformation. SEe Transforms.py for supported transformations 22 | ''' 23 | self.imList = imList 24 | self.labelList = labelList 25 | self.transform = transform 26 | 27 | def __len__(self): 28 | return len(self.imList) 29 | 30 | def __getitem__(self, idx): 31 | ''' 32 | 33 | :param idx: Index of the image file 34 | :return: returns the image and corresponding label file. 35 | ''' 36 | image_name = self.imList[idx] 37 | label_name = self.labelList[idx] 38 | image = cv2.imread(image_name) 39 | label = cv2.imread(label_name, 0) 40 | # if you have 255 label in your label files, map it to the background class (19) in the Cityscapes dataset 41 | if 255 in np.unique(label): 42 | label[label==255] = 19 43 | 44 | if self.transform: 45 | [image, label] = self.transform(image, label) 46 | return (image, label) 47 | -------------------------------------------------------------------------------- /segmentation/IOUEval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | #adapted from https://github.com/shelhamer/fcn.berkeleyvision.org/blob/master/score.py 4 | 5 | #============================================ 6 | __author__ = "Sachin Mehta" 7 | __license__ = "MIT" 8 | __maintainer__ = "Sachin Mehta" 9 | #============================================ 10 | 11 | class iouEval: 12 | def __init__(self, nClasses): 13 | self.nClasses = nClasses 14 | self.reset() 15 | 16 | def reset(self): 17 | self.overall_acc = 0 18 | self.per_class_acc = np.zeros(self.nClasses, dtype=np.float32) 19 | self.per_class_iu = np.zeros(self.nClasses, dtype=np.float32) 20 | self.mIOU = 0 21 | self.batchCount = 0 22 | 23 | def fast_hist(self, a, b): 24 | k = (a >= 0) & (a < self.nClasses) 25 | return np.bincount(self.nClasses * a[k].astype(int) + b[k], minlength=self.nClasses ** 2).reshape(self.nClasses, self.nClasses) 26 | 27 | def compute_hist(self, predict, gth): 28 | hist = self.fast_hist(gth, predict) 29 | return hist 30 | 31 | def addBatch(self, predict, gth): 32 | predict = predict.cpu().numpy().flatten() 33 | gth = gth.cpu().numpy().flatten() 34 | 35 | epsilon = 0.00000001 36 | hist = self.compute_hist(predict, gth) 37 | overall_acc = np.diag(hist).sum() / (hist.sum() + epsilon) 38 | per_class_acc = np.diag(hist) / (hist.sum(1) + epsilon) 39 | per_class_iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist) + epsilon) 40 | mIou = np.nanmean(per_class_iu) 41 | 42 | self.overall_acc +=overall_acc 43 | self.per_class_acc += per_class_acc 44 | self.per_class_iu += per_class_iu 45 | self.mIOU += mIou 46 | self.batchCount += 1 47 | 48 | def getMetric(self): 49 | overall_acc = self.overall_acc/self.batchCount 50 | per_class_acc = self.per_class_acc / self.batchCount 51 | per_class_iu = self.per_class_iu / self.batchCount 52 | mIOU = self.mIOU / self.batchCount 53 | 54 | return overall_acc, per_class_acc, per_class_iu, mIOU 55 | -------------------------------------------------------------------------------- /segmentation/README.md: -------------------------------------------------------------------------------- 1 | # ESPNetv2: A Light-weight, Power Efficient, and General Purpose Convolutional Neural Network 2 | 3 | This repository contains the source code that we used for semantic segmentation in our paper of our paper, [ESPNetv2](https://arxiv.org/abs/1811.11431). 4 | 5 | ***Note:*** New segmentation models for the PASCAL VOC and the Cityscapes are coming soon. Our new models achieves mIOU of [68.0](http://host.robots.ox.ac.uk:8080/anonymous/DAMVRR.html) and [66.15](https://www.cityscapes-dataset.com/anonymous-results/?id=2267c613d55dd75d5301850c913b1507bf2f10586ca73eb8ebcf357cdcf3e036) on the PASCAL VOC and the Cityscapes test sets, respectively. 6 | 7 | **IMPORTANT NOTE** We released a new repository EdgeNets that contains our work on efficient network design. See [EdgeNets](https://github.com/sacmehta/EdgeNets/) for performance comparison and pretrained models. 8 | 9 | **THIS REPO IS OBSOLETE** 10 | -------------------------------------------------------------------------------- /segmentation/Transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | import cv2 5 | 6 | 7 | #============================================ 8 | __author__ = "Sachin Mehta" 9 | __license__ = "MIT" 10 | __maintainer__ = "Sachin Mehta" 11 | #============================================ 12 | 13 | 14 | class Scale(object): 15 | """ 16 | Randomly crop and resize the given PIL image with a probability of 0.5 17 | """ 18 | def __init__(self, wi, he): 19 | ''' 20 | 21 | :param wi: width after resizing 22 | :param he: height after reszing 23 | ''' 24 | self.w = wi 25 | self.h = he 26 | 27 | def __call__(self, img, label): 28 | ''' 29 | :param img: RGB image 30 | :param label: semantic label image 31 | :return: resized images 32 | ''' 33 | #bilinear interpolation for RGB image 34 | img = cv2.resize(img, (self.w, self.h)) 35 | # nearest neighbour interpolation for label image 36 | label = cv2.resize(label, (self.w, self.h), interpolation=cv2.INTER_NEAREST) 37 | 38 | return [img, label] 39 | 40 | 41 | 42 | class RandomCropResize(object): 43 | """ 44 | Randomly crop and resize the given PIL image with a probability of 0.5 45 | """ 46 | def __init__(self, size): 47 | ''' 48 | :param crop_area: area to be cropped (this is the max value and we select between o and crop area 49 | ''' 50 | self.size = size 51 | 52 | def __call__(self, img, label): 53 | h, w = img.shape[:2] 54 | x1 = random.randint(0, int(w*0.1)) # 25% to 10% 55 | y1 = random.randint(0, int(h*0.1)) 56 | 57 | img_crop = img[y1:h-y1, x1:w-x1] 58 | label_crop = label[y1:h-y1, x1:w-x1] 59 | 60 | img_crop = cv2.resize(img_crop, self.size) 61 | label_crop = cv2.resize(label_crop, self.size, interpolation=cv2.INTER_NEAREST) 62 | return img_crop, label_crop 63 | 64 | class RandomCrop(object): 65 | ''' 66 | This class if for random cropping 67 | ''' 68 | def __init__(self, cropArea): 69 | ''' 70 | :param cropArea: amount of cropping (in pixels) 71 | ''' 72 | self.crop = cropArea 73 | 74 | def __call__(self, img, label): 75 | 76 | if random.random() < 0.5: 77 | h, w = img.shape[:2] 78 | img_crop = img[self.crop:h-self.crop, self.crop:w-self.crop] 79 | label_crop = label[self.crop:h-self.crop, self.crop:w-self.crop] 80 | return img_crop, label_crop 81 | else: 82 | return [img, label] 83 | 84 | 85 | 86 | class RandomFlip(object): 87 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5 88 | """ 89 | 90 | def __call__(self, image, label): 91 | if random.random() < 0.5: 92 | x1 = 0#random.randint(0, 1) #if you want to do vertical flip, uncomment this line 93 | if x1 == 0: 94 | image = cv2.flip(image, 0) # horizontal flip 95 | label = cv2.flip(label, 0) # horizontal flip 96 | else: 97 | image = cv2.flip(image, 1) # veritcal flip 98 | label = cv2.flip(label, 1) # veritcal flip 99 | return [image, label] 100 | 101 | 102 | class Normalize(object): 103 | """Given mean: (R, G, B) and std: (R, G, B), 104 | will normalize each channel of the torch.*Tensor, i.e. 105 | channel = (channel - mean) / std 106 | """ 107 | 108 | def __init__(self, mean, std): 109 | ''' 110 | :param mean: global mean computed from dataset 111 | :param std: global std computed from dataset 112 | ''' 113 | self.mean = mean 114 | self.std = std 115 | 116 | def __call__(self, image, label): 117 | image = image.astype(np.float32) 118 | for i in range(3): 119 | image[:,:,i] -= self.mean[i] 120 | for i in range(3): 121 | image[:,:, i] /= self.std[i] 122 | 123 | return [image, label] 124 | 125 | class ToTensor(object): 126 | ''' 127 | This class converts the data to tensor so that it can be processed by PyTorch 128 | ''' 129 | def __init__(self, scale=1): 130 | ''' 131 | :param scale: ESPNet-C's output is 1/8th of original image size, so set this parameter accordingly 132 | ''' 133 | self.scale = scale # original images are 2048 x 1024 134 | 135 | def __call__(self, image, label): 136 | 137 | if self.scale != 1: 138 | h, w = label.shape[:2] 139 | image = cv2.resize(image, (int(w), int(h))) 140 | label = cv2.resize(label, (int(w/self.scale), int(h/self.scale)), interpolation=cv2.INTER_NEAREST) 141 | 142 | image = image.transpose((2,0,1)) 143 | 144 | image_tensor = torch.from_numpy(image).div(255) 145 | label_tensor = torch.LongTensor(np.array(label, dtype=np.int)) #torch.from_numpy(label) 146 | 147 | return [image_tensor, label_tensor] 148 | 149 | class Compose(object): 150 | """Composes several transforms together. 151 | """ 152 | 153 | def __init__(self, transforms): 154 | self.transforms = transforms 155 | 156 | def __call__(self, *args): 157 | for t in self.transforms: 158 | args = t(*args) 159 | return args 160 | -------------------------------------------------------------------------------- /segmentation/cnn/Model.py: -------------------------------------------------------------------------------- 1 | from torch.nn import init 2 | import torch.nn.functional as F 3 | from cnn.cnn_utils import * 4 | import math 5 | import torch 6 | 7 | __author__ = "Sachin Mehta" 8 | __version__ = "1.0.1" 9 | __maintainer__ = "Sachin Mehta" 10 | 11 | class EESP(nn.Module): 12 | ''' 13 | This class defines the EESP block, which is based on the following principle 14 | REDUCE ---> SPLIT ---> TRANSFORM --> MERGE 15 | ''' 16 | 17 | def __init__(self, nIn, nOut, stride=1, k=4, r_lim=7, down_method='esp'): #down_method --> ['avg' or 'esp'] 18 | ''' 19 | :param nIn: number of input channels 20 | :param nOut: number of output channels 21 | :param stride: factor by which we should skip (useful for down-sampling). If 2, then down-samples the feature map by 2 22 | :param k: # of parallel branches 23 | :param r_lim: A maximum value of receptive field allowed for EESP block 24 | :param g: number of groups to be used in the feature map reduction step. 25 | ''' 26 | super().__init__() 27 | self.stride = stride 28 | n = int(nOut / k) 29 | n1 = nOut - (k - 1) * n 30 | assert down_method in ['avg', 'esp'], 'One of these is suppported (avg or esp)' 31 | assert n == n1, "n(={}) and n1(={}) should be equal for Depth-wise Convolution ".format(n, n1) 32 | #assert nIn%k == 0, "Number of input channels ({}) should be divisible by # of branches ({})".format(nIn, k) 33 | #assert n % k == 0, "Number of output channels ({}) should be divisible by # of branches ({})".format(n, k) 34 | self.proj_1x1 = CBR(nIn, n, 1, stride=1, groups=k) 35 | 36 | # (For convenience) Mapping between dilation rate and receptive field for a 3x3 kernel 37 | map_receptive_ksize = {3: 1, 5: 2, 7: 3, 9: 4, 11: 5, 13: 6, 15: 7, 17: 8} 38 | self.k_sizes = list() 39 | for i in range(k): 40 | ksize = int(3 + 2 * i) 41 | # After reaching the receptive field limit, fall back to the base kernel size of 3 with a dilation rate of 1 42 | ksize = ksize if ksize <= r_lim else 3 43 | self.k_sizes.append(ksize) 44 | # sort (in ascending order) these kernel sizes based on their receptive field 45 | # This enables us to ignore the kernels (3x3 in our case) with the same effective receptive field in hierarchical 46 | # feature fusion because kernels with 3x3 receptive fields does not have gridding artifact. 47 | self.k_sizes.sort() 48 | self.spp_dw = nn.ModuleList() 49 | #self.bn = nn.ModuleList() 50 | for i in range(k): 51 | d_rate = map_receptive_ksize[self.k_sizes[i]] 52 | self.spp_dw.append(CDilated(n, n, kSize=3, stride=stride, groups=n, d=d_rate)) 53 | #self.bn.append(nn.BatchNorm2d(n)) 54 | self.conv_1x1_exp = CB(nOut, nOut, 1, 1, groups=k) 55 | self.br_after_cat = BR(nOut) 56 | self.module_act = nn.PReLU(nOut) 57 | self.downAvg = True if down_method == 'avg' else False 58 | 59 | def forward(self, input): 60 | ''' 61 | :param input: input feature map 62 | :return: transformed feature map 63 | ''' 64 | 65 | # Reduce --> project high-dimensional feature maps to low-dimensional space 66 | output1 = self.proj_1x1(input) 67 | output = [self.spp_dw[0](output1)] 68 | # compute the output for each branch and hierarchically fuse them 69 | # i.e. Split --> Transform --> HFF 70 | for k in range(1, len(self.spp_dw)): 71 | out_k = self.spp_dw[k](output1) 72 | # HFF 73 | # We donot combine the branches that have the same effective receptive (3x3 in our case) 74 | # because there are no holes in those kernels. 75 | out_k = out_k + output[k - 1] 76 | #apply batch norm after fusion and then append to the list 77 | output.append(out_k) 78 | # Merge 79 | expanded = self.conv_1x1_exp( # Aggregate the feature maps using point-wise convolution 80 | self.br_after_cat( # apply batch normalization followed by activation function (PRelu in this case) 81 | torch.cat(output, 1) # concatenate the output of different branches 82 | ) 83 | ) 84 | del output 85 | # if down-sampling, then return the concatenated vector 86 | # as Downsampling function will combine it with avg. pooled feature map and then threshold it 87 | if self.stride == 2 and self.downAvg: 88 | return expanded 89 | 90 | # if dimensions of input and concatenated vector are the same, add them (RESIDUAL LINK) 91 | if expanded.size() == input.size(): 92 | expanded = expanded + input 93 | 94 | # Threshold the feature map using activation function (PReLU in this case) 95 | return self.module_act(expanded) 96 | 97 | 98 | class DownSampler(nn.Module): 99 | ''' 100 | Down-sampling fucntion that has two parallel branches: (1) avg pooling 101 | and (2) EESP block with stride of 2. The output feature maps of these branches 102 | are then concatenated and thresholded using an activation function (PReLU in our 103 | case) to produce the final output. 104 | ''' 105 | 106 | def __init__(self, nin, nout, k=4, r_lim=9, reinf=True): 107 | ''' 108 | :param nin: number of input channels 109 | :param nout: number of output channels 110 | :param k: # of parallel branches 111 | :param r_lim: A maximum value of receptive field allowed for EESP block 112 | :param g: number of groups to be used in the feature map reduction step. 113 | ''' 114 | super().__init__() 115 | nout_new = nout - nin 116 | self.eesp = EESP(nin, nout_new, stride=2, k=k, r_lim=r_lim, down_method='avg') 117 | self.avg = nn.AvgPool2d(kernel_size=3, padding=1, stride=2) 118 | if reinf: 119 | self.inp_reinf = nn.Sequential( 120 | CBR(config_inp_reinf, config_inp_reinf, 3, 1), 121 | CB(config_inp_reinf, nout, 1, 1) 122 | ) 123 | self.act = nn.PReLU(nout) 124 | 125 | def forward(self, input, input2=None): 126 | ''' 127 | :param input: input feature map 128 | :return: feature map down-sampled by a factor of 2 129 | ''' 130 | avg_out = self.avg(input) 131 | eesp_out = self.eesp(input) 132 | output = torch.cat([avg_out, eesp_out], 1) 133 | if input2 is not None: 134 | #assuming the input is a square image 135 | w1 = avg_out.size(2) 136 | while True: 137 | input2 = F.avg_pool2d(input2, kernel_size=3, padding=1, stride=2) 138 | w2 = input2.size(2) 139 | if w2 == w1: 140 | break 141 | output = output + self.inp_reinf(input2) 142 | 143 | return self.act(output) #self.act(output) 144 | 145 | class EESPNet(nn.Module): 146 | ''' 147 | This class defines the ESPNetv2 architecture for the ImageNet classification 148 | ''' 149 | 150 | def __init__(self, classes=20, s=1): 151 | ''' 152 | :param classes: number of classes in the dataset. Default is 20 for the cityscapes 153 | :param s: factor that scales the number of output feature maps 154 | ''' 155 | super().__init__() 156 | reps = [0, 3, 7, 3] # how many times EESP blocks should be repeated. 157 | channels = 3 158 | 159 | r_lim = [13, 11, 9, 7, 5] # receptive field at each spatial level 160 | K = [4]*len(r_lim) # No. of parallel branches at different levels 161 | 162 | base = 32 #base configuration 163 | config_len = 5 164 | config = [base] * config_len 165 | base_s = 0 166 | for i in range(config_len): 167 | if i== 0: 168 | base_s = int(base * s) 169 | base_s = math.ceil(base_s / K[0]) * K[0] 170 | config[i] = base if base_s > base else base_s 171 | else: 172 | config[i] = base_s * pow(2, i) 173 | if s <= 1.5: 174 | config.append(1024) 175 | elif s in [1.5, 2]: 176 | config.append(1280) 177 | else: 178 | ValueError('Configuration not supported') 179 | 180 | #print('Config: ', config) 181 | 182 | global config_inp_reinf 183 | config_inp_reinf = 3 184 | self.input_reinforcement = True 185 | assert len(K) == len(r_lim), 'Length of branching factor array and receptive field array should be the same.' 186 | 187 | self.level1 = CBR(channels, config[0], 3, 2) # 112 L1 188 | 189 | self.level2_0 = DownSampler(config[0], config[1], k=K[0], r_lim=r_lim[0], reinf=self.input_reinforcement) # out = 56 190 | self.level3_0 = DownSampler(config[1], config[2], k=K[1], r_lim=r_lim[1], reinf=self.input_reinforcement) # out = 28 191 | self.level3 = nn.ModuleList() 192 | for i in range(reps[1]): 193 | self.level3.append(EESP(config[2], config[2], stride=1, k=K[2], r_lim=r_lim[2])) 194 | 195 | self.level4_0 = DownSampler(config[2], config[3], k=K[2], r_lim=r_lim[2], reinf=self.input_reinforcement) #out = 14 196 | self.level4 = nn.ModuleList() 197 | for i in range(reps[2]): 198 | self.level4.append(EESP(config[3], config[3], stride=1, k=K[3], r_lim=r_lim[3])) 199 | 200 | self.level5_0 = DownSampler(config[3], config[4], k=K[3], r_lim=r_lim[3]) #7 201 | self.level5 = nn.ModuleList() 202 | for i in range(reps[3]): 203 | self.level5.append(EESP(config[4], config[4], stride=1, k=K[4], r_lim=r_lim[4])) 204 | 205 | # expand the feature maps using depth-wise separable convolution 206 | self.level5.append(CBR(config[4], config[4], 3, 1, groups=config[4])) 207 | self.level5.append(CBR(config[4], config[5], 1, 1, groups=K[4])) 208 | 209 | 210 | 211 | #self.level5_exp = nn.ModuleList() 212 | #assert config[5]%config[4] == 0, '{} should be divisible by {}'.format(config[5], config[4]) 213 | #gr = int(config[5]/config[4]) 214 | #for i in range(gr): 215 | # self.level5_exp.append(CBR(config[4], config[4], 1, 1, groups=pow(2, i))) 216 | 217 | self.classifier = nn.Linear(config[5], classes) 218 | self.init_params() 219 | 220 | def init_params(self): 221 | ''' 222 | Function to initialze the parameters 223 | ''' 224 | for m in self.modules(): 225 | if isinstance(m, nn.Conv2d): 226 | init.kaiming_normal_(m.weight, mode='fan_out') 227 | if m.bias is not None: 228 | init.constant_(m.bias, 0) 229 | elif isinstance(m, nn.BatchNorm2d): 230 | init.constant_(m.weight, 1) 231 | init.constant_(m.bias, 0) 232 | elif isinstance(m, nn.Linear): 233 | init.normal_(m.weight, std=0.001) 234 | if m.bias is not None: 235 | init.constant_(m.bias, 0) 236 | 237 | def forward(self, input, p=0.2, seg=True): 238 | ''' 239 | :param input: Receives the input RGB image 240 | :return: a C-dimensional vector, C=# of classes 241 | ''' 242 | out_l1 = self.level1(input) # 112 243 | if not self.input_reinforcement: 244 | del input 245 | input = None 246 | 247 | out_l2 = self.level2_0(out_l1, input) # 56 248 | 249 | out_l3_0 = self.level3_0(out_l2, input) # out_l2_inp_rein 250 | for i, layer in enumerate(self.level3): 251 | if i == 0: 252 | out_l3 = layer(out_l3_0) 253 | else: 254 | out_l3 = layer(out_l3) 255 | 256 | out_l4_0 = self.level4_0(out_l3, input) # down-sampled 257 | for i, layer in enumerate(self.level4): 258 | if i == 0: 259 | out_l4 = layer(out_l4_0) 260 | else: 261 | out_l4 = layer(out_l4) 262 | 263 | if not seg: 264 | out_l5_0 = self.level5_0(out_l4) # down-sampled 265 | for i, layer in enumerate(self.level5): 266 | if i == 0: 267 | out_l5 = layer(out_l5_0) 268 | else: 269 | out_l5 = layer(out_l5) 270 | 271 | #out_e = [] 272 | #for layer in self.level5_exp: 273 | # out_e.append(layer(out_l5)) 274 | #out_exp = torch.cat(out_e, dim=1) 275 | 276 | 277 | 278 | output_g = F.adaptive_avg_pool2d(out_l5, output_size=1) 279 | output_g = F.dropout(output_g, p=p, training=self.training) 280 | output_1x1 = output_g.view(output_g.size(0), -1) 281 | 282 | return self.classifier(output_1x1) 283 | return out_l1, out_l2, out_l3, out_l4 284 | 285 | 286 | -------------------------------------------------------------------------------- /segmentation/cnn/SegmentationModel.py: -------------------------------------------------------------------------------- 1 | #============================================ 2 | __author__ = "Sachin Mehta" 3 | __license__ = "MIT" 4 | __maintainer__ = "Sachin Mehta" 5 | #============================================ 6 | import torch 7 | from torch import nn 8 | 9 | from cnn.Model import EESPNet, EESP 10 | import os 11 | import torch.nn.functional as F 12 | from cnn.cnn_utils import * 13 | 14 | class EESPNet_Seg(nn.Module): 15 | def __init__(self, classes=20, s=1, pretrained=None, gpus=1): 16 | super().__init__() 17 | classificationNet = EESPNet(classes=1000, s=s) 18 | if gpus >=1: 19 | classificationNet = nn.DataParallel(classificationNet) 20 | # load the pretrained weights 21 | if pretrained: 22 | if not os.path.isfile(pretrained): 23 | print('Weight file does not exist. Training without pre-trained weights') 24 | print('Model initialized with pretrained weights') 25 | classificationNet.load_state_dict(torch.load(pretrained)) 26 | 27 | self.net = classificationNet.module 28 | 29 | del classificationNet 30 | # delete last few layers 31 | del self.net.classifier 32 | del self.net.level5 33 | del self.net.level5_0 34 | if s <=0.5: 35 | p = 0.1 36 | else: 37 | p=0.2 38 | 39 | self.proj_L4_C = CBR(self.net.level4[-1].module_act.num_parameters, self.net.level3[-1].module_act.num_parameters, 1, 1) 40 | pspSize = 2*self.net.level3[-1].module_act.num_parameters 41 | self.pspMod = nn.Sequential(EESP(pspSize, pspSize //2, stride=1, k=4, r_lim=7), 42 | PSPModule(pspSize // 2, pspSize //2)) 43 | self.project_l3 = nn.Sequential(nn.Dropout2d(p=p), C(pspSize // 2, classes, 1, 1)) 44 | self.act_l3 = BR(classes) 45 | self.project_l2 = CBR(self.net.level2_0.act.num_parameters + classes, classes, 1, 1) 46 | self.project_l1 = nn.Sequential(nn.Dropout2d(p=p), C(self.net.level1.act.num_parameters + classes, classes, 1, 1)) 47 | 48 | def hierarchicalUpsample(self, x, factor=3): 49 | for i in range(factor): 50 | x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) 51 | return x 52 | 53 | 54 | def forward(self, input): 55 | out_l1, out_l2, out_l3, out_l4 = self.net(input, seg=True) 56 | out_l4_proj = self.proj_L4_C(out_l4) 57 | up_l4_to_l3 = F.interpolate(out_l4_proj, scale_factor=2, mode='bilinear', align_corners=True) 58 | merged_l3_upl4 = self.pspMod(torch.cat([out_l3, up_l4_to_l3], 1)) 59 | proj_merge_l3_bef_act = self.project_l3(merged_l3_upl4) 60 | proj_merge_l3 = self.act_l3(proj_merge_l3_bef_act) 61 | out_up_l3 = F.interpolate(proj_merge_l3, scale_factor=2, mode='bilinear', align_corners=True) 62 | merge_l2 = self.project_l2(torch.cat([out_l2, out_up_l3], 1)) 63 | out_up_l2 = F.interpolate(merge_l2, scale_factor=2, mode='bilinear', align_corners=True) 64 | merge_l1 = self.project_l1(torch.cat([out_l1, out_up_l2], 1)) 65 | if self.training: 66 | return F.interpolate(merge_l1, scale_factor=2, mode='bilinear', align_corners=True), self.hierarchicalUpsample(proj_merge_l3_bef_act) 67 | else: 68 | return F.interpolate(merge_l1, scale_factor=2, mode='bilinear', align_corners=True) 69 | 70 | 71 | if __name__ == '__main__': 72 | input = torch.Tensor(1, 3, 512, 1024).cuda() 73 | net = EESPNet_Seg(classes=20, s=2).cuda() 74 | out_x_8 = net(input) 75 | print(out_x_8.size()) 76 | 77 | -------------------------------------------------------------------------------- /segmentation/cnn/cnn_utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | __author__ = "Sachin Mehta" 7 | __version__ = "1.0.1" 8 | __maintainer__ = "Sachin Mehta" 9 | 10 | 11 | class PSPModule(nn.Module): 12 | def __init__(self, features, out_features=1024, sizes=(1, 2, 4, 8)): 13 | super().__init__() 14 | self.stages = [] 15 | self.stages = nn.ModuleList([C(features, features, 3, 1, groups=features) for size in sizes]) 16 | self.project = CBR(features * (len(sizes) + 1), out_features, 1, 1) 17 | 18 | def forward(self, feats): 19 | h, w = feats.size(2), feats.size(3) 20 | out = [feats] 21 | for stage in self.stages: 22 | feats = F.avg_pool2d(feats, kernel_size=3, stride=2, padding=1) 23 | upsampled = F.interpolate(input=stage(feats), size=(h, w), mode='bilinear', align_corners=True) 24 | out.append(upsampled) 25 | return self.project(torch.cat(out, dim=1)) 26 | 27 | class CBR(nn.Module): 28 | ''' 29 | This class defines the convolution layer with batch normalization and PReLU activation 30 | ''' 31 | 32 | def __init__(self, nIn, nOut, kSize, stride=1, groups=1): 33 | ''' 34 | 35 | :param nIn: number of input channels 36 | :param nOut: number of output channels 37 | :param kSize: kernel size 38 | :param stride: stride rate for down-sampling. Default is 1 39 | ''' 40 | super().__init__() 41 | padding = int((kSize - 1) / 2) 42 | self.conv = nn.Conv2d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False, groups=groups) 43 | self.bn = nn.BatchNorm2d(nOut) 44 | self.act = nn.PReLU(nOut) 45 | 46 | def forward(self, input): 47 | ''' 48 | :param input: input feature map 49 | :return: transformed feature map 50 | ''' 51 | output = self.conv(input) 52 | # output = self.conv1(output) 53 | output = self.bn(output) 54 | output = self.act(output) 55 | return output 56 | 57 | 58 | class BR(nn.Module): 59 | ''' 60 | This class groups the batch normalization and PReLU activation 61 | ''' 62 | 63 | def __init__(self, nOut): 64 | ''' 65 | :param nOut: output feature maps 66 | ''' 67 | super().__init__() 68 | self.bn = nn.BatchNorm2d(nOut) 69 | self.act = nn.PReLU(nOut) 70 | 71 | def forward(self, input): 72 | ''' 73 | :param input: input feature map 74 | :return: normalized and thresholded feature map 75 | ''' 76 | output = self.bn(input) 77 | output = self.act(output) 78 | return output 79 | 80 | 81 | class CB(nn.Module): 82 | ''' 83 | This class groups the convolution and batch normalization 84 | ''' 85 | 86 | def __init__(self, nIn, nOut, kSize, stride=1, groups=1): 87 | ''' 88 | :param nIn: number of input channels 89 | :param nOut: number of output channels 90 | :param kSize: kernel size 91 | :param stride: optinal stide for down-sampling 92 | ''' 93 | super().__init__() 94 | padding = int((kSize - 1) / 2) 95 | self.conv = nn.Conv2d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False, 96 | groups=groups) 97 | self.bn = nn.BatchNorm2d(nOut) 98 | 99 | def forward(self, input): 100 | ''' 101 | 102 | :param input: input feature map 103 | :return: transformed feature map 104 | ''' 105 | output = self.conv(input) 106 | output = self.bn(output) 107 | return output 108 | 109 | 110 | class C(nn.Module): 111 | ''' 112 | This class is for a convolutional layer. 113 | ''' 114 | 115 | def __init__(self, nIn, nOut, kSize, stride=1, groups=1): 116 | ''' 117 | 118 | :param nIn: number of input channels 119 | :param nOut: number of output channels 120 | :param kSize: kernel size 121 | :param stride: optional stride rate for down-sampling 122 | ''' 123 | super().__init__() 124 | padding = int((kSize - 1) / 2) 125 | self.conv = nn.Conv2d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False, 126 | groups=groups) 127 | 128 | def forward(self, input): 129 | ''' 130 | :param input: input feature map 131 | :return: transformed feature map 132 | ''' 133 | output = self.conv(input) 134 | return output 135 | 136 | 137 | class CDilated(nn.Module): 138 | ''' 139 | This class defines the dilated convolution. 140 | ''' 141 | 142 | def __init__(self, nIn, nOut, kSize, stride=1, d=1, groups=1): 143 | ''' 144 | :param nIn: number of input channels 145 | :param nOut: number of output channels 146 | :param kSize: kernel size 147 | :param stride: optional stride rate for down-sampling 148 | :param d: optional dilation rate 149 | ''' 150 | super().__init__() 151 | padding = int((kSize - 1) / 2) * d 152 | self.conv = nn.Conv2d(nIn, nOut,kSize, stride=stride, padding=padding, bias=False, 153 | dilation=d, groups=groups) 154 | 155 | def forward(self, input): 156 | ''' 157 | :param input: input feature map 158 | :return: transformed feature map 159 | ''' 160 | output = self.conv(input) 161 | return output 162 | 163 | class CDilatedB(nn.Module): 164 | ''' 165 | This class defines the dilated convolution with batch normalization. 166 | ''' 167 | 168 | def __init__(self, nIn, nOut, kSize, stride=1, d=1, groups=1): 169 | ''' 170 | :param nIn: number of input channels 171 | :param nOut: number of output channels 172 | :param kSize: kernel size 173 | :param stride: optional stride rate for down-sampling 174 | :param d: optional dilation rate 175 | ''' 176 | super().__init__() 177 | padding = int((kSize - 1) / 2) * d 178 | self.conv = nn.Conv2d(nIn, nOut,kSize, stride=stride, padding=padding, bias=False, 179 | dilation=d, groups=groups) 180 | self.bn = nn.BatchNorm2d(nOut) 181 | 182 | def forward(self, input): 183 | ''' 184 | :param input: input feature map 185 | :return: transformed feature map 186 | ''' 187 | return self.bn(self.conv(input)) 188 | -------------------------------------------------------------------------------- /segmentation/gen_cityscapes.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import glob 4 | 5 | import cv2 6 | import os 7 | from argparse import ArgumentParser 8 | from cnn import SegmentationModel as net 9 | from torch import nn 10 | 11 | 12 | #============================================ 13 | __author__ = "Sachin Mehta" 14 | __license__ = "MIT" 15 | __maintainer__ = "Sachin Mehta" 16 | #============================================ 17 | 18 | pallete = [[128, 64, 128], 19 | [244, 35, 232], 20 | [70, 70, 70], 21 | [102, 102, 156], 22 | [190, 153, 153], 23 | [153, 153, 153], 24 | [250, 170, 30], 25 | [220, 220, 0], 26 | [107, 142, 35], 27 | [152, 251, 152], 28 | [70, 130, 180], 29 | [220, 20, 60], 30 | [255, 0, 0], 31 | [0, 0, 142], 32 | [0, 0, 70], 33 | [0, 60, 100], 34 | [0, 80, 100], 35 | [0, 0, 230], 36 | [119, 11, 32], 37 | [0, 0, 0]] 38 | 39 | 40 | def relabel(img): 41 | ''' 42 | This function relabels the predicted labels so that cityscape dataset can process 43 | :param img: 44 | :return: 45 | ''' 46 | img[img == 19] = 255 47 | img[img == 18] = 33 48 | img[img == 17] = 32 49 | img[img == 16] = 31 50 | img[img == 15] = 28 51 | img[img == 14] = 27 52 | img[img == 13] = 26 53 | img[img == 12] = 25 54 | img[img == 11] = 24 55 | img[img == 10] = 23 56 | img[img == 9] = 22 57 | img[img == 8] = 21 58 | img[img == 7] = 20 59 | img[img == 6] = 19 60 | img[img == 5] = 17 61 | img[img == 4] = 13 62 | img[img == 3] = 12 63 | img[img == 2] = 11 64 | img[img == 1] = 8 65 | img[img == 0] = 7 66 | img[img == 255] = 0 67 | return img 68 | 69 | 70 | def evaluateModel(args, model, image_list): 71 | # gloabl mean and std values 72 | mean = [72.3923111, 82.90893555, 73.15840149] 73 | std = [45.3192215, 46.15289307, 44.91483307] 74 | 75 | model.eval() 76 | for i, imgName in enumerate(image_list): 77 | img = cv2.imread(imgName) 78 | if args.overlay: 79 | img_orig = np.copy(img) 80 | 81 | img = img.astype(np.float32) 82 | for j in range(3): 83 | img[:, :, j] -= mean[j] 84 | for j in range(3): 85 | img[:, :, j] /= std[j] 86 | 87 | # resize the image to 1024x512x3 88 | img = cv2.resize(img, (args.inWidth, args.inHeight)) 89 | if args.overlay: 90 | img_orig = cv2.resize(img_orig, (args.inWidth, args.inHeight)) 91 | 92 | img /= 255 93 | img = img.transpose((2, 0, 1)) 94 | img_tensor = torch.from_numpy(img) 95 | img_tensor = torch.unsqueeze(img_tensor, 0) # add a batch dimension 96 | if args.gpu: 97 | img_tensor = img_tensor.cuda() 98 | img_out = model(img_tensor) 99 | 100 | classMap_numpy = img_out[0].max(0)[1].byte().cpu().data.numpy() 101 | # upsample the feature maps to the same size as the input image using Nearest neighbour interpolation 102 | # upsample the feature map from 1024x512 to 2048x1024 103 | classMap_numpy = cv2.resize(classMap_numpy, (args.inWidth*2, args.inHeight*2), interpolation=cv2.INTER_NEAREST) 104 | if i % 100 == 0 and i > 0: 105 | print('Processed [{}/{}]'.format(i, len(image_list))) 106 | 107 | name = imgName.split('/')[-1] 108 | 109 | if args.colored: 110 | classMap_numpy_color = np.zeros((img.shape[1], img.shape[2], img.shape[0]), dtype=np.uint8) 111 | for idx in range(len(pallete)): 112 | [r, g, b] = pallete[idx] 113 | classMap_numpy_color[classMap_numpy == idx] = [b, g, r] 114 | cv2.imwrite(args.savedir + os.sep + 'c_' + name.replace(args.img_extn, 'png'), classMap_numpy_color) 115 | if args.overlay: 116 | overlayed = cv2.addWeighted(img_orig, 0.5, classMap_numpy_color, 0.5, 0) 117 | cv2.imwrite(args.savedir + os.sep + 'over_' + name.replace(args.img_extn, 'jpg'), overlayed) 118 | 119 | if args.cityFormat: 120 | classMap_numpy = relabel(classMap_numpy.astype(np.uint8)) 121 | 122 | 123 | cv2.imwrite(args.savedir + os.sep + name.replace(args.img_extn, 'png'), classMap_numpy) 124 | 125 | 126 | def main(args): 127 | # read all the images in the folder 128 | image_list = glob.glob(args.data_dir + os.sep + '*.' + args.img_extn) 129 | 130 | modelA = net.EESPNet_Seg(args.classes, s=args.s) 131 | if not os.path.isfile(args.pretrained): 132 | print('Pre-trained model file does not exist. Please check ./pretrained_models folder') 133 | exit(-1) 134 | modelA = nn.DataParallel(modelA) 135 | modelA.load_state_dict(torch.load(args.pretrained)) 136 | if args.gpu: 137 | modelA = modelA.cuda() 138 | 139 | # set to evaluation mode 140 | modelA.eval() 141 | 142 | if not os.path.isdir(args.savedir): 143 | os.mkdir(args.savedir) 144 | 145 | evaluateModel(args, modelA, image_list) 146 | 147 | 148 | if __name__ == '__main__': 149 | parser = ArgumentParser() 150 | parser.add_argument('--model', default="ESPNetv2", help='Model name') 151 | parser.add_argument('--data_dir', default="./data", help='Data directory') 152 | parser.add_argument('--img_extn', default="png", help='RGB Image format') 153 | parser.add_argument('--inWidth', type=int, default=1024, help='Width of RGB image') 154 | parser.add_argument('--inHeight', type=int, default=512, help='Height of RGB image') 155 | parser.add_argument('--savedir', default='./results', help='directory to save the results') 156 | parser.add_argument('--gpu', default=True, type=bool, help='Run on CPU or GPU. If TRUE, then GPU.') 157 | parser.add_argument('--pretrained', default='', help='Pretrained weights directory.') 158 | parser.add_argument('--s', default=0.5, type=float, help='scale') 159 | parser.add_argument('--cityFormat', default=True, type=bool, help='If you want to convert to cityscape ' 160 | 'original label ids') 161 | parser.add_argument('--colored', default=False, type=bool, help='If you want to visualize the ' 162 | 'segmentation masks in color') 163 | parser.add_argument('--overlay', default=False, type=bool, help='If you want to visualize the ' 164 | 'segmentation masks overlayed on top of RGB image') 165 | parser.add_argument('--classes', default=20, type=int, help='Number of classes in the dataset. 20 for Cityscapes') 166 | 167 | args = parser.parse_args() 168 | if args.overlay: 169 | args.colored = True # This has to be true if you want to overlay 170 | main(args) 171 | -------------------------------------------------------------------------------- /segmentation/loadData.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import pickle 4 | 5 | 6 | #============================================ 7 | __author__ = "Sachin Mehta" 8 | __license__ = "MIT" 9 | __maintainer__ = "Sachin Mehta" 10 | #============================================ 11 | 12 | class LoadData: 13 | ''' 14 | Class to laod the data 15 | ''' 16 | def __init__(self, data_dir, classes, cached_data_file, normVal=1.10): 17 | ''' 18 | :param data_dir: directory where the dataset is kept 19 | :param classes: number of classes in the dataset 20 | :param cached_data_file: location where cached file has to be stored 21 | :param normVal: normalization value, as defined in ERFNet paper 22 | ''' 23 | self.data_dir = data_dir 24 | self.classes = classes 25 | self.classWeights = np.ones(self.classes, dtype=np.float32) 26 | self.normVal = normVal 27 | self.mean = np.zeros(3, dtype=np.float32) 28 | self.std = np.zeros(3, dtype=np.float32) 29 | self.trainImList = list() 30 | self.valImList = list() 31 | self.trainAnnotList = list() 32 | self.valAnnotList = list() 33 | self.cached_data_file = cached_data_file 34 | 35 | def compute_class_weights(self, histogram): 36 | ''' 37 | Helper function to compute the class weights 38 | :param histogram: distribution of class samples 39 | :return: None, but updates the classWeights variable 40 | ''' 41 | normHist = histogram / np.sum(histogram) 42 | for i in range(self.classes): 43 | self.classWeights[i] = 1 / (np.log(self.normVal + normHist[i])) 44 | 45 | def readFile(self, fileName, trainStg=False): 46 | ''' 47 | Function to read the data 48 | :param fileName: file that stores the image locations 49 | :param trainStg: if processing training or validation data 50 | :return: 0 if successful 51 | ''' 52 | if trainStg == True: 53 | global_hist = np.zeros(self.classes, dtype=np.float32) 54 | 55 | no_files = 0 56 | min_val_al = 0 57 | max_val_al = 0 58 | with open(self.data_dir + '/' + fileName, 'r') as textFile: 59 | for line in textFile: 60 | # we expect the text file to contain the data in following format 61 | # ,