├── .gitignore ├── License.txt ├── README.md ├── backbone └── resnext │ ├── __init__.py │ ├── resnext101_regular.py │ └── resnext_101_32x4d_.py ├── config.py ├── dataset.py ├── gdnet.py ├── infer.py ├── misc.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | /backbone/resnext/resnext_101_32x4d.pth 2 | /GDD/ 3 | /results/ 4 | /joint_transforms.py 5 | /loss.py 6 | /train.py 7 | /evaluation.py 8 | /excel/ 9 | /MSD/ 10 | /ckpt/ 11 | -------------------------------------------------------------------------------- /License.txt: -------------------------------------------------------------------------------- 1 | Don't Hit Me! Glass Detection in Real-world Scenes 2 | Haiyang Mei et al. 3 | CVPR June, 2020 4 | 5 | Copyright (c) 2020 6 | All rights reserved. 7 | 8 | School of Computer Science and Technology 9 | Dalian University of Technology 10 | 11 | 12 | ------------------------------------------------------- 13 | 14 | Redistribution and use in source and binary forms, with or without 15 | modification, are permitted provided that the following conditions 16 | are met: 17 | 18 | 1. Redistributions of source code must retain the above copyright 19 | notice, this list of conditions and the following disclaimer. 20 | 21 | 2. Redistributions in binary form must reproduce the above copyright 22 | notice, this list of conditions and the following disclaimer in the 23 | documentation and/or other materials provided with the distribution. 24 | 25 | 3. Neither name of copyright holders nor the names of its contributors 26 | may be used to endorse or promote products derived from this software 27 | without specific prior written permission. 28 | 29 | 30 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 31 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 32 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 33 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR 34 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 35 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 36 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 37 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 38 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 39 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 40 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CVPR2020_GDNet 2 | 3 | ## Don't Hit Me! Glass Detection in Real-world Scenes 4 | [Haiyang Mei](https://mhaiyang.github.io/), Xin Yang, Yang Wang, Yuanyuan Liu, Shengfeng He, Qiang Zhang, Xiaopeng Wei, and Rynson W.H. Lau 5 | 6 | [[Paper](http://openaccess.thecvf.com/content_CVPR_2020/papers/Mei_Dont_Hit_Me_Glass_Detection_in_Real-World_Scenes_CVPR_2020_paper.pdf)] [[Project Page](https://mhaiyang.github.io/CVPR2020_GDNet/index.html)] 7 | 8 | ### Abstract 9 | Glass is very common in our daily life. Existing computer vision systems neglect the glass and thus might lead to severe consequence, \eg, the robot might crash into the glass wall. However, sensing the presence of the glass is not straightforward. The key challenge is that arbitrary objects/scenes can appear behind the glass and the content presented in the glass region typically similar to those outside of it. In this paper, we raise an interesting but important problem of detecting glass from a single RGB image. To address this problem, we construct a large-scale glass detection dataset (GDD) and design a glass detection network, called GDNet, by learning abundant contextual features from a global perspective with a novel large-field contextual feature integration module. Extensive experiments demonstrate the proposed method achieves superior glass detection results on our GDD test set. Particularly, we outperform state-of-the-art methods that fine-tuned for glass detection. 10 | 11 | ### Citation 12 | If you use this code or our dataset (including test set), please cite: 13 | 14 | ``` 15 | @InProceedings{Mei_2020_CVPR, 16 | author = {Mei, Haiyang and Yang, Xin and Wang, Yang and Liu, Yuanyuan and He, Shengfeng and Zhang, Qiang and Wei, Xiaopeng and Lau, Rynson W.H.}, 17 | title = {Don't Hit Me! Glass Detection in Real-World Scenes}, 18 | booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 19 | month = {June}, 20 | year = {2020} 21 | } 22 | ``` 23 | 24 | ### Dataset 25 | See [Peoject Page](https://mhaiyang.github.io/CVPR2020_GDNet/index.html) 26 | 27 | ### Requirements 28 | * PyTorch == 1.0.0 29 | * TorchVision == 0.2.1 30 | * CUDA 10.0 cudnn 7.2 31 | * Setup 32 | ``` 33 | sudo pip3 install -r requirements.txt 34 | git clone https://github.com/Mhaiyang/dss_crf.git 35 | sudo python setup.py install 36 | ``` 37 | 38 | ### Test 39 | Download the `resnext_101_32x4d.pth` at [here](https://drive.google.com/file/d/1e7N7LLZFWX4z0AkMG9wSQDCkOZEaSuFa/view?usp=sharing) and the trained model `GDNet.pth` at [here](https://mhaiyang.github.io/CVPR2020_GDNet/index.html), then run `infer.py`. 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | ### License 55 | Please see `license.txt` 56 | 57 | ### Contact 58 | E-Mail: mhy666@mail.dlut.edu.cn 59 | -------------------------------------------------------------------------------- /backbone/resnext/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnext101_regular import ResNeXt101 2 | -------------------------------------------------------------------------------- /backbone/resnext/resnext101_regular.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from backbone.resnext import resnext_101_32x4d_ 5 | 6 | 7 | class ResNeXt101(nn.Module): 8 | def __init__(self, backbone_path): 9 | super(ResNeXt101, self).__init__() 10 | net = resnext_101_32x4d_.resnext_101_32x4d 11 | if backbone_path is not None: 12 | weights = torch.load(backbone_path) 13 | net.load_state_dict(weights, strict=True) 14 | print("Load ResNeXt Weights Succeed!") 15 | 16 | net = list(net.children()) 17 | self.layer0 = nn.Sequential(*net[:3]) 18 | self.layer1 = nn.Sequential(*net[3: 5]) 19 | self.layer2 = net[5] 20 | self.layer3 = net[6] 21 | self.layer4 = net[7] 22 | 23 | def forward(self, x): 24 | layer0 = self.layer0(x) 25 | layer1 = self.layer1(layer0) 26 | layer2 = self.layer2(layer1) 27 | layer3 = self.layer3(layer2) 28 | layer4 = self.layer4(layer3) 29 | return layer4 30 | -------------------------------------------------------------------------------- /backbone/resnext/resnext_101_32x4d_.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | 3 | import torch.nn as nn 4 | 5 | 6 | class LambdaBase(nn.Sequential): 7 | def __init__(self, fn, *args): 8 | super(LambdaBase, self).__init__(*args) 9 | self.lambda_func = fn 10 | 11 | def forward_prepare(self, input): 12 | output = [] 13 | for module in self._modules.values(): 14 | output.append(module(input)) 15 | return output if output else input 16 | 17 | 18 | class Lambda(LambdaBase): 19 | def forward(self, input): 20 | return self.lambda_func(self.forward_prepare(input)) 21 | 22 | 23 | class LambdaMap(LambdaBase): 24 | def forward(self, input): 25 | return list(map(self.lambda_func, self.forward_prepare(input))) 26 | 27 | 28 | class LambdaReduce(LambdaBase): 29 | def forward(self, input): 30 | return reduce(self.lambda_func, self.forward_prepare(input)) 31 | 32 | 33 | resnext_101_32x4d = nn.Sequential( # Sequential, 34 | nn.Conv2d(3, 64, (7, 7), (2, 2), (3, 3), 1, 1, bias=False), 35 | nn.BatchNorm2d(64), 36 | nn.ReLU(), 37 | nn.MaxPool2d((3, 3), (2, 2), (1, 1)), 38 | nn.Sequential( # Sequential, 39 | nn.Sequential( # Sequential, 40 | LambdaMap(lambda x: x, # ConcatTable, 41 | nn.Sequential( # Sequential, 42 | nn.Sequential( # Sequential, 43 | nn.Conv2d(64, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 44 | nn.BatchNorm2d(128), 45 | nn.ReLU(), 46 | nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 47 | nn.BatchNorm2d(128), 48 | nn.ReLU(), 49 | ), 50 | nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 51 | nn.BatchNorm2d(256), 52 | ), 53 | nn.Sequential( # Sequential, 54 | nn.Conv2d(64, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 55 | nn.BatchNorm2d(256), 56 | ), 57 | ), 58 | LambdaReduce(lambda x, y: x + y), # CAddTable, 59 | nn.ReLU(), 60 | ), 61 | nn.Sequential( # Sequential, 62 | LambdaMap(lambda x: x, # ConcatTable, 63 | nn.Sequential( # Sequential, 64 | nn.Sequential( # Sequential, 65 | nn.Conv2d(256, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 66 | nn.BatchNorm2d(128), 67 | nn.ReLU(), 68 | nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 69 | nn.BatchNorm2d(128), 70 | nn.ReLU(), 71 | ), 72 | nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 73 | nn.BatchNorm2d(256), 74 | ), 75 | Lambda(lambda x: x), # Identity, 76 | ), 77 | LambdaReduce(lambda x, y: x + y), # CAddTable, 78 | nn.ReLU(), 79 | ), 80 | nn.Sequential( # Sequential, 81 | LambdaMap(lambda x: x, # ConcatTable, 82 | nn.Sequential( # Sequential, 83 | nn.Sequential( # Sequential, 84 | nn.Conv2d(256, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 85 | nn.BatchNorm2d(128), 86 | nn.ReLU(), 87 | nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 88 | nn.BatchNorm2d(128), 89 | nn.ReLU(), 90 | ), 91 | nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 92 | nn.BatchNorm2d(256), 93 | ), 94 | Lambda(lambda x: x), # Identity, 95 | ), 96 | LambdaReduce(lambda x, y: x + y), # CAddTable, 97 | nn.ReLU(), 98 | ), 99 | ), 100 | nn.Sequential( # Sequential, 101 | nn.Sequential( # Sequential, 102 | LambdaMap(lambda x: x, # ConcatTable, 103 | nn.Sequential( # Sequential, 104 | nn.Sequential( # Sequential, 105 | nn.Conv2d(256, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 106 | nn.BatchNorm2d(256), 107 | nn.ReLU(), 108 | nn.Conv2d(256, 256, (3, 3), (2, 2), (1, 1), 1, 32, bias=False), 109 | nn.BatchNorm2d(256), 110 | nn.ReLU(), 111 | ), 112 | nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 113 | nn.BatchNorm2d(512), 114 | ), 115 | nn.Sequential( # Sequential, 116 | nn.Conv2d(256, 512, (1, 1), (2, 2), (0, 0), 1, 1, bias=False), 117 | nn.BatchNorm2d(512), 118 | ), 119 | ), 120 | LambdaReduce(lambda x, y: x + y), # CAddTable, 121 | nn.ReLU(), 122 | ), 123 | nn.Sequential( # Sequential, 124 | LambdaMap(lambda x: x, # ConcatTable, 125 | nn.Sequential( # Sequential, 126 | nn.Sequential( # Sequential, 127 | nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 128 | nn.BatchNorm2d(256), 129 | nn.ReLU(), 130 | nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 131 | nn.BatchNorm2d(256), 132 | nn.ReLU(), 133 | ), 134 | nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 135 | nn.BatchNorm2d(512), 136 | ), 137 | Lambda(lambda x: x), # Identity, 138 | ), 139 | LambdaReduce(lambda x, y: x + y), # CAddTable, 140 | nn.ReLU(), 141 | ), 142 | nn.Sequential( # Sequential, 143 | LambdaMap(lambda x: x, # ConcatTable, 144 | nn.Sequential( # Sequential, 145 | nn.Sequential( # Sequential, 146 | nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 147 | nn.BatchNorm2d(256), 148 | nn.ReLU(), 149 | nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 150 | nn.BatchNorm2d(256), 151 | nn.ReLU(), 152 | ), 153 | nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 154 | nn.BatchNorm2d(512), 155 | ), 156 | Lambda(lambda x: x), # Identity, 157 | ), 158 | LambdaReduce(lambda x, y: x + y), # CAddTable, 159 | nn.ReLU(), 160 | ), 161 | nn.Sequential( # Sequential, 162 | LambdaMap(lambda x: x, # ConcatTable, 163 | nn.Sequential( # Sequential, 164 | nn.Sequential( # Sequential, 165 | nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 166 | nn.BatchNorm2d(256), 167 | nn.ReLU(), 168 | nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 169 | nn.BatchNorm2d(256), 170 | nn.ReLU(), 171 | ), 172 | nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 173 | nn.BatchNorm2d(512), 174 | ), 175 | Lambda(lambda x: x), # Identity, 176 | ), 177 | LambdaReduce(lambda x, y: x + y), # CAddTable, 178 | nn.ReLU(), 179 | ), 180 | ), 181 | nn.Sequential( # Sequential, 182 | nn.Sequential( # Sequential, 183 | LambdaMap(lambda x: x, # ConcatTable, 184 | nn.Sequential( # Sequential, 185 | nn.Sequential( # Sequential, 186 | nn.Conv2d(512, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 187 | nn.BatchNorm2d(512), 188 | nn.ReLU(), 189 | nn.Conv2d(512, 512, (3, 3), (2, 2), (1, 1), 1, 32, bias=False), 190 | nn.BatchNorm2d(512), 191 | nn.ReLU(), 192 | ), 193 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 194 | nn.BatchNorm2d(1024), 195 | ), 196 | nn.Sequential( # Sequential, 197 | nn.Conv2d(512, 1024, (1, 1), (2, 2), (0, 0), 1, 1, bias=False), 198 | nn.BatchNorm2d(1024), 199 | ), 200 | ), 201 | LambdaReduce(lambda x, y: x + y), # CAddTable, 202 | nn.ReLU(), 203 | ), 204 | nn.Sequential( # Sequential, 205 | LambdaMap(lambda x: x, # ConcatTable, 206 | nn.Sequential( # Sequential, 207 | nn.Sequential( # Sequential, 208 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 209 | nn.BatchNorm2d(512), 210 | nn.ReLU(), 211 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 212 | nn.BatchNorm2d(512), 213 | nn.ReLU(), 214 | ), 215 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 216 | nn.BatchNorm2d(1024), 217 | ), 218 | Lambda(lambda x: x), # Identity, 219 | ), 220 | LambdaReduce(lambda x, y: x + y), # CAddTable, 221 | nn.ReLU(), 222 | ), 223 | nn.Sequential( # Sequential, 224 | LambdaMap(lambda x: x, # ConcatTable, 225 | nn.Sequential( # Sequential, 226 | nn.Sequential( # Sequential, 227 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 228 | nn.BatchNorm2d(512), 229 | nn.ReLU(), 230 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 231 | nn.BatchNorm2d(512), 232 | nn.ReLU(), 233 | ), 234 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 235 | nn.BatchNorm2d(1024), 236 | ), 237 | Lambda(lambda x: x), # Identity, 238 | ), 239 | LambdaReduce(lambda x, y: x + y), # CAddTable, 240 | nn.ReLU(), 241 | ), 242 | nn.Sequential( # Sequential, 243 | LambdaMap(lambda x: x, # ConcatTable, 244 | nn.Sequential( # Sequential, 245 | nn.Sequential( # Sequential, 246 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 247 | nn.BatchNorm2d(512), 248 | nn.ReLU(), 249 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 250 | nn.BatchNorm2d(512), 251 | nn.ReLU(), 252 | ), 253 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 254 | nn.BatchNorm2d(1024), 255 | ), 256 | Lambda(lambda x: x), # Identity, 257 | ), 258 | LambdaReduce(lambda x, y: x + y), # CAddTable, 259 | nn.ReLU(), 260 | ), 261 | nn.Sequential( # Sequential, 262 | LambdaMap(lambda x: x, # ConcatTable, 263 | nn.Sequential( # Sequential, 264 | nn.Sequential( # Sequential, 265 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 266 | nn.BatchNorm2d(512), 267 | nn.ReLU(), 268 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 269 | nn.BatchNorm2d(512), 270 | nn.ReLU(), 271 | ), 272 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 273 | nn.BatchNorm2d(1024), 274 | ), 275 | Lambda(lambda x: x), # Identity, 276 | ), 277 | LambdaReduce(lambda x, y: x + y), # CAddTable, 278 | nn.ReLU(), 279 | ), 280 | nn.Sequential( # Sequential, 281 | LambdaMap(lambda x: x, # ConcatTable, 282 | nn.Sequential( # Sequential, 283 | nn.Sequential( # Sequential, 284 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 285 | nn.BatchNorm2d(512), 286 | nn.ReLU(), 287 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 288 | nn.BatchNorm2d(512), 289 | nn.ReLU(), 290 | ), 291 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 292 | nn.BatchNorm2d(1024), 293 | ), 294 | Lambda(lambda x: x), # Identity, 295 | ), 296 | LambdaReduce(lambda x, y: x + y), # CAddTable, 297 | nn.ReLU(), 298 | ), 299 | nn.Sequential( # Sequential, 300 | LambdaMap(lambda x: x, # ConcatTable, 301 | nn.Sequential( # Sequential, 302 | nn.Sequential( # Sequential, 303 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 304 | nn.BatchNorm2d(512), 305 | nn.ReLU(), 306 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 307 | nn.BatchNorm2d(512), 308 | nn.ReLU(), 309 | ), 310 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 311 | nn.BatchNorm2d(1024), 312 | ), 313 | Lambda(lambda x: x), # Identity, 314 | ), 315 | LambdaReduce(lambda x, y: x + y), # CAddTable, 316 | nn.ReLU(), 317 | ), 318 | nn.Sequential( # Sequential, 319 | LambdaMap(lambda x: x, # ConcatTable, 320 | nn.Sequential( # Sequential, 321 | nn.Sequential( # Sequential, 322 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 323 | nn.BatchNorm2d(512), 324 | nn.ReLU(), 325 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 326 | nn.BatchNorm2d(512), 327 | nn.ReLU(), 328 | ), 329 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 330 | nn.BatchNorm2d(1024), 331 | ), 332 | Lambda(lambda x: x), # Identity, 333 | ), 334 | LambdaReduce(lambda x, y: x + y), # CAddTable, 335 | nn.ReLU(), 336 | ), 337 | nn.Sequential( # Sequential, 338 | LambdaMap(lambda x: x, # ConcatTable, 339 | nn.Sequential( # Sequential, 340 | nn.Sequential( # Sequential, 341 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 342 | nn.BatchNorm2d(512), 343 | nn.ReLU(), 344 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 345 | nn.BatchNorm2d(512), 346 | nn.ReLU(), 347 | ), 348 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 349 | nn.BatchNorm2d(1024), 350 | ), 351 | Lambda(lambda x: x), # Identity, 352 | ), 353 | LambdaReduce(lambda x, y: x + y), # CAddTable, 354 | nn.ReLU(), 355 | ), 356 | nn.Sequential( # Sequential, 357 | LambdaMap(lambda x: x, # ConcatTable, 358 | nn.Sequential( # Sequential, 359 | nn.Sequential( # Sequential, 360 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 361 | nn.BatchNorm2d(512), 362 | nn.ReLU(), 363 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 364 | nn.BatchNorm2d(512), 365 | nn.ReLU(), 366 | ), 367 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 368 | nn.BatchNorm2d(1024), 369 | ), 370 | Lambda(lambda x: x), # Identity, 371 | ), 372 | LambdaReduce(lambda x, y: x + y), # CAddTable, 373 | nn.ReLU(), 374 | ), 375 | nn.Sequential( # Sequential, 376 | LambdaMap(lambda x: x, # ConcatTable, 377 | nn.Sequential( # Sequential, 378 | nn.Sequential( # Sequential, 379 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 380 | nn.BatchNorm2d(512), 381 | nn.ReLU(), 382 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 383 | nn.BatchNorm2d(512), 384 | nn.ReLU(), 385 | ), 386 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 387 | nn.BatchNorm2d(1024), 388 | ), 389 | Lambda(lambda x: x), # Identity, 390 | ), 391 | LambdaReduce(lambda x, y: x + y), # CAddTable, 392 | nn.ReLU(), 393 | ), 394 | nn.Sequential( # Sequential, 395 | LambdaMap(lambda x: x, # ConcatTable, 396 | nn.Sequential( # Sequential, 397 | nn.Sequential( # Sequential, 398 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 399 | nn.BatchNorm2d(512), 400 | nn.ReLU(), 401 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 402 | nn.BatchNorm2d(512), 403 | nn.ReLU(), 404 | ), 405 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 406 | nn.BatchNorm2d(1024), 407 | ), 408 | Lambda(lambda x: x), # Identity, 409 | ), 410 | LambdaReduce(lambda x, y: x + y), # CAddTable, 411 | nn.ReLU(), 412 | ), 413 | nn.Sequential( # Sequential, 414 | LambdaMap(lambda x: x, # ConcatTable, 415 | nn.Sequential( # Sequential, 416 | nn.Sequential( # Sequential, 417 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 418 | nn.BatchNorm2d(512), 419 | nn.ReLU(), 420 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 421 | nn.BatchNorm2d(512), 422 | nn.ReLU(), 423 | ), 424 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 425 | nn.BatchNorm2d(1024), 426 | ), 427 | Lambda(lambda x: x), # Identity, 428 | ), 429 | LambdaReduce(lambda x, y: x + y), # CAddTable, 430 | nn.ReLU(), 431 | ), 432 | nn.Sequential( # Sequential, 433 | LambdaMap(lambda x: x, # ConcatTable, 434 | nn.Sequential( # Sequential, 435 | nn.Sequential( # Sequential, 436 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 437 | nn.BatchNorm2d(512), 438 | nn.ReLU(), 439 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 440 | nn.BatchNorm2d(512), 441 | nn.ReLU(), 442 | ), 443 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 444 | nn.BatchNorm2d(1024), 445 | ), 446 | Lambda(lambda x: x), # Identity, 447 | ), 448 | LambdaReduce(lambda x, y: x + y), # CAddTable, 449 | nn.ReLU(), 450 | ), 451 | nn.Sequential( # Sequential, 452 | LambdaMap(lambda x: x, # ConcatTable, 453 | nn.Sequential( # Sequential, 454 | nn.Sequential( # Sequential, 455 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 456 | nn.BatchNorm2d(512), 457 | nn.ReLU(), 458 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 459 | nn.BatchNorm2d(512), 460 | nn.ReLU(), 461 | ), 462 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 463 | nn.BatchNorm2d(1024), 464 | ), 465 | Lambda(lambda x: x), # Identity, 466 | ), 467 | LambdaReduce(lambda x, y: x + y), # CAddTable, 468 | nn.ReLU(), 469 | ), 470 | nn.Sequential( # Sequential, 471 | LambdaMap(lambda x: x, # ConcatTable, 472 | nn.Sequential( # Sequential, 473 | nn.Sequential( # Sequential, 474 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 475 | nn.BatchNorm2d(512), 476 | nn.ReLU(), 477 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 478 | nn.BatchNorm2d(512), 479 | nn.ReLU(), 480 | ), 481 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 482 | nn.BatchNorm2d(1024), 483 | ), 484 | Lambda(lambda x: x), # Identity, 485 | ), 486 | LambdaReduce(lambda x, y: x + y), # CAddTable, 487 | nn.ReLU(), 488 | ), 489 | nn.Sequential( # Sequential, 490 | LambdaMap(lambda x: x, # ConcatTable, 491 | nn.Sequential( # Sequential, 492 | nn.Sequential( # Sequential, 493 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 494 | nn.BatchNorm2d(512), 495 | nn.ReLU(), 496 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 497 | nn.BatchNorm2d(512), 498 | nn.ReLU(), 499 | ), 500 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 501 | nn.BatchNorm2d(1024), 502 | ), 503 | Lambda(lambda x: x), # Identity, 504 | ), 505 | LambdaReduce(lambda x, y: x + y), # CAddTable, 506 | nn.ReLU(), 507 | ), 508 | nn.Sequential( # Sequential, 509 | LambdaMap(lambda x: x, # ConcatTable, 510 | nn.Sequential( # Sequential, 511 | nn.Sequential( # Sequential, 512 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 513 | nn.BatchNorm2d(512), 514 | nn.ReLU(), 515 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 516 | nn.BatchNorm2d(512), 517 | nn.ReLU(), 518 | ), 519 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 520 | nn.BatchNorm2d(1024), 521 | ), 522 | Lambda(lambda x: x), # Identity, 523 | ), 524 | LambdaReduce(lambda x, y: x + y), # CAddTable, 525 | nn.ReLU(), 526 | ), 527 | nn.Sequential( # Sequential, 528 | LambdaMap(lambda x: x, # ConcatTable, 529 | nn.Sequential( # Sequential, 530 | nn.Sequential( # Sequential, 531 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 532 | nn.BatchNorm2d(512), 533 | nn.ReLU(), 534 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 535 | nn.BatchNorm2d(512), 536 | nn.ReLU(), 537 | ), 538 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 539 | nn.BatchNorm2d(1024), 540 | ), 541 | Lambda(lambda x: x), # Identity, 542 | ), 543 | LambdaReduce(lambda x, y: x + y), # CAddTable, 544 | nn.ReLU(), 545 | ), 546 | nn.Sequential( # Sequential, 547 | LambdaMap(lambda x: x, # ConcatTable, 548 | nn.Sequential( # Sequential, 549 | nn.Sequential( # Sequential, 550 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 551 | nn.BatchNorm2d(512), 552 | nn.ReLU(), 553 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 554 | nn.BatchNorm2d(512), 555 | nn.ReLU(), 556 | ), 557 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 558 | nn.BatchNorm2d(1024), 559 | ), 560 | Lambda(lambda x: x), # Identity, 561 | ), 562 | LambdaReduce(lambda x, y: x + y), # CAddTable, 563 | nn.ReLU(), 564 | ), 565 | nn.Sequential( # Sequential, 566 | LambdaMap(lambda x: x, # ConcatTable, 567 | nn.Sequential( # Sequential, 568 | nn.Sequential( # Sequential, 569 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 570 | nn.BatchNorm2d(512), 571 | nn.ReLU(), 572 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 573 | nn.BatchNorm2d(512), 574 | nn.ReLU(), 575 | ), 576 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 577 | nn.BatchNorm2d(1024), 578 | ), 579 | Lambda(lambda x: x), # Identity, 580 | ), 581 | LambdaReduce(lambda x, y: x + y), # CAddTable, 582 | nn.ReLU(), 583 | ), 584 | nn.Sequential( # Sequential, 585 | LambdaMap(lambda x: x, # ConcatTable, 586 | nn.Sequential( # Sequential, 587 | nn.Sequential( # Sequential, 588 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 589 | nn.BatchNorm2d(512), 590 | nn.ReLU(), 591 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 592 | nn.BatchNorm2d(512), 593 | nn.ReLU(), 594 | ), 595 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 596 | nn.BatchNorm2d(1024), 597 | ), 598 | Lambda(lambda x: x), # Identity, 599 | ), 600 | LambdaReduce(lambda x, y: x + y), # CAddTable, 601 | nn.ReLU(), 602 | ), 603 | nn.Sequential( # Sequential, 604 | LambdaMap(lambda x: x, # ConcatTable, 605 | nn.Sequential( # Sequential, 606 | nn.Sequential( # Sequential, 607 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 608 | nn.BatchNorm2d(512), 609 | nn.ReLU(), 610 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 611 | nn.BatchNorm2d(512), 612 | nn.ReLU(), 613 | ), 614 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 615 | nn.BatchNorm2d(1024), 616 | ), 617 | Lambda(lambda x: x), # Identity, 618 | ), 619 | LambdaReduce(lambda x, y: x + y), # CAddTable, 620 | nn.ReLU(), 621 | ), 622 | ), 623 | nn.Sequential( # Sequential, 624 | nn.Sequential( # Sequential, 625 | LambdaMap(lambda x: x, # ConcatTable, 626 | nn.Sequential( # Sequential, 627 | nn.Sequential( # Sequential, 628 | nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 629 | nn.BatchNorm2d(1024), 630 | nn.ReLU(), 631 | nn.Conv2d(1024, 1024, (3, 3), (2, 2), (1, 1), 1, 32, bias=False), 632 | nn.BatchNorm2d(1024), 633 | nn.ReLU(), 634 | ), 635 | nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 636 | nn.BatchNorm2d(2048), 637 | ), 638 | nn.Sequential( # Sequential, 639 | nn.Conv2d(1024, 2048, (1, 1), (2, 2), (0, 0), 1, 1, bias=False), 640 | nn.BatchNorm2d(2048), 641 | ), 642 | ), 643 | LambdaReduce(lambda x, y: x + y), # CAddTable, 644 | nn.ReLU(), 645 | ), 646 | nn.Sequential( # Sequential, 647 | LambdaMap(lambda x: x, # ConcatTable, 648 | nn.Sequential( # Sequential, 649 | nn.Sequential( # Sequential, 650 | nn.Conv2d(2048, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 651 | nn.BatchNorm2d(1024), 652 | nn.ReLU(), 653 | nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 654 | nn.BatchNorm2d(1024), 655 | nn.ReLU(), 656 | ), 657 | nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 658 | nn.BatchNorm2d(2048), 659 | ), 660 | Lambda(lambda x: x), # Identity, 661 | ), 662 | LambdaReduce(lambda x, y: x + y), # CAddTable, 663 | nn.ReLU(), 664 | ), 665 | nn.Sequential( # Sequential, 666 | LambdaMap(lambda x: x, # ConcatTable, 667 | nn.Sequential( # Sequential, 668 | nn.Sequential( # Sequential, 669 | nn.Conv2d(2048, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 670 | nn.BatchNorm2d(1024), 671 | nn.ReLU(), 672 | nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 673 | nn.BatchNorm2d(1024), 674 | nn.ReLU(), 675 | ), 676 | nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 677 | nn.BatchNorm2d(2048), 678 | ), 679 | Lambda(lambda x: x), # Identity, 680 | ), 681 | LambdaReduce(lambda x, y: x + y), # CAddTable, 682 | nn.ReLU(), 683 | ), 684 | ), 685 | nn.AvgPool2d((7, 7), (1, 1)), 686 | Lambda(lambda x: x.view(x.size(0), -1)), # View, 687 | nn.Sequential(Lambda(lambda x: x.view(1, -1) if 1 == len(x.size()) else x), nn.Linear(2048, 1000)), # Linear, 688 | ) 689 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2020/3/15 18:44 3 | @Author : TaylorMei 4 | @E-mail : mhy666@mail.dlut.edu.cn 5 | 6 | @Project : CVPR2020_GDNet 7 | @File : config.py 8 | @Function: 9 | 10 | """ 11 | backbone_path = '/home/iccd/CVPR2020_GDNet/backbone/resnext/resnext_101_32x4d.pth' 12 | 13 | gdd_training_root = "/home/iccd/CVPR2020_GDNet/GDD/train" 14 | gdd_testing_root = "/home/iccd/CVPR2020_GDNet/GDD/test" 15 | gdd_results_root = "/home/iccd/CVPR2020_GDNet/results/GDD" 16 | 17 | msd_training_root = "/home/iccd/CVPR2020_GDNet/MSD/train" 18 | msd_testing_root = "/home/iccd/CVPR2020_GDNet/MSD/test" 19 | msd_results_root = "/home/iccd/CVPR2020_GDNet/results/MSD" -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2020/3/15 18:56 3 | @Author : TaylorMei 4 | @E-mail : mhy666@mail.dlut.edu.cn 5 | 6 | @Project : CVPR2020_GDNet 7 | @File : dataset.py 8 | @Function: 9 | 10 | """ 11 | import os 12 | import os.path 13 | 14 | import torch.utils.data as data 15 | from PIL import Image 16 | 17 | 18 | def make_dataset(root): 19 | img_list = [os.path.splitext(f)[0] for f in os.listdir(os.path.join(root, 'image')) if f.endswith('.jpg')] 20 | return [ 21 | (os.path.join(root, 'image', img_name + '.jpg'), os.path.join(root, 'mask', img_name + '.png')) 22 | for img_name in img_list] 23 | 24 | 25 | class ImageFolder(data.Dataset): 26 | def __init__(self, root, joint_transform=None, img_transform=None, target_transform=None): 27 | self.root = root 28 | self.imgs = make_dataset(root) 29 | self.joint_transform = joint_transform 30 | self.img_transform = img_transform 31 | self.target_transform = target_transform 32 | 33 | def __getitem__(self, index): 34 | img_path, gt_path = self.imgs[index] 35 | img = Image.open(img_path).convert('RGB') 36 | target = Image.open(gt_path).convert('L') 37 | if self.joint_transform is not None: 38 | img, target = self.joint_transform(img, target) 39 | if self.img_transform is not None: 40 | img = self.img_transform(img) 41 | if self.target_transform is not None: 42 | target = self.target_transform(target) 43 | 44 | return img, target 45 | 46 | def __len__(self): 47 | return len(self.imgs) -------------------------------------------------------------------------------- /gdnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2020/3/15 22:09 3 | @Author : TaylorMei 4 | @E-mail : mhy666@mail.dlut.edu.cn 5 | 6 | @Project : CVPR2020_GDNet 7 | @File : gdnet.py 8 | @Function: 9 | 10 | """ 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | from backbone.resnext.resnext101_regular import ResNeXt101 16 | 17 | 18 | ################################################################### 19 | # ########################## CBAM ################################# 20 | ################################################################### 21 | class BasicConv(nn.Module): 22 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, 23 | bn=True, bias=False): 24 | super(BasicConv, self).__init__() 25 | self.out_channels = out_planes 26 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, 27 | dilation=dilation, groups=groups, bias=bias) 28 | self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None 29 | self.relu = nn.ReLU() if relu else None 30 | 31 | def forward(self, x): 32 | x = self.conv(x) 33 | if self.bn is not None: 34 | x = self.bn(x) 35 | if self.relu is not None: 36 | x = self.relu(x) 37 | return x 38 | 39 | 40 | class Flatten(nn.Module): 41 | def forward(self, x): 42 | return x.view(x.size(0), -1) 43 | 44 | 45 | class ChannelGate(nn.Module): 46 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg']): 47 | super(ChannelGate, self).__init__() 48 | self.gate_channels = gate_channels 49 | self.mlp = nn.Sequential( 50 | Flatten(), 51 | nn.Linear(gate_channels, gate_channels // reduction_ratio), 52 | nn.ReLU(), 53 | nn.Linear(gate_channels // reduction_ratio, gate_channels) 54 | ) 55 | self.pool_types = pool_types 56 | 57 | def forward(self, x): 58 | channel_att_sum = None 59 | for pool_type in self.pool_types: 60 | if pool_type == 'avg': 61 | avg_pool = F.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 62 | channel_att_raw = self.mlp(avg_pool) 63 | elif pool_type == 'max': 64 | max_pool = F.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 65 | channel_att_raw = self.mlp(max_pool) 66 | elif pool_type == 'lp': 67 | lp_pool = F.lp_pool2d(x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 68 | channel_att_raw = self.mlp(lp_pool) 69 | elif pool_type == 'lse': 70 | # LSE pool only 71 | lse_pool = logsumexp_2d(x) 72 | channel_att_raw = self.mlp(lse_pool) 73 | 74 | if channel_att_sum is None: 75 | channel_att_sum = channel_att_raw 76 | else: 77 | channel_att_sum = channel_att_sum + channel_att_raw 78 | 79 | scale = F.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x) 80 | return x * scale 81 | 82 | 83 | def logsumexp_2d(tensor): 84 | tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1) 85 | s, _ = torch.max(tensor_flatten, dim=2, keepdim=True) 86 | outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log() 87 | return outputs 88 | 89 | 90 | class ChannelPool(nn.Module): 91 | def forward(self, x): 92 | # original 93 | # return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1) 94 | # max 95 | # torch.max(x, 1)[0].unsqueeze(1) 96 | # avg 97 | return torch.mean(x, 1).unsqueeze(1) 98 | 99 | 100 | class SpatialGate(nn.Module): 101 | def __init__(self): 102 | super(SpatialGate, self).__init__() 103 | kernel_size = 7 104 | self.compress = ChannelPool() 105 | self.spatial = BasicConv(1, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False) 106 | 107 | def forward(self, x): 108 | x_compress = self.compress(x) 109 | x_out = self.spatial(x_compress) 110 | scale = F.sigmoid(x_out) # broadcasting 111 | return x * scale 112 | 113 | 114 | class CBAM(nn.Module): 115 | def __init__(self, gate_channels=128, reduction_ratio=16, pool_types=['avg'], no_spatial=False): 116 | super(CBAM, self).__init__() 117 | self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) 118 | self.no_spatial = no_spatial 119 | if not no_spatial: 120 | self.SpatialGate = SpatialGate() 121 | 122 | def forward(self, x): 123 | x_out = self.ChannelGate(x) 124 | if not self.no_spatial: 125 | x_out = self.SpatialGate(x_out) 126 | return x_out 127 | 128 | 129 | ################################################################### 130 | # ########################## LCFI ################################# 131 | ################################################################### 132 | class LCFI(nn.Module): 133 | def __init__(self, input_channels, dr1=1, dr2=2, dr3=3, dr4=4): 134 | super(LCFI, self).__init__() 135 | self.input_channels = input_channels 136 | self.channels_single = int(input_channels / 4) 137 | self.channels_double = int(input_channels / 2) 138 | self.dr1 = dr1 139 | self.dr2 = dr2 140 | self.dr3 = dr3 141 | self.dr4 = dr4 142 | self.padding1 = 1 * dr1 143 | self.padding2 = 2 * dr2 144 | self.padding3 = 3 * dr3 145 | self.padding4 = 4 * dr4 146 | 147 | self.p1_channel_reduction = nn.Sequential( 148 | nn.Conv2d(self.input_channels, self.channels_single, 3, 1, 1, dilation=1), 149 | nn.BatchNorm2d(self.channels_single), nn.ReLU()) 150 | self.p2_channel_reduction = nn.Sequential( 151 | nn.Conv2d(self.input_channels, self.channels_single, 3, 1, 1, dilation=1), 152 | nn.BatchNorm2d(self.channels_single), nn.ReLU()) 153 | self.p3_channel_reduction = nn.Sequential( 154 | nn.Conv2d(self.input_channels, self.channels_single, 3, 1, 1, dilation=1), 155 | nn.BatchNorm2d(self.channels_single), nn.ReLU()) 156 | self.p4_channel_reduction = nn.Sequential( 157 | nn.Conv2d(self.input_channels, self.channels_single, 3, 1, 1, dilation=1), 158 | nn.BatchNorm2d(self.channels_single), nn.ReLU()) 159 | 160 | self.p1_d1 = nn.Sequential( 161 | nn.Conv2d(self.channels_single, self.channels_single, (3, 1), 1, padding=(self.padding1, 0), 162 | dilation=(self.dr1, 1)), 163 | nn.Conv2d(self.channels_single, self.channels_single, (1, 3), 1, padding=(0, self.padding1), 164 | dilation=(1, self.dr1)), 165 | nn.BatchNorm2d(self.channels_single), nn.ReLU()) 166 | self.p1_d2 = nn.Sequential( 167 | nn.Conv2d(self.channels_single, self.channels_single, (1, 3), 1, padding=(0, self.padding1), 168 | dilation=(1, self.dr1)), 169 | nn.Conv2d(self.channels_single, self.channels_single, (3, 1), 1, padding=(self.padding1, 0), 170 | dilation=(self.dr1, 1)), 171 | nn.BatchNorm2d(self.channels_single), nn.ReLU()) 172 | self.p1_fusion = nn.Sequential(nn.Conv2d(self.channels_double, self.channels_single, 3, 1, 1, dilation=1), 173 | nn.BatchNorm2d(self.channels_single), nn.ReLU()) 174 | 175 | self.p2_d1 = nn.Sequential( 176 | nn.Conv2d(self.channels_double, self.channels_single, (5, 1), 1, padding=(self.padding2, 0), 177 | dilation=(self.dr2, 1)), 178 | nn.Conv2d(self.channels_single, self.channels_single, (1, 5), 1, padding=(0, self.padding2), 179 | dilation=(1, self.dr2)), 180 | nn.BatchNorm2d(self.channels_single), nn.ReLU()) 181 | self.p2_d2 = nn.Sequential( 182 | nn.Conv2d(self.channels_double, self.channels_single, (1, 5), 1, padding=(0, self.padding2), 183 | dilation=(1, self.dr2)), 184 | nn.Conv2d(self.channels_single, self.channels_single, (5, 1), 1, padding=(self.padding2, 0), 185 | dilation=(self.dr2, 1)), 186 | nn.BatchNorm2d(self.channels_single), nn.ReLU()) 187 | self.p2_fusion = nn.Sequential(nn.Conv2d(self.channels_double, self.channels_single, 3, 1, 1, dilation=1), 188 | nn.BatchNorm2d(self.channels_single), nn.ReLU()) 189 | 190 | self.p3_d1 = nn.Sequential( 191 | nn.Conv2d(self.channels_double, self.channels_single, (7, 1), 1, padding=(self.padding3, 0), 192 | dilation=(self.dr3, 1)), 193 | nn.Conv2d(self.channels_single, self.channels_single, (1, 7), 1, padding=(0, self.padding3), 194 | dilation=(1, self.dr3)), 195 | nn.BatchNorm2d(self.channels_single), nn.ReLU()) 196 | self.p3_d2 = nn.Sequential( 197 | nn.Conv2d(self.channels_double, self.channels_single, (1, 7), 1, padding=(0, self.padding3), 198 | dilation=(1, self.dr3)), 199 | nn.Conv2d(self.channels_single, self.channels_single, (7, 1), 1, padding=(self.padding3, 0), 200 | dilation=(self.dr3, 1)), 201 | nn.BatchNorm2d(self.channels_single), nn.ReLU()) 202 | self.p3_fusion = nn.Sequential(nn.Conv2d(self.channels_double, self.channels_single, 3, 1, 1, dilation=1), 203 | nn.BatchNorm2d(self.channels_single), nn.ReLU()) 204 | 205 | self.p4_d1 = nn.Sequential( 206 | nn.Conv2d(self.channels_double, self.channels_single, (9, 1), 1, padding=(self.padding4, 0), 207 | dilation=(self.dr4, 1)), 208 | nn.Conv2d(self.channels_single, self.channels_single, (1, 9), 1, padding=(0, self.padding4), 209 | dilation=(1, self.dr4)), 210 | nn.BatchNorm2d(self.channels_single), nn.ReLU()) 211 | self.p4_d2 = nn.Sequential( 212 | nn.Conv2d(self.channels_double, self.channels_single, (1, 9), 1, padding=(0, self.padding4), 213 | dilation=(1, self.dr4)), 214 | nn.Conv2d(self.channels_single, self.channels_single, (9, 1), 1, padding=(self.padding4, 0), 215 | dilation=(self.dr4, 1)), 216 | nn.BatchNorm2d(self.channels_single), nn.ReLU()) 217 | self.p4_fusion = nn.Sequential(nn.Conv2d(self.channels_double, self.channels_single, 3, 1, 1, dilation=1), 218 | nn.BatchNorm2d(self.channels_single), nn.ReLU()) 219 | 220 | self.cbam = CBAM(self.input_channels) 221 | 222 | self.channel_reduction = nn.Sequential( 223 | nn.Conv2d(self.input_channels, self.channels_single, 3, 1, 1, dilation=1), 224 | nn.BatchNorm2d(self.channels_single), 225 | nn.ReLU()) 226 | 227 | def forward(self, x): 228 | p1_input = self.p1_channel_reduction(x) 229 | p1 = self.p1_fusion(torch.cat((self.p1_d1(p1_input), self.p1_d2(p1_input)), 1)) 230 | 231 | p2_input = torch.cat((self.p2_channel_reduction(x), p1), 1) 232 | p2 = self.p2_fusion(torch.cat((self.p2_d1(p2_input), self.p2_d2(p2_input)), 1)) 233 | 234 | p3_input = torch.cat((self.p3_channel_reduction(x), p2), 1) 235 | p3 = self.p3_fusion(torch.cat((self.p3_d1(p3_input), self.p3_d2(p3_input)), 1)) 236 | 237 | p4_input = torch.cat((self.p4_channel_reduction(x), p3), 1) 238 | p4 = self.p4_fusion(torch.cat((self.p4_d1(p4_input), self.p4_d2(p4_input)), 1)) 239 | 240 | channel_reduction = self.channel_reduction(self.cbam(torch.cat((p1, p2, p3, p4), 1))) 241 | 242 | return channel_reduction 243 | 244 | 245 | ################################################################### 246 | # ########################## NETWORK ############################## 247 | ################################################################### 248 | class GDNet(nn.Module): 249 | def __init__(self, backbone_path=None): 250 | super(GDNet, self).__init__() 251 | # params 252 | 253 | # backbone 254 | resnext = ResNeXt101(backbone_path) 255 | self.layer0 = resnext.layer0 256 | self.layer1 = resnext.layer1 257 | self.layer2 = resnext.layer2 258 | self.layer3 = resnext.layer3 259 | self.layer4 = resnext.layer4 260 | 261 | self.h5_conv = LCFI(2048, 1, 2, 3, 4) 262 | self.h4_conv = LCFI(1024, 1, 2, 3, 4) 263 | self.h3_conv = LCFI(512, 1, 2, 3, 4) 264 | self.l2_conv = LCFI(256, 1, 2, 3, 4) 265 | 266 | # h fusion 267 | self.h5_up = nn.UpsamplingBilinear2d(scale_factor=2) 268 | self.h3_down = nn.AvgPool2d((2, 2), stride=2) 269 | self.h_fusion = CBAM(896) 270 | self.h_fusion_conv = nn.Sequential(nn.Conv2d(896, 896, 3, 1, 1), nn.BatchNorm2d(896), nn.ReLU()) 271 | 272 | # l fusion 273 | self.l_fusion_conv = nn.Sequential(nn.Conv2d(64, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU()) 274 | self.h2l = nn.ConvTranspose2d(896, 1, 8, 4, 2) 275 | 276 | # final fusion 277 | self.h_up_for_final_fusion = nn.ConvTranspose2d(896, 256, 8, 4, 2) 278 | self.final_fusion = CBAM(320) 279 | self.final_fusion_conv = nn.Sequential(nn.Conv2d(320, 320, 3, 1, 1), nn.BatchNorm2d(320), nn.ReLU()) 280 | 281 | # predict conv 282 | self.h_predict = nn.Conv2d(896, 1, 3, 1, 1) 283 | self.l_predict = nn.Conv2d(64, 1, 3, 1, 1) 284 | self.final_predict = nn.Conv2d(320, 1, 3, 1, 1) 285 | 286 | for m in self.modules(): 287 | if isinstance(m, nn.ReLU): 288 | m.inplace = True 289 | 290 | def forward(self, x): 291 | # x: [batch_size, channel=3, h, w] 292 | layer0 = self.layer0(x) # [-1, 64, h/2, w/2] 293 | layer1 = self.layer1(layer0) # [-1, 256, h/4, w/4] 294 | layer2 = self.layer2(layer1) # [-1, 512, h/8, w/8] 295 | layer3 = self.layer3(layer2) # [-1, 1024, h/16, w/16] 296 | layer4 = self.layer4(layer3) # [-1, 2048, h/32, w/32] 297 | 298 | h5_conv = self.h5_conv(layer4) 299 | h4_conv = self.h4_conv(layer3) 300 | h3_conv = self.h3_conv(layer2) 301 | l2_conv = self.l2_conv(layer1) 302 | 303 | # h fusion 304 | h5_up = self.h5_up(h5_conv) 305 | h3_down = self.h3_down(h3_conv) 306 | h_fusion = self.h_fusion(torch.cat((h5_up, h4_conv, h3_down), 1)) 307 | h_fusion = self.h_fusion_conv(h_fusion) 308 | 309 | # l fusion 310 | l_fusion = self.l_fusion_conv(l2_conv) 311 | h2l = self.h2l(h_fusion) 312 | l_fusion = F.sigmoid(h2l) * l_fusion 313 | 314 | # final fusion 315 | h_up_for_final_fusion = self.h_up_for_final_fusion(h_fusion) 316 | final_fusion = self.final_fusion(torch.cat((h_up_for_final_fusion, l_fusion), 1)) 317 | final_fusion = self.final_fusion_conv(final_fusion) 318 | 319 | # h predict 320 | h_predict = self.h_predict(h_fusion) 321 | 322 | # l predict 323 | l_predict = self.l_predict(l_fusion) 324 | 325 | # final predict 326 | final_predict = self.final_predict(final_fusion) 327 | 328 | # rescale to original size 329 | h_predict = F.upsample(h_predict, size=x.size()[2:], mode='bilinear', align_corners=True) 330 | l_predict = F.upsample(l_predict, size=x.size()[2:], mode='bilinear', align_corners=True) 331 | final_predict = F.upsample(final_predict, size=x.size()[2:], mode='bilinear', align_corners=True) 332 | 333 | return torch.sigmoid(h_predict), torch.sigmoid(l_predict), torch.sigmoid(final_predict) 334 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2020/3/15 20:43 3 | @Author : TaylorMei 4 | @E-mail : mhy666@mail.dlut.edu.cn 5 | 6 | @Project : CVPR2020_GDNet 7 | @File : infer.py 8 | @Function: 9 | 10 | """ 11 | import os 12 | import time 13 | 14 | import numpy as np 15 | 16 | import torch 17 | from PIL import Image 18 | from torch.autograd import Variable 19 | from torchvision import transforms 20 | 21 | from config import gdd_testing_root, gdd_results_root 22 | from misc import check_mkdir, crf_refine 23 | from gdnet import GDNet 24 | 25 | device_ids = [0] 26 | torch.cuda.set_device(device_ids[0]) 27 | 28 | ckpt_path = './ckpt' 29 | exp_name = 'GDNet' 30 | args = { 31 | 'snapshot': '200', 32 | 'scale': 416, 33 | # 'crf': True, 34 | 'crf': False, 35 | } 36 | 37 | print(torch.__version__) 38 | 39 | img_transform = transforms.Compose([ 40 | transforms.Resize((args['scale'], args['scale'])), 41 | transforms.ToTensor(), 42 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 43 | ]) 44 | 45 | to_test = {'GDD': gdd_testing_root} 46 | 47 | to_pil = transforms.ToPILImage() 48 | 49 | 50 | def main(): 51 | net = GDNet().cuda(device_ids[0]) 52 | 53 | if len(args['snapshot']) > 0: 54 | print('Load snapshot {} for testing'.format(args['snapshot'])) 55 | net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth'))) 56 | print('Load {} succeed!'.format(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth'))) 57 | 58 | net.eval() 59 | with torch.no_grad(): 60 | for name, root in to_test.items(): 61 | img_list = [img_name for img_name in os.listdir(os.path.join(root, 'image'))] 62 | start = time.time() 63 | for idx, img_name in enumerate(img_list): 64 | print('predicting for {}: {:>4d} / {}'.format(name, idx + 1, len(img_list))) 65 | check_mkdir(os.path.join(gdd_results_root, '%s_%s' % (exp_name, args['snapshot']))) 66 | img = Image.open(os.path.join(root, 'image', img_name)) 67 | if img.mode != 'RGB': 68 | img = img.convert('RGB') 69 | print("{} is a gray image.".format(name)) 70 | w, h = img.size 71 | img_var = Variable(img_transform(img).unsqueeze(0)).cuda(device_ids[0]) 72 | f1, f2, f3 = net(img_var) 73 | f1 = f1.data.squeeze(0).cpu() 74 | f2 = f2.data.squeeze(0).cpu() 75 | f3 = f3.data.squeeze(0).cpu() 76 | f1 = np.array(transforms.Resize((h, w))(to_pil(f1))) 77 | f2 = np.array(transforms.Resize((h, w))(to_pil(f2))) 78 | f3 = np.array(transforms.Resize((h, w))(to_pil(f3))) 79 | if args['crf']: 80 | # f1 = crf_refine(np.array(img.convert('RGB')), f1) 81 | # f2 = crf_refine(np.array(img.convert('RGB')), f2) 82 | f3 = crf_refine(np.array(img.convert('RGB')), f3) 83 | 84 | # Image.fromarray(f1).save(os.path.join(ckpt_path, exp_name, '%s_%s' % (exp_name, args['snapshot']), 85 | # img_name[:-4] + "_h.png")) 86 | # Image.fromarray(f2).save(os.path.join(ckpt_path, exp_name, '%s_%s' % (exp_name, args['snapshot']), 87 | # img_name[:-4] + "_l.png")) 88 | Image.fromarray(f3).save(os.path.join(gdd_results_root, '%s_%s' % (exp_name, args['snapshot']), 89 | img_name[:-4] + ".png")) 90 | 91 | end = time.time() 92 | print("Average Time Is : {:.2f}".format((end - start) / len(img_list))) 93 | 94 | 95 | if __name__ == '__main__': 96 | main() -------------------------------------------------------------------------------- /misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2020/3/15 20:09 3 | @Author : TaylorMei 4 | @E-mail : mhy666@mail.dlut.edu.cn 5 | 6 | @Project : CVPR2020_GDNet 7 | @File : misc.py 8 | @Function: 9 | 10 | """ 11 | import os 12 | import xlwt 13 | import numpy as np 14 | import pydensecrf.densecrf as dcrf 15 | from skimage import io 16 | 17 | 18 | ################################################################ 19 | ######################## Utils ################################# 20 | ################################################################ 21 | def check_mkdir(dir_name): 22 | if not os.path.exists(dir_name): 23 | os.makedirs(dir_name) 24 | 25 | 26 | def data_write(file_path, datas): 27 | f = xlwt.Workbook() 28 | sheet1 = f.add_sheet(sheetname="sheet1", cell_overwrite_ok=True) 29 | 30 | j = 0 31 | for data in datas: 32 | for i in range(len(data)): 33 | sheet1.write(i, j, data[i]) 34 | j = j + 1 35 | 36 | f.save(file_path) 37 | 38 | 39 | ################################################################ 40 | ######################## Train & Test ########################## 41 | ################################################################ 42 | class AvgMeter(object): 43 | def __init__(self): 44 | self.reset() 45 | 46 | def reset(self): 47 | self.val = 0 48 | self.avg = 0 49 | self.sum = 0 50 | self.count = 0 51 | 52 | def update(self, val, n=1): 53 | self.val = val 54 | self.sum += val * n 55 | self.count += n 56 | self.avg = self.sum / self.count 57 | 58 | 59 | # codes of this function are borrowed from https://github.com/Andrew-Qibin/dss_crf 60 | def crf_refine(img, annos): 61 | def _sigmoid(x): 62 | return 1 / (1 + np.exp(-x)) 63 | 64 | assert img.dtype == np.uint8 65 | assert annos.dtype == np.uint8 66 | assert img.shape[:2] == annos.shape 67 | 68 | # img and annos should be np array with data type uint8 69 | 70 | EPSILON = 1e-8 71 | 72 | M = 2 # salient or not 73 | tau = 1.05 74 | # Setup the CRF model 75 | d = dcrf.DenseCRF2D(img.shape[1], img.shape[0], M) 76 | 77 | anno_norm = annos / 255. 78 | 79 | n_energy = -np.log((1.0 - anno_norm + EPSILON)) / (tau * _sigmoid(1 - anno_norm)) 80 | p_energy = -np.log(anno_norm + EPSILON) / (tau * _sigmoid(anno_norm)) 81 | 82 | U = np.zeros((M, img.shape[0] * img.shape[1]), dtype='float32') 83 | U[0, :] = n_energy.flatten() 84 | U[1, :] = p_energy.flatten() 85 | 86 | d.setUnaryEnergy(U) 87 | 88 | d.addPairwiseGaussian(sxy=3, compat=3) 89 | d.addPairwiseBilateral(sxy=60, srgb=5, rgbim=img, compat=5) 90 | 91 | # Do the inference 92 | infer = np.array(d.inference(1)).astype('float32') 93 | res = infer[1, :] 94 | 95 | res = res * 255 96 | res = res.reshape(img.shape[:2]) 97 | return res.astype('uint8') 98 | 99 | 100 | def get_gt_mask(imgname, MASK_DIR): 101 | filestr = imgname[:-4] 102 | mask_folder = MASK_DIR 103 | mask_path = os.path.join(mask_folder, filestr + ".png") 104 | mask = io.imread(mask_path) 105 | mask = np.where(mask == 255, 1, 0).astype(np.float32) 106 | 107 | return mask 108 | 109 | 110 | def get_normalized_predict_mask(imgname, PREDICT_MASK_DIR): 111 | filestr = imgname[:-4] 112 | mask_folder = PREDICT_MASK_DIR 113 | mask_path = os.path.join(mask_folder, filestr + ".png") 114 | if not os.path.exists(mask_path): 115 | print("{} has no predict mask!".format(imgname)) 116 | mask = io.imread(mask_path).astype(np.float32) 117 | if np.max(mask) - np.min(mask) > 0: 118 | mask = (mask - np.min(mask)) / (np.max(mask) - np.min(mask)) 119 | else: 120 | mask = mask / 255.0 121 | mask = mask.astype(np.float32) 122 | 123 | return mask 124 | 125 | 126 | def get_binary_predict_mask(imgname, PREDICT_MASK_DIR): 127 | filestr = imgname[:-4] 128 | mask_folder = PREDICT_MASK_DIR 129 | mask_path = os.path.join(mask_folder, filestr + ".png") 130 | if not os.path.exists(mask_path): 131 | print("{} has no predict mask!".format(imgname)) 132 | mask = io.imread(mask_path).astype(np.float32) 133 | mask = np.where(mask >= 127.5, 1, 0).astype(np.float32) 134 | 135 | return mask 136 | 137 | 138 | ################################################################ 139 | ######################## Evaluation ############################ 140 | ################################################################ 141 | def compute_iou(predict_mask, gt_mask): 142 | check_size(predict_mask, gt_mask) 143 | 144 | if np.sum(predict_mask) == 0 or np.sum(gt_mask) == 0: 145 | iou_ = 0 146 | return iou_ 147 | 148 | n_ii = np.sum(np.logical_and(predict_mask, gt_mask)) 149 | t_i = np.sum(gt_mask) 150 | n_ij = np.sum(predict_mask) 151 | 152 | iou_ = n_ii / (t_i + n_ij - n_ii) 153 | 154 | return iou_ 155 | 156 | 157 | def compute_acc(predict_mask, gt_mask): 158 | # recall 159 | check_size(predict_mask, gt_mask) 160 | 161 | N_p = np.sum(gt_mask) 162 | N_n = np.sum(np.logical_not(gt_mask)) 163 | 164 | TP = np.sum(np.logical_and(predict_mask, gt_mask)) 165 | TN = np.sum(np.logical_and(np.logical_not(predict_mask), np.logical_not(gt_mask))) 166 | 167 | accuracy_ = TP / N_p 168 | 169 | return accuracy_ 170 | 171 | 172 | def compute_acc_image(predict_mask, gt_mask): 173 | check_size(predict_mask, gt_mask) 174 | 175 | N_p = np.sum(gt_mask) 176 | N_n = np.sum(np.logical_not(gt_mask)) 177 | 178 | TP = np.sum(np.logical_and(predict_mask, gt_mask)) 179 | TN = np.sum(np.logical_and(np.logical_not(predict_mask), np.logical_not(gt_mask))) 180 | 181 | accuracy_ = (TP + TN) / (N_p + N_n) 182 | 183 | return accuracy_ 184 | 185 | 186 | def compute_precision_recall(prediction, gt): 187 | assert prediction.dtype == np.float32 188 | assert gt.dtype == np.float32 189 | assert prediction.shape == gt.shape 190 | 191 | eps = 1e-4 192 | 193 | hard_gt = np.zeros(prediction.shape) 194 | hard_gt[gt > 0.5] = 1 195 | t = np.sum(hard_gt) 196 | 197 | precision, recall = [], [] 198 | # calculating precision and recall at 255 different binarizing thresholds 199 | for threshold in range(256): 200 | threshold = threshold / 255. 201 | 202 | hard_prediction = np.zeros(prediction.shape) 203 | hard_prediction[prediction > threshold] = 1 204 | 205 | tp = np.sum(hard_prediction * hard_gt) 206 | p = np.sum(hard_prediction) 207 | 208 | precision.append((tp + eps) / (p + eps)) 209 | recall.append((tp + eps) / (t + eps)) 210 | 211 | return precision, recall 212 | 213 | 214 | def compute_fmeasure(precision, recall): 215 | assert len(precision) == 256 216 | assert len(recall) == 256 217 | beta_square = 0.3 218 | max_fmeasure = max([(1 + beta_square) * p * r / (beta_square * p + r) for p, r in zip(precision, recall)]) 219 | 220 | return max_fmeasure 221 | 222 | 223 | def compute_mae(predict_mask, gt_mask): 224 | check_size(predict_mask, gt_mask) 225 | 226 | mae_ = np.mean(abs(predict_mask - gt_mask)).item() 227 | 228 | return mae_ 229 | 230 | 231 | def compute_ber(predict_mask, gt_mask): 232 | check_size(predict_mask, gt_mask) 233 | 234 | N_p = np.sum(gt_mask) 235 | N_n = np.sum(np.logical_not(gt_mask)) 236 | 237 | TP = np.sum(np.logical_and(predict_mask, gt_mask)) 238 | TN = np.sum(np.logical_and(np.logical_not(predict_mask), np.logical_not(gt_mask))) 239 | 240 | ber_ = 100 * (1 - (1 / 2) * ((TP / N_p) + (TN / N_n))) 241 | 242 | return ber_ 243 | 244 | 245 | def segm_size(segm): 246 | try: 247 | height = segm.shape[0] 248 | width = segm.shape[1] 249 | except IndexError: 250 | raise 251 | 252 | return height, width 253 | 254 | 255 | def check_size(eval_segm, gt_segm): 256 | h_e, w_e = segm_size(eval_segm) 257 | h_g, w_g = segm_size(gt_segm) 258 | 259 | if (h_e != h_g) or (w_e != w_g): 260 | raise EvalSegErr("DiffDim: Different dimensions of matrices!") 261 | 262 | 263 | class EvalSegErr(Exception): 264 | def __init__(self, value): 265 | self.value = value 266 | 267 | def __str__(self): 268 | return repr(self.value) 269 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Cython 2 | scikit-image==0.16.2 3 | xlwt 4 | tqdm --------------------------------------------------------------------------------