├── .DS_Store
├── .gitattributes
├── .gitignore
├── .idea
├── change_detection.iml
├── dictionaries
│ └── leon.xml
├── misc.xml
├── modules.xml
├── vcs.xml
└── workspace.xml
├── LICENSE
├── README.md
├── data_loaders.py
├── datasets.py
├── losses.py
├── metrics.py
├── models.py
├── test.py
├── train.py
└── utils.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hhhhhhao/change_detection/13b87c02166cc98d39d8be240a07abcf12893fe3/.DS_Store
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | #data folder
2 | data/*
3 | models/*
4 |
5 | # Byte-compiled / optimized / DLL files
6 | __pycache__/
7 | *.py[cod]
8 | *$py.class
9 |
10 | # C extensions
11 | *.so
12 |
13 | # Distribution / packaging
14 | .Python
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | pip-wheel-metadata/
28 | share/python-wheels/
29 | *.egg-info/
30 | .installed.cfg
31 | *.egg
32 | MANIFEST
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .nox/
48 | .coverage
49 | .coverage.*
50 | .cache
51 | nosetests.xml
52 | coverage.xml
53 | *.cover
54 | .hypothesis/
55 | .pytest_cache/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | .python-version
88 |
89 | # pipenv
90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
93 | # install all needed dependencies.
94 | #Pipfile.lock
95 |
96 | # celery beat schedule file
97 | celerybeat-schedule
98 |
99 | # SageMath parsed files
100 | *.sage.py
101 |
102 | # Environments
103 | .env
104 | .venv
105 | env/
106 | venv/
107 | ENV/
108 | env.bak/
109 | venv.bak/
110 |
111 | # Spyder project settings
112 | .spyderproject
113 | .spyproject
114 |
115 | # Rope project settings
116 | .ropeproject
117 |
118 | # mkdocs documentation
119 | /site
120 |
121 | # mypy
122 | .mypy_cache/
123 | .dmypy.json
124 | dmypy.json
125 |
126 | # Pyre type checker
127 | .pyre/
128 |
--------------------------------------------------------------------------------
/.idea/change_detection.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/dictionaries/leon.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
129 |
130 |
131 |
132 | img_size
133 | iter
134 | inputs, targets = next(iter(data_loader))
135 | output_
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 | true
162 | DEFINITION_ORDER
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 | 1566984986195
310 |
311 |
312 | 1566984986195
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 | file://$PROJECT_DIR$/utils.py
354 | 100
355 |
356 |
357 |
358 | file://$PROJECT_DIR$/datasets.py
359 | 54
360 |
361 |
362 |
363 | file://$PROJECT_DIR$/datasets.py
364 | 57
365 |
366 |
367 |
368 | file://$PROJECT_DIR$/datasets.py
369 | 47
370 |
371 |
372 |
373 | file://$PROJECT_DIR$/datasets.py
374 | 52
375 |
376 |
377 |
378 | file://$PROJECT_DIR$/utils.py
379 | 16
380 |
381 |
382 |
383 | file://$PROJECT_DIR$/data_loaders.py
384 | 53
385 |
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
408 |
409 |
410 |
411 |
412 |
413 |
414 |
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 |
423 |
424 |
425 |
426 |
427 |
428 |
429 |
430 |
431 |
432 |
433 |
434 |
435 |
436 |
437 |
438 |
439 |
440 |
441 |
442 |
443 |
444 |
445 |
446 |
447 |
448 |
449 |
450 |
451 |
452 |
453 |
454 |
455 |
456 |
457 |
458 |
459 |
460 |
461 |
462 |
463 |
464 |
465 |
466 |
467 |
468 |
469 |
470 |
471 |
472 |
473 |
474 |
475 |
476 |
477 |
478 |
479 |
480 |
481 |
482 |
483 |
484 |
485 |
486 |
487 |
488 |
489 |
490 |
491 |
492 |
493 |
494 |
495 |
496 |
497 |
498 |
499 |
500 |
501 |
502 |
503 |
504 |
505 |
506 |
507 |
508 |
509 |
510 |
511 |
512 |
513 |
514 |
515 |
516 |
517 |
518 |
519 |
520 |
521 |
522 |
523 |
524 |
525 |
526 |
527 |
528 |
529 |
530 |
531 |
532 |
533 |
534 |
535 |
536 |
537 |
538 |
539 |
540 |
541 |
542 |
543 |
544 |
545 |
546 |
547 |
548 |
549 |
550 |
551 |
552 |
553 |
554 |
555 |
556 |
557 |
558 |
559 |
560 |
561 |
562 |
563 |
564 |
565 |
566 |
567 |
568 |
569 |
570 |
571 |
572 |
573 |
574 |
575 |
576 |
577 |
578 |
579 |
580 |
581 |
582 |
583 |
584 |
585 |
586 |
587 |
588 |
589 |
590 |
591 |
592 |
593 |
594 |
595 |
596 |
597 |
598 |
599 |
600 |
601 |
602 |
603 |
604 |
605 |
606 |
607 |
608 |
609 |
610 |
611 |
612 |
613 |
614 |
615 |
616 |
617 |
618 |
619 |
620 |
621 |
622 |
623 |
624 |
625 |
626 |
627 |
628 |
629 |
630 |
631 |
632 |
633 |
634 |
635 |
636 |
637 |
638 |
639 |
640 |
641 |
642 |
643 |
644 |
645 |
646 |
647 |
648 |
649 |
650 |
651 |
652 |
653 |
654 |
655 |
656 |
657 |
658 |
659 |
660 |
661 |
662 |
663 |
664 |
665 |
666 |
667 |
668 |
669 |
670 |
671 |
672 |
673 |
674 |
675 |
676 |
677 |
678 |
679 |
680 |
681 |
682 |
683 |
684 |
685 |
686 |
687 |
688 |
689 |
690 |
691 |
692 |
693 |
694 |
695 |
696 |
697 |
698 |
699 |
700 |
701 |
702 |
703 |
704 |
705 |
706 |
707 |
708 |
709 |
710 |
711 |
712 |
713 |
714 |
715 |
716 |
717 |
718 |
719 |
720 |
721 |
722 |
723 |
724 |
725 |
726 |
727 |
728 |
729 |
730 |
731 |
732 |
733 |
734 |
735 |
736 |
737 |
738 |
739 |
740 |
741 |
742 |
743 |
744 |
745 |
746 |
747 |
748 |
749 |
750 |
751 |
752 |
753 |
754 |
755 |
756 |
757 |
758 |
759 |
760 |
761 |
762 |
763 |
764 |
765 |
766 |
767 |
768 |
769 |
770 |
771 |
772 |
773 |
774 |
775 |
776 |
777 |
778 |
779 |
780 |
781 |
782 |
783 |
784 |
785 |
786 |
787 |
788 |
789 |
790 |
791 |
792 |
793 |
794 |
795 |
796 |
797 |
798 |
799 |
800 |
801 |
802 |
803 |
804 |
805 |
806 |
807 |
808 |
809 |
810 |
811 |
812 |
813 |
814 |
815 |
816 |
817 |
818 |
819 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Hao
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # change_detection
2 | Change Detection for high resolution statellite images
3 |
--------------------------------------------------------------------------------
/data_loaders.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.utils.data import DataLoader
4 | from functools import partial
5 | from datasets import RssraiDataset
6 | from utils import calculate_bce_loss, get_pixels
7 |
8 |
9 | def collate_fn(data, batch_size, img_size):
10 | img, mask = data[0]
11 | batch_x, batch_y = [], []
12 | num_img = len(batch_x)
13 | num_white_img = 0
14 |
15 | # randomly crop image and mask at white pixels
16 | pixels = get_pixels(mask[0], img_size, img_size)
17 | num_pix = len(pixels[0])
18 |
19 | while True:
20 | if num_pix == 0:
21 | s_x = np.random.randint(0, img.shape[1] - img_size + 1)
22 | s_y = np.random.randint(0, img.shape[2] - img_size + 1)
23 | else:
24 | index = np.random.randint(num_pix)
25 | s_x = pixels[0][index]
26 | s_y = pixels[1][index]
27 | y = mask[:, s_x:s_x + img_size, s_y:s_y + img_size]
28 |
29 | if len(np.where(y!= 0)[0]) > 0:
30 | num_white_img += 1
31 |
32 | x = img[:, s_x:s_x + img_size, s_y:s_y + img_size]
33 |
34 | if num_white_img < 1:
35 | continue
36 |
37 | batch_x.append(x)
38 | batch_y.append(y)
39 | num_img = len(batch_x)
40 |
41 | if num_img == batch_size:
42 | break
43 |
44 | # for i in range(batch_size):
45 | # s_x = np.random.randint(0, img.shape[1] - img_size + 1)
46 | # s_y = np.random.randint(0, img.shape[2] - img_size + 1)
47 | # y = mask[:, s_x:s_x + img_size, s_y:s_y + img_size]
48 | # x = img[:, s_x:s_x + img_size, s_y:s_y + img_size]
49 | # batch_x.append(x)
50 | # batch_y.append(y)
51 |
52 | batch_x = np.array(batch_x)
53 | batch_y = np.array(batch_y)
54 | loss_weight = calculate_bce_loss(batch_y)
55 |
56 | if torch.cuda.is_available():
57 | batch_x = torch.cuda.FloatTensor(batch_x)
58 | batch_y = torch.cuda.FloatTensor(batch_y)
59 | else:
60 | batch_x = torch.FloatTensor(batch_x)
61 | batch_y = torch.FloatTensor(batch_y)
62 |
63 | return batch_x, batch_y, loss_weight
64 |
65 |
66 | class RssraiDataLoader(DataLoader):
67 | """
68 | Retinal vessel segmentation data loader
69 | """
70 | def __init__(self,
71 | which_set='train',
72 | batch_size=16,
73 | img_size=256,
74 | shuffle=True
75 | ):
76 | self.dataset = RssraiDataset(which_set=which_set)
77 | self.batch_size = batch_size
78 | self.img_size = img_size
79 | self.shuffle = shuffle
80 |
81 | super(RssraiDataLoader, self).__init__(
82 | dataset=self.dataset,
83 | batch_size=1, # batch_size set to 1 as we use only 1 full images to extract many patches
84 | shuffle=self.shuffle,
85 | num_workers=0,
86 | drop_last=self.shuffle,
87 | collate_fn=partial(collate_fn, batch_size=self.batch_size, img_size=self.img_size))
88 |
89 |
90 | if __name__ == '__main__':
91 | data_loader = RssraiDataLoader(which_set='train', batch_size=16, img_size=256, shuffle=True)
92 |
93 | for i, (input, mask, loss_weight) in enumerate(data_loader):
94 | print('{}th batch: input shape {}, mask shape {}, loss_weight{}'.format(i, input.shape, mask.shape, loss_weight))
95 |
96 |
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from tifffile import imread
4 | from torch.utils.data import Dataset
5 | from utils import get_files
6 |
7 |
8 | class RssraiDataset(Dataset):
9 | """
10 | rssrai2019 change detection dataset
11 | """
12 | def __init__(self, which_set):
13 | self.data_dir = os.path.join(os.path.dirname(__file__), 'data/rssrai2019_change_detection', which_set)
14 | assert os.path.exists(self.data_dir), "cannot find data folder: {}".format(self.data_dir)
15 |
16 | self.img1_ids = get_files(os.path.join(self.data_dir, 'img_2017/'))
17 | self.img2_ids = get_files(os.path.join(self.data_dir, 'img_2018/'))
18 | self.mask_ids = get_files(os.path.join(self.data_dir, 'mask/'))
19 |
20 | def __len__(self):
21 | return len(self.img1_ids)
22 |
23 | def __getitem__(self, idx):
24 | img1_id = self.img1_ids[idx]
25 | img2_id = self.img2_ids[idx]
26 | mask_id = self.mask_ids[idx]
27 | img1 = imread(img1_id).astype('uint8')
28 | img2 = imread(img2_id).astype('uint8')
29 | mask = imread(mask_id).astype('uint8')
30 |
31 | img1 = img1.astype('float32') / 255.
32 | img2 = img2.astype('float32') / 255.
33 | mask = mask.astype('float32') / 255.
34 |
35 | img = np.concatenate([img1, img2], axis=-1)
36 | img = img.transpose((2, 0, 1))
37 | mask = mask[:, :, np.newaxis]
38 | mask = mask.transpose((2, 0, 1))
39 | return img, mask
40 |
41 |
42 | if __name__ == '__main__':
43 | from tqdm import tqdm
44 | import matplotlib.pyplot as plt
45 | import cv2
46 |
47 | train_dataset = RssraiDataset(which_set='train')
48 | print('length of the dataset: {}'.format(len(train_dataset)))
49 |
50 | for i, (input, mask) in tqdm(enumerate(train_dataset), total=len(train_dataset)):
51 | print('input image shape:{}'.format(input.shape))
52 | print('mask shape:{}'.format(mask.shape))
53 | input = input.transpose(1, 2, 0)
54 |
55 | plt.imshow((input[:, :, :3] * 255).astype('uint8'))
56 | plt.show()
57 |
58 | cv2.imshow('image', input[:, :, :3] * 255)
59 | cv2.waitKey(0)
60 | # cv2.destroyAllWindows()
61 |
62 | break
63 |
64 |
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def weighted_binary_cross_entropy(sigmoid_x, targets, pos_weight, weight=None, size_average=True, reduce=True):
7 | """
8 | Args:
9 | sigmoid_x: predicted probability of size [N,C], N sample and C Class. Eg. Must be in range of [0,1], i.e. Output from Sigmoid.
10 | targets: true value, one-hot-like vector of size [N,C]
11 | pos_weight: Weight for postive sample
12 | """
13 | if not (targets.size() == sigmoid_x.size()):
14 | raise ValueError("Target size ({}) must be the same as input size ({})".format(targets.size(), sigmoid_x.size()))
15 |
16 | sigmoid_x = sigmoid_x.clamp(min=1e-8, max=1-1e-8)
17 | loss = -pos_weight * targets * sigmoid_x.log() - (1-targets)*(1-sigmoid_x).log()
18 |
19 | if weight is not None:
20 | loss = loss * weight
21 |
22 | if not reduce:
23 | return loss
24 | elif size_average:
25 | return loss.mean()
26 | else:
27 | return loss.sum()
28 |
29 |
30 | def dice_loss(pred, target, smooth=1e-8):
31 | iflat = pred.view(-1)
32 | tflat = target.view(-1)
33 | intersection = (iflat * tflat).sum()
34 |
35 | return 1 - ((2. * intersection + smooth) /
36 | (iflat.sum() + tflat.sum() + smooth))
37 |
38 |
39 | class WBCEDiceLoss(nn.Module):
40 | def __init__(self, alpha=0.5):
41 | super(WBCEDiceLoss, self).__init__()
42 | self.alpha = alpha
43 |
44 | def forward(self, pred, target, weight):
45 | pred = F.sigmoid(pred)
46 | dice = dice_loss(pred, target)
47 | bce = weighted_binary_cross_entropy(pred, target, weight)
48 |
49 | # bce = F.binary_cross_entropy_with_logits(pred, target)
50 | # pred = F.sigmoid(pred)
51 |
52 | loss = bce + dice * self.alpha
53 | return loss
54 |
55 |
56 | class BCEDiceLoss(nn.Module):
57 | def __init__(self, alpha=0.5):
58 | super(BCEDiceLoss, self).__init__()
59 | self.alpha = alpha
60 |
61 | def forward(self, pred, target, weight):
62 | bce = F.binary_cross_entropy_with_logits(pred, target)
63 | pred = F.sigmoid(pred)
64 | dice = dice_loss(pred, target)
65 |
66 | loss = bce + dice * self.alpha
67 | return loss
68 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 |
7 | def mean_iou(y_true_in, y_pred_in, print_table=False):
8 | if True: #not np.sum(y_true_in.flatten()) == 0:
9 | labels = y_true_in
10 | y_pred = y_pred_in
11 |
12 | true_objects = 2
13 | pred_objects = 2
14 |
15 | intersection = np.histogram2d(labels.flatten(), y_pred.flatten(), bins=(true_objects, pred_objects))[0]
16 |
17 | # Compute areas (needed for finding the union between all objects)
18 | area_true = np.histogram(labels, bins = true_objects)[0]
19 | area_pred = np.histogram(y_pred, bins = pred_objects)[0]
20 | area_true = np.expand_dims(area_true, -1)
21 | area_pred = np.expand_dims(area_pred, 0)
22 |
23 | # Compute union
24 | union = area_true + area_pred - intersection
25 |
26 | # Exclude background from the analysis
27 | intersection = intersection[1:,1:]
28 | union = union[1:,1:]
29 | union[union == 0] = 1e-9
30 |
31 | # Compute the intersection over union
32 | iou = intersection / union
33 |
34 | # Precision helper function
35 | def precision_at(threshold, iou):
36 | matches = iou > threshold
37 | true_positives = np.sum(matches, axis=1) == 1 # Correct objects
38 | false_positives = np.sum(matches, axis=0) == 0 # Missed objects
39 | false_negatives = np.sum(matches, axis=1) == 0 # Extra objects
40 | tp, fp, fn = np.sum(true_positives), np.sum(false_positives), np.sum(false_negatives)
41 | return tp, fp, fn
42 |
43 | # Loop over IoU thresholds
44 | prec = []
45 | if print_table:
46 | print("Thresh\tTP\tFP\tFN\tPrec.")
47 | for t in np.arange(0.5, 1.0, 0.05):
48 | tp, fp, fn = precision_at(t, iou)
49 | if (tp + fp + fn) > 0:
50 | p = tp / (tp + fp + fn)
51 | else:
52 | p = 0
53 | if print_table:
54 | print("{:1.3f}\t{}\t{}\t{}\t{:1.3f}".format(t, tp, fp, fn, p))
55 | prec.append(p)
56 |
57 | if print_table:
58 | print("AP\t-\t-\t-\t{:1.3f}".format(np.mean(prec)))
59 | return np.mean(prec)
60 |
61 | else:
62 | if np.sum(y_pred_in.flatten()) == 0:
63 | return 1
64 | else:
65 | return 0
66 |
67 |
68 | def batch_iou(output, target):
69 | output = torch.sigmoid(output).data.cpu().numpy() > 0.5
70 | target = (target.data.cpu().numpy() > 0.5).astype('int')
71 | output = output[:,0,:,:]
72 | target = target[:,0,:,:]
73 |
74 | ious = []
75 | for i in range(output.shape[0]):
76 | ious.append(mean_iou(output[i], target[i]))
77 |
78 | return np.mean(ious)
79 |
80 |
81 | def mean_iou(output, target):
82 | smooth = 1e-5
83 |
84 | output = torch.sigmoid(output).data.cpu().numpy()
85 | target = target.data.cpu().numpy()
86 | ious = []
87 | for t in np.arange(0.5, 1.0, 0.05):
88 | output_ = output > t
89 | target_ = target > t
90 | intersection = (output_ & target_).sum()
91 | union = (output_ | target_).sum()
92 | iou = (intersection + smooth) / (union + smooth)
93 | ious.append(iou)
94 |
95 | return np.mean(ious)
96 |
97 |
98 | def iou_score(output, target):
99 | smooth = 1e-5
100 |
101 | if torch.is_tensor(output):
102 | output = torch.sigmoid(output).data.cpu().numpy()
103 | if torch.is_tensor(target):
104 | target = target.data.cpu().numpy()
105 | output_ = output > 0.5
106 | target_ = target > 0.5
107 | intersection = (output_ & target_).sum()
108 | union = (output_ | target_).sum()
109 |
110 | return (intersection + smooth) / (union + smooth)
111 |
112 |
113 | def dice_coef(output, target):
114 | smooth = 1e-5
115 |
116 | output = torch.sigmoid(output).view(-1).data.cpu().numpy()
117 | target = target.view(-1).data.cpu().numpy()
118 | intersection = (output * target).sum()
119 |
120 | return (2. * intersection + smooth) / \
121 | (output.sum() + target.sum() + smooth)
122 |
123 |
124 | def accuracy(output, target):
125 | output = torch.sigmoid(output).view(-1).data.cpu().numpy()
126 | output = (np.round(output)).astype('int')
127 | target = target.view(-1).data.cpu().numpy()
128 | target = (np.round(target)).astype('int')
129 | (output == target).sum()
130 |
131 | return (output == target).sum() / len(output)
132 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 |
4 |
5 | class ConvBlock(nn.Module):
6 | """
7 | Convolution Block
8 | """
9 | def __init__(self, in_ch, out_ch, kernel_size=3, stride=1, padding=1, dilation=1):
10 | super(ConvBlock, self).__init__()
11 | block = [nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True)]
12 | block += [nn.BatchNorm2d(out_channels=out_ch)]
13 | block += [nn.ReLU(inplace=True)]
14 | self.block = nn.Sequential(*block)
15 |
16 | def forward(self, x):
17 | return self.block(x)
18 |
19 |
20 | class UpConvBlock(nn.Module):
21 | """
22 | Upsampling Convolution Block
23 | """
24 | def __init__(self, in_ch, out_ch, kernel_size=3, stride=1, padding=1, dilation=1):
25 | super(UpConvBlock, self).__init__()
26 | block = [nn.Upsample(scale_factor=2)]
27 | block += [nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True)]
28 | block += [nn.BatchNorm2d(out_channels=out_ch)]
29 | block += [nn.ReLU(inplace=True)]
30 | self.block = nn.Sequential(*block)
31 |
32 | def forward(self, x):
33 | return self.block(x)
34 |
35 |
36 | class ResBlock(nn.Module):
37 | """
38 | Residual Block
39 | """
40 | def __init__(self, in_ch, out_ch, stride=1):
41 | super(ResBlock, self).__init__()
42 | self.relu = nn.ReLU(inplace=True)
43 | self.conv1 = nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, stride=stride, padding=1, dilation=1, bias=True)
44 | self.bn1 = nn.BatchNorm2d(out_ch)
45 | self.conv2 = nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, stride=stride, padding=1, dilation=1, bias=True)
46 | self.bn2 = nn.BatchNorm2d(out_ch)
47 |
48 | def forward(self, x):
49 | identity = self.conv1(x)
50 | out = self.bn1(identity)
51 | out = self.relu(out)
52 |
53 | out = self.conv2(out)
54 | out = self.bn2(out)
55 |
56 | out += identity
57 | out = self.relu(out)
58 | return out
59 |
60 |
61 | class NestedUNet(nn.Module):
62 | """
63 | Implementation of nested Unet (Unet++)
64 | """
65 |
66 | def __init__(self, in_ch=3, out_ch=1, n=32):
67 | super(NestedUNet, self).__init__()
68 |
69 | filters = [n, n * 2, n * 4, n * 8, n * 16]
70 |
71 | self.conv0_0 = ResBlock(in_ch, filters[0])
72 | self.conv1_0 = ResBlock(filters[0], filters[1])
73 | self.conv2_0 = ResBlock(filters[1], filters[2])
74 | self.conv3_0 = ResBlock(filters[2], filters[3])
75 | self.conv4_0 = ResBlock(filters[3], filters[4])
76 |
77 | self.conv0_1 = ResBlock(filters[0] + filters[1], filters[0])
78 | self.conv1_1 = ResBlock(filters[1] + filters[2], filters[1])
79 | self.conv2_1 = ResBlock(filters[2] + filters[3], filters[2])
80 | self.conv3_1 = ResBlock(filters[3] + filters[4], filters[3])
81 |
82 | self.conv0_2 = ResBlock(filters[0] * 2 + filters[1], filters[0])
83 | self.conv1_2 = ResBlock(filters[1] * 2 + filters[2], filters[1])
84 | self.conv2_2 = ResBlock(filters[2] * 2 + filters[3], filters[2])
85 |
86 | self.conv0_3 = ResBlock(filters[0] * 3 + filters[1], filters[0])
87 | self.conv1_3 = ResBlock(filters[1] * 3 + filters[2], filters[1])
88 |
89 | self.conv0_4 = ResBlock(filters[0] * 4 + filters[1], filters[0])
90 |
91 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
92 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
93 |
94 | self.final1 = nn.Conv2d(filters[0], out_ch, kernel_size=1)
95 | self.final2 = nn.Conv2d(filters[0], out_ch, kernel_size=1)
96 | self.final3 = nn.Conv2d(filters[0], out_ch, kernel_size=1)
97 | self.final4 = nn.Conv2d(filters[0], out_ch, kernel_size=1)
98 | self.final5 = nn.Conv2d(4, out_ch, kernel_size=1)
99 |
100 | def forward(self, x):
101 |
102 | x0_0 = self.conv0_0(x)
103 |
104 | x1_0 = self.conv1_0(self.pool(x0_0))
105 | x0_1 = self.conv0_1(torch.cat((x0_0, self.up(x1_0)), 1))
106 |
107 | x2_0 = self.conv2_0(self.pool(x1_0))
108 | x1_1 = self.conv1_1(torch.cat((x1_0, self.up(x2_0)), 1))
109 | x0_2 = self.conv0_2(torch.cat((x0_0, x0_1, self.up(x1_1)), 1))
110 |
111 | x3_0 = self.conv3_0(self.pool(x2_0))
112 | x2_1 = self.conv2_1(torch.cat((x2_0, self.up(x3_0)), 1))
113 | x1_2 = self.conv1_2(torch.cat((x1_0, x1_1, self.up(x2_1)), 1))
114 | x0_3 = self.conv0_3(torch.cat((x0_0, x0_1, x0_2, self.up(x1_2)), 1))
115 |
116 | x4_0 = self.conv4_0(self.pool(x3_0))
117 | x3_1 = self.conv3_1(torch.cat((x3_0, self.up(x4_0)), 1))
118 | x2_2 = self.conv2_2(torch.cat((x2_0, x2_1, self.up(x3_1)), 1))
119 | x1_3 = self.conv1_3(torch.cat((x1_0, x1_1, x1_2, self.up(x2_2)), 1))
120 | x0_4 = self.conv0_4(torch.cat((x0_0, x0_1, x0_2, x0_3, self.up(x1_3)), 1))
121 |
122 | output1 = self.final1(x0_1)
123 | output2 = self.final2(x0_2)
124 | output3 = self.final3(x0_3)
125 | output4 = self.final4(x0_4)
126 | output5 = self.final5(torch.cat((output1, output2, output3, output4), 1))
127 |
128 | return [output1, output2, output3, output4, output5]
129 |
130 |
131 | if __name__ == '__main__':
132 | a = torch.ones((16, 6, 256, 256))
133 | Unet = NestedUNet(in_ch=6, out_ch=1)
134 | output = Unet(a)
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from glob import glob
4 | import warnings
5 |
6 | import numpy as np
7 | from tqdm import tqdm
8 |
9 | from skimage.io import imread, imsave
10 |
11 | import torch
12 |
13 | import models
14 | from metrics import dice_coef, batch_iou, mean_iou, iou_score
15 | import losses
16 | from utils import count_params
17 | from data_loaders import RssraiDataLoader
18 |
19 |
20 | def parse_args():
21 | parser = argparse.ArgumentParser()
22 |
23 | parser.add_argument('--name', default=None,
24 | help='model name')
25 |
26 | args = parser.parse_args()
27 |
28 | return args
29 |
30 |
31 | def main():
32 | val_args = parse_args()
33 |
34 | args = joblib.load('models/%s/args.pkl' %val_args.name)
35 |
36 | if not os.path.exists('output/%s' %args.name):
37 | os.makedirs('output/%s' %args.name)
38 |
39 | print('Config -----')
40 | for arg in vars(args):
41 | print('%s: %s' %(arg, getattr(args, arg)))
42 | print('------------')
43 |
44 | joblib.dump(args, 'models/%s/args.pkl' %args.name)
45 |
46 | # create model
47 | print("=> creating model %s" %args.arch)
48 | model = models.__dict__[args.arch](args.in_ch, args.out_ch, args.num_filters)
49 |
50 | if torch.cuda.is_available():
51 | model = model.cuda()
52 |
53 | model.load_state_dict(torch.load('models/%s/model.pth' %args.name))
54 | model.eval()
55 |
56 | val_loader = RssraiDataLoader(
57 | which_set='test',
58 | batch_size=args.batch_size,
59 | img_size=args.img_size,
60 | shuffle=False
61 | )
62 |
63 | with warnings.catch_warnings():
64 | warnings.simplefilter('ignore')
65 |
66 | with torch.no_grad():
67 | for i, (input, target) in tqdm(enumerate(val_loader), total=len(val_loader)):
68 | # compute output
69 | output = model(input)[-1]
70 | output = torch.sigmoid(output).data.cpu().numpy()
71 | img_paths = val_img_paths[args.batch_size*i:args.batch_size*(i+1)]
72 |
73 | for i in range(output.shape[0]):
74 | imsave('output/%s/'%args.name+os.path.basename(img_paths[i]), (output[i,0,:,:]*255).astype('uint8'))
75 |
76 | torch.cuda.empty_cache()
77 |
78 | # IoU
79 | ious = []
80 | for i in tqdm(range(len(val_mask_paths))):
81 | mask = imread(val_mask_paths[i])
82 | pb = imread('output/%s/'%args.name+os.path.basename(val_mask_paths[i]))
83 |
84 | mask = mask.astype('float32') / 255
85 | pb = pb.astype('float32') / 255
86 |
87 |
88 | '''
89 | plt.figure()
90 | plt.subplot(121)
91 | plt.imshow(mask)
92 | plt.subplot(122)
93 | plt.imshow(pb)
94 | plt.show()
95 | '''
96 |
97 | iou = iou_score(pb, mask)
98 | ious.append(iou)
99 | print('IoU: %.4f' %np.mean(ious))
100 |
101 |
102 | if __name__ == '__main__':
103 | main()
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import pandas as pd
4 | from tqdm import tqdm
5 | from collections import OrderedDict
6 | from datetime import datetime
7 | from sklearn.externals import joblib
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.optim as optim
12 | import torch.backends.cudnn as cudnn
13 | from torch.optim.lr_scheduler import CosineAnnealingLR
14 |
15 | import models
16 | import losses
17 | from data_loaders import RssraiDataLoader
18 | from metrics import iou_score
19 | from utils import count_params, save_example
20 |
21 | arch_names = list(models.__dict__.keys())
22 | loss_names = list(losses.__dict__.keys())
23 | loss_names.append('BCEWithLogitsLoss')
24 |
25 |
26 | def parse_args():
27 | parser = argparse.ArgumentParser()
28 |
29 | parser.add_argument('--name', default='baseline',
30 | help='model name: (default: arch+timestamp)')
31 | parser.add_argument('--img_size', default=256, type=int, help='size of training image patches')
32 | parser.add_argument('--arch', '-a', metavar='ARCH', default='NestedUNet',
33 | choices=arch_names,
34 | help='model architecture: ' +
35 | ' | '.join(arch_names) +
36 | ' (default: NestedUNet)')
37 | parser.add_argument('--num_filters', default=32, type=int,
38 | help='number of starting filters in CNN.')
39 | parser.add_argument('--in_ch', default=8, type=int,
40 | help='input channels')
41 | parser.add_argument('--out_ch', default=1, type=int,
42 | help='output channels')
43 | parser.add_argument('--loss', default='BCEDiceLoss',
44 | choices=loss_names,
45 | help='loss: ' +
46 | ' | '.join(loss_names) +
47 | ' (default: BCEDiceLoss)')
48 | parser.add_argument('--epochs', default=1000, type=int, metavar='N',
49 | help='number of total epochs to run')
50 | parser.add_argument('--early-stop', default=20, type=int,
51 | metavar='N', help='early stopping (default: 20)')
52 | parser.add_argument('-b', '--batch_size', default=16, type=int,
53 | metavar='N', help='mini-batch size (default: 16)')
54 | parser.add_argument('--optimizer', default='Adam',
55 | choices=['Adam', 'SGD'],
56 | help='loss: ' +
57 | ' | '.join(['Adam', 'SGD']) +
58 | ' (default: Adam)')
59 | parser.add_argument('--lr', '--learning-rate', default=2e-4, type=float,
60 | metavar='LR', help='initial learning rate')
61 | parser.add_argument('--momentum', default=0.9, type=float,
62 | help='momentum')
63 | parser.add_argument('--weight-decay', default=1e-6, type=float,
64 | help='weight decay')
65 | args = parser.parse_args()
66 |
67 | return args
68 |
69 |
70 | class AverageMeter(object):
71 | """Computes and stores the average and current value"""
72 | def __init__(self):
73 | self.reset()
74 |
75 | def reset(self):
76 | self.val = 0
77 | self.avg = 0
78 | self.sum = 0
79 | self.count = 0
80 |
81 | def update(self, val, n=1):
82 | self.val = val
83 | self.sum += val * n
84 | self.count += n
85 | self.avg = self.sum / self.count
86 |
87 |
88 | def train(train_loader, model, criterion, optimizer):
89 | losses = AverageMeter()
90 | ious = AverageMeter()
91 |
92 | model.train()
93 |
94 | pbar = tqdm(enumerate(train_loader), total=len(train_loader))
95 |
96 | for i, (input, target, loss_weight) in pbar:
97 | # compute output
98 | outputs = model(input)
99 | loss = 0
100 | for output in outputs:
101 | loss += criterion(output, target, loss_weight)
102 | loss /= len(outputs)
103 | iou = iou_score(outputs[-1], target)
104 |
105 | # update log and progress bar
106 | losses.update(loss.item(), input.size(0))
107 | ious.update(iou, input.size(0))
108 | pbar.set_postfix({'loss': loss.item(), 'iou': iou})
109 |
110 | # compute gradient and do optimizing step
111 | optimizer.zero_grad()
112 | loss.backward()
113 | optimizer.step()
114 |
115 | log = OrderedDict([
116 | ('loss', losses.avg),
117 | ('iou', ious.avg),
118 | ])
119 |
120 | return log
121 |
122 |
123 | def validate(val_loader, model, criterion):
124 | losses = AverageMeter()
125 | ious = AverageMeter()
126 |
127 | # switch to evaluate mode
128 | model.eval()
129 |
130 | with torch.no_grad():
131 | for i, (input, target, loss_weight) in enumerate(val_loader):
132 | # compute output
133 | outputs = model(input)
134 | loss = 0
135 | for output in outputs:
136 | loss += criterion(output, target, loss_weight)
137 | loss /= len(outputs)
138 | iou = iou_score(outputs[-1], target)
139 |
140 | losses.update(loss.item(), input.size(0))
141 | ious.update(iou, input.size(0))
142 |
143 | log = OrderedDict([
144 | ('loss', losses.avg),
145 | ('iou', ious.avg),
146 | ])
147 |
148 | return log
149 |
150 |
151 | def main():
152 | args = parse_args()
153 | timestamp = datetime.now().strftime('%m%d_%H%M%S')
154 | model_dir = 'models/{}'.format(args.name + '_' + timestamp)
155 |
156 | # check if model directory if exists
157 | if not os.path.exists(model_dir):
158 | os.makedirs(model_dir)
159 |
160 | # print configuration and save it
161 | print('Config -----')
162 | for arg in vars(args):
163 | print('{0}: {1}'.format(arg, getattr(args, arg)))
164 | print('------------')
165 |
166 | with open('{}/args.txt'.format(model_dir), 'w') as f:
167 | for arg in vars(args):
168 | print('%s: %s' %(arg, getattr(args, arg)), file=f)
169 |
170 | joblib.dump(args, '{}/args.pkl'.format(model_dir))
171 |
172 | # define loss function (criterion)
173 | if args.loss == 'BCEWithLogitsLoss':
174 | criterion = nn.BCEWithLogitsLoss().cuda()
175 | else:
176 | criterion = losses.__dict__[args.loss]().cuda()
177 |
178 | # create model
179 | print("=> creating model {}".format(args.arch))
180 | model = models.__dict__[args.arch](args.in_ch, args.out_ch, args.num_filters)
181 |
182 | if torch.cuda.is_available():
183 | cudnn.benchmark = True
184 | model = model.cuda()
185 |
186 | print(count_params(model))
187 |
188 | if args.optimizer == 'Adam':
189 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)
190 | elif args.optimizer == 'SGD':
191 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr,
192 | momentum=args.momentum, weight_decay=args.weight_decay)
193 | else:
194 | raise ValueError('optimizer not specified')
195 |
196 | scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-8)
197 |
198 | train_loader = RssraiDataLoader(which_set='train', batch_size=args.batch_size, img_size=args.img_size, shuffle=True)
199 | val_loader = RssraiDataLoader(which_set='val', batch_size=1, img_size=960, shuffle=True)
200 |
201 | log = pd.DataFrame(index=[], columns=[
202 | 'epoch', 'lr', 'loss', 'iou', 'val_loss', 'val_iou'
203 | ])
204 |
205 | best_loss = 0
206 | trigger = 0
207 |
208 | for epoch in range(args.epochs):
209 | print('Epoch [{0:d}/{1:d}]'.format(epoch, args.epochs))
210 |
211 | # train for one epoch
212 | train_log = train(train_loader=train_loader, model=model, criterion=criterion, optimizer=optimizer)
213 | # evaluate on validation set
214 | val_log = validate(val_loader=val_loader, model=model, criterion=criterion)
215 | # update learning rate
216 | scheduler.step()
217 |
218 | print('loss {0:.4f} - iou {1:.4f} - val_loss {2:.4f} - val_iou {3:.4f}'.format(train_log['loss'], train_log['iou'], val_log['loss'], val_log['iou']))
219 |
220 | tmp = pd.Series([
221 | epoch,
222 | args.lr,
223 | train_log['loss'],
224 | train_log['iou'],
225 | val_log['loss'],
226 | val_log['iou'],
227 | ], index=['epoch', 'lr', 'loss', 'iou', 'val_loss', 'val_iou'])
228 |
229 | save_example(model_dir, epoch, model, val_loader)
230 | log = log.append(tmp, ignore_index=True)
231 | log.to_csv(os.path.join(model_dir, 'log.csv'), index=False)
232 |
233 | trigger += 1
234 |
235 | if epoch == 0:
236 | best_loss = val_log['loss']
237 |
238 | if val_log['loss'] <= best_loss:
239 | torch.save(model.state_dict(), os.path.join(model_dir, 'model_best.pth'))
240 | best_loss = val_log['loss']
241 | print("=> saved best model")
242 | trigger = 0
243 |
244 | # early stopping
245 | if not args.early_stop is None:
246 | if trigger >= args.early_stop:
247 | print("=> early stopping")
248 | break
249 | torch.save(model.state_dict(), os.path.join(model_dir, 'model_final.pth'))
250 |
251 | torch.cuda.empty_cache()
252 |
253 |
254 | if __name__ == '__main__':
255 | main()
256 |
257 |
258 |
259 |
260 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import numpy as np
3 | import torch
4 | import matplotlib.pyplot as plt
5 |
6 |
7 | def calculate_bce_loss(masks, epsilon=1e-8):
8 | y_minus = 0
9 | y_plus = 0
10 | single_total = masks.shape[-1] * masks.shape[-2]
11 | for m in masks:
12 | black = len(np.where(m[0]==0)[0])
13 | white = single_total - black
14 | y_minus += black
15 | y_plus += white
16 |
17 | return y_minus / (y_plus + epsilon)
18 |
19 |
20 | def get_pixels(image, boarder_height, boarder_width):
21 | """
22 | Get the white pixels from an 2D image, with boarder removed defined by boarder height and boarder width
23 | :param image: 2D image array
24 | :param boarder_height: boarder height to be removed
25 | :param boarder_width: boarder width to be removed
26 | :return: points
27 | """
28 | image_height = image.shape[0]
29 | image_width = image.shape[1]
30 | masked_image = image.copy()
31 | mask = np.zeros_like(image)
32 | mask[:image_height-boarder_height, :image_width-boarder_width] = 1
33 | masked_image[mask == 0] = 0
34 | points = np.where(masked_image != 0)
35 | return points
36 |
37 |
38 | def get_files(directory, format='tif'):
39 | """
40 | To get a list of file names in one directory, especially images
41 | :param directory: a path to the directory of the image files
42 | :return: a list of all the file names in that directory
43 | """
44 | if format is 'png':
45 | file_list = glob.glob(directory + "*.png")
46 | elif format is 'tif':
47 | file_list = glob.glob(directory + "*.tif")
48 | else:
49 | raise ValueError("dataset do not support")
50 |
51 | file_list.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))
52 | return file_list
53 |
54 |
55 | def count_params(model):
56 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
57 |
58 |
59 | def save_example(folder, epoch, model, data_loader):
60 | """
61 | save prediction examples during training
62 | """
63 | if epoch == 0 or epoch % 5 == 0:
64 | model.eval()
65 | with torch.no_grad():
66 | inputs, targets, _ = next(iter(data_loader))
67 | # if torch.cuda.is_available():
68 | # inputs = inputs.cuda()
69 |
70 | preds = model(inputs)
71 | preds = preds[-1]
72 | preds = torch.sigmoid(preds)
73 |
74 | if torch.cuda.is_available():
75 | preds = preds.cpu()
76 | inputs = inputs.cpu()
77 | targets = targets.cpu()
78 |
79 | inputs, targets, preds = inputs.numpy(), targets.numpy(), preds.numpy()
80 | inputs, targets, preds = (inputs * 255).astype('uint8'), (targets * 255).astype('uint8'), (preds * 255).astype('uint8')
81 | imgs1 = inputs[:, :3, :, :].transpose(0, 2, 3, 1)
82 | imgs2 = inputs[:, 3:6, :, :].transpose(0, 2, 3, 1)
83 | targets = targets.transpose(0, 2, 3, 1)
84 | preds = preds.transpose(0, 2, 3, 1)
85 | targets = targets[:, :, :, 0]
86 | preds = preds[:, :, :, 0]
87 |
88 | fig, axs = plt.subplots(imgs1.shape[0], 4)
89 | for i in range(imgs1.shape[0]):
90 | if i == 0:
91 | axs[i, 0].set_title('I_1')
92 | axs[i, 1].set_title('I_2')
93 | axs[i, 2].set_title('M')
94 | axs[i, 3].set_title('P')
95 | for j, imgs in enumerate([imgs1, imgs2, targets, preds]):
96 | if j <= 1:
97 | axs[i, j].imshow(imgs[i])
98 | else:
99 | axs[i, j].imshow(imgs[i], cmap='gray')
100 | axs[i, j].axis('off')
101 |
102 | fig.savefig(folder + '/{}.png'.format(epoch))
103 |
104 |
105 |
106 |
107 |
108 |
--------------------------------------------------------------------------------