├── License.txt ├── README.md ├── assets ├── results.png ├── table1.png └── table2.png ├── backbone └── resnext │ ├── __init__.py │ ├── resnext101_regular.py │ └── resnext_101_32x4d_.py ├── ckpt └── MirrorNet │ └── placeholder ├── config.py ├── dataset.py ├── infer.py ├── mirrornet.py ├── misc.py ├── requirements.txt └── utils ├── compute_contrast.py ├── compute_overlap.py ├── compute_size.py └── generate_overlap_map.py /License.txt: -------------------------------------------------------------------------------- 1 | Where Is My Mirror? 2 | Xin Yang*, Haiyang Mei*, Ke Xu, Xiaopeng Wei, Baocai Yin, Rynson W.H. Lau (*Joint first authors) 3 | ICCV October, 2019 4 | 5 | Copyright (c) 2019 6 | All rights reserved. 7 | 8 | Computer Science and Technology 9 | Dalian University of Technology 10 | 11 | Department of Computer Science 12 | City University of Hong Kong 13 | 14 | 15 | ------------------------------------------------------- 16 | 17 | Redistribution and use in source and binary forms, with or without 18 | modification, are permitted provided that the following conditions 19 | are met: 20 | 21 | 1. Redistributions of source code must retain the above copyright 22 | notice, this list of conditions and the following disclaimer. 23 | 24 | 2. Redistributions in binary form must reproduce the above copyright 25 | notice, this list of conditions and the following disclaimer in the 26 | documentation and/or other materials provided with the distribution. 27 | 28 | 3. Neither name of copyright holders nor the names of its contributors 29 | may be used to endorse or promote products derived from this software 30 | without specific prior written permission. 31 | 32 | 33 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 34 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 35 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 36 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR 37 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 38 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 39 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 40 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 41 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 42 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 43 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ICCV2019_MirrorNet 2 | 3 | ## Where Is My Mirror? (ICCV2019) 4 | Xin Yang\*, [Haiyang Mei](https://mhaiyang.github.io/)\*, Ke Xu, Xiaopeng Wei, Baocai Yin, [Rynson W.H. Lau](http://www.cs.cityu.edu.hk/~rynson/) (* Joint first authors, Rynson Lau is the corresponding author and he led this project.) 5 | 6 | [[Project Page](https://mhaiyang.github.io/ICCV2019_MirrorNet/index.html)][[Arxiv](https://arxiv.org/pdf/1908.09101v2.pdf)] 7 | 8 | ### Abstract 9 | Mirrors are everywhere in our daily lives. Existing computer vision systems do not consider mirrors, and hence may get confused by the reflected content inside a mirror, resulting in a severe performance degradation. However, separating the real content outside a mirror from the reflected content inside it is non-trivial. The key challenge lies in that mirrors typically reflect contents similar to their surroundings, making it very difficult to differentiate the two. In this paper, we present a novel method to accurately segment mirrors from an input image. To the best of our knowledge, this is the first work to address the mirror segmentation problem with a computational approach. We make the following contributions. First, we construct a large-scale mirror dataset that contains mirror images with the corresponding manually annotated masks. This dataset covers a variety of daily life scenes, and will be made publicly available for future research. Second, we propose a novel network, called MirrorNet, for mirror segmentation, by modeling both semantical and low-level color/texture discontinuities between the contents inside and outside of the mirrors. Third, we conduct extensive experiments to evaluate the proposed method, and show that it outperforms the carefully chosen baselines from the state-of-the-art detection and segmentation methods. 10 | 11 | ### Citation 12 | If you use this code or our dataset (including test set), please cite: 13 | 14 | ``` 15 | @InProceedings{Yang_2019_ICCV, 16 | author = {Yang, Xin and Mei, Haiyang and Xu, Ke and Wei, Xiaopeng and Yin, Baocai and Lau, Rynson W.H.}, 17 | title = {Where Is My Mirror?}, 18 | booktitle = {The IEEE International Conference on Computer Vision (ICCV)}, 19 | month = {October}, 20 | year = {2019} 21 | } 22 | ``` 23 | 24 | ### Dataset 25 | See [Peoject Page](https://mhaiyang.github.io/ICCV2019_MirrorNet/index.html) 26 | 27 | ### Requirements 28 | * PyTorch == 0.4.1 29 | * TorchVision == 0.2.1 30 | * CUDA 9.0 cudnn 7 31 | * Setup 32 | ``` 33 | sudo pip3 install -r requirements.txt 34 | git clone https://github.com/Mhaiyang/dss_crf.git 35 | sudo python setup.py install 36 | ``` 37 | 38 | ### Test 39 | Download the `resnext_101_32x4d.pth` at [here](https://drive.google.com/file/d/1e7N7LLZFWX4z0AkMG9wSQDCkOZEaSuFa/view?usp=sharing) and the trained model `MirrorNet.pth` at [here](https://mhaiyang.github.io/ICCV2019_MirrorNet/index.html), then run `infer.py`. 40 | 41 | ### Updated Main Results 42 | 43 | ##### Quantitative Results 44 | 45 | 46 | 47 | ##### Component analysis 48 | 49 | 50 | 51 | ##### Qualitative Results 52 | 53 | 54 | ### License 55 | Please see `license.txt` 56 | 57 | ### Contact 58 | E-Mail: mhy666@mail.dlut.edu.cn 59 | -------------------------------------------------------------------------------- /assets/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mhaiyang/ICCV2019_MirrorNet/0fdfd0f3b1608c16fbc70f60450d0ddd2e2e5efb/assets/results.png -------------------------------------------------------------------------------- /assets/table1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mhaiyang/ICCV2019_MirrorNet/0fdfd0f3b1608c16fbc70f60450d0ddd2e2e5efb/assets/table1.png -------------------------------------------------------------------------------- /assets/table2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mhaiyang/ICCV2019_MirrorNet/0fdfd0f3b1608c16fbc70f60450d0ddd2e2e5efb/assets/table2.png -------------------------------------------------------------------------------- /backbone/resnext/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnext101_regular import ResNeXt101 2 | -------------------------------------------------------------------------------- /backbone/resnext/resnext101_regular.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from backbone.resnext import resnext_101_32x4d_ 5 | 6 | 7 | class ResNeXt101(nn.Module): 8 | def __init__(self, backbone_path): 9 | super(ResNeXt101, self).__init__() 10 | net = resnext_101_32x4d_.resnext_101_32x4d 11 | if backbone_path is not None: 12 | weights = torch.load(backbone_path) 13 | net.load_state_dict(weights, strict=True) 14 | print("Load ResNeXt Weights Succeed!") 15 | 16 | net = list(net.children()) 17 | self.layer0 = nn.Sequential(*net[:3]) 18 | self.layer1 = nn.Sequential(*net[3: 5]) 19 | self.layer2 = net[5] 20 | self.layer3 = net[6] 21 | self.layer4 = net[7] 22 | 23 | def forward(self, x): 24 | layer0 = self.layer0(x) 25 | layer1 = self.layer1(layer0) 26 | layer2 = self.layer2(layer1) 27 | layer3 = self.layer3(layer2) 28 | layer4 = self.layer4(layer3) 29 | return layer4 30 | -------------------------------------------------------------------------------- /backbone/resnext/resnext_101_32x4d_.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | 3 | import torch.nn as nn 4 | 5 | 6 | class LambdaBase(nn.Sequential): 7 | def __init__(self, fn, *args): 8 | super(LambdaBase, self).__init__(*args) 9 | self.lambda_func = fn 10 | 11 | def forward_prepare(self, input): 12 | output = [] 13 | for module in self._modules.values(): 14 | output.append(module(input)) 15 | return output if output else input 16 | 17 | 18 | class Lambda(LambdaBase): 19 | def forward(self, input): 20 | return self.lambda_func(self.forward_prepare(input)) 21 | 22 | 23 | class LambdaMap(LambdaBase): 24 | def forward(self, input): 25 | return list(map(self.lambda_func, self.forward_prepare(input))) 26 | 27 | 28 | class LambdaReduce(LambdaBase): 29 | def forward(self, input): 30 | return reduce(self.lambda_func, self.forward_prepare(input)) 31 | 32 | 33 | resnext_101_32x4d = nn.Sequential( # Sequential, 34 | nn.Conv2d(3, 64, (7, 7), (2, 2), (3, 3), 1, 1, bias=False), 35 | nn.BatchNorm2d(64), 36 | nn.ReLU(), 37 | nn.MaxPool2d((3, 3), (2, 2), (1, 1)), 38 | nn.Sequential( # Sequential, 39 | nn.Sequential( # Sequential, 40 | LambdaMap(lambda x: x, # ConcatTable, 41 | nn.Sequential( # Sequential, 42 | nn.Sequential( # Sequential, 43 | nn.Conv2d(64, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 44 | nn.BatchNorm2d(128), 45 | nn.ReLU(), 46 | nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 47 | nn.BatchNorm2d(128), 48 | nn.ReLU(), 49 | ), 50 | nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 51 | nn.BatchNorm2d(256), 52 | ), 53 | nn.Sequential( # Sequential, 54 | nn.Conv2d(64, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 55 | nn.BatchNorm2d(256), 56 | ), 57 | ), 58 | LambdaReduce(lambda x, y: x + y), # CAddTable, 59 | nn.ReLU(), 60 | ), 61 | nn.Sequential( # Sequential, 62 | LambdaMap(lambda x: x, # ConcatTable, 63 | nn.Sequential( # Sequential, 64 | nn.Sequential( # Sequential, 65 | nn.Conv2d(256, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 66 | nn.BatchNorm2d(128), 67 | nn.ReLU(), 68 | nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 69 | nn.BatchNorm2d(128), 70 | nn.ReLU(), 71 | ), 72 | nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 73 | nn.BatchNorm2d(256), 74 | ), 75 | Lambda(lambda x: x), # Identity, 76 | ), 77 | LambdaReduce(lambda x, y: x + y), # CAddTable, 78 | nn.ReLU(), 79 | ), 80 | nn.Sequential( # Sequential, 81 | LambdaMap(lambda x: x, # ConcatTable, 82 | nn.Sequential( # Sequential, 83 | nn.Sequential( # Sequential, 84 | nn.Conv2d(256, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 85 | nn.BatchNorm2d(128), 86 | nn.ReLU(), 87 | nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 88 | nn.BatchNorm2d(128), 89 | nn.ReLU(), 90 | ), 91 | nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 92 | nn.BatchNorm2d(256), 93 | ), 94 | Lambda(lambda x: x), # Identity, 95 | ), 96 | LambdaReduce(lambda x, y: x + y), # CAddTable, 97 | nn.ReLU(), 98 | ), 99 | ), 100 | nn.Sequential( # Sequential, 101 | nn.Sequential( # Sequential, 102 | LambdaMap(lambda x: x, # ConcatTable, 103 | nn.Sequential( # Sequential, 104 | nn.Sequential( # Sequential, 105 | nn.Conv2d(256, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 106 | nn.BatchNorm2d(256), 107 | nn.ReLU(), 108 | nn.Conv2d(256, 256, (3, 3), (2, 2), (1, 1), 1, 32, bias=False), 109 | nn.BatchNorm2d(256), 110 | nn.ReLU(), 111 | ), 112 | nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 113 | nn.BatchNorm2d(512), 114 | ), 115 | nn.Sequential( # Sequential, 116 | nn.Conv2d(256, 512, (1, 1), (2, 2), (0, 0), 1, 1, bias=False), 117 | nn.BatchNorm2d(512), 118 | ), 119 | ), 120 | LambdaReduce(lambda x, y: x + y), # CAddTable, 121 | nn.ReLU(), 122 | ), 123 | nn.Sequential( # Sequential, 124 | LambdaMap(lambda x: x, # ConcatTable, 125 | nn.Sequential( # Sequential, 126 | nn.Sequential( # Sequential, 127 | nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 128 | nn.BatchNorm2d(256), 129 | nn.ReLU(), 130 | nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 131 | nn.BatchNorm2d(256), 132 | nn.ReLU(), 133 | ), 134 | nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 135 | nn.BatchNorm2d(512), 136 | ), 137 | Lambda(lambda x: x), # Identity, 138 | ), 139 | LambdaReduce(lambda x, y: x + y), # CAddTable, 140 | nn.ReLU(), 141 | ), 142 | nn.Sequential( # Sequential, 143 | LambdaMap(lambda x: x, # ConcatTable, 144 | nn.Sequential( # Sequential, 145 | nn.Sequential( # Sequential, 146 | nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 147 | nn.BatchNorm2d(256), 148 | nn.ReLU(), 149 | nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 150 | nn.BatchNorm2d(256), 151 | nn.ReLU(), 152 | ), 153 | nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 154 | nn.BatchNorm2d(512), 155 | ), 156 | Lambda(lambda x: x), # Identity, 157 | ), 158 | LambdaReduce(lambda x, y: x + y), # CAddTable, 159 | nn.ReLU(), 160 | ), 161 | nn.Sequential( # Sequential, 162 | LambdaMap(lambda x: x, # ConcatTable, 163 | nn.Sequential( # Sequential, 164 | nn.Sequential( # Sequential, 165 | nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 166 | nn.BatchNorm2d(256), 167 | nn.ReLU(), 168 | nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 169 | nn.BatchNorm2d(256), 170 | nn.ReLU(), 171 | ), 172 | nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 173 | nn.BatchNorm2d(512), 174 | ), 175 | Lambda(lambda x: x), # Identity, 176 | ), 177 | LambdaReduce(lambda x, y: x + y), # CAddTable, 178 | nn.ReLU(), 179 | ), 180 | ), 181 | nn.Sequential( # Sequential, 182 | nn.Sequential( # Sequential, 183 | LambdaMap(lambda x: x, # ConcatTable, 184 | nn.Sequential( # Sequential, 185 | nn.Sequential( # Sequential, 186 | nn.Conv2d(512, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 187 | nn.BatchNorm2d(512), 188 | nn.ReLU(), 189 | nn.Conv2d(512, 512, (3, 3), (2, 2), (1, 1), 1, 32, bias=False), 190 | nn.BatchNorm2d(512), 191 | nn.ReLU(), 192 | ), 193 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 194 | nn.BatchNorm2d(1024), 195 | ), 196 | nn.Sequential( # Sequential, 197 | nn.Conv2d(512, 1024, (1, 1), (2, 2), (0, 0), 1, 1, bias=False), 198 | nn.BatchNorm2d(1024), 199 | ), 200 | ), 201 | LambdaReduce(lambda x, y: x + y), # CAddTable, 202 | nn.ReLU(), 203 | ), 204 | nn.Sequential( # Sequential, 205 | LambdaMap(lambda x: x, # ConcatTable, 206 | nn.Sequential( # Sequential, 207 | nn.Sequential( # Sequential, 208 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 209 | nn.BatchNorm2d(512), 210 | nn.ReLU(), 211 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 212 | nn.BatchNorm2d(512), 213 | nn.ReLU(), 214 | ), 215 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 216 | nn.BatchNorm2d(1024), 217 | ), 218 | Lambda(lambda x: x), # Identity, 219 | ), 220 | LambdaReduce(lambda x, y: x + y), # CAddTable, 221 | nn.ReLU(), 222 | ), 223 | nn.Sequential( # Sequential, 224 | LambdaMap(lambda x: x, # ConcatTable, 225 | nn.Sequential( # Sequential, 226 | nn.Sequential( # Sequential, 227 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 228 | nn.BatchNorm2d(512), 229 | nn.ReLU(), 230 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 231 | nn.BatchNorm2d(512), 232 | nn.ReLU(), 233 | ), 234 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 235 | nn.BatchNorm2d(1024), 236 | ), 237 | Lambda(lambda x: x), # Identity, 238 | ), 239 | LambdaReduce(lambda x, y: x + y), # CAddTable, 240 | nn.ReLU(), 241 | ), 242 | nn.Sequential( # Sequential, 243 | LambdaMap(lambda x: x, # ConcatTable, 244 | nn.Sequential( # Sequential, 245 | nn.Sequential( # Sequential, 246 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 247 | nn.BatchNorm2d(512), 248 | nn.ReLU(), 249 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 250 | nn.BatchNorm2d(512), 251 | nn.ReLU(), 252 | ), 253 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 254 | nn.BatchNorm2d(1024), 255 | ), 256 | Lambda(lambda x: x), # Identity, 257 | ), 258 | LambdaReduce(lambda x, y: x + y), # CAddTable, 259 | nn.ReLU(), 260 | ), 261 | nn.Sequential( # Sequential, 262 | LambdaMap(lambda x: x, # ConcatTable, 263 | nn.Sequential( # Sequential, 264 | nn.Sequential( # Sequential, 265 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 266 | nn.BatchNorm2d(512), 267 | nn.ReLU(), 268 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 269 | nn.BatchNorm2d(512), 270 | nn.ReLU(), 271 | ), 272 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 273 | nn.BatchNorm2d(1024), 274 | ), 275 | Lambda(lambda x: x), # Identity, 276 | ), 277 | LambdaReduce(lambda x, y: x + y), # CAddTable, 278 | nn.ReLU(), 279 | ), 280 | nn.Sequential( # Sequential, 281 | LambdaMap(lambda x: x, # ConcatTable, 282 | nn.Sequential( # Sequential, 283 | nn.Sequential( # Sequential, 284 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 285 | nn.BatchNorm2d(512), 286 | nn.ReLU(), 287 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 288 | nn.BatchNorm2d(512), 289 | nn.ReLU(), 290 | ), 291 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 292 | nn.BatchNorm2d(1024), 293 | ), 294 | Lambda(lambda x: x), # Identity, 295 | ), 296 | LambdaReduce(lambda x, y: x + y), # CAddTable, 297 | nn.ReLU(), 298 | ), 299 | nn.Sequential( # Sequential, 300 | LambdaMap(lambda x: x, # ConcatTable, 301 | nn.Sequential( # Sequential, 302 | nn.Sequential( # Sequential, 303 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 304 | nn.BatchNorm2d(512), 305 | nn.ReLU(), 306 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 307 | nn.BatchNorm2d(512), 308 | nn.ReLU(), 309 | ), 310 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 311 | nn.BatchNorm2d(1024), 312 | ), 313 | Lambda(lambda x: x), # Identity, 314 | ), 315 | LambdaReduce(lambda x, y: x + y), # CAddTable, 316 | nn.ReLU(), 317 | ), 318 | nn.Sequential( # Sequential, 319 | LambdaMap(lambda x: x, # ConcatTable, 320 | nn.Sequential( # Sequential, 321 | nn.Sequential( # Sequential, 322 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 323 | nn.BatchNorm2d(512), 324 | nn.ReLU(), 325 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 326 | nn.BatchNorm2d(512), 327 | nn.ReLU(), 328 | ), 329 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 330 | nn.BatchNorm2d(1024), 331 | ), 332 | Lambda(lambda x: x), # Identity, 333 | ), 334 | LambdaReduce(lambda x, y: x + y), # CAddTable, 335 | nn.ReLU(), 336 | ), 337 | nn.Sequential( # Sequential, 338 | LambdaMap(lambda x: x, # ConcatTable, 339 | nn.Sequential( # Sequential, 340 | nn.Sequential( # Sequential, 341 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 342 | nn.BatchNorm2d(512), 343 | nn.ReLU(), 344 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 345 | nn.BatchNorm2d(512), 346 | nn.ReLU(), 347 | ), 348 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 349 | nn.BatchNorm2d(1024), 350 | ), 351 | Lambda(lambda x: x), # Identity, 352 | ), 353 | LambdaReduce(lambda x, y: x + y), # CAddTable, 354 | nn.ReLU(), 355 | ), 356 | nn.Sequential( # Sequential, 357 | LambdaMap(lambda x: x, # ConcatTable, 358 | nn.Sequential( # Sequential, 359 | nn.Sequential( # Sequential, 360 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 361 | nn.BatchNorm2d(512), 362 | nn.ReLU(), 363 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 364 | nn.BatchNorm2d(512), 365 | nn.ReLU(), 366 | ), 367 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 368 | nn.BatchNorm2d(1024), 369 | ), 370 | Lambda(lambda x: x), # Identity, 371 | ), 372 | LambdaReduce(lambda x, y: x + y), # CAddTable, 373 | nn.ReLU(), 374 | ), 375 | nn.Sequential( # Sequential, 376 | LambdaMap(lambda x: x, # ConcatTable, 377 | nn.Sequential( # Sequential, 378 | nn.Sequential( # Sequential, 379 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 380 | nn.BatchNorm2d(512), 381 | nn.ReLU(), 382 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 383 | nn.BatchNorm2d(512), 384 | nn.ReLU(), 385 | ), 386 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 387 | nn.BatchNorm2d(1024), 388 | ), 389 | Lambda(lambda x: x), # Identity, 390 | ), 391 | LambdaReduce(lambda x, y: x + y), # CAddTable, 392 | nn.ReLU(), 393 | ), 394 | nn.Sequential( # Sequential, 395 | LambdaMap(lambda x: x, # ConcatTable, 396 | nn.Sequential( # Sequential, 397 | nn.Sequential( # Sequential, 398 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 399 | nn.BatchNorm2d(512), 400 | nn.ReLU(), 401 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 402 | nn.BatchNorm2d(512), 403 | nn.ReLU(), 404 | ), 405 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 406 | nn.BatchNorm2d(1024), 407 | ), 408 | Lambda(lambda x: x), # Identity, 409 | ), 410 | LambdaReduce(lambda x, y: x + y), # CAddTable, 411 | nn.ReLU(), 412 | ), 413 | nn.Sequential( # Sequential, 414 | LambdaMap(lambda x: x, # ConcatTable, 415 | nn.Sequential( # Sequential, 416 | nn.Sequential( # Sequential, 417 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 418 | nn.BatchNorm2d(512), 419 | nn.ReLU(), 420 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 421 | nn.BatchNorm2d(512), 422 | nn.ReLU(), 423 | ), 424 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 425 | nn.BatchNorm2d(1024), 426 | ), 427 | Lambda(lambda x: x), # Identity, 428 | ), 429 | LambdaReduce(lambda x, y: x + y), # CAddTable, 430 | nn.ReLU(), 431 | ), 432 | nn.Sequential( # Sequential, 433 | LambdaMap(lambda x: x, # ConcatTable, 434 | nn.Sequential( # Sequential, 435 | nn.Sequential( # Sequential, 436 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 437 | nn.BatchNorm2d(512), 438 | nn.ReLU(), 439 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 440 | nn.BatchNorm2d(512), 441 | nn.ReLU(), 442 | ), 443 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 444 | nn.BatchNorm2d(1024), 445 | ), 446 | Lambda(lambda x: x), # Identity, 447 | ), 448 | LambdaReduce(lambda x, y: x + y), # CAddTable, 449 | nn.ReLU(), 450 | ), 451 | nn.Sequential( # Sequential, 452 | LambdaMap(lambda x: x, # ConcatTable, 453 | nn.Sequential( # Sequential, 454 | nn.Sequential( # Sequential, 455 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 456 | nn.BatchNorm2d(512), 457 | nn.ReLU(), 458 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 459 | nn.BatchNorm2d(512), 460 | nn.ReLU(), 461 | ), 462 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 463 | nn.BatchNorm2d(1024), 464 | ), 465 | Lambda(lambda x: x), # Identity, 466 | ), 467 | LambdaReduce(lambda x, y: x + y), # CAddTable, 468 | nn.ReLU(), 469 | ), 470 | nn.Sequential( # Sequential, 471 | LambdaMap(lambda x: x, # ConcatTable, 472 | nn.Sequential( # Sequential, 473 | nn.Sequential( # Sequential, 474 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 475 | nn.BatchNorm2d(512), 476 | nn.ReLU(), 477 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 478 | nn.BatchNorm2d(512), 479 | nn.ReLU(), 480 | ), 481 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 482 | nn.BatchNorm2d(1024), 483 | ), 484 | Lambda(lambda x: x), # Identity, 485 | ), 486 | LambdaReduce(lambda x, y: x + y), # CAddTable, 487 | nn.ReLU(), 488 | ), 489 | nn.Sequential( # Sequential, 490 | LambdaMap(lambda x: x, # ConcatTable, 491 | nn.Sequential( # Sequential, 492 | nn.Sequential( # Sequential, 493 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 494 | nn.BatchNorm2d(512), 495 | nn.ReLU(), 496 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 497 | nn.BatchNorm2d(512), 498 | nn.ReLU(), 499 | ), 500 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 501 | nn.BatchNorm2d(1024), 502 | ), 503 | Lambda(lambda x: x), # Identity, 504 | ), 505 | LambdaReduce(lambda x, y: x + y), # CAddTable, 506 | nn.ReLU(), 507 | ), 508 | nn.Sequential( # Sequential, 509 | LambdaMap(lambda x: x, # ConcatTable, 510 | nn.Sequential( # Sequential, 511 | nn.Sequential( # Sequential, 512 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 513 | nn.BatchNorm2d(512), 514 | nn.ReLU(), 515 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 516 | nn.BatchNorm2d(512), 517 | nn.ReLU(), 518 | ), 519 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 520 | nn.BatchNorm2d(1024), 521 | ), 522 | Lambda(lambda x: x), # Identity, 523 | ), 524 | LambdaReduce(lambda x, y: x + y), # CAddTable, 525 | nn.ReLU(), 526 | ), 527 | nn.Sequential( # Sequential, 528 | LambdaMap(lambda x: x, # ConcatTable, 529 | nn.Sequential( # Sequential, 530 | nn.Sequential( # Sequential, 531 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 532 | nn.BatchNorm2d(512), 533 | nn.ReLU(), 534 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 535 | nn.BatchNorm2d(512), 536 | nn.ReLU(), 537 | ), 538 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 539 | nn.BatchNorm2d(1024), 540 | ), 541 | Lambda(lambda x: x), # Identity, 542 | ), 543 | LambdaReduce(lambda x, y: x + y), # CAddTable, 544 | nn.ReLU(), 545 | ), 546 | nn.Sequential( # Sequential, 547 | LambdaMap(lambda x: x, # ConcatTable, 548 | nn.Sequential( # Sequential, 549 | nn.Sequential( # Sequential, 550 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 551 | nn.BatchNorm2d(512), 552 | nn.ReLU(), 553 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 554 | nn.BatchNorm2d(512), 555 | nn.ReLU(), 556 | ), 557 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 558 | nn.BatchNorm2d(1024), 559 | ), 560 | Lambda(lambda x: x), # Identity, 561 | ), 562 | LambdaReduce(lambda x, y: x + y), # CAddTable, 563 | nn.ReLU(), 564 | ), 565 | nn.Sequential( # Sequential, 566 | LambdaMap(lambda x: x, # ConcatTable, 567 | nn.Sequential( # Sequential, 568 | nn.Sequential( # Sequential, 569 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 570 | nn.BatchNorm2d(512), 571 | nn.ReLU(), 572 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 573 | nn.BatchNorm2d(512), 574 | nn.ReLU(), 575 | ), 576 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 577 | nn.BatchNorm2d(1024), 578 | ), 579 | Lambda(lambda x: x), # Identity, 580 | ), 581 | LambdaReduce(lambda x, y: x + y), # CAddTable, 582 | nn.ReLU(), 583 | ), 584 | nn.Sequential( # Sequential, 585 | LambdaMap(lambda x: x, # ConcatTable, 586 | nn.Sequential( # Sequential, 587 | nn.Sequential( # Sequential, 588 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 589 | nn.BatchNorm2d(512), 590 | nn.ReLU(), 591 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 592 | nn.BatchNorm2d(512), 593 | nn.ReLU(), 594 | ), 595 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 596 | nn.BatchNorm2d(1024), 597 | ), 598 | Lambda(lambda x: x), # Identity, 599 | ), 600 | LambdaReduce(lambda x, y: x + y), # CAddTable, 601 | nn.ReLU(), 602 | ), 603 | nn.Sequential( # Sequential, 604 | LambdaMap(lambda x: x, # ConcatTable, 605 | nn.Sequential( # Sequential, 606 | nn.Sequential( # Sequential, 607 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 608 | nn.BatchNorm2d(512), 609 | nn.ReLU(), 610 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 611 | nn.BatchNorm2d(512), 612 | nn.ReLU(), 613 | ), 614 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 615 | nn.BatchNorm2d(1024), 616 | ), 617 | Lambda(lambda x: x), # Identity, 618 | ), 619 | LambdaReduce(lambda x, y: x + y), # CAddTable, 620 | nn.ReLU(), 621 | ), 622 | ), 623 | nn.Sequential( # Sequential, 624 | nn.Sequential( # Sequential, 625 | LambdaMap(lambda x: x, # ConcatTable, 626 | nn.Sequential( # Sequential, 627 | nn.Sequential( # Sequential, 628 | nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 629 | nn.BatchNorm2d(1024), 630 | nn.ReLU(), 631 | nn.Conv2d(1024, 1024, (3, 3), (2, 2), (1, 1), 1, 32, bias=False), 632 | nn.BatchNorm2d(1024), 633 | nn.ReLU(), 634 | ), 635 | nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 636 | nn.BatchNorm2d(2048), 637 | ), 638 | nn.Sequential( # Sequential, 639 | nn.Conv2d(1024, 2048, (1, 1), (2, 2), (0, 0), 1, 1, bias=False), 640 | nn.BatchNorm2d(2048), 641 | ), 642 | ), 643 | LambdaReduce(lambda x, y: x + y), # CAddTable, 644 | nn.ReLU(), 645 | ), 646 | nn.Sequential( # Sequential, 647 | LambdaMap(lambda x: x, # ConcatTable, 648 | nn.Sequential( # Sequential, 649 | nn.Sequential( # Sequential, 650 | nn.Conv2d(2048, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 651 | nn.BatchNorm2d(1024), 652 | nn.ReLU(), 653 | nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 654 | nn.BatchNorm2d(1024), 655 | nn.ReLU(), 656 | ), 657 | nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 658 | nn.BatchNorm2d(2048), 659 | ), 660 | Lambda(lambda x: x), # Identity, 661 | ), 662 | LambdaReduce(lambda x, y: x + y), # CAddTable, 663 | nn.ReLU(), 664 | ), 665 | nn.Sequential( # Sequential, 666 | LambdaMap(lambda x: x, # ConcatTable, 667 | nn.Sequential( # Sequential, 668 | nn.Sequential( # Sequential, 669 | nn.Conv2d(2048, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 670 | nn.BatchNorm2d(1024), 671 | nn.ReLU(), 672 | nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 673 | nn.BatchNorm2d(1024), 674 | nn.ReLU(), 675 | ), 676 | nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 677 | nn.BatchNorm2d(2048), 678 | ), 679 | Lambda(lambda x: x), # Identity, 680 | ), 681 | LambdaReduce(lambda x, y: x + y), # CAddTable, 682 | nn.ReLU(), 683 | ), 684 | ), 685 | nn.AvgPool2d((7, 7), (1, 1)), 686 | Lambda(lambda x: x.view(x.size(0), -1)), # View, 687 | nn.Sequential(Lambda(lambda x: x.view(1, -1) if 1 == len(x.size()) else x), nn.Linear(2048, 1000)), # Linear, 688 | ) 689 | -------------------------------------------------------------------------------- /ckpt/MirrorNet/placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mhaiyang/ICCV2019_MirrorNet/0fdfd0f3b1608c16fbc70f60450d0ddd2e2e5efb/ckpt/MirrorNet/placeholder -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 9/15/19 10:22 3 | @Author : TaylorMei 4 | @Email : mhy666@mail.dlut.edu.cn 5 | 6 | @Project : ICCV2019_MirrorNet 7 | @File : config.py 8 | @Function: configurations. 9 | 10 | """ 11 | backbone_path = '/home/iccd/ICCV2019_MirrorNet/backbone/resnext/resnext_101_32x4d.pth' 12 | 13 | msd_training_root = "/media/iccd/disk/release/MSD/train" 14 | 15 | msd_testing_root = "/media/iccd/disk/release/MSD/test" 16 | 17 | msd_results_root = "/media/iccd/disk/release/MSD/results" 18 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 10/2/19 18:00 3 | @Author : TaylorMei 4 | @Email : mhy666@mail.dlut.edu.cn 5 | 6 | @Project : ICCV2019_MirrorNet 7 | @File : dataset.py 8 | @Function: prepare data for training. 9 | 10 | """ 11 | import os 12 | import os.path 13 | 14 | import torch.utils.data as data 15 | from PIL import Image 16 | 17 | 18 | def make_dataset(root): 19 | img_list = [os.path.splitext(f)[0] for f in os.listdir(os.path.join(root, 'image')) if f.endswith('.jpg')] 20 | return [ 21 | (os.path.join(root, 'image', img_name + '.jpg'), os.path.join(root, 'mask', img_name + '.png')) 22 | for img_name in img_list] 23 | 24 | 25 | class ImageFolder(data.Dataset): 26 | def __init__(self, root, joint_transform=None, img_transform=None, target_transform=None): 27 | self.root = root 28 | self.imgs = make_dataset(root) 29 | self.joint_transform = joint_transform 30 | self.img_transform = img_transform 31 | self.target_transform = target_transform 32 | 33 | def __getitem__(self, index): 34 | img_path, gt_path = self.imgs[index] 35 | img = Image.open(img_path).convert('RGB') 36 | target = Image.open(gt_path) 37 | if self.joint_transform is not None: 38 | img, target = self.joint_transform(img, target) 39 | if self.img_transform is not None: 40 | img = self.img_transform(img) 41 | if self.target_transform is not None: 42 | target = self.target_transform(target) 43 | 44 | return img, target 45 | 46 | def __len__(self): 47 | return len(self.imgs) 48 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 9/29/19 17:14 3 | @Author : TaylorMei 4 | @Email : mhy666@mail.dlut.edu.cn 5 | 6 | @Project : ICCV2019_MirrorNet 7 | @File : infer.py 8 | @Function: predict mirror map. 9 | 10 | """ 11 | import numpy as np 12 | import os 13 | import time 14 | 15 | import torch 16 | from PIL import Image 17 | from torch.autograd import Variable 18 | from torchvision import transforms 19 | 20 | from config import msd_testing_root 21 | from misc import check_mkdir, crf_refine 22 | from mirrornet import MirrorNet 23 | 24 | device_ids = [0] 25 | torch.cuda.set_device(device_ids[0]) 26 | 27 | ckpt_path = './ckpt' 28 | exp_name = 'MirrorNet' 29 | args = { 30 | 'snapshot': '160', 31 | 'scale': 384, 32 | 'crf': True 33 | } 34 | 35 | img_transform = transforms.Compose([ 36 | transforms.Resize((args['scale'], args['scale'])), 37 | transforms.ToTensor(), 38 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 39 | ]) 40 | 41 | to_test = {'MSD': msd_testing_root} 42 | 43 | to_pil = transforms.ToPILImage() 44 | 45 | 46 | def main(): 47 | net = MirrorNet().cuda(device_ids[0]) 48 | 49 | if len(args['snapshot']) > 0: 50 | print('Load snapshot {} for testing'.format(args['snapshot'])) 51 | # net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, 'MirrorNet.pth'))) 52 | # print('Load {} succeed!'.format(os.path.join(ckpt_path, exp_name, 'MirrorNet.pth'))) 53 | net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth'))) 54 | print('Load {} succeed!'.format(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth'))) 55 | 56 | net.eval() 57 | with torch.no_grad(): 58 | for name, root in to_test.items(): 59 | img_list = [img_name for img_name in os.listdir(os.path.join(root, 'image'))] 60 | start = time.time() 61 | for idx, img_name in enumerate(img_list): 62 | print('predicting for {}: {:>4d} / {}'.format(name, idx + 1, len(img_list))) 63 | check_mkdir(os.path.join(ckpt_path, exp_name, '%s_%s_%s' % (exp_name, args['snapshot'], 'nocrf'))) 64 | img = Image.open(os.path.join(root, 'image', img_name)) 65 | if img.mode != 'RGB': 66 | img = img.convert('RGB') 67 | print("{} is a gray image.".format(name)) 68 | w, h = img.size 69 | img_var = Variable(img_transform(img).unsqueeze(0)).cuda(device_ids[0]) 70 | f_4, f_3, f_2, f_1 = net(img_var) 71 | f_4 = f_4.data.squeeze(0).cpu() 72 | f_3 = f_3.data.squeeze(0).cpu() 73 | f_2 = f_2.data.squeeze(0).cpu() 74 | f_1 = f_1.data.squeeze(0).cpu() 75 | f_4 = np.array(transforms.Resize((h, w))(to_pil(f_4))) 76 | f_3 = np.array(transforms.Resize((h, w))(to_pil(f_3))) 77 | f_2 = np.array(transforms.Resize((h, w))(to_pil(f_2))) 78 | f_1 = np.array(transforms.Resize((h, w))(to_pil(f_1))) 79 | if args['crf']: 80 | f_1 = crf_refine(np.array(img.convert('RGB')), f_1) 81 | 82 | Image.fromarray(f_1).save(os.path.join(ckpt_path, exp_name, '%s_%s_%s' % (exp_name, args['snapshot'], 'nocrf'), img_name[:-4] + ".png")) 83 | 84 | end = time.time() 85 | print("Average Time Is : {:.2f}".format((end - start) / len(img_list))) 86 | 87 | 88 | if __name__ == '__main__': 89 | main() 90 | -------------------------------------------------------------------------------- /mirrornet.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 9/29/19 17:16 3 | @Author : TaylorMei 4 | @Email : mhy666@mail.dlut.edu.cn 5 | 6 | @Project : ICCV2019_MirrorNet 7 | @File : mirrornet.py 8 | @Function: MirrorNet. 9 | 10 | """ 11 | import torch 12 | import torch.nn.functional as F 13 | from torch import nn 14 | 15 | from backbone.resnext.resnext101_regular import ResNeXt101 16 | 17 | 18 | ################################################################### 19 | # ########################## CBAM ################################# 20 | ################################################################### 21 | class BasicConv(nn.Module): 22 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, 23 | bn=True, bias=False): 24 | super(BasicConv, self).__init__() 25 | self.out_channels = out_planes 26 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, 27 | dilation=dilation, groups=groups, bias=bias) 28 | self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None 29 | self.relu = nn.ReLU() if relu else None 30 | 31 | def forward(self, x): 32 | x = self.conv(x) 33 | if self.bn is not None: 34 | x = self.bn(x) 35 | if self.relu is not None: 36 | x = self.relu(x) 37 | return x 38 | 39 | 40 | class Flatten(nn.Module): 41 | def forward(self, x): 42 | return x.view(x.size(0), -1) 43 | 44 | 45 | class ChannelGate(nn.Module): 46 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg']): 47 | super(ChannelGate, self).__init__() 48 | self.gate_channels = gate_channels 49 | self.mlp = nn.Sequential( 50 | Flatten(), 51 | nn.Linear(gate_channels, gate_channels // reduction_ratio), 52 | nn.ReLU(), 53 | nn.Linear(gate_channels // reduction_ratio, gate_channels) 54 | ) 55 | self.pool_types = pool_types 56 | 57 | def forward(self, x): 58 | channel_att_sum = None 59 | for pool_type in self.pool_types: 60 | if pool_type == 'avg': 61 | avg_pool = F.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 62 | channel_att_raw = self.mlp(avg_pool) 63 | elif pool_type == 'max': 64 | max_pool = F.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 65 | channel_att_raw = self.mlp(max_pool) 66 | elif pool_type == 'lp': 67 | lp_pool = F.lp_pool2d(x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 68 | channel_att_raw = self.mlp(lp_pool) 69 | elif pool_type == 'lse': 70 | # LSE pool only 71 | lse_pool = logsumexp_2d(x) 72 | channel_att_raw = self.mlp(lse_pool) 73 | 74 | if channel_att_sum is None: 75 | channel_att_sum = channel_att_raw 76 | else: 77 | channel_att_sum = channel_att_sum + channel_att_raw 78 | 79 | scale = F.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x) 80 | return x * scale 81 | 82 | 83 | def logsumexp_2d(tensor): 84 | tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1) 85 | s, _ = torch.max(tensor_flatten, dim=2, keepdim=True) 86 | outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log() 87 | return outputs 88 | 89 | 90 | class ChannelPool(nn.Module): 91 | def forward(self, x): 92 | # original 93 | # return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1) 94 | # max 95 | # torch.max(x, 1)[0].unsqueeze(1) 96 | # avg 97 | return torch.mean(x, 1).unsqueeze(1) 98 | 99 | 100 | class SpatialGate(nn.Module): 101 | def __init__(self): 102 | super(SpatialGate, self).__init__() 103 | kernel_size = 7 104 | self.compress = ChannelPool() 105 | self.spatial = BasicConv(1, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False) 106 | 107 | def forward(self, x): 108 | x_compress = self.compress(x) 109 | x_out = self.spatial(x_compress) 110 | scale = F.sigmoid(x_out) # broadcasting 111 | return x * scale 112 | 113 | 114 | class CBAM(nn.Module): 115 | def __init__(self, gate_channels=128, reduction_ratio=16, pool_types=['avg'], no_spatial=False): 116 | super(CBAM, self).__init__() 117 | self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) 118 | self.no_spatial = no_spatial 119 | if not no_spatial: 120 | self.SpatialGate = SpatialGate() 121 | 122 | def forward(self, x): 123 | x_out = self.ChannelGate(x) 124 | if not self.no_spatial: 125 | x_out = self.SpatialGate(x_out) 126 | return x_out 127 | 128 | 129 | ################################################################### 130 | # ###################### Contrast Module ########################## 131 | ################################################################### 132 | class Contrast_Module(nn.Module): 133 | def __init__(self, planes): 134 | super(Contrast_Module, self).__init__() 135 | self.inplanes = int(planes) 136 | self.inplanes_half = int(planes / 2) 137 | self.outplanes = int(planes / 4) 138 | 139 | self.conv1 = nn.Sequential(nn.Conv2d(self.inplanes, self.inplanes_half, 3, 1, 1), 140 | nn.BatchNorm2d(self.inplanes_half), nn.ReLU()) 141 | 142 | self.conv2 = nn.Sequential(nn.Conv2d(self.inplanes_half, self.outplanes, 3, 1, 1), 143 | nn.BatchNorm2d(self.outplanes), nn.ReLU()) 144 | 145 | self.contrast_block_1 = Contrast_Block(self.outplanes) 146 | self.contrast_block_2 = Contrast_Block(self.outplanes) 147 | self.contrast_block_3 = Contrast_Block(self.outplanes) 148 | self.contrast_block_4 = Contrast_Block(self.outplanes) 149 | 150 | self.cbam = CBAM(self.inplanes) 151 | 152 | def forward(self, x): 153 | conv1 = self.conv1(x) 154 | conv2 = self.conv2(conv1) 155 | 156 | contrast_block_1 = self.contrast_block_1(conv2) 157 | contrast_block_2 = self.contrast_block_2(contrast_block_1) 158 | contrast_block_3 = self.contrast_block_3(contrast_block_2) 159 | contrast_block_4 = self.contrast_block_4(contrast_block_3) 160 | 161 | output = self.cbam(torch.cat((contrast_block_1, contrast_block_2, contrast_block_3, contrast_block_4), 1)) 162 | 163 | return output 164 | 165 | 166 | ################################################################### 167 | # ###################### Contrast Block ########################### 168 | ################################################################### 169 | class Contrast_Block(nn.Module): 170 | def __init__(self, planes): 171 | super(Contrast_Block, self).__init__() 172 | self.inplanes = int(planes) 173 | self.outplanes = int(planes / 4) 174 | 175 | self.local_1 = nn.Conv2d(self.inplanes, self.outplanes, kernel_size=3, stride=1, padding=1, dilation=1) 176 | self.context_1 = nn.Conv2d(self.inplanes, self.outplanes, kernel_size=3, stride=1, padding=2, dilation=2) 177 | 178 | self.local_2 = nn.Conv2d(self.inplanes, self.outplanes, kernel_size=3, stride=1, padding=1, dilation=1) 179 | self.context_2 = nn.Conv2d(self.inplanes, self.outplanes, kernel_size=3, stride=1, padding=4, dilation=4) 180 | 181 | self.local_3 = nn.Conv2d(self.inplanes, self.outplanes, kernel_size=3, stride=1, padding=1, dilation=1) 182 | self.context_3 = nn.Conv2d(self.inplanes, self.outplanes, kernel_size=3, stride=1, padding=8, dilation=8) 183 | 184 | self.local_4 = nn.Conv2d(self.inplanes, self.outplanes, kernel_size=3, stride=1, padding=1, dilation=1) 185 | self.context_4 = nn.Conv2d(self.inplanes, self.outplanes, kernel_size=3, stride=1, padding=16, dilation=16) 186 | 187 | self.bn = nn.BatchNorm2d(self.outplanes) 188 | self.relu = nn.ReLU() 189 | 190 | self.cbam = CBAM(self.inplanes) 191 | 192 | def forward(self, x): 193 | local_1 = self.local_1(x) 194 | context_1 = self.context_1(x) 195 | ccl_1 = local_1 - context_1 196 | ccl_1 = self.bn(ccl_1) 197 | ccl_1 = self.relu(ccl_1) 198 | 199 | local_2 = self.local_2(x) 200 | context_2 = self.context_2(x) 201 | ccl_2 = local_2 - context_2 202 | ccl_2 = self.bn(ccl_2) 203 | ccl_2 = self.relu(ccl_2) 204 | 205 | local_3 = self.local_3(x) 206 | context_3 = self.context_3(x) 207 | ccl_3 = local_3 - context_3 208 | ccl_3 = self.bn(ccl_3) 209 | ccl_3 = self.relu(ccl_3) 210 | 211 | local_4 = self.local_4(x) 212 | context_4 = self.context_4(x) 213 | ccl_4 = local_4 - context_4 214 | ccl_4 = self.bn(ccl_4) 215 | ccl_4 = self.relu(ccl_4) 216 | 217 | output = self.cbam(torch.cat((ccl_1, ccl_2, ccl_3, ccl_4), 1)) 218 | 219 | return output 220 | 221 | 222 | ################################################################### 223 | # ########################## NETWORK ############################## 224 | ################################################################### 225 | class MirrorNet(nn.Module): 226 | def __init__(self, backbone_path=None): 227 | super(MirrorNet, self).__init__() 228 | resnext = ResNeXt101(backbone_path) 229 | self.layer0 = resnext.layer0 230 | self.layer1 = resnext.layer1 231 | self.layer2 = resnext.layer2 232 | self.layer3 = resnext.layer3 233 | self.layer4 = resnext.layer4 234 | 235 | self.contrast_4 = Contrast_Module(2048) 236 | self.contrast_3 = Contrast_Module(1024) 237 | self.contrast_2 = Contrast_Module(512) 238 | self.contrast_1 = Contrast_Module(256) 239 | 240 | self.up_4 = nn.Sequential(nn.ConvTranspose2d(2048, 512, 4, 2, 1), nn.BatchNorm2d(512), nn.ReLU()) 241 | self.up_3 = nn.Sequential(nn.ConvTranspose2d(1024, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.ReLU()) 242 | self.up_2 = nn.Sequential(nn.ConvTranspose2d(512, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.ReLU()) 243 | self.up_1 = nn.Sequential(nn.Conv2d(256, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.ReLU()) 244 | 245 | self.cbam_4 = CBAM(512) 246 | self.cbam_3 = CBAM(256) 247 | self.cbam_2 = CBAM(128) 248 | self.cbam_1 = CBAM(64) 249 | 250 | self.layer4_predict = nn.Conv2d(512, 1, 3, 1, 1) 251 | self.layer3_predict = nn.Conv2d(256, 1, 3, 1, 1) 252 | self.layer2_predict = nn.Conv2d(128, 1, 3, 1, 1) 253 | self.layer1_predict = nn.Conv2d(64, 1, 3, 1, 1) 254 | 255 | for m in self.modules(): 256 | if isinstance(m, nn.ReLU): 257 | m.inplace = True 258 | 259 | def forward(self, x): 260 | layer0 = self.layer0(x) 261 | layer1 = self.layer1(layer0) 262 | layer2 = self.layer2(layer1) 263 | layer3 = self.layer3(layer2) 264 | layer4 = self.layer4(layer3) 265 | 266 | contrast_4 = self.contrast_4(layer4) 267 | up_4 = self.up_4(contrast_4) 268 | cbam_4 = self.cbam_4(up_4) 269 | layer4_predict = self.layer4_predict(cbam_4) 270 | layer4_map = F.sigmoid(layer4_predict) 271 | 272 | contrast_3 = self.contrast_3(layer3 * layer4_map) 273 | up_3 = self.up_3(contrast_3) 274 | cbam_3 = self.cbam_3(up_3) 275 | layer3_predict = self.layer3_predict(cbam_3) 276 | layer3_map = F.sigmoid(layer3_predict) 277 | 278 | contrast_2 = self.contrast_2(layer2 * layer3_map) 279 | up_2 = self.up_2(contrast_2) 280 | cbam_2 = self.cbam_2(up_2) 281 | layer2_predict = self.layer2_predict(cbam_2) 282 | layer2_map = F.sigmoid(layer2_predict) 283 | 284 | contrast_1 = self.contrast_1(layer1 * layer2_map) 285 | up_1 = self.up_1(contrast_1) 286 | cbam_1 = self.cbam_1(up_1) 287 | layer1_predict = self.layer1_predict(cbam_1) 288 | 289 | layer4_predict = F.upsample(layer4_predict, size=x.size()[2:], mode='bilinear', align_corners=True) 290 | layer3_predict = F.upsample(layer3_predict, size=x.size()[2:], mode='bilinear', align_corners=True) 291 | layer2_predict = F.upsample(layer2_predict, size=x.size()[2:], mode='bilinear', align_corners=True) 292 | layer1_predict = F.upsample(layer1_predict, size=x.size()[2:], mode='bilinear', align_corners=True) 293 | 294 | if self.training: 295 | return layer4_predict, layer3_predict, layer2_predict, layer1_predict 296 | 297 | return F.sigmoid(layer4_predict), F.sigmoid(layer3_predict), F.sigmoid(layer2_predict), \ 298 | F.sigmoid(layer1_predict) 299 | -------------------------------------------------------------------------------- /misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 9/15/19 10:19 3 | @Author : TaylorMei 4 | @Email : mhy666@mail.dlut.edu.cn 5 | 6 | @Project : ICCV2019_MirrorNet 7 | @File : misc.py 8 | @Function: functions. 9 | 10 | """ 11 | import numpy as np 12 | import os 13 | import skimage.io 14 | import skimage.transform 15 | import xlwt 16 | 17 | import pydensecrf.densecrf as dcrf 18 | 19 | 20 | ################################################################ 21 | ######################## Train & Test ########################## 22 | ################################################################ 23 | class AvgMeter(object): 24 | def __init__(self): 25 | self.reset() 26 | 27 | def reset(self): 28 | self.val = 0 29 | self.avg = 0 30 | self.sum = 0 31 | self.count = 0 32 | 33 | def update(self, val, n=1): 34 | self.val = val 35 | self.sum += val * n 36 | self.count += n 37 | self.avg = self.sum / self.count 38 | 39 | 40 | def check_mkdir(dir_name): 41 | if not os.path.exists(dir_name): 42 | os.mkdir(dir_name) 43 | 44 | 45 | def _sigmoid(x): 46 | return 1 / (1 + np.exp(-x)) 47 | 48 | 49 | def crf_refine(img, annos): 50 | assert img.dtype == np.uint8 51 | assert annos.dtype == np.uint8 52 | assert img.shape[:2] == annos.shape 53 | 54 | # img and annos should be np array with data type uint8 55 | 56 | EPSILON = 1e-8 57 | 58 | M = 2 # salient or not 59 | tau = 1.05 60 | # Setup the CRF model 61 | d = dcrf.DenseCRF2D(img.shape[1], img.shape[0], M) 62 | 63 | anno_norm = annos / 255. 64 | 65 | n_energy = -np.log((1.0 - anno_norm + EPSILON)) / (tau * _sigmoid(1 - anno_norm)) 66 | p_energy = -np.log(anno_norm + EPSILON) / (tau * _sigmoid(anno_norm)) 67 | 68 | U = np.zeros((M, img.shape[0] * img.shape[1]), dtype='float32') 69 | U[0, :] = n_energy.flatten() 70 | U[1, :] = p_energy.flatten() 71 | 72 | d.setUnaryEnergy(U) 73 | 74 | d.addPairwiseGaussian(sxy=3, compat=3) 75 | d.addPairwiseBilateral(sxy=60, srgb=5, rgbim=img, compat=5) 76 | 77 | # Do the inference 78 | infer = np.array(d.inference(1)).astype('float32') 79 | res = infer[1, :] 80 | 81 | res = res * 255 82 | res = res.reshape(img.shape[:2]) 83 | return res.astype('uint8') 84 | 85 | 86 | ################################################################ 87 | ######################## Evaluation ############################ 88 | ################################################################ 89 | def data_write(file_path, datas): 90 | f = xlwt.Workbook() 91 | sheet1 = f.add_sheet(sheetname="sheet1", cell_overwrite_ok=True) 92 | 93 | j = 0 94 | for data in datas: 95 | for i in range(len(data)): 96 | sheet1.write(i, j, data[i]) 97 | j = j + 1 98 | 99 | f.save(file_path) 100 | 101 | 102 | def get_gt_mask(imgname, MASK_DIR): 103 | filestr = imgname[:-4] 104 | mask_folder = MASK_DIR 105 | mask_path = mask_folder + "/" + filestr + ".png" 106 | mask = skimage.io.imread(mask_path) 107 | mask = np.where(mask == 255, 1, 0).astype(np.float32) 108 | 109 | return mask 110 | 111 | 112 | def get_normalized_predict_mask(imgname, PREDICT_MASK_DIR): 113 | filestr = imgname[:-4] 114 | mask_folder = PREDICT_MASK_DIR 115 | mask_path = mask_folder + "/" + filestr + ".png" 116 | if not os.path.exists(mask_path): 117 | print("{} has no predict mask!".format(imgname)) 118 | mask = skimage.io.imread(mask_path).astype(np.float32) 119 | if np.max(mask) > 0: 120 | mask = (mask - np.min(mask))/(np.max(mask) - np.min(mask)) 121 | mask = mask.astype(np.float32) 122 | 123 | return mask 124 | 125 | 126 | def get_binary_predict_mask(imgname, PREDICT_MASK_DIR): 127 | filestr = imgname[:-4] 128 | mask_folder = PREDICT_MASK_DIR 129 | mask_path = mask_folder + "/" + filestr + ".png" 130 | if not os.path.exists(mask_path): 131 | print("{} has no predict mask!".format(imgname)) 132 | mask = skimage.io.imread(mask_path).astype(np.float32) 133 | mask = np.where(mask >= 127.5, 1, 0).astype(np.float32) 134 | 135 | return mask 136 | 137 | 138 | def compute_iou(predict_mask, gt_mask): 139 | """ 140 | (1/n_cl) * sum_i(n_ii / (t_i + sum_j(n_ji) - n_ii)) 141 | Here, n_cl = 1 as we have only one class (mirror). 142 | """ 143 | 144 | check_size(predict_mask, gt_mask) 145 | 146 | if np.sum(predict_mask) == 0 or np.sum(gt_mask) == 0: 147 | iou_ = 0 148 | return iou_ 149 | 150 | n_ii = np.sum(np.logical_and(predict_mask, gt_mask)) 151 | t_i = np.sum(gt_mask) 152 | n_ij = np.sum(predict_mask) 153 | 154 | iou_ = n_ii / (t_i + n_ij - n_ii) 155 | 156 | return iou_ 157 | 158 | 159 | def compute_acc_mirror(predict_mask, gt_mask): 160 | 161 | check_size(predict_mask, gt_mask) 162 | 163 | N_p = np.sum(gt_mask) 164 | N_n = np.sum(np.logical_not(gt_mask)) 165 | 166 | TP = np.sum(np.logical_and(predict_mask, gt_mask)) 167 | TN = np.sum(np.logical_and(np.logical_not(predict_mask), np.logical_not(gt_mask))) 168 | 169 | accuracy_ = TP / N_p 170 | 171 | return accuracy_ 172 | 173 | 174 | def compute_acc_image(predict_mask, gt_mask): 175 | 176 | check_size(predict_mask, gt_mask) 177 | 178 | N_p = np.sum(gt_mask) 179 | N_n = np.sum(np.logical_not(gt_mask)) 180 | 181 | TP = np.sum(np.logical_and(predict_mask, gt_mask)) 182 | TN = np.sum(np.logical_and(np.logical_not(predict_mask), np.logical_not(gt_mask))) 183 | 184 | accuracy_ = (TP + TN) / (N_p + N_n) 185 | 186 | return accuracy_ 187 | 188 | 189 | def compute_mae(predict_mask, gt_mask): 190 | 191 | check_size(predict_mask, gt_mask) 192 | 193 | N_p = np.sum(gt_mask) 194 | N_n = np.sum(np.logical_not(gt_mask)) 195 | 196 | mae_ = np.mean(abs(predict_mask - gt_mask)).item() 197 | 198 | return mae_ 199 | 200 | 201 | def compute_ber(predict_mask, gt_mask): 202 | 203 | check_size(predict_mask, gt_mask) 204 | 205 | N_p = np.sum(gt_mask) 206 | N_n = np.sum(np.logical_not(gt_mask)) 207 | 208 | TP = np.sum(np.logical_and(predict_mask, gt_mask)) 209 | TN = np.sum(np.logical_and(np.logical_not(predict_mask), np.logical_not(gt_mask))) 210 | 211 | ber_ = 1 - (1 / 2) * ((TP / N_p) + (TN / N_n)) 212 | 213 | return ber_ 214 | 215 | 216 | def segm_size(segm): 217 | try: 218 | height = segm.shape[0] 219 | width = segm.shape[1] 220 | except IndexError: 221 | raise 222 | 223 | return height, width 224 | 225 | 226 | def check_size(eval_segm, gt_segm): 227 | h_e, w_e = segm_size(eval_segm) 228 | h_g, w_g = segm_size(gt_segm) 229 | 230 | if (h_e != h_g) or (w_e != w_g): 231 | raise EvalSegErr("DiffDim: Different dimensions of matrices!") 232 | 233 | 234 | class EvalSegErr(Exception): 235 | def __init__(self, value): 236 | self.value = value 237 | 238 | def __str__(self): 239 | return repr(self.value) 240 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Cython 2 | scikit-image 3 | tensorboardX==1.4 4 | xlwt 5 | tqdm -------------------------------------------------------------------------------- /utils/compute_contrast.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 9/28/19 16:25 3 | @Author : TaylorMei 4 | @Email : mhy666@mail.dlut.edu.cn 5 | 6 | @Project : ICCV2019_MirrorNet 7 | @File : compute_contrast.py 8 | @Function: compute color contrast distribution. 9 | 10 | """ 11 | import os 12 | import numpy as np 13 | import cv2 14 | import skimage.io 15 | from misc import data_write 16 | 17 | image_path = '/media/iccd/disk/release/MSD/all_images/' 18 | mask_path = '/media/iccd/disk/release/MSD/all_masks/' 19 | 20 | imglist = os.listdir(image_path) 21 | 22 | chi_sq_color = [] 23 | 24 | def chi2(arr1, arr2): 25 | 26 | return np.sum((arr1 - arr2)**2 / (arr1 + arr2 + np.finfo(np.float).eps)) 27 | 28 | 29 | for i, imgname in enumerate(imglist): 30 | print(i, imgname) 31 | 32 | image = skimage.io.imread(image_path + imgname) 33 | 34 | name = imgname.split('.')[0] 35 | mask = skimage.io.imread(mask_path + name + '.png') 36 | mask_f = np.where(mask != 0, 1, 0).astype(np.uint8) 37 | mask_b = np.where(mask == 0, 1, 0).astype(np.uint8) 38 | 39 | if np.sum(mask_f) == 0: 40 | print('**************************************************') 41 | continue 42 | 43 | hist_f_r = cv2.calcHist([image], [0], mask_f, [256], [0,256]) 44 | hist_f_g = cv2.calcHist([image], [1], mask_f, [256], [0,256]) 45 | hist_f_b = cv2.calcHist([image], [2], mask_f, [256], [0,256]) 46 | hist_b_r = cv2.calcHist([image], [0], mask_b, [256], [0,256]) 47 | hist_b_g = cv2.calcHist([image], [1], mask_b, [256], [0,256]) 48 | hist_b_b = cv2.calcHist([image], [2], mask_b, [256], [0,256]) 49 | 50 | chi_sq_r = chi2(hist_f_r.flatten()/np.sum(mask_f), hist_b_r.flatten()/np.sum(mask_b)) 51 | chi_sq_g = chi2(hist_f_g.flatten()/np.sum(mask_f), hist_b_g.flatten()/np.sum(mask_b)) 52 | chi_sq_b = chi2(hist_f_b.flatten()/np.sum(mask_f), hist_b_b.flatten()/np.sum(mask_b)) 53 | 54 | chi_sq_color.append(((chi_sq_r + chi_sq_g + chi_sq_b) / 3).item()) 55 | 56 | chi_sq_color = np.array(chi_sq_color) 57 | chi_sq_color = (chi_sq_color - np.min(chi_sq_color)) / (np.max(chi_sq_color - np.min(chi_sq_color))) 58 | 59 | print(chi_sq_color) 60 | data_write(os.path.join(os.getcwd(), 'msd_chi_sq.xlsx'), [chi_sq_color]) -------------------------------------------------------------------------------- /utils/compute_overlap.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 9/28/19 15:51 3 | @Author : TaylorMei 4 | @Email : mhy666@mail.dlut.edu.cn 5 | 6 | @Project : ICCV2019_MirrorNet 7 | @File : compute_overlap.py 8 | @Function: compute mirror location distribution. 9 | 10 | """ 11 | import os 12 | import numpy as np 13 | import skimage.io 14 | import skimage.transform 15 | import matplotlib.pyplot as plt 16 | from matplotlib import cm 17 | import seaborn as sns 18 | 19 | # image_path = '/media/iccd/disk/release/MSD/test/image/' 20 | # mask_path = '/media/iccd/disk/release/MSD/test/mask/' 21 | image_path = '/media/iccd/disk/release/MSD/all_images/' 22 | mask_path = '/media/iccd/disk/release/MSD/all_masks/' 23 | 24 | imglist = os.listdir(image_path) 25 | print(len(imglist)) 26 | 27 | overlap = np.zeros([256, 256], dtype=np.float64) 28 | tall, wide = 0, 0 29 | 30 | for i, imgname in enumerate(imglist): 31 | print(i, imgname) 32 | name = imgname.split('.')[0] 33 | 34 | mask = skimage.io.imread(mask_path + name + '.png') 35 | 36 | height = mask.shape[0] 37 | width = mask.shape[1] 38 | if height > width: 39 | tall += 1 40 | else: 41 | wide += 1 42 | mask = skimage.transform.resize(mask, [256, 256], order=0) 43 | mask = np.where(mask != 0, 1, 0).astype(np.float64) 44 | overlap += mask 45 | 46 | overlap = overlap / len(imglist) 47 | overlap_normalized = (overlap - np.min(overlap)) / (np.max(overlap) - np.min(overlap)) 48 | skimage.io.imsave('./msd_all.png', (overlap * 255).astype(np.uint8)) 49 | skimage.io.imsave('./msd_all_normalized.png', (overlap_normalized * 255).astype(np.uint8)) 50 | 51 | print(tall, wide) 52 | 53 | f, ax = plt.subplots() 54 | sns.set() 55 | ax = sns.heatmap(overlap, ax=ax, cmap=cm.summer, cbar=False) 56 | ax.set_xticklabels([]) 57 | ax.set_yticklabels([]) 58 | plt.xticks([]) 59 | plt.yticks([]) 60 | plt.show() -------------------------------------------------------------------------------- /utils/compute_size.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 9/28/19 15:37 3 | @Author : TaylorMei 4 | @Email : mhy666@mail.dlut.edu.cn 5 | 6 | @Project : ICCV2019_MirrorNet 7 | @File : compute_size.py 8 | @Function: compute mirror area distribution. 9 | 10 | """ 11 | import os 12 | import numpy as np 13 | import skimage.io 14 | from misc import data_write 15 | 16 | image_path = '/media/iccd/disk/release/MSD/all_images/' 17 | mask_path = '/media/iccd/disk/release/MSD/all_masks/' 18 | 19 | imglist = os.listdir(image_path) 20 | print(len(imglist)) 21 | 22 | output = [] 23 | 24 | for i, imgname in enumerate(imglist): 25 | print(i, imgname) 26 | name = imgname.split('.')[0] 27 | 28 | mask = skimage.io.imread(mask_path + name + '.png') 29 | mask = np.where(mask != 0, 1, 0).astype(np.uint8) 30 | 31 | height = mask.shape[0] 32 | width = mask.shape[1] 33 | total_area = height * width 34 | if total_area != 640*512: 35 | print('size error!') 36 | 37 | mirror_area = np.sum(mask) 38 | proportion = mirror_area / total_area 39 | output.append(proportion) 40 | data_write(os.path.join(os.getcwd(), 'msd_size.xlsx'), [output]) -------------------------------------------------------------------------------- /utils/generate_overlap_map.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 9/15/19 16:47 3 | @Author : TaylorMei 4 | @Email : mhy666@mail.dlut.edu.cn 5 | 6 | @Project : ICCV2019_MirrorNet 7 | @File : generate_overlap_map.py 8 | @Function: generate overlap map of each image in test set, according to the statistic on training set. 9 | 10 | """ 11 | import os 12 | import numpy as np 13 | from skimage import io, transform 14 | from config import msd_training_root, msd_testing_root, msd_results_root 15 | 16 | train_image_path = os.path.join(msd_training_root, 'image') 17 | test_image_path = os.path.join(msd_testing_root, 'image') 18 | mask_path = os.path.join(msd_training_root, 'mask') 19 | output_path = os.path.join(msd_results_root, 'Statistics') 20 | if not os.path.exists(output_path): 21 | os.mkdir(output_path) 22 | 23 | overlap = np.zeros([256, 256], dtype=np.float64) 24 | 25 | train_imglist = os.listdir(train_image_path) 26 | for i, imgname in enumerate(train_imglist): 27 | 28 | print(i, imgname) 29 | 30 | name = imgname.split('.')[0] 31 | 32 | mask = io.imread(os.path.join(mask_path, name + '.png')) 33 | 34 | mask = transform.resize(mask, [256, 256], order=0) 35 | mask = np.where(mask != 0, 1, 0).astype(np.float64) 36 | 37 | overlap += mask 38 | 39 | overlap = overlap / len(train_imglist) 40 | overlap = (overlap - np.min(overlap)) / (np.max(overlap) - np.min(overlap)) 41 | 42 | test_imglist = os.listdir(test_image_path) 43 | for j, imgname in enumerate(test_imglist): 44 | 45 | print(j, imgname) 46 | 47 | name = imgname.split('.')[0] 48 | 49 | image = io.imread(os.path.join(test_image_path, imgname)) 50 | 51 | height = image.shape[0] 52 | width = image.shape[1] 53 | 54 | mask = transform.resize(overlap, [height, width], 0) 55 | 56 | save_path = os.path.join(output_path, name + '.png') 57 | io.imsave(save_path, (mask * 255).astype(np.uint8)) 58 | 59 | print("OK!") --------------------------------------------------------------------------------