├── License.txt
├── README.md
├── assets
├── results.png
├── table1.png
└── table2.png
├── backbone
└── resnext
│ ├── __init__.py
│ ├── resnext101_regular.py
│ └── resnext_101_32x4d_.py
├── ckpt
└── MirrorNet
│ └── placeholder
├── config.py
├── dataset.py
├── infer.py
├── mirrornet.py
├── misc.py
├── requirements.txt
└── utils
├── compute_contrast.py
├── compute_overlap.py
├── compute_size.py
└── generate_overlap_map.py
/License.txt:
--------------------------------------------------------------------------------
1 | Where Is My Mirror?
2 | Xin Yang*, Haiyang Mei*, Ke Xu, Xiaopeng Wei, Baocai Yin, Rynson W.H. Lau (*Joint first authors)
3 | ICCV October, 2019
4 |
5 | Copyright (c) 2019
6 | All rights reserved.
7 |
8 | Computer Science and Technology
9 | Dalian University of Technology
10 |
11 | Department of Computer Science
12 | City University of Hong Kong
13 |
14 |
15 | -------------------------------------------------------
16 |
17 | Redistribution and use in source and binary forms, with or without
18 | modification, are permitted provided that the following conditions
19 | are met:
20 |
21 | 1. Redistributions of source code must retain the above copyright
22 | notice, this list of conditions and the following disclaimer.
23 |
24 | 2. Redistributions in binary form must reproduce the above copyright
25 | notice, this list of conditions and the following disclaimer in the
26 | documentation and/or other materials provided with the distribution.
27 |
28 | 3. Neither name of copyright holders nor the names of its contributors
29 | may be used to endorse or promote products derived from this software
30 | without specific prior written permission.
31 |
32 |
33 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
34 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
35 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
36 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR
37 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
38 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
39 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
40 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
41 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
42 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
43 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ICCV2019_MirrorNet
2 |
3 | ## Where Is My Mirror? (ICCV2019)
4 | Xin Yang\*, [Haiyang Mei](https://mhaiyang.github.io/)\*, Ke Xu, Xiaopeng Wei, Baocai Yin, [Rynson W.H. Lau](http://www.cs.cityu.edu.hk/~rynson/)† (* Joint first authors, † Rynson Lau is the corresponding author and he led this project.)
5 |
6 | [[Project Page](https://mhaiyang.github.io/ICCV2019_MirrorNet/index.html)][[Arxiv](https://arxiv.org/pdf/1908.09101v2.pdf)]
7 |
8 | ### Abstract
9 | Mirrors are everywhere in our daily lives. Existing computer vision systems do not consider mirrors, and hence may get confused by the reflected content inside a mirror, resulting in a severe performance degradation. However, separating the real content outside a mirror from the reflected content inside it is non-trivial. The key challenge lies in that mirrors typically reflect contents similar to their surroundings, making it very difficult to differentiate the two. In this paper, we present a novel method to accurately segment mirrors from an input image. To the best of our knowledge, this is the first work to address the mirror segmentation problem with a computational approach. We make the following contributions. First, we construct a large-scale mirror dataset that contains mirror images with the corresponding manually annotated masks. This dataset covers a variety of daily life scenes, and will be made publicly available for future research. Second, we propose a novel network, called MirrorNet, for mirror segmentation, by modeling both semantical and low-level color/texture discontinuities between the contents inside and outside of the mirrors. Third, we conduct extensive experiments to evaluate the proposed method, and show that it outperforms the carefully chosen baselines from the state-of-the-art detection and segmentation methods.
10 |
11 | ### Citation
12 | If you use this code or our dataset (including test set), please cite:
13 |
14 | ```
15 | @InProceedings{Yang_2019_ICCV,
16 | author = {Yang, Xin and Mei, Haiyang and Xu, Ke and Wei, Xiaopeng and Yin, Baocai and Lau, Rynson W.H.},
17 | title = {Where Is My Mirror?},
18 | booktitle = {The IEEE International Conference on Computer Vision (ICCV)},
19 | month = {October},
20 | year = {2019}
21 | }
22 | ```
23 |
24 | ### Dataset
25 | See [Peoject Page](https://mhaiyang.github.io/ICCV2019_MirrorNet/index.html)
26 |
27 | ### Requirements
28 | * PyTorch == 0.4.1
29 | * TorchVision == 0.2.1
30 | * CUDA 9.0 cudnn 7
31 | * Setup
32 | ```
33 | sudo pip3 install -r requirements.txt
34 | git clone https://github.com/Mhaiyang/dss_crf.git
35 | sudo python setup.py install
36 | ```
37 |
38 | ### Test
39 | Download the `resnext_101_32x4d.pth` at [here](https://drive.google.com/file/d/1e7N7LLZFWX4z0AkMG9wSQDCkOZEaSuFa/view?usp=sharing) and the trained model `MirrorNet.pth` at [here](https://mhaiyang.github.io/ICCV2019_MirrorNet/index.html), then run `infer.py`.
40 |
41 | ### Updated Main Results
42 |
43 | ##### Quantitative Results
44 |
45 |
46 |
47 | ##### Component analysis
48 |
49 |
50 |
51 | ##### Qualitative Results
52 |
53 |
54 | ### License
55 | Please see `license.txt`
56 |
57 | ### Contact
58 | E-Mail: mhy666@mail.dlut.edu.cn
59 |
--------------------------------------------------------------------------------
/assets/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mhaiyang/ICCV2019_MirrorNet/0fdfd0f3b1608c16fbc70f60450d0ddd2e2e5efb/assets/results.png
--------------------------------------------------------------------------------
/assets/table1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mhaiyang/ICCV2019_MirrorNet/0fdfd0f3b1608c16fbc70f60450d0ddd2e2e5efb/assets/table1.png
--------------------------------------------------------------------------------
/assets/table2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mhaiyang/ICCV2019_MirrorNet/0fdfd0f3b1608c16fbc70f60450d0ddd2e2e5efb/assets/table2.png
--------------------------------------------------------------------------------
/backbone/resnext/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnext101_regular import ResNeXt101
2 |
--------------------------------------------------------------------------------
/backbone/resnext/resnext101_regular.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from backbone.resnext import resnext_101_32x4d_
5 |
6 |
7 | class ResNeXt101(nn.Module):
8 | def __init__(self, backbone_path):
9 | super(ResNeXt101, self).__init__()
10 | net = resnext_101_32x4d_.resnext_101_32x4d
11 | if backbone_path is not None:
12 | weights = torch.load(backbone_path)
13 | net.load_state_dict(weights, strict=True)
14 | print("Load ResNeXt Weights Succeed!")
15 |
16 | net = list(net.children())
17 | self.layer0 = nn.Sequential(*net[:3])
18 | self.layer1 = nn.Sequential(*net[3: 5])
19 | self.layer2 = net[5]
20 | self.layer3 = net[6]
21 | self.layer4 = net[7]
22 |
23 | def forward(self, x):
24 | layer0 = self.layer0(x)
25 | layer1 = self.layer1(layer0)
26 | layer2 = self.layer2(layer1)
27 | layer3 = self.layer3(layer2)
28 | layer4 = self.layer4(layer3)
29 | return layer4
30 |
--------------------------------------------------------------------------------
/backbone/resnext/resnext_101_32x4d_.py:
--------------------------------------------------------------------------------
1 | from functools import reduce
2 |
3 | import torch.nn as nn
4 |
5 |
6 | class LambdaBase(nn.Sequential):
7 | def __init__(self, fn, *args):
8 | super(LambdaBase, self).__init__(*args)
9 | self.lambda_func = fn
10 |
11 | def forward_prepare(self, input):
12 | output = []
13 | for module in self._modules.values():
14 | output.append(module(input))
15 | return output if output else input
16 |
17 |
18 | class Lambda(LambdaBase):
19 | def forward(self, input):
20 | return self.lambda_func(self.forward_prepare(input))
21 |
22 |
23 | class LambdaMap(LambdaBase):
24 | def forward(self, input):
25 | return list(map(self.lambda_func, self.forward_prepare(input)))
26 |
27 |
28 | class LambdaReduce(LambdaBase):
29 | def forward(self, input):
30 | return reduce(self.lambda_func, self.forward_prepare(input))
31 |
32 |
33 | resnext_101_32x4d = nn.Sequential( # Sequential,
34 | nn.Conv2d(3, 64, (7, 7), (2, 2), (3, 3), 1, 1, bias=False),
35 | nn.BatchNorm2d(64),
36 | nn.ReLU(),
37 | nn.MaxPool2d((3, 3), (2, 2), (1, 1)),
38 | nn.Sequential( # Sequential,
39 | nn.Sequential( # Sequential,
40 | LambdaMap(lambda x: x, # ConcatTable,
41 | nn.Sequential( # Sequential,
42 | nn.Sequential( # Sequential,
43 | nn.Conv2d(64, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
44 | nn.BatchNorm2d(128),
45 | nn.ReLU(),
46 | nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
47 | nn.BatchNorm2d(128),
48 | nn.ReLU(),
49 | ),
50 | nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
51 | nn.BatchNorm2d(256),
52 | ),
53 | nn.Sequential( # Sequential,
54 | nn.Conv2d(64, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
55 | nn.BatchNorm2d(256),
56 | ),
57 | ),
58 | LambdaReduce(lambda x, y: x + y), # CAddTable,
59 | nn.ReLU(),
60 | ),
61 | nn.Sequential( # Sequential,
62 | LambdaMap(lambda x: x, # ConcatTable,
63 | nn.Sequential( # Sequential,
64 | nn.Sequential( # Sequential,
65 | nn.Conv2d(256, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
66 | nn.BatchNorm2d(128),
67 | nn.ReLU(),
68 | nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
69 | nn.BatchNorm2d(128),
70 | nn.ReLU(),
71 | ),
72 | nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
73 | nn.BatchNorm2d(256),
74 | ),
75 | Lambda(lambda x: x), # Identity,
76 | ),
77 | LambdaReduce(lambda x, y: x + y), # CAddTable,
78 | nn.ReLU(),
79 | ),
80 | nn.Sequential( # Sequential,
81 | LambdaMap(lambda x: x, # ConcatTable,
82 | nn.Sequential( # Sequential,
83 | nn.Sequential( # Sequential,
84 | nn.Conv2d(256, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
85 | nn.BatchNorm2d(128),
86 | nn.ReLU(),
87 | nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
88 | nn.BatchNorm2d(128),
89 | nn.ReLU(),
90 | ),
91 | nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
92 | nn.BatchNorm2d(256),
93 | ),
94 | Lambda(lambda x: x), # Identity,
95 | ),
96 | LambdaReduce(lambda x, y: x + y), # CAddTable,
97 | nn.ReLU(),
98 | ),
99 | ),
100 | nn.Sequential( # Sequential,
101 | nn.Sequential( # Sequential,
102 | LambdaMap(lambda x: x, # ConcatTable,
103 | nn.Sequential( # Sequential,
104 | nn.Sequential( # Sequential,
105 | nn.Conv2d(256, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
106 | nn.BatchNorm2d(256),
107 | nn.ReLU(),
108 | nn.Conv2d(256, 256, (3, 3), (2, 2), (1, 1), 1, 32, bias=False),
109 | nn.BatchNorm2d(256),
110 | nn.ReLU(),
111 | ),
112 | nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
113 | nn.BatchNorm2d(512),
114 | ),
115 | nn.Sequential( # Sequential,
116 | nn.Conv2d(256, 512, (1, 1), (2, 2), (0, 0), 1, 1, bias=False),
117 | nn.BatchNorm2d(512),
118 | ),
119 | ),
120 | LambdaReduce(lambda x, y: x + y), # CAddTable,
121 | nn.ReLU(),
122 | ),
123 | nn.Sequential( # Sequential,
124 | LambdaMap(lambda x: x, # ConcatTable,
125 | nn.Sequential( # Sequential,
126 | nn.Sequential( # Sequential,
127 | nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
128 | nn.BatchNorm2d(256),
129 | nn.ReLU(),
130 | nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
131 | nn.BatchNorm2d(256),
132 | nn.ReLU(),
133 | ),
134 | nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
135 | nn.BatchNorm2d(512),
136 | ),
137 | Lambda(lambda x: x), # Identity,
138 | ),
139 | LambdaReduce(lambda x, y: x + y), # CAddTable,
140 | nn.ReLU(),
141 | ),
142 | nn.Sequential( # Sequential,
143 | LambdaMap(lambda x: x, # ConcatTable,
144 | nn.Sequential( # Sequential,
145 | nn.Sequential( # Sequential,
146 | nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
147 | nn.BatchNorm2d(256),
148 | nn.ReLU(),
149 | nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
150 | nn.BatchNorm2d(256),
151 | nn.ReLU(),
152 | ),
153 | nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
154 | nn.BatchNorm2d(512),
155 | ),
156 | Lambda(lambda x: x), # Identity,
157 | ),
158 | LambdaReduce(lambda x, y: x + y), # CAddTable,
159 | nn.ReLU(),
160 | ),
161 | nn.Sequential( # Sequential,
162 | LambdaMap(lambda x: x, # ConcatTable,
163 | nn.Sequential( # Sequential,
164 | nn.Sequential( # Sequential,
165 | nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
166 | nn.BatchNorm2d(256),
167 | nn.ReLU(),
168 | nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
169 | nn.BatchNorm2d(256),
170 | nn.ReLU(),
171 | ),
172 | nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
173 | nn.BatchNorm2d(512),
174 | ),
175 | Lambda(lambda x: x), # Identity,
176 | ),
177 | LambdaReduce(lambda x, y: x + y), # CAddTable,
178 | nn.ReLU(),
179 | ),
180 | ),
181 | nn.Sequential( # Sequential,
182 | nn.Sequential( # Sequential,
183 | LambdaMap(lambda x: x, # ConcatTable,
184 | nn.Sequential( # Sequential,
185 | nn.Sequential( # Sequential,
186 | nn.Conv2d(512, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
187 | nn.BatchNorm2d(512),
188 | nn.ReLU(),
189 | nn.Conv2d(512, 512, (3, 3), (2, 2), (1, 1), 1, 32, bias=False),
190 | nn.BatchNorm2d(512),
191 | nn.ReLU(),
192 | ),
193 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
194 | nn.BatchNorm2d(1024),
195 | ),
196 | nn.Sequential( # Sequential,
197 | nn.Conv2d(512, 1024, (1, 1), (2, 2), (0, 0), 1, 1, bias=False),
198 | nn.BatchNorm2d(1024),
199 | ),
200 | ),
201 | LambdaReduce(lambda x, y: x + y), # CAddTable,
202 | nn.ReLU(),
203 | ),
204 | nn.Sequential( # Sequential,
205 | LambdaMap(lambda x: x, # ConcatTable,
206 | nn.Sequential( # Sequential,
207 | nn.Sequential( # Sequential,
208 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
209 | nn.BatchNorm2d(512),
210 | nn.ReLU(),
211 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
212 | nn.BatchNorm2d(512),
213 | nn.ReLU(),
214 | ),
215 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
216 | nn.BatchNorm2d(1024),
217 | ),
218 | Lambda(lambda x: x), # Identity,
219 | ),
220 | LambdaReduce(lambda x, y: x + y), # CAddTable,
221 | nn.ReLU(),
222 | ),
223 | nn.Sequential( # Sequential,
224 | LambdaMap(lambda x: x, # ConcatTable,
225 | nn.Sequential( # Sequential,
226 | nn.Sequential( # Sequential,
227 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
228 | nn.BatchNorm2d(512),
229 | nn.ReLU(),
230 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
231 | nn.BatchNorm2d(512),
232 | nn.ReLU(),
233 | ),
234 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
235 | nn.BatchNorm2d(1024),
236 | ),
237 | Lambda(lambda x: x), # Identity,
238 | ),
239 | LambdaReduce(lambda x, y: x + y), # CAddTable,
240 | nn.ReLU(),
241 | ),
242 | nn.Sequential( # Sequential,
243 | LambdaMap(lambda x: x, # ConcatTable,
244 | nn.Sequential( # Sequential,
245 | nn.Sequential( # Sequential,
246 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
247 | nn.BatchNorm2d(512),
248 | nn.ReLU(),
249 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
250 | nn.BatchNorm2d(512),
251 | nn.ReLU(),
252 | ),
253 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
254 | nn.BatchNorm2d(1024),
255 | ),
256 | Lambda(lambda x: x), # Identity,
257 | ),
258 | LambdaReduce(lambda x, y: x + y), # CAddTable,
259 | nn.ReLU(),
260 | ),
261 | nn.Sequential( # Sequential,
262 | LambdaMap(lambda x: x, # ConcatTable,
263 | nn.Sequential( # Sequential,
264 | nn.Sequential( # Sequential,
265 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
266 | nn.BatchNorm2d(512),
267 | nn.ReLU(),
268 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
269 | nn.BatchNorm2d(512),
270 | nn.ReLU(),
271 | ),
272 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
273 | nn.BatchNorm2d(1024),
274 | ),
275 | Lambda(lambda x: x), # Identity,
276 | ),
277 | LambdaReduce(lambda x, y: x + y), # CAddTable,
278 | nn.ReLU(),
279 | ),
280 | nn.Sequential( # Sequential,
281 | LambdaMap(lambda x: x, # ConcatTable,
282 | nn.Sequential( # Sequential,
283 | nn.Sequential( # Sequential,
284 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
285 | nn.BatchNorm2d(512),
286 | nn.ReLU(),
287 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
288 | nn.BatchNorm2d(512),
289 | nn.ReLU(),
290 | ),
291 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
292 | nn.BatchNorm2d(1024),
293 | ),
294 | Lambda(lambda x: x), # Identity,
295 | ),
296 | LambdaReduce(lambda x, y: x + y), # CAddTable,
297 | nn.ReLU(),
298 | ),
299 | nn.Sequential( # Sequential,
300 | LambdaMap(lambda x: x, # ConcatTable,
301 | nn.Sequential( # Sequential,
302 | nn.Sequential( # Sequential,
303 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
304 | nn.BatchNorm2d(512),
305 | nn.ReLU(),
306 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
307 | nn.BatchNorm2d(512),
308 | nn.ReLU(),
309 | ),
310 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
311 | nn.BatchNorm2d(1024),
312 | ),
313 | Lambda(lambda x: x), # Identity,
314 | ),
315 | LambdaReduce(lambda x, y: x + y), # CAddTable,
316 | nn.ReLU(),
317 | ),
318 | nn.Sequential( # Sequential,
319 | LambdaMap(lambda x: x, # ConcatTable,
320 | nn.Sequential( # Sequential,
321 | nn.Sequential( # Sequential,
322 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
323 | nn.BatchNorm2d(512),
324 | nn.ReLU(),
325 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
326 | nn.BatchNorm2d(512),
327 | nn.ReLU(),
328 | ),
329 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
330 | nn.BatchNorm2d(1024),
331 | ),
332 | Lambda(lambda x: x), # Identity,
333 | ),
334 | LambdaReduce(lambda x, y: x + y), # CAddTable,
335 | nn.ReLU(),
336 | ),
337 | nn.Sequential( # Sequential,
338 | LambdaMap(lambda x: x, # ConcatTable,
339 | nn.Sequential( # Sequential,
340 | nn.Sequential( # Sequential,
341 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
342 | nn.BatchNorm2d(512),
343 | nn.ReLU(),
344 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
345 | nn.BatchNorm2d(512),
346 | nn.ReLU(),
347 | ),
348 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
349 | nn.BatchNorm2d(1024),
350 | ),
351 | Lambda(lambda x: x), # Identity,
352 | ),
353 | LambdaReduce(lambda x, y: x + y), # CAddTable,
354 | nn.ReLU(),
355 | ),
356 | nn.Sequential( # Sequential,
357 | LambdaMap(lambda x: x, # ConcatTable,
358 | nn.Sequential( # Sequential,
359 | nn.Sequential( # Sequential,
360 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
361 | nn.BatchNorm2d(512),
362 | nn.ReLU(),
363 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
364 | nn.BatchNorm2d(512),
365 | nn.ReLU(),
366 | ),
367 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
368 | nn.BatchNorm2d(1024),
369 | ),
370 | Lambda(lambda x: x), # Identity,
371 | ),
372 | LambdaReduce(lambda x, y: x + y), # CAddTable,
373 | nn.ReLU(),
374 | ),
375 | nn.Sequential( # Sequential,
376 | LambdaMap(lambda x: x, # ConcatTable,
377 | nn.Sequential( # Sequential,
378 | nn.Sequential( # Sequential,
379 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
380 | nn.BatchNorm2d(512),
381 | nn.ReLU(),
382 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
383 | nn.BatchNorm2d(512),
384 | nn.ReLU(),
385 | ),
386 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
387 | nn.BatchNorm2d(1024),
388 | ),
389 | Lambda(lambda x: x), # Identity,
390 | ),
391 | LambdaReduce(lambda x, y: x + y), # CAddTable,
392 | nn.ReLU(),
393 | ),
394 | nn.Sequential( # Sequential,
395 | LambdaMap(lambda x: x, # ConcatTable,
396 | nn.Sequential( # Sequential,
397 | nn.Sequential( # Sequential,
398 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
399 | nn.BatchNorm2d(512),
400 | nn.ReLU(),
401 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
402 | nn.BatchNorm2d(512),
403 | nn.ReLU(),
404 | ),
405 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
406 | nn.BatchNorm2d(1024),
407 | ),
408 | Lambda(lambda x: x), # Identity,
409 | ),
410 | LambdaReduce(lambda x, y: x + y), # CAddTable,
411 | nn.ReLU(),
412 | ),
413 | nn.Sequential( # Sequential,
414 | LambdaMap(lambda x: x, # ConcatTable,
415 | nn.Sequential( # Sequential,
416 | nn.Sequential( # Sequential,
417 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
418 | nn.BatchNorm2d(512),
419 | nn.ReLU(),
420 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
421 | nn.BatchNorm2d(512),
422 | nn.ReLU(),
423 | ),
424 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
425 | nn.BatchNorm2d(1024),
426 | ),
427 | Lambda(lambda x: x), # Identity,
428 | ),
429 | LambdaReduce(lambda x, y: x + y), # CAddTable,
430 | nn.ReLU(),
431 | ),
432 | nn.Sequential( # Sequential,
433 | LambdaMap(lambda x: x, # ConcatTable,
434 | nn.Sequential( # Sequential,
435 | nn.Sequential( # Sequential,
436 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
437 | nn.BatchNorm2d(512),
438 | nn.ReLU(),
439 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
440 | nn.BatchNorm2d(512),
441 | nn.ReLU(),
442 | ),
443 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
444 | nn.BatchNorm2d(1024),
445 | ),
446 | Lambda(lambda x: x), # Identity,
447 | ),
448 | LambdaReduce(lambda x, y: x + y), # CAddTable,
449 | nn.ReLU(),
450 | ),
451 | nn.Sequential( # Sequential,
452 | LambdaMap(lambda x: x, # ConcatTable,
453 | nn.Sequential( # Sequential,
454 | nn.Sequential( # Sequential,
455 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
456 | nn.BatchNorm2d(512),
457 | nn.ReLU(),
458 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
459 | nn.BatchNorm2d(512),
460 | nn.ReLU(),
461 | ),
462 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
463 | nn.BatchNorm2d(1024),
464 | ),
465 | Lambda(lambda x: x), # Identity,
466 | ),
467 | LambdaReduce(lambda x, y: x + y), # CAddTable,
468 | nn.ReLU(),
469 | ),
470 | nn.Sequential( # Sequential,
471 | LambdaMap(lambda x: x, # ConcatTable,
472 | nn.Sequential( # Sequential,
473 | nn.Sequential( # Sequential,
474 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
475 | nn.BatchNorm2d(512),
476 | nn.ReLU(),
477 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
478 | nn.BatchNorm2d(512),
479 | nn.ReLU(),
480 | ),
481 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
482 | nn.BatchNorm2d(1024),
483 | ),
484 | Lambda(lambda x: x), # Identity,
485 | ),
486 | LambdaReduce(lambda x, y: x + y), # CAddTable,
487 | nn.ReLU(),
488 | ),
489 | nn.Sequential( # Sequential,
490 | LambdaMap(lambda x: x, # ConcatTable,
491 | nn.Sequential( # Sequential,
492 | nn.Sequential( # Sequential,
493 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
494 | nn.BatchNorm2d(512),
495 | nn.ReLU(),
496 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
497 | nn.BatchNorm2d(512),
498 | nn.ReLU(),
499 | ),
500 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
501 | nn.BatchNorm2d(1024),
502 | ),
503 | Lambda(lambda x: x), # Identity,
504 | ),
505 | LambdaReduce(lambda x, y: x + y), # CAddTable,
506 | nn.ReLU(),
507 | ),
508 | nn.Sequential( # Sequential,
509 | LambdaMap(lambda x: x, # ConcatTable,
510 | nn.Sequential( # Sequential,
511 | nn.Sequential( # Sequential,
512 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
513 | nn.BatchNorm2d(512),
514 | nn.ReLU(),
515 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
516 | nn.BatchNorm2d(512),
517 | nn.ReLU(),
518 | ),
519 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
520 | nn.BatchNorm2d(1024),
521 | ),
522 | Lambda(lambda x: x), # Identity,
523 | ),
524 | LambdaReduce(lambda x, y: x + y), # CAddTable,
525 | nn.ReLU(),
526 | ),
527 | nn.Sequential( # Sequential,
528 | LambdaMap(lambda x: x, # ConcatTable,
529 | nn.Sequential( # Sequential,
530 | nn.Sequential( # Sequential,
531 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
532 | nn.BatchNorm2d(512),
533 | nn.ReLU(),
534 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
535 | nn.BatchNorm2d(512),
536 | nn.ReLU(),
537 | ),
538 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
539 | nn.BatchNorm2d(1024),
540 | ),
541 | Lambda(lambda x: x), # Identity,
542 | ),
543 | LambdaReduce(lambda x, y: x + y), # CAddTable,
544 | nn.ReLU(),
545 | ),
546 | nn.Sequential( # Sequential,
547 | LambdaMap(lambda x: x, # ConcatTable,
548 | nn.Sequential( # Sequential,
549 | nn.Sequential( # Sequential,
550 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
551 | nn.BatchNorm2d(512),
552 | nn.ReLU(),
553 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
554 | nn.BatchNorm2d(512),
555 | nn.ReLU(),
556 | ),
557 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
558 | nn.BatchNorm2d(1024),
559 | ),
560 | Lambda(lambda x: x), # Identity,
561 | ),
562 | LambdaReduce(lambda x, y: x + y), # CAddTable,
563 | nn.ReLU(),
564 | ),
565 | nn.Sequential( # Sequential,
566 | LambdaMap(lambda x: x, # ConcatTable,
567 | nn.Sequential( # Sequential,
568 | nn.Sequential( # Sequential,
569 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
570 | nn.BatchNorm2d(512),
571 | nn.ReLU(),
572 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
573 | nn.BatchNorm2d(512),
574 | nn.ReLU(),
575 | ),
576 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
577 | nn.BatchNorm2d(1024),
578 | ),
579 | Lambda(lambda x: x), # Identity,
580 | ),
581 | LambdaReduce(lambda x, y: x + y), # CAddTable,
582 | nn.ReLU(),
583 | ),
584 | nn.Sequential( # Sequential,
585 | LambdaMap(lambda x: x, # ConcatTable,
586 | nn.Sequential( # Sequential,
587 | nn.Sequential( # Sequential,
588 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
589 | nn.BatchNorm2d(512),
590 | nn.ReLU(),
591 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
592 | nn.BatchNorm2d(512),
593 | nn.ReLU(),
594 | ),
595 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
596 | nn.BatchNorm2d(1024),
597 | ),
598 | Lambda(lambda x: x), # Identity,
599 | ),
600 | LambdaReduce(lambda x, y: x + y), # CAddTable,
601 | nn.ReLU(),
602 | ),
603 | nn.Sequential( # Sequential,
604 | LambdaMap(lambda x: x, # ConcatTable,
605 | nn.Sequential( # Sequential,
606 | nn.Sequential( # Sequential,
607 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
608 | nn.BatchNorm2d(512),
609 | nn.ReLU(),
610 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
611 | nn.BatchNorm2d(512),
612 | nn.ReLU(),
613 | ),
614 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
615 | nn.BatchNorm2d(1024),
616 | ),
617 | Lambda(lambda x: x), # Identity,
618 | ),
619 | LambdaReduce(lambda x, y: x + y), # CAddTable,
620 | nn.ReLU(),
621 | ),
622 | ),
623 | nn.Sequential( # Sequential,
624 | nn.Sequential( # Sequential,
625 | LambdaMap(lambda x: x, # ConcatTable,
626 | nn.Sequential( # Sequential,
627 | nn.Sequential( # Sequential,
628 | nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
629 | nn.BatchNorm2d(1024),
630 | nn.ReLU(),
631 | nn.Conv2d(1024, 1024, (3, 3), (2, 2), (1, 1), 1, 32, bias=False),
632 | nn.BatchNorm2d(1024),
633 | nn.ReLU(),
634 | ),
635 | nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
636 | nn.BatchNorm2d(2048),
637 | ),
638 | nn.Sequential( # Sequential,
639 | nn.Conv2d(1024, 2048, (1, 1), (2, 2), (0, 0), 1, 1, bias=False),
640 | nn.BatchNorm2d(2048),
641 | ),
642 | ),
643 | LambdaReduce(lambda x, y: x + y), # CAddTable,
644 | nn.ReLU(),
645 | ),
646 | nn.Sequential( # Sequential,
647 | LambdaMap(lambda x: x, # ConcatTable,
648 | nn.Sequential( # Sequential,
649 | nn.Sequential( # Sequential,
650 | nn.Conv2d(2048, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
651 | nn.BatchNorm2d(1024),
652 | nn.ReLU(),
653 | nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
654 | nn.BatchNorm2d(1024),
655 | nn.ReLU(),
656 | ),
657 | nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
658 | nn.BatchNorm2d(2048),
659 | ),
660 | Lambda(lambda x: x), # Identity,
661 | ),
662 | LambdaReduce(lambda x, y: x + y), # CAddTable,
663 | nn.ReLU(),
664 | ),
665 | nn.Sequential( # Sequential,
666 | LambdaMap(lambda x: x, # ConcatTable,
667 | nn.Sequential( # Sequential,
668 | nn.Sequential( # Sequential,
669 | nn.Conv2d(2048, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
670 | nn.BatchNorm2d(1024),
671 | nn.ReLU(),
672 | nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
673 | nn.BatchNorm2d(1024),
674 | nn.ReLU(),
675 | ),
676 | nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
677 | nn.BatchNorm2d(2048),
678 | ),
679 | Lambda(lambda x: x), # Identity,
680 | ),
681 | LambdaReduce(lambda x, y: x + y), # CAddTable,
682 | nn.ReLU(),
683 | ),
684 | ),
685 | nn.AvgPool2d((7, 7), (1, 1)),
686 | Lambda(lambda x: x.view(x.size(0), -1)), # View,
687 | nn.Sequential(Lambda(lambda x: x.view(1, -1) if 1 == len(x.size()) else x), nn.Linear(2048, 1000)), # Linear,
688 | )
689 |
--------------------------------------------------------------------------------
/ckpt/MirrorNet/placeholder:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mhaiyang/ICCV2019_MirrorNet/0fdfd0f3b1608c16fbc70f60450d0ddd2e2e5efb/ckpt/MirrorNet/placeholder
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | """
2 | @Time : 9/15/19 10:22
3 | @Author : TaylorMei
4 | @Email : mhy666@mail.dlut.edu.cn
5 |
6 | @Project : ICCV2019_MirrorNet
7 | @File : config.py
8 | @Function: configurations.
9 |
10 | """
11 | backbone_path = '/home/iccd/ICCV2019_MirrorNet/backbone/resnext/resnext_101_32x4d.pth'
12 |
13 | msd_training_root = "/media/iccd/disk/release/MSD/train"
14 |
15 | msd_testing_root = "/media/iccd/disk/release/MSD/test"
16 |
17 | msd_results_root = "/media/iccd/disk/release/MSD/results"
18 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | @Time : 10/2/19 18:00
3 | @Author : TaylorMei
4 | @Email : mhy666@mail.dlut.edu.cn
5 |
6 | @Project : ICCV2019_MirrorNet
7 | @File : dataset.py
8 | @Function: prepare data for training.
9 |
10 | """
11 | import os
12 | import os.path
13 |
14 | import torch.utils.data as data
15 | from PIL import Image
16 |
17 |
18 | def make_dataset(root):
19 | img_list = [os.path.splitext(f)[0] for f in os.listdir(os.path.join(root, 'image')) if f.endswith('.jpg')]
20 | return [
21 | (os.path.join(root, 'image', img_name + '.jpg'), os.path.join(root, 'mask', img_name + '.png'))
22 | for img_name in img_list]
23 |
24 |
25 | class ImageFolder(data.Dataset):
26 | def __init__(self, root, joint_transform=None, img_transform=None, target_transform=None):
27 | self.root = root
28 | self.imgs = make_dataset(root)
29 | self.joint_transform = joint_transform
30 | self.img_transform = img_transform
31 | self.target_transform = target_transform
32 |
33 | def __getitem__(self, index):
34 | img_path, gt_path = self.imgs[index]
35 | img = Image.open(img_path).convert('RGB')
36 | target = Image.open(gt_path)
37 | if self.joint_transform is not None:
38 | img, target = self.joint_transform(img, target)
39 | if self.img_transform is not None:
40 | img = self.img_transform(img)
41 | if self.target_transform is not None:
42 | target = self.target_transform(target)
43 |
44 | return img, target
45 |
46 | def __len__(self):
47 | return len(self.imgs)
48 |
--------------------------------------------------------------------------------
/infer.py:
--------------------------------------------------------------------------------
1 | """
2 | @Time : 9/29/19 17:14
3 | @Author : TaylorMei
4 | @Email : mhy666@mail.dlut.edu.cn
5 |
6 | @Project : ICCV2019_MirrorNet
7 | @File : infer.py
8 | @Function: predict mirror map.
9 |
10 | """
11 | import numpy as np
12 | import os
13 | import time
14 |
15 | import torch
16 | from PIL import Image
17 | from torch.autograd import Variable
18 | from torchvision import transforms
19 |
20 | from config import msd_testing_root
21 | from misc import check_mkdir, crf_refine
22 | from mirrornet import MirrorNet
23 |
24 | device_ids = [0]
25 | torch.cuda.set_device(device_ids[0])
26 |
27 | ckpt_path = './ckpt'
28 | exp_name = 'MirrorNet'
29 | args = {
30 | 'snapshot': '160',
31 | 'scale': 384,
32 | 'crf': True
33 | }
34 |
35 | img_transform = transforms.Compose([
36 | transforms.Resize((args['scale'], args['scale'])),
37 | transforms.ToTensor(),
38 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
39 | ])
40 |
41 | to_test = {'MSD': msd_testing_root}
42 |
43 | to_pil = transforms.ToPILImage()
44 |
45 |
46 | def main():
47 | net = MirrorNet().cuda(device_ids[0])
48 |
49 | if len(args['snapshot']) > 0:
50 | print('Load snapshot {} for testing'.format(args['snapshot']))
51 | # net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, 'MirrorNet.pth')))
52 | # print('Load {} succeed!'.format(os.path.join(ckpt_path, exp_name, 'MirrorNet.pth')))
53 | net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth')))
54 | print('Load {} succeed!'.format(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth')))
55 |
56 | net.eval()
57 | with torch.no_grad():
58 | for name, root in to_test.items():
59 | img_list = [img_name for img_name in os.listdir(os.path.join(root, 'image'))]
60 | start = time.time()
61 | for idx, img_name in enumerate(img_list):
62 | print('predicting for {}: {:>4d} / {}'.format(name, idx + 1, len(img_list)))
63 | check_mkdir(os.path.join(ckpt_path, exp_name, '%s_%s_%s' % (exp_name, args['snapshot'], 'nocrf')))
64 | img = Image.open(os.path.join(root, 'image', img_name))
65 | if img.mode != 'RGB':
66 | img = img.convert('RGB')
67 | print("{} is a gray image.".format(name))
68 | w, h = img.size
69 | img_var = Variable(img_transform(img).unsqueeze(0)).cuda(device_ids[0])
70 | f_4, f_3, f_2, f_1 = net(img_var)
71 | f_4 = f_4.data.squeeze(0).cpu()
72 | f_3 = f_3.data.squeeze(0).cpu()
73 | f_2 = f_2.data.squeeze(0).cpu()
74 | f_1 = f_1.data.squeeze(0).cpu()
75 | f_4 = np.array(transforms.Resize((h, w))(to_pil(f_4)))
76 | f_3 = np.array(transforms.Resize((h, w))(to_pil(f_3)))
77 | f_2 = np.array(transforms.Resize((h, w))(to_pil(f_2)))
78 | f_1 = np.array(transforms.Resize((h, w))(to_pil(f_1)))
79 | if args['crf']:
80 | f_1 = crf_refine(np.array(img.convert('RGB')), f_1)
81 |
82 | Image.fromarray(f_1).save(os.path.join(ckpt_path, exp_name, '%s_%s_%s' % (exp_name, args['snapshot'], 'nocrf'), img_name[:-4] + ".png"))
83 |
84 | end = time.time()
85 | print("Average Time Is : {:.2f}".format((end - start) / len(img_list)))
86 |
87 |
88 | if __name__ == '__main__':
89 | main()
90 |
--------------------------------------------------------------------------------
/mirrornet.py:
--------------------------------------------------------------------------------
1 | """
2 | @Time : 9/29/19 17:16
3 | @Author : TaylorMei
4 | @Email : mhy666@mail.dlut.edu.cn
5 |
6 | @Project : ICCV2019_MirrorNet
7 | @File : mirrornet.py
8 | @Function: MirrorNet.
9 |
10 | """
11 | import torch
12 | import torch.nn.functional as F
13 | from torch import nn
14 |
15 | from backbone.resnext.resnext101_regular import ResNeXt101
16 |
17 |
18 | ###################################################################
19 | # ########################## CBAM #################################
20 | ###################################################################
21 | class BasicConv(nn.Module):
22 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True,
23 | bn=True, bias=False):
24 | super(BasicConv, self).__init__()
25 | self.out_channels = out_planes
26 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
27 | dilation=dilation, groups=groups, bias=bias)
28 | self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None
29 | self.relu = nn.ReLU() if relu else None
30 |
31 | def forward(self, x):
32 | x = self.conv(x)
33 | if self.bn is not None:
34 | x = self.bn(x)
35 | if self.relu is not None:
36 | x = self.relu(x)
37 | return x
38 |
39 |
40 | class Flatten(nn.Module):
41 | def forward(self, x):
42 | return x.view(x.size(0), -1)
43 |
44 |
45 | class ChannelGate(nn.Module):
46 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg']):
47 | super(ChannelGate, self).__init__()
48 | self.gate_channels = gate_channels
49 | self.mlp = nn.Sequential(
50 | Flatten(),
51 | nn.Linear(gate_channels, gate_channels // reduction_ratio),
52 | nn.ReLU(),
53 | nn.Linear(gate_channels // reduction_ratio, gate_channels)
54 | )
55 | self.pool_types = pool_types
56 |
57 | def forward(self, x):
58 | channel_att_sum = None
59 | for pool_type in self.pool_types:
60 | if pool_type == 'avg':
61 | avg_pool = F.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
62 | channel_att_raw = self.mlp(avg_pool)
63 | elif pool_type == 'max':
64 | max_pool = F.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
65 | channel_att_raw = self.mlp(max_pool)
66 | elif pool_type == 'lp':
67 | lp_pool = F.lp_pool2d(x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
68 | channel_att_raw = self.mlp(lp_pool)
69 | elif pool_type == 'lse':
70 | # LSE pool only
71 | lse_pool = logsumexp_2d(x)
72 | channel_att_raw = self.mlp(lse_pool)
73 |
74 | if channel_att_sum is None:
75 | channel_att_sum = channel_att_raw
76 | else:
77 | channel_att_sum = channel_att_sum + channel_att_raw
78 |
79 | scale = F.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
80 | return x * scale
81 |
82 |
83 | def logsumexp_2d(tensor):
84 | tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
85 | s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
86 | outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
87 | return outputs
88 |
89 |
90 | class ChannelPool(nn.Module):
91 | def forward(self, x):
92 | # original
93 | # return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)
94 | # max
95 | # torch.max(x, 1)[0].unsqueeze(1)
96 | # avg
97 | return torch.mean(x, 1).unsqueeze(1)
98 |
99 |
100 | class SpatialGate(nn.Module):
101 | def __init__(self):
102 | super(SpatialGate, self).__init__()
103 | kernel_size = 7
104 | self.compress = ChannelPool()
105 | self.spatial = BasicConv(1, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False)
106 |
107 | def forward(self, x):
108 | x_compress = self.compress(x)
109 | x_out = self.spatial(x_compress)
110 | scale = F.sigmoid(x_out) # broadcasting
111 | return x * scale
112 |
113 |
114 | class CBAM(nn.Module):
115 | def __init__(self, gate_channels=128, reduction_ratio=16, pool_types=['avg'], no_spatial=False):
116 | super(CBAM, self).__init__()
117 | self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
118 | self.no_spatial = no_spatial
119 | if not no_spatial:
120 | self.SpatialGate = SpatialGate()
121 |
122 | def forward(self, x):
123 | x_out = self.ChannelGate(x)
124 | if not self.no_spatial:
125 | x_out = self.SpatialGate(x_out)
126 | return x_out
127 |
128 |
129 | ###################################################################
130 | # ###################### Contrast Module ##########################
131 | ###################################################################
132 | class Contrast_Module(nn.Module):
133 | def __init__(self, planes):
134 | super(Contrast_Module, self).__init__()
135 | self.inplanes = int(planes)
136 | self.inplanes_half = int(planes / 2)
137 | self.outplanes = int(planes / 4)
138 |
139 | self.conv1 = nn.Sequential(nn.Conv2d(self.inplanes, self.inplanes_half, 3, 1, 1),
140 | nn.BatchNorm2d(self.inplanes_half), nn.ReLU())
141 |
142 | self.conv2 = nn.Sequential(nn.Conv2d(self.inplanes_half, self.outplanes, 3, 1, 1),
143 | nn.BatchNorm2d(self.outplanes), nn.ReLU())
144 |
145 | self.contrast_block_1 = Contrast_Block(self.outplanes)
146 | self.contrast_block_2 = Contrast_Block(self.outplanes)
147 | self.contrast_block_3 = Contrast_Block(self.outplanes)
148 | self.contrast_block_4 = Contrast_Block(self.outplanes)
149 |
150 | self.cbam = CBAM(self.inplanes)
151 |
152 | def forward(self, x):
153 | conv1 = self.conv1(x)
154 | conv2 = self.conv2(conv1)
155 |
156 | contrast_block_1 = self.contrast_block_1(conv2)
157 | contrast_block_2 = self.contrast_block_2(contrast_block_1)
158 | contrast_block_3 = self.contrast_block_3(contrast_block_2)
159 | contrast_block_4 = self.contrast_block_4(contrast_block_3)
160 |
161 | output = self.cbam(torch.cat((contrast_block_1, contrast_block_2, contrast_block_3, contrast_block_4), 1))
162 |
163 | return output
164 |
165 |
166 | ###################################################################
167 | # ###################### Contrast Block ###########################
168 | ###################################################################
169 | class Contrast_Block(nn.Module):
170 | def __init__(self, planes):
171 | super(Contrast_Block, self).__init__()
172 | self.inplanes = int(planes)
173 | self.outplanes = int(planes / 4)
174 |
175 | self.local_1 = nn.Conv2d(self.inplanes, self.outplanes, kernel_size=3, stride=1, padding=1, dilation=1)
176 | self.context_1 = nn.Conv2d(self.inplanes, self.outplanes, kernel_size=3, stride=1, padding=2, dilation=2)
177 |
178 | self.local_2 = nn.Conv2d(self.inplanes, self.outplanes, kernel_size=3, stride=1, padding=1, dilation=1)
179 | self.context_2 = nn.Conv2d(self.inplanes, self.outplanes, kernel_size=3, stride=1, padding=4, dilation=4)
180 |
181 | self.local_3 = nn.Conv2d(self.inplanes, self.outplanes, kernel_size=3, stride=1, padding=1, dilation=1)
182 | self.context_3 = nn.Conv2d(self.inplanes, self.outplanes, kernel_size=3, stride=1, padding=8, dilation=8)
183 |
184 | self.local_4 = nn.Conv2d(self.inplanes, self.outplanes, kernel_size=3, stride=1, padding=1, dilation=1)
185 | self.context_4 = nn.Conv2d(self.inplanes, self.outplanes, kernel_size=3, stride=1, padding=16, dilation=16)
186 |
187 | self.bn = nn.BatchNorm2d(self.outplanes)
188 | self.relu = nn.ReLU()
189 |
190 | self.cbam = CBAM(self.inplanes)
191 |
192 | def forward(self, x):
193 | local_1 = self.local_1(x)
194 | context_1 = self.context_1(x)
195 | ccl_1 = local_1 - context_1
196 | ccl_1 = self.bn(ccl_1)
197 | ccl_1 = self.relu(ccl_1)
198 |
199 | local_2 = self.local_2(x)
200 | context_2 = self.context_2(x)
201 | ccl_2 = local_2 - context_2
202 | ccl_2 = self.bn(ccl_2)
203 | ccl_2 = self.relu(ccl_2)
204 |
205 | local_3 = self.local_3(x)
206 | context_3 = self.context_3(x)
207 | ccl_3 = local_3 - context_3
208 | ccl_3 = self.bn(ccl_3)
209 | ccl_3 = self.relu(ccl_3)
210 |
211 | local_4 = self.local_4(x)
212 | context_4 = self.context_4(x)
213 | ccl_4 = local_4 - context_4
214 | ccl_4 = self.bn(ccl_4)
215 | ccl_4 = self.relu(ccl_4)
216 |
217 | output = self.cbam(torch.cat((ccl_1, ccl_2, ccl_3, ccl_4), 1))
218 |
219 | return output
220 |
221 |
222 | ###################################################################
223 | # ########################## NETWORK ##############################
224 | ###################################################################
225 | class MirrorNet(nn.Module):
226 | def __init__(self, backbone_path=None):
227 | super(MirrorNet, self).__init__()
228 | resnext = ResNeXt101(backbone_path)
229 | self.layer0 = resnext.layer0
230 | self.layer1 = resnext.layer1
231 | self.layer2 = resnext.layer2
232 | self.layer3 = resnext.layer3
233 | self.layer4 = resnext.layer4
234 |
235 | self.contrast_4 = Contrast_Module(2048)
236 | self.contrast_3 = Contrast_Module(1024)
237 | self.contrast_2 = Contrast_Module(512)
238 | self.contrast_1 = Contrast_Module(256)
239 |
240 | self.up_4 = nn.Sequential(nn.ConvTranspose2d(2048, 512, 4, 2, 1), nn.BatchNorm2d(512), nn.ReLU())
241 | self.up_3 = nn.Sequential(nn.ConvTranspose2d(1024, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.ReLU())
242 | self.up_2 = nn.Sequential(nn.ConvTranspose2d(512, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.ReLU())
243 | self.up_1 = nn.Sequential(nn.Conv2d(256, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.ReLU())
244 |
245 | self.cbam_4 = CBAM(512)
246 | self.cbam_3 = CBAM(256)
247 | self.cbam_2 = CBAM(128)
248 | self.cbam_1 = CBAM(64)
249 |
250 | self.layer4_predict = nn.Conv2d(512, 1, 3, 1, 1)
251 | self.layer3_predict = nn.Conv2d(256, 1, 3, 1, 1)
252 | self.layer2_predict = nn.Conv2d(128, 1, 3, 1, 1)
253 | self.layer1_predict = nn.Conv2d(64, 1, 3, 1, 1)
254 |
255 | for m in self.modules():
256 | if isinstance(m, nn.ReLU):
257 | m.inplace = True
258 |
259 | def forward(self, x):
260 | layer0 = self.layer0(x)
261 | layer1 = self.layer1(layer0)
262 | layer2 = self.layer2(layer1)
263 | layer3 = self.layer3(layer2)
264 | layer4 = self.layer4(layer3)
265 |
266 | contrast_4 = self.contrast_4(layer4)
267 | up_4 = self.up_4(contrast_4)
268 | cbam_4 = self.cbam_4(up_4)
269 | layer4_predict = self.layer4_predict(cbam_4)
270 | layer4_map = F.sigmoid(layer4_predict)
271 |
272 | contrast_3 = self.contrast_3(layer3 * layer4_map)
273 | up_3 = self.up_3(contrast_3)
274 | cbam_3 = self.cbam_3(up_3)
275 | layer3_predict = self.layer3_predict(cbam_3)
276 | layer3_map = F.sigmoid(layer3_predict)
277 |
278 | contrast_2 = self.contrast_2(layer2 * layer3_map)
279 | up_2 = self.up_2(contrast_2)
280 | cbam_2 = self.cbam_2(up_2)
281 | layer2_predict = self.layer2_predict(cbam_2)
282 | layer2_map = F.sigmoid(layer2_predict)
283 |
284 | contrast_1 = self.contrast_1(layer1 * layer2_map)
285 | up_1 = self.up_1(contrast_1)
286 | cbam_1 = self.cbam_1(up_1)
287 | layer1_predict = self.layer1_predict(cbam_1)
288 |
289 | layer4_predict = F.upsample(layer4_predict, size=x.size()[2:], mode='bilinear', align_corners=True)
290 | layer3_predict = F.upsample(layer3_predict, size=x.size()[2:], mode='bilinear', align_corners=True)
291 | layer2_predict = F.upsample(layer2_predict, size=x.size()[2:], mode='bilinear', align_corners=True)
292 | layer1_predict = F.upsample(layer1_predict, size=x.size()[2:], mode='bilinear', align_corners=True)
293 |
294 | if self.training:
295 | return layer4_predict, layer3_predict, layer2_predict, layer1_predict
296 |
297 | return F.sigmoid(layer4_predict), F.sigmoid(layer3_predict), F.sigmoid(layer2_predict), \
298 | F.sigmoid(layer1_predict)
299 |
--------------------------------------------------------------------------------
/misc.py:
--------------------------------------------------------------------------------
1 | """
2 | @Time : 9/15/19 10:19
3 | @Author : TaylorMei
4 | @Email : mhy666@mail.dlut.edu.cn
5 |
6 | @Project : ICCV2019_MirrorNet
7 | @File : misc.py
8 | @Function: functions.
9 |
10 | """
11 | import numpy as np
12 | import os
13 | import skimage.io
14 | import skimage.transform
15 | import xlwt
16 |
17 | import pydensecrf.densecrf as dcrf
18 |
19 |
20 | ################################################################
21 | ######################## Train & Test ##########################
22 | ################################################################
23 | class AvgMeter(object):
24 | def __init__(self):
25 | self.reset()
26 |
27 | def reset(self):
28 | self.val = 0
29 | self.avg = 0
30 | self.sum = 0
31 | self.count = 0
32 |
33 | def update(self, val, n=1):
34 | self.val = val
35 | self.sum += val * n
36 | self.count += n
37 | self.avg = self.sum / self.count
38 |
39 |
40 | def check_mkdir(dir_name):
41 | if not os.path.exists(dir_name):
42 | os.mkdir(dir_name)
43 |
44 |
45 | def _sigmoid(x):
46 | return 1 / (1 + np.exp(-x))
47 |
48 |
49 | def crf_refine(img, annos):
50 | assert img.dtype == np.uint8
51 | assert annos.dtype == np.uint8
52 | assert img.shape[:2] == annos.shape
53 |
54 | # img and annos should be np array with data type uint8
55 |
56 | EPSILON = 1e-8
57 |
58 | M = 2 # salient or not
59 | tau = 1.05
60 | # Setup the CRF model
61 | d = dcrf.DenseCRF2D(img.shape[1], img.shape[0], M)
62 |
63 | anno_norm = annos / 255.
64 |
65 | n_energy = -np.log((1.0 - anno_norm + EPSILON)) / (tau * _sigmoid(1 - anno_norm))
66 | p_energy = -np.log(anno_norm + EPSILON) / (tau * _sigmoid(anno_norm))
67 |
68 | U = np.zeros((M, img.shape[0] * img.shape[1]), dtype='float32')
69 | U[0, :] = n_energy.flatten()
70 | U[1, :] = p_energy.flatten()
71 |
72 | d.setUnaryEnergy(U)
73 |
74 | d.addPairwiseGaussian(sxy=3, compat=3)
75 | d.addPairwiseBilateral(sxy=60, srgb=5, rgbim=img, compat=5)
76 |
77 | # Do the inference
78 | infer = np.array(d.inference(1)).astype('float32')
79 | res = infer[1, :]
80 |
81 | res = res * 255
82 | res = res.reshape(img.shape[:2])
83 | return res.astype('uint8')
84 |
85 |
86 | ################################################################
87 | ######################## Evaluation ############################
88 | ################################################################
89 | def data_write(file_path, datas):
90 | f = xlwt.Workbook()
91 | sheet1 = f.add_sheet(sheetname="sheet1", cell_overwrite_ok=True)
92 |
93 | j = 0
94 | for data in datas:
95 | for i in range(len(data)):
96 | sheet1.write(i, j, data[i])
97 | j = j + 1
98 |
99 | f.save(file_path)
100 |
101 |
102 | def get_gt_mask(imgname, MASK_DIR):
103 | filestr = imgname[:-4]
104 | mask_folder = MASK_DIR
105 | mask_path = mask_folder + "/" + filestr + ".png"
106 | mask = skimage.io.imread(mask_path)
107 | mask = np.where(mask == 255, 1, 0).astype(np.float32)
108 |
109 | return mask
110 |
111 |
112 | def get_normalized_predict_mask(imgname, PREDICT_MASK_DIR):
113 | filestr = imgname[:-4]
114 | mask_folder = PREDICT_MASK_DIR
115 | mask_path = mask_folder + "/" + filestr + ".png"
116 | if not os.path.exists(mask_path):
117 | print("{} has no predict mask!".format(imgname))
118 | mask = skimage.io.imread(mask_path).astype(np.float32)
119 | if np.max(mask) > 0:
120 | mask = (mask - np.min(mask))/(np.max(mask) - np.min(mask))
121 | mask = mask.astype(np.float32)
122 |
123 | return mask
124 |
125 |
126 | def get_binary_predict_mask(imgname, PREDICT_MASK_DIR):
127 | filestr = imgname[:-4]
128 | mask_folder = PREDICT_MASK_DIR
129 | mask_path = mask_folder + "/" + filestr + ".png"
130 | if not os.path.exists(mask_path):
131 | print("{} has no predict mask!".format(imgname))
132 | mask = skimage.io.imread(mask_path).astype(np.float32)
133 | mask = np.where(mask >= 127.5, 1, 0).astype(np.float32)
134 |
135 | return mask
136 |
137 |
138 | def compute_iou(predict_mask, gt_mask):
139 | """
140 | (1/n_cl) * sum_i(n_ii / (t_i + sum_j(n_ji) - n_ii))
141 | Here, n_cl = 1 as we have only one class (mirror).
142 | """
143 |
144 | check_size(predict_mask, gt_mask)
145 |
146 | if np.sum(predict_mask) == 0 or np.sum(gt_mask) == 0:
147 | iou_ = 0
148 | return iou_
149 |
150 | n_ii = np.sum(np.logical_and(predict_mask, gt_mask))
151 | t_i = np.sum(gt_mask)
152 | n_ij = np.sum(predict_mask)
153 |
154 | iou_ = n_ii / (t_i + n_ij - n_ii)
155 |
156 | return iou_
157 |
158 |
159 | def compute_acc_mirror(predict_mask, gt_mask):
160 |
161 | check_size(predict_mask, gt_mask)
162 |
163 | N_p = np.sum(gt_mask)
164 | N_n = np.sum(np.logical_not(gt_mask))
165 |
166 | TP = np.sum(np.logical_and(predict_mask, gt_mask))
167 | TN = np.sum(np.logical_and(np.logical_not(predict_mask), np.logical_not(gt_mask)))
168 |
169 | accuracy_ = TP / N_p
170 |
171 | return accuracy_
172 |
173 |
174 | def compute_acc_image(predict_mask, gt_mask):
175 |
176 | check_size(predict_mask, gt_mask)
177 |
178 | N_p = np.sum(gt_mask)
179 | N_n = np.sum(np.logical_not(gt_mask))
180 |
181 | TP = np.sum(np.logical_and(predict_mask, gt_mask))
182 | TN = np.sum(np.logical_and(np.logical_not(predict_mask), np.logical_not(gt_mask)))
183 |
184 | accuracy_ = (TP + TN) / (N_p + N_n)
185 |
186 | return accuracy_
187 |
188 |
189 | def compute_mae(predict_mask, gt_mask):
190 |
191 | check_size(predict_mask, gt_mask)
192 |
193 | N_p = np.sum(gt_mask)
194 | N_n = np.sum(np.logical_not(gt_mask))
195 |
196 | mae_ = np.mean(abs(predict_mask - gt_mask)).item()
197 |
198 | return mae_
199 |
200 |
201 | def compute_ber(predict_mask, gt_mask):
202 |
203 | check_size(predict_mask, gt_mask)
204 |
205 | N_p = np.sum(gt_mask)
206 | N_n = np.sum(np.logical_not(gt_mask))
207 |
208 | TP = np.sum(np.logical_and(predict_mask, gt_mask))
209 | TN = np.sum(np.logical_and(np.logical_not(predict_mask), np.logical_not(gt_mask)))
210 |
211 | ber_ = 1 - (1 / 2) * ((TP / N_p) + (TN / N_n))
212 |
213 | return ber_
214 |
215 |
216 | def segm_size(segm):
217 | try:
218 | height = segm.shape[0]
219 | width = segm.shape[1]
220 | except IndexError:
221 | raise
222 |
223 | return height, width
224 |
225 |
226 | def check_size(eval_segm, gt_segm):
227 | h_e, w_e = segm_size(eval_segm)
228 | h_g, w_g = segm_size(gt_segm)
229 |
230 | if (h_e != h_g) or (w_e != w_g):
231 | raise EvalSegErr("DiffDim: Different dimensions of matrices!")
232 |
233 |
234 | class EvalSegErr(Exception):
235 | def __init__(self, value):
236 | self.value = value
237 |
238 | def __str__(self):
239 | return repr(self.value)
240 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | Cython
2 | scikit-image
3 | tensorboardX==1.4
4 | xlwt
5 | tqdm
--------------------------------------------------------------------------------
/utils/compute_contrast.py:
--------------------------------------------------------------------------------
1 | """
2 | @Time : 9/28/19 16:25
3 | @Author : TaylorMei
4 | @Email : mhy666@mail.dlut.edu.cn
5 |
6 | @Project : ICCV2019_MirrorNet
7 | @File : compute_contrast.py
8 | @Function: compute color contrast distribution.
9 |
10 | """
11 | import os
12 | import numpy as np
13 | import cv2
14 | import skimage.io
15 | from misc import data_write
16 |
17 | image_path = '/media/iccd/disk/release/MSD/all_images/'
18 | mask_path = '/media/iccd/disk/release/MSD/all_masks/'
19 |
20 | imglist = os.listdir(image_path)
21 |
22 | chi_sq_color = []
23 |
24 | def chi2(arr1, arr2):
25 |
26 | return np.sum((arr1 - arr2)**2 / (arr1 + arr2 + np.finfo(np.float).eps))
27 |
28 |
29 | for i, imgname in enumerate(imglist):
30 | print(i, imgname)
31 |
32 | image = skimage.io.imread(image_path + imgname)
33 |
34 | name = imgname.split('.')[0]
35 | mask = skimage.io.imread(mask_path + name + '.png')
36 | mask_f = np.where(mask != 0, 1, 0).astype(np.uint8)
37 | mask_b = np.where(mask == 0, 1, 0).astype(np.uint8)
38 |
39 | if np.sum(mask_f) == 0:
40 | print('**************************************************')
41 | continue
42 |
43 | hist_f_r = cv2.calcHist([image], [0], mask_f, [256], [0,256])
44 | hist_f_g = cv2.calcHist([image], [1], mask_f, [256], [0,256])
45 | hist_f_b = cv2.calcHist([image], [2], mask_f, [256], [0,256])
46 | hist_b_r = cv2.calcHist([image], [0], mask_b, [256], [0,256])
47 | hist_b_g = cv2.calcHist([image], [1], mask_b, [256], [0,256])
48 | hist_b_b = cv2.calcHist([image], [2], mask_b, [256], [0,256])
49 |
50 | chi_sq_r = chi2(hist_f_r.flatten()/np.sum(mask_f), hist_b_r.flatten()/np.sum(mask_b))
51 | chi_sq_g = chi2(hist_f_g.flatten()/np.sum(mask_f), hist_b_g.flatten()/np.sum(mask_b))
52 | chi_sq_b = chi2(hist_f_b.flatten()/np.sum(mask_f), hist_b_b.flatten()/np.sum(mask_b))
53 |
54 | chi_sq_color.append(((chi_sq_r + chi_sq_g + chi_sq_b) / 3).item())
55 |
56 | chi_sq_color = np.array(chi_sq_color)
57 | chi_sq_color = (chi_sq_color - np.min(chi_sq_color)) / (np.max(chi_sq_color - np.min(chi_sq_color)))
58 |
59 | print(chi_sq_color)
60 | data_write(os.path.join(os.getcwd(), 'msd_chi_sq.xlsx'), [chi_sq_color])
--------------------------------------------------------------------------------
/utils/compute_overlap.py:
--------------------------------------------------------------------------------
1 | """
2 | @Time : 9/28/19 15:51
3 | @Author : TaylorMei
4 | @Email : mhy666@mail.dlut.edu.cn
5 |
6 | @Project : ICCV2019_MirrorNet
7 | @File : compute_overlap.py
8 | @Function: compute mirror location distribution.
9 |
10 | """
11 | import os
12 | import numpy as np
13 | import skimage.io
14 | import skimage.transform
15 | import matplotlib.pyplot as plt
16 | from matplotlib import cm
17 | import seaborn as sns
18 |
19 | # image_path = '/media/iccd/disk/release/MSD/test/image/'
20 | # mask_path = '/media/iccd/disk/release/MSD/test/mask/'
21 | image_path = '/media/iccd/disk/release/MSD/all_images/'
22 | mask_path = '/media/iccd/disk/release/MSD/all_masks/'
23 |
24 | imglist = os.listdir(image_path)
25 | print(len(imglist))
26 |
27 | overlap = np.zeros([256, 256], dtype=np.float64)
28 | tall, wide = 0, 0
29 |
30 | for i, imgname in enumerate(imglist):
31 | print(i, imgname)
32 | name = imgname.split('.')[0]
33 |
34 | mask = skimage.io.imread(mask_path + name + '.png')
35 |
36 | height = mask.shape[0]
37 | width = mask.shape[1]
38 | if height > width:
39 | tall += 1
40 | else:
41 | wide += 1
42 | mask = skimage.transform.resize(mask, [256, 256], order=0)
43 | mask = np.where(mask != 0, 1, 0).astype(np.float64)
44 | overlap += mask
45 |
46 | overlap = overlap / len(imglist)
47 | overlap_normalized = (overlap - np.min(overlap)) / (np.max(overlap) - np.min(overlap))
48 | skimage.io.imsave('./msd_all.png', (overlap * 255).astype(np.uint8))
49 | skimage.io.imsave('./msd_all_normalized.png', (overlap_normalized * 255).astype(np.uint8))
50 |
51 | print(tall, wide)
52 |
53 | f, ax = plt.subplots()
54 | sns.set()
55 | ax = sns.heatmap(overlap, ax=ax, cmap=cm.summer, cbar=False)
56 | ax.set_xticklabels([])
57 | ax.set_yticklabels([])
58 | plt.xticks([])
59 | plt.yticks([])
60 | plt.show()
--------------------------------------------------------------------------------
/utils/compute_size.py:
--------------------------------------------------------------------------------
1 | """
2 | @Time : 9/28/19 15:37
3 | @Author : TaylorMei
4 | @Email : mhy666@mail.dlut.edu.cn
5 |
6 | @Project : ICCV2019_MirrorNet
7 | @File : compute_size.py
8 | @Function: compute mirror area distribution.
9 |
10 | """
11 | import os
12 | import numpy as np
13 | import skimage.io
14 | from misc import data_write
15 |
16 | image_path = '/media/iccd/disk/release/MSD/all_images/'
17 | mask_path = '/media/iccd/disk/release/MSD/all_masks/'
18 |
19 | imglist = os.listdir(image_path)
20 | print(len(imglist))
21 |
22 | output = []
23 |
24 | for i, imgname in enumerate(imglist):
25 | print(i, imgname)
26 | name = imgname.split('.')[0]
27 |
28 | mask = skimage.io.imread(mask_path + name + '.png')
29 | mask = np.where(mask != 0, 1, 0).astype(np.uint8)
30 |
31 | height = mask.shape[0]
32 | width = mask.shape[1]
33 | total_area = height * width
34 | if total_area != 640*512:
35 | print('size error!')
36 |
37 | mirror_area = np.sum(mask)
38 | proportion = mirror_area / total_area
39 | output.append(proportion)
40 | data_write(os.path.join(os.getcwd(), 'msd_size.xlsx'), [output])
--------------------------------------------------------------------------------
/utils/generate_overlap_map.py:
--------------------------------------------------------------------------------
1 | """
2 | @Time : 9/15/19 16:47
3 | @Author : TaylorMei
4 | @Email : mhy666@mail.dlut.edu.cn
5 |
6 | @Project : ICCV2019_MirrorNet
7 | @File : generate_overlap_map.py
8 | @Function: generate overlap map of each image in test set, according to the statistic on training set.
9 |
10 | """
11 | import os
12 | import numpy as np
13 | from skimage import io, transform
14 | from config import msd_training_root, msd_testing_root, msd_results_root
15 |
16 | train_image_path = os.path.join(msd_training_root, 'image')
17 | test_image_path = os.path.join(msd_testing_root, 'image')
18 | mask_path = os.path.join(msd_training_root, 'mask')
19 | output_path = os.path.join(msd_results_root, 'Statistics')
20 | if not os.path.exists(output_path):
21 | os.mkdir(output_path)
22 |
23 | overlap = np.zeros([256, 256], dtype=np.float64)
24 |
25 | train_imglist = os.listdir(train_image_path)
26 | for i, imgname in enumerate(train_imglist):
27 |
28 | print(i, imgname)
29 |
30 | name = imgname.split('.')[0]
31 |
32 | mask = io.imread(os.path.join(mask_path, name + '.png'))
33 |
34 | mask = transform.resize(mask, [256, 256], order=0)
35 | mask = np.where(mask != 0, 1, 0).astype(np.float64)
36 |
37 | overlap += mask
38 |
39 | overlap = overlap / len(train_imglist)
40 | overlap = (overlap - np.min(overlap)) / (np.max(overlap) - np.min(overlap))
41 |
42 | test_imglist = os.listdir(test_image_path)
43 | for j, imgname in enumerate(test_imglist):
44 |
45 | print(j, imgname)
46 |
47 | name = imgname.split('.')[0]
48 |
49 | image = io.imread(os.path.join(test_image_path, imgname))
50 |
51 | height = image.shape[0]
52 | width = image.shape[1]
53 |
54 | mask = transform.resize(overlap, [height, width], 0)
55 |
56 | save_path = os.path.join(output_path, name + '.png')
57 | io.imsave(save_path, (mask * 255).astype(np.uint8))
58 |
59 | print("OK!")
--------------------------------------------------------------------------------