├── .gitignore
├── .idea
├── misc.xml
├── modules.xml
├── unet_pytorch.iml
├── vcs.xml
└── workspace.xml
├── README.md
├── assets
└── unet.png
├── crf
└── crf.py
├── data
└── membrane
│ ├── test
│ ├── 0.png
│ ├── 1.png
│ ├── 10.png
│ ├── 11.png
│ ├── 12.png
│ ├── 13.png
│ ├── 14.png
│ ├── 15.png
│ ├── 16.png
│ ├── 17.png
│ ├── 18.png
│ ├── 19.png
│ ├── 2.png
│ ├── 20.png
│ ├── 21.png
│ ├── 22.png
│ ├── 23.png
│ ├── 24.png
│ ├── 25.png
│ ├── 26.png
│ ├── 27.png
│ ├── 28.png
│ ├── 29.png
│ ├── 3.png
│ ├── 4.png
│ ├── 5.png
│ ├── 6.png
│ ├── 7.png
│ ├── 8.png
│ └── 9.png
│ └── train
│ ├── image
│ ├── 0.png
│ ├── 1.png
│ ├── 10.png
│ ├── 11.png
│ ├── 12.png
│ ├── 13.png
│ ├── 14.png
│ ├── 15.png
│ ├── 16.png
│ ├── 17.png
│ ├── 18.png
│ ├── 19.png
│ ├── 2.png
│ ├── 20.png
│ ├── 21.png
│ ├── 22.png
│ ├── 23.png
│ ├── 24.png
│ ├── 25.png
│ ├── 26.png
│ ├── 27.png
│ ├── 28.png
│ ├── 29.png
│ ├── 3.png
│ ├── 4.png
│ ├── 5.png
│ ├── 6.png
│ ├── 7.png
│ ├── 8.png
│ └── 9.png
│ └── label
│ ├── 0.png
│ ├── 1.png
│ ├── 10.png
│ ├── 11.png
│ ├── 12.png
│ ├── 13.png
│ ├── 14.png
│ ├── 15.png
│ ├── 16.png
│ ├── 17.png
│ ├── 18.png
│ ├── 19.png
│ ├── 2.png
│ ├── 20.png
│ ├── 21.png
│ ├── 22.png
│ ├── 23.png
│ ├── 24.png
│ ├── 25.png
│ ├── 26.png
│ ├── 27.png
│ ├── 28.png
│ ├── 29.png
│ ├── 3.png
│ ├── 4.png
│ ├── 5.png
│ ├── 6.png
│ ├── 7.png
│ ├── 8.png
│ └── 9.png
├── data_aug.py
├── dice_loss.py
├── load.py
├── main.py
├── model
├── __pycache__
│ └── unet.cpython-36.pyc
└── unet.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | data/
2 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/unet_pytorch.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 | 1548992939158
230 |
231 |
232 | 1548992939158
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
366 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 |
377 |
378 |
379 |
380 |
381 |
382 |
383 |
384 |
385 |
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
408 |
409 |
410 |
411 |
412 |
413 |
414 |
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 |
423 |
424 |
425 |
426 |
427 |
428 |
429 |
430 |
431 |
432 |
433 |
434 |
435 |
436 |
437 |
438 |
439 |
440 |
441 |
442 |
443 |
444 |
445 |
446 |
447 |
448 |
449 |
450 |
451 |
452 |
453 |
454 |
455 |
456 |
457 |
458 |
459 |
460 |
461 |
462 |
463 |
464 |
465 |
466 |
467 |
468 |
469 |
470 |
471 |
472 |
473 |
474 |
475 |
476 |
477 |
478 |
479 |
480 |
481 |
482 |
483 |
484 |
485 |
486 |
487 |
488 |
489 |
490 |
491 |
492 |
493 |
494 |
495 |
496 |
497 |
498 |
499 |
500 |
501 |
502 |
503 |
504 |
505 |
506 |
507 |
508 |
509 |
510 |
511 |
512 |
513 |
514 |
515 |
516 |
517 |
518 |
519 |
520 |
521 |
522 |
523 |
524 |
525 |
526 |
527 |
528 |
529 |
530 |
531 |
532 |
533 |
534 |
535 |
536 |
537 |
538 |
539 |
540 |
541 |
542 |
543 |
544 |
545 |
546 |
547 |
548 |
549 |
550 |
551 |
552 |
553 |
554 |
555 |
556 |
557 |
558 |
559 |
560 |
561 |
562 |
563 |
564 |
565 |
566 |
567 |
568 |
569 |
570 |
571 |
572 |
573 |
574 |
575 |
576 |
577 |
578 |
579 |
580 |
581 |
582 |
583 |
584 |
585 |
586 |
587 |
588 |
589 |
590 |
591 |
592 |
593 |
594 |
595 |
596 |
597 |
598 |
599 |
600 |
601 |
602 |
603 |
604 |
605 |
606 |
607 |
608 |
609 |
610 |
611 |
612 |
613 |
614 |
615 |
616 |
617 |
618 |
619 |
620 |
621 |
622 |
623 |
624 |
625 |
626 |
627 |
628 |
629 |
630 |
631 |
632 |
633 |
634 |
635 |
636 |
637 |
638 |
639 |
640 |
641 |
642 |
643 |
644 |
645 |
646 |
647 |
648 |
649 |
650 |
651 |
652 |
653 |
654 |
655 |
656 |
657 |
658 |
659 |
660 |
661 |
662 |
663 |
664 |
665 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Unet-Pytorch
2 |
3 | > Unet for skin lesion segmentation in ISIC2018 implemented with pytorch
4 |
5 | ### Network Architecture
6 |
7 | 
8 |
9 | ### Environment
10 |
11 | * Pytorch 1.0
12 | * CUDA 10.0
13 |
14 | ### Results
15 |
16 |
--------------------------------------------------------------------------------
/assets/unet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/assets/unet.png
--------------------------------------------------------------------------------
/crf/crf.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pydensecrf.densecrf as dcrf
3 |
4 | def dense_crf(img, output_probs):
5 | h = output_probs.shape[0]
6 | w = output_probs.shape[1]
7 |
8 | output_probs = np.expand_dims(output_probs, 0)
9 | output_probs = np.append(1 - output_probs, output_probs, axis=0)
10 |
11 | d = dcrf.DenseCRF2D(w, h, 2)
12 | U = -np.log(output_probs)
13 | U = U.reshape((2, -1))
14 | U = np.ascontiguousarray(U)
15 | img = np.ascontiguousarray(img)
16 |
17 | d.setUnaryEnergy(U)
18 |
19 | d.addPairwiseGaussian(sxy=20, compat=3)
20 | d.addPairwiseBilateral(sxy=30, srgb=20, rgbim=img, compat=10)
21 |
22 | Q = d.inference(5)
23 | Q = np.argmax(np.array(Q), axis=0).reshape((h, w))
24 |
25 | return Q
--------------------------------------------------------------------------------
/data/membrane/test/0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/test/0.png
--------------------------------------------------------------------------------
/data/membrane/test/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/test/1.png
--------------------------------------------------------------------------------
/data/membrane/test/10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/test/10.png
--------------------------------------------------------------------------------
/data/membrane/test/11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/test/11.png
--------------------------------------------------------------------------------
/data/membrane/test/12.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/test/12.png
--------------------------------------------------------------------------------
/data/membrane/test/13.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/test/13.png
--------------------------------------------------------------------------------
/data/membrane/test/14.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/test/14.png
--------------------------------------------------------------------------------
/data/membrane/test/15.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/test/15.png
--------------------------------------------------------------------------------
/data/membrane/test/16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/test/16.png
--------------------------------------------------------------------------------
/data/membrane/test/17.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/test/17.png
--------------------------------------------------------------------------------
/data/membrane/test/18.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/test/18.png
--------------------------------------------------------------------------------
/data/membrane/test/19.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/test/19.png
--------------------------------------------------------------------------------
/data/membrane/test/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/test/2.png
--------------------------------------------------------------------------------
/data/membrane/test/20.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/test/20.png
--------------------------------------------------------------------------------
/data/membrane/test/21.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/test/21.png
--------------------------------------------------------------------------------
/data/membrane/test/22.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/test/22.png
--------------------------------------------------------------------------------
/data/membrane/test/23.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/test/23.png
--------------------------------------------------------------------------------
/data/membrane/test/24.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/test/24.png
--------------------------------------------------------------------------------
/data/membrane/test/25.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/test/25.png
--------------------------------------------------------------------------------
/data/membrane/test/26.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/test/26.png
--------------------------------------------------------------------------------
/data/membrane/test/27.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/test/27.png
--------------------------------------------------------------------------------
/data/membrane/test/28.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/test/28.png
--------------------------------------------------------------------------------
/data/membrane/test/29.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/test/29.png
--------------------------------------------------------------------------------
/data/membrane/test/3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/test/3.png
--------------------------------------------------------------------------------
/data/membrane/test/4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/test/4.png
--------------------------------------------------------------------------------
/data/membrane/test/5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/test/5.png
--------------------------------------------------------------------------------
/data/membrane/test/6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/test/6.png
--------------------------------------------------------------------------------
/data/membrane/test/7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/test/7.png
--------------------------------------------------------------------------------
/data/membrane/test/8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/test/8.png
--------------------------------------------------------------------------------
/data/membrane/test/9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/test/9.png
--------------------------------------------------------------------------------
/data/membrane/train/image/0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/image/0.png
--------------------------------------------------------------------------------
/data/membrane/train/image/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/image/1.png
--------------------------------------------------------------------------------
/data/membrane/train/image/10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/image/10.png
--------------------------------------------------------------------------------
/data/membrane/train/image/11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/image/11.png
--------------------------------------------------------------------------------
/data/membrane/train/image/12.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/image/12.png
--------------------------------------------------------------------------------
/data/membrane/train/image/13.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/image/13.png
--------------------------------------------------------------------------------
/data/membrane/train/image/14.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/image/14.png
--------------------------------------------------------------------------------
/data/membrane/train/image/15.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/image/15.png
--------------------------------------------------------------------------------
/data/membrane/train/image/16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/image/16.png
--------------------------------------------------------------------------------
/data/membrane/train/image/17.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/image/17.png
--------------------------------------------------------------------------------
/data/membrane/train/image/18.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/image/18.png
--------------------------------------------------------------------------------
/data/membrane/train/image/19.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/image/19.png
--------------------------------------------------------------------------------
/data/membrane/train/image/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/image/2.png
--------------------------------------------------------------------------------
/data/membrane/train/image/20.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/image/20.png
--------------------------------------------------------------------------------
/data/membrane/train/image/21.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/image/21.png
--------------------------------------------------------------------------------
/data/membrane/train/image/22.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/image/22.png
--------------------------------------------------------------------------------
/data/membrane/train/image/23.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/image/23.png
--------------------------------------------------------------------------------
/data/membrane/train/image/24.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/image/24.png
--------------------------------------------------------------------------------
/data/membrane/train/image/25.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/image/25.png
--------------------------------------------------------------------------------
/data/membrane/train/image/26.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/image/26.png
--------------------------------------------------------------------------------
/data/membrane/train/image/27.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/image/27.png
--------------------------------------------------------------------------------
/data/membrane/train/image/28.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/image/28.png
--------------------------------------------------------------------------------
/data/membrane/train/image/29.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/image/29.png
--------------------------------------------------------------------------------
/data/membrane/train/image/3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/image/3.png
--------------------------------------------------------------------------------
/data/membrane/train/image/4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/image/4.png
--------------------------------------------------------------------------------
/data/membrane/train/image/5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/image/5.png
--------------------------------------------------------------------------------
/data/membrane/train/image/6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/image/6.png
--------------------------------------------------------------------------------
/data/membrane/train/image/7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/image/7.png
--------------------------------------------------------------------------------
/data/membrane/train/image/8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/image/8.png
--------------------------------------------------------------------------------
/data/membrane/train/image/9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/image/9.png
--------------------------------------------------------------------------------
/data/membrane/train/label/0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/label/0.png
--------------------------------------------------------------------------------
/data/membrane/train/label/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/label/1.png
--------------------------------------------------------------------------------
/data/membrane/train/label/10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/label/10.png
--------------------------------------------------------------------------------
/data/membrane/train/label/11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/label/11.png
--------------------------------------------------------------------------------
/data/membrane/train/label/12.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/label/12.png
--------------------------------------------------------------------------------
/data/membrane/train/label/13.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/label/13.png
--------------------------------------------------------------------------------
/data/membrane/train/label/14.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/label/14.png
--------------------------------------------------------------------------------
/data/membrane/train/label/15.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/label/15.png
--------------------------------------------------------------------------------
/data/membrane/train/label/16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/label/16.png
--------------------------------------------------------------------------------
/data/membrane/train/label/17.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/label/17.png
--------------------------------------------------------------------------------
/data/membrane/train/label/18.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/label/18.png
--------------------------------------------------------------------------------
/data/membrane/train/label/19.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/label/19.png
--------------------------------------------------------------------------------
/data/membrane/train/label/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/label/2.png
--------------------------------------------------------------------------------
/data/membrane/train/label/20.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/label/20.png
--------------------------------------------------------------------------------
/data/membrane/train/label/21.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/label/21.png
--------------------------------------------------------------------------------
/data/membrane/train/label/22.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/label/22.png
--------------------------------------------------------------------------------
/data/membrane/train/label/23.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/label/23.png
--------------------------------------------------------------------------------
/data/membrane/train/label/24.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/label/24.png
--------------------------------------------------------------------------------
/data/membrane/train/label/25.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/label/25.png
--------------------------------------------------------------------------------
/data/membrane/train/label/26.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/label/26.png
--------------------------------------------------------------------------------
/data/membrane/train/label/27.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/label/27.png
--------------------------------------------------------------------------------
/data/membrane/train/label/28.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/label/28.png
--------------------------------------------------------------------------------
/data/membrane/train/label/29.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/label/29.png
--------------------------------------------------------------------------------
/data/membrane/train/label/3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/label/3.png
--------------------------------------------------------------------------------
/data/membrane/train/label/4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/label/4.png
--------------------------------------------------------------------------------
/data/membrane/train/label/5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/label/5.png
--------------------------------------------------------------------------------
/data/membrane/train/label/6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/label/6.png
--------------------------------------------------------------------------------
/data/membrane/train/label/7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/label/7.png
--------------------------------------------------------------------------------
/data/membrane/train/label/8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/label/8.png
--------------------------------------------------------------------------------
/data/membrane/train/label/9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/data/membrane/train/label/9.png
--------------------------------------------------------------------------------
/data_aug.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import random
3 | import math
4 | import torch
5 | import types
6 |
7 | # Random rotation.
8 | def random_rotate(img, mask, degrees):
9 | if isinstance(degrees, int):
10 | if degrees < 0:
11 | raise ValueError("If degrees is a single number, it must be positive.")
12 | degrees = (-degrees, degrees)
13 | else:
14 | if len(degrees) != 2:
15 | raise ValueError("If degrees is a sequence, it must be of len 2.")
16 | degrees = degrees
17 | random_degree = random.randint(degrees[0], degrees[1])
18 | img_rotate = img.rotate(random_degree)
19 | mask_rotate = mask.rotate(random_degree)
20 | return img_rotate, mask_rotate
21 |
22 | # Apply random cropping to original images, and resize the cropped result to a fixed size.
23 | def random_resized_crop(img, mask, size=(512, 512), scale=(0.6, 1.0), ratio=(4. / 5., 5. / 4.)):
24 | area = img.size[0] * img.size[1]
25 | target_area = random.uniform(*scale) * area
26 | aspect_ratio = random.uniform(*ratio)
27 | w = int(round(math.sqrt(target_area * aspect_ratio)))
28 | h = int(round(math.sqrt(target_area / aspect_ratio)))
29 | w = min(w, img.size[0])
30 | h = min(h, img.size[1])
31 | axis_x = round(random.uniform(0, img.size[0] - w))
32 | axis_y = round(random.uniform(0, img.size[1] - h))
33 | box = (axis_x, axis_y, axis_x + w, axis_y + h)
34 | print(box)
35 | img_roi = img.crop(box)
36 | img_roi = img_roi.resize(size)
37 | mask_roi = mask.crop(box)
38 | mask_roi = mask_roi.resize(size)
39 | return img_roi, mask_roi
40 |
41 | # Apply random horizontal flip to original images with 0.5 as probability.
42 | def random_horizontal_flip(img, mask):
43 | if_trans = True if random.random() >= 0.5 else False
44 | if if_trans:
45 | img = img.transpose(Image.FLIP_LEFT_RIGHT)
46 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
47 | return img, mask
48 |
49 | # Apply random vertical flip to original images with 0.5 as probability.
50 | def random_vertical_flip(img, mask):
51 | if_trans = True if random.random() >= 0.5 else False
52 | if if_trans:
53 | img = img.transpose(Image.FLIP_TOP_BOTTOM)
54 | mask = mask.transpose(Image.FLIP_TOP_BOTTOM)
55 | return img, mask
56 |
57 | def image_resize(img, mask, size=(512, 512)):
58 | return img.resize(size), mask.resize(size)
59 |
60 | def image_to_tensor(img, mask):
61 | img = img / 255.0
62 | mask = mask / 255.0
63 | img = torch.from_numpy(img).permute(2, 0, 1)
64 | mask = torch.from_numpy(mask).unsqueeze(0)
65 | return img, mask
66 |
67 |
68 |
69 | #
70 | # def compose(image, mask, trans_opt=[]):
71 | # for i in range(len(trans_opt)):
72 | # trans_func = trans_opt[i]
73 | # if not isinstance(trans_func, types.FunctionType):
74 | # assert "Invalid function type."
75 | # return image, mask
76 | # image, mask = trans_func(image, mask)
--------------------------------------------------------------------------------
/dice_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Function, Variable
3 |
4 | class DiceCoeff(Function):
5 | """Dice coeff for individual examples"""
6 |
7 | def forward(self, input, target):
8 | self.save_for_backward(input, target)
9 | eps = 0.0001
10 | self.inter = torch.dot(input.view(-1), target.view(-1))
11 | self.union = torch.sum(input) + torch.sum(target) + eps
12 |
13 | t = (2 * self.inter.float() + eps) / self.union.float()
14 | return t
15 |
16 | # This function has only a single output, so it gets only one gradient
17 | def backward(self, grad_output):
18 |
19 | input, target = self.saved_variables
20 | grad_input = grad_target = None
21 |
22 | if self.needs_input_grad[0]:
23 | grad_input = grad_output * 2 * (target * self.union - self.inter) \
24 | / (self.union * self.union)
25 | if self.needs_input_grad[1]:
26 | grad_target = None
27 |
28 | return grad_input, grad_target
29 |
30 |
31 | def dice_coeff(input, target):
32 | """Dice coeff for batches"""
33 | if input.is_cuda:
34 | s = torch.FloatTensor(1).cuda().zero_()
35 | else:
36 | s = torch.FloatTensor(1).zero_()
37 |
38 | for i, c in enumerate(zip(input, target)):
39 | s = s + 1 - DiceCoeff().forward(c[0], c[1])
40 |
41 | return s / (i + 1)
--------------------------------------------------------------------------------
/load.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision import transforms
3 | import torch.utils.data as Data
4 | import os
5 | import numpy as np
6 | from glob import glob
7 | from PIL import Image
8 | from data_aug import *
9 |
10 |
11 | class CustomDataset(Data.Dataset):
12 | def __len__(self):
13 | return len(self._img_frames)
14 |
15 | def __init__(self, image_frames, mask_frames, transform=None):
16 | super(CustomDataset, self).__init__()
17 | self._img_frames = image_frames
18 | self._mask_frames = mask_frames
19 | self._transform = transform
20 |
21 | def __getitem__(self, index):
22 | img_path = self._img_frames[index]
23 | mask_path = self._mask_frames[index]
24 | img = Image.open(img_path)
25 | mask = Image.open(mask_path)
26 |
27 | if self._transform is not None:
28 | img, mask = self._transform(img, mask)
29 | img, mask = image_resize(img, mask, (512, 512))
30 |
31 | img = np.array(img).astype("float32")
32 | mask = np.array(mask).astype("float32")
33 | img, mask = image_to_tensor(img, mask)
34 | return (img, mask)
35 |
36 |
37 | class MyDataLoader(object):
38 | def __init__(self, image_path, mask_path, split_ratio=0.05,transforms=None):
39 | img_frames = sorted(glob(os.path.join(image_path, "*")))
40 | mask_frames = sorted(glob(os.path.join(mask_path, "*")))
41 | self._split_ratio = split_ratio
42 | train_frames, val_frames = self.split_dataset(img_frames, mask_frames)
43 | self._train_dataset = CustomDataset(train_frames['image'], train_frames['mask'], transforms)
44 | self._val_dataset = CustomDataset(val_frames['image'], val_frames['mask'], transforms)
45 |
46 | def split_dataset(self, img_frames, mask_frames):
47 | total_len = len(img_frames)
48 | val_len = round(total_len * self._split_ratio)
49 | train_img_frames = img_frames[:-val_len]
50 | train_mask_frames = mask_frames[:-val_len]
51 | val_img_frames = img_frames[-val_len:]
52 | val_mask_frames = mask_frames[-val_len:]
53 | return {"image":train_img_frames, "mask":train_mask_frames}, {"image":val_img_frames, "mask":val_mask_frames}
54 |
55 | def get_train_dataloader(self, batch_size=4, shuffle=True, num_works=4):
56 | return Data.DataLoader(self._train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_works)
57 |
58 | def get_val_dataloader(self, batch_size=4, num_works=4):
59 | return Data.DataLoader(self._val_dataset, batch_size=batch_size, num_workers=num_works)
60 |
61 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import torch.optim as optim
2 | import torch.nn as nn
3 | from model.unet import UNet
4 | from load import *
5 | from dice_loss import *
6 | import utils
7 |
8 |
9 | def eval_net(net, dataloader):
10 | """Evaluation without the densecrf with the dice coefficient"""
11 |
12 | # choose the device gpu / cpu
13 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
14 | # net
15 | net = net.to(device)
16 | net.eval()
17 | tot = 0
18 | for (img, mask) in dataloader:
19 | inputs, targets = img.to(device), mask.to(device)
20 | outputs = net(inputs)
21 | len = list(outputs.size())[0]
22 |
23 | for i in range(len):
24 | mask_pred = outputs[i]
25 | true_mask = targets[i]
26 | mask_pred = (mask_pred > 0.5).float()
27 | tot += dice_coeff(mask_pred, true_mask).item()
28 | return tot / (i + 1)
29 |
30 |
31 |
32 | def train_net(net, epochs=10, batch_size=4,lr=0.1, val_percent=0.05, save_cp=True):
33 | # choose the device gpu / cpu
34 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
35 |
36 | # net
37 | net = net.to(device)
38 | net.apply(utils.weights_initializer)
39 |
40 | # dataloader
41 | dataloader = MyDataLoader(image_path="./data/ISIC2018/image/", mask_path="./data/ISIC2018/mask/")
42 | train_loader = dataloader.get_train_dataloader(batch_size=4, shuffle=True, num_works=0)
43 | val_loader = dataloader.get_val_dataloader(batch_size=4, num_works=0)
44 |
45 | # optimizer and criterion
46 | optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
47 | criterion = nn.BCELoss()
48 |
49 | dir_checkpoint = './checkpoints/'
50 |
51 | # train
52 | for epoch in range(epochs):
53 | epoch_train_loss = 0.0
54 | net.train()
55 | for (img, mask) in train_loader:
56 | inputs, targets = img.to(device), mask.to(device)
57 | optimizer.zero_grad()
58 | outputs = net(inputs)
59 | loss = criterion(outputs, targets)
60 | loss.backward()
61 | optimizer.step()
62 | epoch_train_loss += loss.item()
63 | val_dice_loss = eval_net(net, val_loader)
64 | print('Epoch %d, total train loss is %.5f, test dice loss is %.5f' % ((epoch + 1),epoch_train_loss, val_dice_loss) )
65 | if save_cp:
66 | torch.save(net.state_dict(),
67 | dir_checkpoint + 'CP{}.pth'.format(epoch + 1))
68 | print('Checkpoint {} saved !'.format(epoch + 1))
69 |
70 |
71 | if __name__ == '__main__':
72 | net = UNet(n_channels=3, n_classes=1)
73 | train_net(net)
74 |
--------------------------------------------------------------------------------
/model/__pycache__/unet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoshuaSXA/Unet-Pytorch/7e5452dfbc1276bbf3be1530831e1b13bad02229/model/__pycache__/unet.cpython-36.pyc
--------------------------------------------------------------------------------
/model/unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class double_conv(nn.Module):
7 |
8 | def __init__(self, in_dim, out_dim):
9 | super(double_conv, self).__init__()
10 | self.conv = nn.Sequential(
11 | nn.Conv2d(in_dim, out_dim, 3, padding=1),
12 | nn.BatchNorm2d(out_dim),
13 | nn.ReLU(),
14 | nn.Conv2d(out_dim, out_dim, 3, padding=1),
15 | nn.BatchNorm2d(out_dim),
16 | nn.ReLU()
17 | )
18 |
19 | def forward(self, x):
20 | x = self.conv(x)
21 | return x
22 |
23 |
24 | class inconv(nn.Module):
25 | def __init__(self, in_dim, out_dim):
26 | super(inconv, self).__init__()
27 | self.conv = double_conv(in_dim, out_dim)
28 |
29 | def forward(self, x):
30 | x = self.conv(x)
31 | return x
32 |
33 |
34 | class down(nn.Module):
35 | def __init__(self, in_dim, out_dim):
36 | super(down, self).__init__()
37 | self.mpconv = nn.Sequential(
38 | nn.MaxPool2d(2),
39 | double_conv(in_dim, out_dim)
40 | )
41 |
42 | def forward(self, x):
43 | x = self.mpconv(x)
44 | return x
45 |
46 |
47 | class up(nn.Module):
48 | def __init__(self, in_dim, out_dim):
49 | super(up, self).__init__()
50 | self.up = nn.ConvTranspose2d(in_dim // 2, in_dim // 2, 2, stride=2)
51 | self.conv = double_conv(in_dim, out_dim)
52 |
53 | def forward(self, x1, x2):
54 | x1 = self.up(x1)
55 |
56 | # input is CHW
57 | diffY = x2.size()[2] - x1.size()[2]
58 | diffX = x2.size()[3] - x1.size()[3]
59 |
60 | x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2,
61 | diffY // 2, diffY - diffY // 2))
62 |
63 | x = torch.cat([x2, x1], dim=1)
64 | x = self.conv(x)
65 | return x
66 |
67 |
68 | class outconv(nn.Module):
69 | def __init__(self, in_dim, out_dim):
70 | super(outconv, self).__init__()
71 | self.conv = nn.Conv2d(in_dim, out_dim, 1)
72 |
73 | def forward(self, x):
74 | x = self.conv(x)
75 | return x
76 |
77 |
78 |
79 | class UNet(nn.Module):
80 | def __init__(self, n_channels, n_classes):
81 | super(UNet, self).__init__()
82 | self.inc = inconv(n_channels, 64)
83 | # down stream
84 | self.down1 = down(64, 128)
85 | self.down2 = down(128, 256)
86 | self.down3 = down(256, 512)
87 | self.down4 = down(512, 512)
88 | # up stream
89 | self.up1 = up(1024, 256)
90 | self.up2 = up(512, 128)
91 | self.up3 = up(256, 64)
92 | self.up4 = up(128, 64)
93 | self.outc = outconv(64, n_classes)
94 |
95 | def forward(self, x):
96 | x1 = self.inc(x)
97 | x2 = self.down1(x1)
98 | x3 = self.down2(x2)
99 | x4 = self.down3(x3)
100 | x5 = self.down4(x4)
101 | x = self.up1(x5, x4)
102 | x = self.up2(x, x3)
103 | x = self.up3(x, x2)
104 | x = self.up4(x, x1)
105 | x = self.outc(x)
106 | return F.sigmoid(x)
107 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 |
5 | def weights_initializer(model):
6 | classname = model.__class__.__name__
7 | if classname.find("Conv2d") != -1:
8 | nn.init.normal_(model.weight.data, mean=0, std=0.01)
9 | nn.init.constant_(model.bias.data, 0.0)
10 | elif classname.find('Linear') != -1:
11 | nn.init.normal_(model.weight.data, mean=0, std=0.01)
12 | nn.init.constant_(model.bias.data, 0.0)
13 |
14 | def show_image(img, img_size=(512, 512)):
15 | img = img.view(*img_size)
16 | img = img.numpy()
17 | plt.axis("off")
18 | plt.imshow(img)
19 | plt.show()
20 |
21 |
22 |
--------------------------------------------------------------------------------