├── .gitattributes
├── .gitignore
├── LICENSE
├── README.md
├── cs2net.md
├── dataloader
├── MRABrainLoader.py
├── __init__.py
├── drive.py
├── octa.py
├── padova1.py
├── padova2.py
└── stare.py
├── model
├── __init__.py
├── csnet.py
└── csnet_3d.py
├── predict.py
├── predict3d.py
├── train.py
├── train3d.py
└── utils
├── __init__.py
├── dice_loss_single_class.py
├── evaluation_metrics.py
├── evaluation_metrics3D.py
├── losses.py
├── misc.py
├── train_metrics.py
└── visualize.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | .DS_Store
3 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 ineedzx
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CS-Net: Channel and Spatial Attention Network for Curvilinear Structure Segmentation
2 |
3 | Implementation of [CS-Net: Channel and Spatial Attention Network for Curvilinear Structure Segmentation](https://link.springer.com/chapter/10.1007/978-3-030-32239-7_80)
4 |
5 | For the details of 3D extended version of CS-Net, please refer to [CS2-Net: Deep Learning Segmentation of Curvilinear Structures in Medical Imaging](cs2net.md)
6 |
7 | ---
8 |
9 | ## Overview
10 |
11 |
12 |
14 |
15 | The main contribution of this work is the publication of two scarce datasets in the medical image field. Plesae click the link below to access the details and source data. [](http://www.imed-lab.com/?p=16073)
16 |
17 | ## Requirements
18 |
19 |     
20 |
21 | The attention module was implemented based on [DANet](https://github.com/junfu1115/DANet). The difference between the proposed module and the original block is that we added a new 1x3 and 3x1 kernel convolution layer into spatial attention module. Plese refer to the paper for details.
22 |
23 | ## Get Started
24 |
25 | Using the ```train.py``` and ```predict.py``` to train and test the model on your own dataset, respectively.
26 |
27 | ## Examples
28 |
29 | - Vessel segmentation on Fundus
30 |
31 |
32 |
33 |
34 |
35 | - Vessel segmentation on OCT-A images
36 |
37 |
38 |
39 |
40 |
41 | - Nerve fiber tracing on CCM
42 |
43 |
44 |
45 |
46 |
47 | ## Citation
48 |
49 | ```
50 | @inproceedings{mou2019cs,
51 | title={CS-Net: channel and spatial attention network for curvilinear structure segmentation},
52 | author={Mou, Lei and Zhao, Yitian and Chen, Li and Cheng, Jun and Gu, Zaiwang and Hao, Huaying and Qi, Hong and Zheng, Yalin and Frangi, Alejandro and Liu, Jiang},
53 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
54 | pages={721--730},
55 | year={2019},
56 | organization={Springer}
57 | }
58 | ```
59 |
60 |
61 |
62 | ## Useful Links
63 |
64 | | DRIVE | http://www.isi.uu.nl/Research/Databases/DRIVE/ |
65 | | :------------- | :---------------------------------------------------------- |
66 | | **STARE** | **http://www.ces.clemson.edu/ahoover/stare/** |
67 | | **IOSTAR** | **http://www.retinacheck.org/** |
68 | | **ToF MIDAS** | **http://insight-journal.org/midas/community/view/21** |
69 | | **Synthetic** | **https://github.com/giesekow/deepvesselnet/wiki/Datasets** |
70 | | **VascuSynth** | **http://vascusynth.cs.sfu.ca/Data.html** |
71 |
--------------------------------------------------------------------------------
/cs2net.md:
--------------------------------------------------------------------------------
1 |
2 | # CS2-Net: Deep Learning Segmentation of Curvilinear Structures in Medical Imaging
3 |
4 | Implementation of [CS2-Net MedIA 2020](https://www.sciencedirect.com/science/article/pii/S1361841520302383)
5 |
6 | ---
7 |
8 | ## Overview
9 |
10 |
11 |
13 |
14 | ## Requirements
15 |
16 |   
17 |
18 | ## Get Started
19 |
20 | - ```train3d.py``` is used to train the 3D segmentation network.
21 |
22 | - ```predict3d.py``` is used to test the trained model.
23 |
24 | - Please note that you should change the dataloader definition in ```train3d.py```.
25 |
26 | ## Examples
27 |
28 | - MRA brain vessel segmentation
29 |
30 |
31 |
32 |
33 |
34 | - Synthetic & VascuSynth
35 |
36 |
37 |
38 |
39 |
40 | ## Citation
41 |
42 | ```
43 | @article{mou2020cs2,
44 | title={CS2-Net: Deep Learning Segmentation of Curvilinear Structures in Medical Imaging},
45 | author={Mou, Lei and Zhao, Yitian and Fu, Huazhu and Liux, Yonghuai and Cheng, Jun and Zheng, Yalin and Su, Pan and Yang, Jianlong and Chen, Li and Frangi, Alejandro F and others},
46 | journal={Medical Image Analysis},
47 | pages={101874},
48 | year={2020},
49 | publisher={Elsevier}
50 | }
51 | ```
52 |
53 |
54 |
55 | #### Corrections to: CS2-Net- Deep learning segmentation of curvilinear structures in medical imaging
56 |
57 | The original comparison results in Table 8 on page 14 are:
58 |
59 |
60 |
61 | The corrected comparison results are:
62 |
63 |
64 |
65 | ## Useful Links
66 |
67 | | DRIVE | http://www.isi.uu.nl/Research/Databases/DRIVE/ |
68 | | :------------- | :---------------------------------------------------------- |
69 | | **STARE** | **http://www.ces.clemson.edu/ahoover/stare/** |
70 | | **IOSTAR** | **http://www.retinacheck.org/** |
71 | | **ToF MIDAS** | **http://insight-journal.org/midas/community/view/21** |
72 | | **Synthetic** | **https://github.com/giesekow/deepvesselnet/wiki/Datasets** |
73 | | **VascuSynth** | **http://vascusynth.cs.sfu.ca/Data.html** |
74 |
75 |
--------------------------------------------------------------------------------
/dataloader/MRABrainLoader.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import os
3 | import glob
4 | import torch
5 | from torch.utils.data import Dataset
6 | from torchvision import transforms
7 | import random
8 | import warnings
9 | import SimpleITK as sitk
10 | import numpy as np
11 | from scipy.ndimage import rotate, map_coordinates, gaussian_filter
12 |
13 | warnings.filterwarnings('ignore')
14 |
15 |
16 | def load_dataset(root_dir, train=True):
17 | images = []
18 | groundtruth = []
19 | if train:
20 | sub_dir = 'training'
21 | else:
22 | sub_dir = 'test'
23 | images_path = os.path.join(root_dir, sub_dir, 'images')
24 | groundtruth_path = os.path.join(root_dir, sub_dir, 'mesh_label')
25 |
26 | for file in glob.glob(os.path.join(images_path, '*.mha')):
27 | image_name = os.path.basename(file)[:-8]
28 | groundtruth_name = image_name + '.mha'
29 |
30 | images.append(file)
31 | groundtruth.append(os.path.join(groundtruth_path, groundtruth_name))
32 |
33 | return images, groundtruth
34 |
35 |
36 | class Data(Dataset):
37 | def __init__(self,
38 | root_dir,
39 | train=True,
40 | rotate=40,
41 | flip=True,
42 | random_crop=True,
43 | scale1=512):
44 |
45 | self.root_dir = root_dir
46 | self.train = train
47 | self.rotate = rotate
48 | self.flip = flip
49 | self.random_crop = random_crop
50 | self.transform = transforms.ToTensor()
51 | self.resize = scale1
52 | self.images, self.groundtruth = load_dataset(self.root_dir, self.train)
53 |
54 | def __len__(self):
55 | return len(self.images)
56 |
57 | def RandomCrop(self, image, label, crop_factor=(0, 0, 0)):
58 | """
59 | Make a random crop of the whole volume
60 | :param image:
61 | :param label:
62 | :param crop_factor: The crop size that you want to crop
63 | :return:
64 | """
65 | w, h, d = image.shape
66 | z = random.randint(0, w - crop_factor[0])
67 | y = random.randint(0, h - crop_factor[1])
68 | x = random.randint(0, d - crop_factor[2])
69 |
70 | image = image[z:z + crop_factor[0], y:y + crop_factor[1], x:x + crop_factor[2]]
71 | label = label[z:z + crop_factor[0], y:y + crop_factor[1], x:x + crop_factor[2]]
72 | return image, label
73 |
74 | def __getitem__(self, idx):
75 | img_path = self.images[idx]
76 | gt_path = self.groundtruth[idx]
77 |
78 | image = sitk.ReadImage(img_path)
79 | image = sitk.GetArrayFromImage(image).astype(np.float32) # [x,y,z] -> [z,y,x]
80 |
81 | label = sitk.ReadImage(gt_path)
82 | # if use CE loss, type: astype(np.int64), or use MSE type: astype(np.float32)
83 | label = sitk.GetArrayFromImage(label).astype(np.int64) # [x,y,z] -> [z,y,x]
84 |
85 | image, label = self.RandomCrop(image, label, crop_factor=(64, 104, 112)) # [z,y,x]
86 |
87 | if self.train:
88 | image = torch.from_numpy(np.ascontiguousarray(image)).unsqueeze(0)
89 | label = torch.from_numpy(np.ascontiguousarray(label)).unsqueeze(0)
90 |
91 | else:
92 | image = torch.from_numpy(np.ascontiguousarray(image)).unsqueeze(0)
93 | label = torch.from_numpy(np.ascontiguousarray(label)).unsqueeze(0)
94 |
95 | image = image / 255
96 | label = label // 255
97 |
98 | return image, label
99 |
--------------------------------------------------------------------------------
/dataloader/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iMED-Lab/CS-Net/25079c377f8db4b57f25c0adc7b70d1a02a3ee62/dataloader/__init__.py
--------------------------------------------------------------------------------
/dataloader/drive.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import os
3 | import glob
4 | from torch.utils.data import Dataset
5 | from torchvision import transforms
6 | from PIL import Image, ImageEnhance
7 | from utils.misc import ReScaleSize
8 | import random
9 | import warnings
10 | import numpy as np
11 | import scipy.misc as misc
12 |
13 | warnings.filterwarnings('ignore')
14 |
15 |
16 | def load_dataset(root_dir, train=True):
17 | images = []
18 | groundtruth = []
19 | if train:
20 | sub_dir = 'training'
21 | else:
22 | sub_dir = 'test'
23 | images_path = os.path.join(root_dir, sub_dir, 'images')
24 | groundtruth_path = os.path.join(root_dir, sub_dir, '1st_manual')
25 |
26 | for file in glob.glob(os.path.join(images_path, '*.tif')):
27 | image_name = os.path.basename(file)
28 | groundtruth_name = image_name[:3] + 'manual1.gif'
29 |
30 | images.append(os.path.join(images_path, image_name))
31 | groundtruth.append(os.path.join(groundtruth_path, groundtruth_name))
32 |
33 | return images, groundtruth
34 |
35 |
36 | class Data(Dataset):
37 | def __init__(self,
38 | root_dir,
39 | train=True,
40 | rotate=40,
41 | flip=True,
42 | random_crop=True,
43 | scale1=512):
44 |
45 | self.root_dir = root_dir
46 | self.train = train
47 | self.rotate = rotate
48 | self.flip = flip
49 | self.random_crop = random_crop
50 | self.transform = transforms.ToTensor()
51 | self.resize = scale1
52 | self.images, self.groundtruth = load_dataset(self.root_dir, self.train)
53 |
54 | def __len__(self):
55 | return len(self.images)
56 |
57 | def RandomCrop(self, image, label, crop_size):
58 | crop_width, crop_height = crop_size
59 | w, h = image.size
60 | left = random.randint(0, w - crop_width)
61 | top = random.randint(0, h - crop_height)
62 | right = left + crop_width
63 | bottom = top + crop_height
64 | new_image = image.crop((left, top, right, bottom))
65 | new_label = label.crop((left, top, right, bottom))
66 | return new_image, new_label
67 |
68 | def RandomEnhance(self, image):
69 | value = random.uniform(-2, 2)
70 | random_seed = random.randint(1, 4)
71 | if random_seed == 1:
72 | img_enhanceed = ImageEnhance.Brightness(image)
73 | elif random_seed == 2:
74 | img_enhanceed = ImageEnhance.Color(image)
75 | elif random_seed == 3:
76 | img_enhanceed = ImageEnhance.Contrast(image)
77 | else:
78 | img_enhanceed = ImageEnhance.Sharpness(image)
79 | image = img_enhanceed.enhance(value)
80 | return image
81 |
82 | def rescale(self, img, re_size):
83 | w, h = img.size
84 | min_len = min(w, h)
85 | new_w, new_h = min_len, min_len
86 | scale_w = (w - new_w) // 2
87 | scale_h = (h - new_h) // 2
88 | box = (scale_w, scale_h, scale_w + new_w, scale_h + new_h)
89 | img = img.crop(box)
90 | img = img.resize((re_size, re_size))
91 | return img
92 |
93 | def __getitem__(self, idx):
94 | img_path = self.images[idx]
95 | gt_path = self.groundtruth[idx]
96 | image = Image.open(img_path)
97 | label = Image.open(gt_path)
98 |
99 | image = self.rescale(image, self.resize)
100 | label = self.rescale(label, self.resize)
101 |
102 | if self.train:
103 | # augumentation
104 | angel = random.randint(-self.rotate, self.rotate)
105 | image = image.rotate(angel)
106 | label = label.rotate(angel)
107 |
108 | if random.random() > 0.5:
109 | image = self.RandomEnhance(image)
110 |
111 | image, label = self.RandomCrop(image, label, crop_size=[self.resize, self.resize])
112 |
113 | # flip
114 | if self.flip and random.random() > 0.5:
115 | image = image.transpose(Image.FLIP_LEFT_RIGHT)
116 | label = label.transpose(Image.FLIP_LEFT_RIGHT)
117 |
118 | # img_size = image.size
119 | # if img_size[0] != self.resize:
120 | # image = image.resize((self.resize, self.resize))
121 | # label = label.resize((self.resize, self.resize))
122 | else:
123 | img_size = image.size
124 | if img_size[0] != self.resize:
125 | image = image.resize((self.resize, self.resize))
126 | label = label.resize((self.resize, self.resize))
127 |
128 | image = self.transform(image)
129 | label = self.transform(label)
130 |
131 | return image, label
132 |
--------------------------------------------------------------------------------
/dataloader/octa.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import os
3 | import glob
4 | from torch.utils.data import Dataset
5 | from torchvision import transforms
6 | from PIL import Image, ImageEnhance, ImageOps
7 | import random
8 | import warnings
9 |
10 | warnings.filterwarnings('ignore')
11 |
12 |
13 | def load_dataset(root_dir, train=True):
14 | labels = []
15 | images = []
16 | if train:
17 | sub_dir = 'training'
18 | else:
19 | sub_dir = 'test'
20 | label_path = os.path.join(root_dir, sub_dir, 'label')
21 | image_path = os.path.join(root_dir, sub_dir, 'images')
22 |
23 | for file in glob.glob(os.path.join(image_path, '*.tif')):
24 | image_name = os.path.basename(file)
25 | label_name = image_name[:-4] + '_nerve_ann.tif'
26 | labels.append(os.path.join(label_path, label_name))
27 | images.append(os.path.join(image_path, image_name))
28 | return images, labels
29 |
30 |
31 | class Data(Dataset):
32 | def __init__(self,
33 | root_dir,
34 | train=True,
35 | rotate=45,
36 | flip=True,
37 | random_crop=True,
38 | scale1=512):
39 |
40 | self.root_dir = root_dir
41 | self.train = train
42 | self.rotate = rotate
43 | self.flip = flip
44 | self.random_crop = random_crop
45 | self.transform = transforms.ToTensor()
46 | self.resize = scale1
47 | self.images, self.groundtruth = load_dataset(self.root_dir, self.train)
48 |
49 | def __len__(self):
50 | return len(self.images)
51 |
52 | def RandomCrop(self, image, label, crop_size):
53 | crop_width, crop_height = crop_size
54 | w, h = image.size
55 | left = random.randint(0, w - crop_width)
56 | top = random.randint(0, h - crop_height)
57 | right = left + crop_width
58 | bottom = top + crop_height
59 | new_image = image.crop((left, top, right, bottom))
60 | new_label = label.crop((left, top, right, bottom))
61 | return new_image, new_label
62 |
63 | def RandomEnhance(self, image):
64 | value = random.uniform(-2, 2)
65 | random_seed = random.randint(1, 4)
66 | if random_seed == 1:
67 | img_enhanceed = ImageEnhance.Brightness(image)
68 | elif random_seed == 2:
69 | img_enhanceed = ImageEnhance.Color(image)
70 | elif random_seed == 3:
71 | img_enhanceed = ImageEnhance.Contrast(image)
72 | else:
73 | img_enhanceed = ImageEnhance.Sharpness(image)
74 | image = img_enhanceed.enhance(value)
75 | return image
76 |
77 | def Crop(self, image):
78 | left = 261
79 | top = 1
80 | right = 1110
81 | bottom = 850
82 | image = image.crop((left, top, right, bottom))
83 | return image
84 |
85 | def ReScaleSize(self, image, re_size=512):
86 | w, h = image.size
87 | max_len = max(w, h)
88 | new_w, new_h = max_len, max_len
89 | delta_w = new_w - w
90 | delta_h = new_h - h
91 | padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
92 | image = ImageOps.expand(image, padding, fill=0)
93 | # origin_w, origin_h = w, h
94 | image = image.resize((re_size, re_size))
95 | return image # , origin_w, origin_h
96 |
97 | def __getitem__(self, idx):
98 | img_path = self.images[idx]
99 | gt_path = self.groundtruth[idx]
100 |
101 | image = Image.open(img_path)
102 | label = Image.open(gt_path)
103 | image = self.Crop(image)
104 | label = self.Crop(label)
105 | image = self.ReScaleSize(image, self.resize)
106 | label = self.ReScaleSize(label, self.resize)
107 |
108 | if self.train:
109 | # augumentation
110 | angel = random.randint(-self.rotate, self.rotate)
111 | image = image.rotate(angel)
112 | label = label.rotate(angel)
113 |
114 | if random.random() > 0.5:
115 | image = self.RandomEnhance(image)
116 |
117 | image, label = self.RandomCrop(image, label, crop_size=[self.resize, self.resize])
118 |
119 | # flip
120 | if self.flip and random.random() > 0.5:
121 | image = image.transpose(Image.FLIP_LEFT_RIGHT)
122 | label = label.transpose(Image.FLIP_LEFT_RIGHT)
123 |
124 | else:
125 | img_size = image.size
126 | if img_size[0] != self.resize:
127 | image = image.resize((self.resize, self.resize))
128 | label = label.resize((self.resize, self.resize))
129 |
130 | image = self.transform(image)
131 | label = self.transform(label)
132 |
133 | return image, label
134 |
--------------------------------------------------------------------------------
/dataloader/padova1.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import os
3 | import glob
4 | from torch.utils.data import Dataset
5 | from torchvision import transforms
6 | from PIL import Image, ImageEnhance
7 | import random
8 | import warnings
9 |
10 | warnings.filterwarnings('ignore')
11 |
12 |
13 | def load_dataset(root_dir, train=True):
14 | labels = []
15 | images = []
16 | if train:
17 | sub_dir = 'training'
18 | else:
19 | sub_dir = 'test'
20 | label_path = os.path.join(root_dir, sub_dir, 'label2')
21 | image_path = os.path.join(root_dir, sub_dir, 'images')
22 |
23 | for file in glob.glob(os.path.join(image_path, '*.tif')):
24 | image_name = os.path.basename(file)
25 | label_name = image_name[:-4] + '_centerline_overlay.tif'
26 | labels.append(os.path.join(label_path, label_name))
27 | images.append(os.path.join(image_path, image_name))
28 | return images, labels
29 |
30 |
31 | class Data(Dataset):
32 | def __init__(self,
33 | root_dir,
34 | train=True,
35 | rotate=45,
36 | flip=True,
37 | random_crop=True,
38 | scale1=384):
39 |
40 | self.root_dir = root_dir
41 | self.train = train
42 | self.rotate = rotate
43 | self.flip = flip
44 | self.random_crop = random_crop
45 | self.transform = transforms.ToTensor()
46 | self.resize = scale1
47 | self.images, self.groundtruth = load_dataset(self.root_dir, self.train)
48 |
49 | def __len__(self):
50 | return len(self.images)
51 |
52 | def RandomCrop(self, image, label, crop_size):
53 | crop_width, crop_height = crop_size
54 | w, h = image.size
55 | left = random.randint(0, w - crop_width)
56 | top = random.randint(0, h - crop_height)
57 | right = left + crop_width
58 | bottom = top + crop_height
59 | new_image = image.crop((left, top, right, bottom))
60 | new_label = label.crop((left, top, right, bottom))
61 | return new_image, new_label
62 |
63 | def RandomEnhance(self, image):
64 | value = random.uniform(-2, 2)
65 | random_seed = random.randint(1, 4)
66 | if random_seed == 1:
67 | img_enhanceed = ImageEnhance.Brightness(image)
68 | elif random_seed == 2:
69 | img_enhanceed = ImageEnhance.Color(image)
70 | elif random_seed == 3:
71 | img_enhanceed = ImageEnhance.Contrast(image)
72 | else:
73 | img_enhanceed = ImageEnhance.Sharpness(image)
74 | image = img_enhanceed.enhance(value)
75 | return image
76 |
77 | def __getitem__(self, idx):
78 | img_path = self.images[idx]
79 | gt_path = self.groundtruth[idx]
80 |
81 | image = Image.open(img_path)
82 | label = Image.open(gt_path)
83 |
84 | if self.train:
85 | # augumentation
86 | angel = random.randint(-self.rotate, self.rotate)
87 | image = image.rotate(angel)
88 | label = label.rotate(angel)
89 |
90 | if random.random() > 0.5:
91 | image = self.RandomEnhance(image)
92 |
93 | image, label = self.RandomCrop(image, label, crop_size=[self.resize, self.resize])
94 |
95 | # flip
96 | if self.flip and random.random() > 0.5:
97 | image = image.transpose(Image.FLIP_LEFT_RIGHT)
98 | label = label.transpose(Image.FLIP_LEFT_RIGHT)
99 |
100 | else:
101 | img_size = image.size
102 | if img_size[0] != self.resize:
103 | image = image.resize((self.resize, self.resize))
104 | label = label.resize((self.resize, self.resize))
105 |
106 | image = self.transform(image)
107 | label = self.transform(label)
108 |
109 | return image, label
110 |
--------------------------------------------------------------------------------
/dataloader/padova2.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import os
3 | import glob
4 | from torch.utils.data import Dataset
5 | from torchvision import transforms
6 | from PIL import Image, ImageEnhance
7 | import random
8 | import warnings
9 |
10 | warnings.filterwarnings('ignore')
11 |
12 |
13 | def load_dataset(root_dir, train=True):
14 | labels = []
15 | images = []
16 | if train:
17 | sub_dir = 'training'
18 | else:
19 | sub_dir = 'test'
20 | label_path = os.path.join(root_dir, sub_dir, 'label2')
21 | image_path = os.path.join(root_dir, sub_dir, 'images')
22 |
23 | for file in glob.glob(os.path.join(image_path, '*.tif')):
24 | image_name = os.path.basename(file)
25 | label_name = image_name[:-4] + '_centerline_overlay.tif'
26 | labels.append(os.path.join(label_path, label_name))
27 | images.append(os.path.join(image_path, image_name))
28 | return images, labels
29 |
30 |
31 | class Data(Dataset):
32 | def __init__(self,
33 | root_dir,
34 | train=True,
35 | rotate=45,
36 | flip=True,
37 | random_crop=True,
38 | scale1=384):
39 |
40 | self.root_dir = root_dir
41 | self.train = train
42 | self.rotate = rotate
43 | self.flip = flip
44 | self.random_crop = random_crop
45 | self.transform = transforms.ToTensor()
46 | self.resize = scale1
47 | self.images, self.groundtruth = load_dataset(self.root_dir, self.train)
48 |
49 | def __len__(self):
50 | return len(self.images)
51 |
52 | def RandomCrop(self, image, label, crop_size):
53 | crop_width, crop_height = crop_size
54 | w, h = image.size
55 | left = random.randint(0, w - crop_width)
56 | top = random.randint(0, h - crop_height)
57 | right = left + crop_width
58 | bottom = top + crop_height
59 | new_image = image.crop((left, top, right, bottom))
60 | new_label = label.crop((left, top, right, bottom))
61 | return new_image, new_label
62 |
63 | def RandomEnhance(self, image):
64 | value = random.uniform(-2, 2)
65 | random_seed = random.randint(1, 4)
66 | if random_seed == 1:
67 | img_enhanceed = ImageEnhance.Brightness(image)
68 | elif random_seed == 2:
69 | img_enhanceed = ImageEnhance.Color(image)
70 | elif random_seed == 3:
71 | img_enhanceed = ImageEnhance.Contrast(image)
72 | else:
73 | img_enhanceed = ImageEnhance.Sharpness(image)
74 | image = img_enhanceed.enhance(value)
75 | return image
76 |
77 | def __getitem__(self, idx):
78 | img_path = self.images[idx]
79 | gt_path = self.groundtruth[idx]
80 |
81 | image = Image.open(img_path)
82 | label = Image.open(gt_path)
83 |
84 | # image = ReScaleSize(image, self.resize)
85 | # label = ReScaleSize(label, self.resize)
86 |
87 | if self.train:
88 | # augumentation
89 | angel = random.randint(-self.rotate, self.rotate)
90 | image = image.rotate(angel)
91 | label = label.rotate(angel)
92 |
93 | if random.random() > 0.5:
94 | image = self.RandomEnhance(image)
95 |
96 | image, label = self.RandomCrop(image, label, crop_size=[self.resize, self.resize])
97 |
98 | # flip
99 | if self.flip and random.random() > 0.5:
100 | image = image.transpose(Image.FLIP_LEFT_RIGHT)
101 | label = label.transpose(Image.FLIP_LEFT_RIGHT)
102 |
103 | else:
104 | img_size = image.size
105 | if img_size[0] != self.resize:
106 | image = image.resize((self.resize, self.resize))
107 | label = label.resize((self.resize, self.resize))
108 |
109 | image = self.transform(image)
110 | label = self.transform(label)
111 |
112 | return image, label
113 |
--------------------------------------------------------------------------------
/dataloader/stare.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import os
3 | import glob
4 | from torch.utils.data import Dataset
5 | from torchvision import transforms
6 | from PIL import Image, ImageEnhance
7 | from utils.misc import ReScaleSize
8 | import random
9 | import warnings
10 |
11 | warnings.filterwarnings('ignore')
12 |
13 |
14 | def load_dataset(root_dir, train=True):
15 | images = []
16 | groundtruth = []
17 | if train:
18 | sub_dir = 'training'
19 | else:
20 | sub_dir = 'test'
21 | images_path = os.path.join(root_dir, sub_dir, 'images')
22 | groundtruth_path = os.path.join(root_dir, sub_dir, 'labels-ah')
23 |
24 | for file in glob.glob(os.path.join(images_path, '*.ppm')):
25 | image_name = os.path.basename(file)
26 | groundtruth_name = image_name[:-4] + '.ah.ppm'
27 | images.append(os.path.join(images_path, image_name))
28 | groundtruth.append(os.path.join(groundtruth_path, groundtruth_name))
29 |
30 | return images, groundtruth
31 |
32 |
33 | class Data(Dataset):
34 | def __init__(self,
35 | root_dir,
36 | train=True,
37 | rotate=40,
38 | flip=True,
39 | random_crop=True,
40 | scale1=688):
41 |
42 | self.root_dir = root_dir
43 | self.train = train
44 | self.rotate = rotate
45 | self.flip = flip
46 | self.random_crop = random_crop
47 | self.transform = transforms.ToTensor()
48 | self.resize = scale1
49 | self.images, self.groundtruth = load_dataset(self.root_dir, self.train)
50 |
51 | def __len__(self):
52 | return len(self.images)
53 |
54 | def RandomCrop(self, image, label, crop_size):
55 | crop_width, crop_height = crop_size
56 | w, h = image.size
57 | left = random.randint(0, w - crop_width)
58 | top = random.randint(0, h - crop_height)
59 | right = left + crop_width
60 | bottom = top + crop_height
61 | new_image = image.crop((left, top, right, bottom))
62 | new_label = label.crop((left, top, right, bottom))
63 | return new_image, new_label
64 |
65 | def RandomEnhance(self, image):
66 | value = random.uniform(-2, 2)
67 | random_seed = random.randint(1, 4)
68 | if random_seed == 1:
69 | img_enhanceed = ImageEnhance.Brightness(image)
70 | elif random_seed == 2:
71 | img_enhanceed = ImageEnhance.Color(image)
72 | elif random_seed == 3:
73 | img_enhanceed = ImageEnhance.Contrast(image)
74 | else:
75 | img_enhanceed = ImageEnhance.Sharpness(image)
76 | image = img_enhanceed.enhance(value)
77 | return image
78 |
79 | def __getitem__(self, idx):
80 | img_path = self.images[idx]
81 | gt_path = self.groundtruth[idx]
82 | image = Image.open(img_path)
83 | label = Image.open(gt_path)
84 | image = ReScaleSize(image, self.resize)
85 | label = ReScaleSize(label, self.resize)
86 |
87 | if self.train:
88 | # augumentation
89 | angel = random.randint(-self.rotate, self.rotate)
90 | image = image.rotate(angel)
91 | label = label.rotate(angel)
92 |
93 | if random.random() > 0.5:
94 | image = self.RandomEnhance(image)
95 |
96 | image, label = self.RandomCrop(image, label, crop_size=[self.resize, self.resize])
97 |
98 | # flip
99 | if self.flip and random.random() > 0.5:
100 | image = image.transpose(Image.FLIP_LEFT_RIGHT)
101 | label = label.transpose(Image.FLIP_LEFT_RIGHT)
102 |
103 | else:
104 | img_size = image.size
105 | if img_size[0] != self.resize:
106 | image = image.resize((self.resize, self.resize))
107 | label = label.resize((self.resize, self.resize))
108 |
109 | image = self.transform(image)
110 | label = self.transform(label)
111 |
112 | return image, label
113 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iMED-Lab/CS-Net/25079c377f8db4b57f25c0adc7b70d1a02a3ee62/model/__init__.py
--------------------------------------------------------------------------------
/model/csnet.py:
--------------------------------------------------------------------------------
1 | """
2 | Channel and Spatial CSNet Network (CS-Net).
3 | """
4 | from __future__ import division
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 |
10 | def downsample():
11 | return nn.MaxPool2d(kernel_size=2, stride=2)
12 |
13 |
14 | def deconv(in_channels, out_channels):
15 | return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
16 |
17 |
18 | def initialize_weights(*models):
19 | for model in models:
20 | for m in model.modules():
21 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
22 | nn.init.kaiming_normal(m.weight)
23 | if m.bias is not None:
24 | m.bias.data.zero_()
25 | elif isinstance(m, nn.BatchNorm2d):
26 | m.weight.data.fill_(1)
27 | m.bias.data.zero_()
28 |
29 |
30 | class ResEncoder(nn.Module):
31 | def __init__(self, in_channels, out_channels):
32 | super(ResEncoder, self).__init__()
33 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
34 | self.bn1 = nn.BatchNorm2d(out_channels)
35 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
36 | self.bn2 = nn.BatchNorm2d(out_channels)
37 | self.relu = nn.ReLU(inplace=False)
38 | self.conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1)
39 |
40 | def forward(self, x):
41 | residual = self.conv1x1(x)
42 | out = self.relu(self.bn1(self.conv1(x)))
43 | out = self.relu(self.bn2(self.conv2(out)))
44 | out += residual
45 | out = self.relu(out)
46 | return out
47 |
48 |
49 | class Decoder(nn.Module):
50 | def __init__(self, in_channels, out_channels):
51 | super(Decoder, self).__init__()
52 | self.conv = nn.Sequential(
53 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
54 | nn.BatchNorm2d(out_channels),
55 | nn.ReLU(inplace=True),
56 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
57 | nn.BatchNorm2d(out_channels),
58 | nn.ReLU(inplace=True)
59 | )
60 |
61 | def forward(self, x):
62 | out = self.conv(x)
63 | return out
64 |
65 |
66 | class SpatialAttentionBlock(nn.Module):
67 | def __init__(self, in_channels):
68 | super(SpatialAttentionBlock, self).__init__()
69 | self.query = nn.Sequential(
70 | nn.Conv2d(in_channels,in_channels//8,kernel_size=(1,3), padding=(0,1)),
71 | nn.BatchNorm2d(in_channels//8),
72 | nn.ReLU(inplace=True)
73 | )
74 | self.key = nn.Sequential(
75 | nn.Conv2d(in_channels, in_channels//8, kernel_size=(3,1), padding=(1,0)),
76 | nn.BatchNorm2d(in_channels//8),
77 | nn.ReLU(inplace=True)
78 | )
79 | self.value = nn.Conv2d(in_channels, in_channels, kernel_size=1)
80 | self.gamma = nn.Parameter(torch.zeros(1))
81 | self.softmax = nn.Softmax(dim=-1)
82 |
83 | def forward(self, x):
84 | """
85 | :param x: input( BxCxHxW )
86 | :return: affinity value + x
87 | """
88 | B, C, H, W = x.size()
89 | # compress x: [B,C,H,W]-->[B,H*W,C], make a matrix transpose
90 | proj_query = self.query(x).view(B, -1, W * H).permute(0, 2, 1)
91 | proj_key = self.key(x).view(B, -1, W * H)
92 | affinity = torch.matmul(proj_query, proj_key)
93 | affinity = self.softmax(affinity)
94 | proj_value = self.value(x).view(B, -1, H * W)
95 | weights = torch.matmul(proj_value, affinity.permute(0, 2, 1))
96 | weights = weights.view(B, C, H, W)
97 | out = self.gamma * weights + x
98 | return out
99 |
100 |
101 | class ChannelAttentionBlock(nn.Module):
102 | def __init__(self, in_channels):
103 | super(ChannelAttentionBlock, self).__init__()
104 | self.gamma = nn.Parameter(torch.zeros(1))
105 | self.softmax = nn.Softmax(dim=-1)
106 |
107 | def forward(self, x):
108 | """
109 | :param x: input( BxCxHxW )
110 | :return: affinity value + x
111 | """
112 | B, C, H, W = x.size()
113 | proj_query = x.view(B, C, -1)
114 | proj_key = x.view(B, C, -1).permute(0, 2, 1)
115 | affinity = torch.matmul(proj_query, proj_key)
116 | affinity_new = torch.max(affinity, -1, keepdim=True)[0].expand_as(affinity) - affinity
117 | affinity_new = self.softmax(affinity_new)
118 | proj_value = x.view(B, C, -1)
119 | weights = torch.matmul(affinity_new, proj_value)
120 | weights = weights.view(B, C, H, W)
121 | out = self.gamma * weights + x
122 | return out
123 |
124 |
125 | class AffinityAttention(nn.Module):
126 | """ Affinity attention module """
127 |
128 | def __init__(self, in_channels):
129 | super(AffinityAttention, self).__init__()
130 | self.sab = SpatialAttentionBlock(in_channels)
131 | self.cab = ChannelAttentionBlock(in_channels)
132 | # self.conv1x1 = nn.Conv2d(in_channels * 2, in_channels, kernel_size=1)
133 |
134 | def forward(self, x):
135 | """
136 | sab: spatial attention block
137 | cab: channel attention block
138 | :param x: input tensor
139 | :return: sab + cab
140 | """
141 | sab = self.sab(x)
142 | cab = self.cab(x)
143 | out = sab + cab
144 | return out
145 |
146 |
147 | class CSNet(nn.Module):
148 | def __init__(self, classes, channels):
149 | """
150 | :param classes: the object classes number.
151 | :param channels: the channels of the input image.
152 | """
153 | super(CSNet, self).__init__()
154 | self.enc_input = ResEncoder(channels, 32)
155 | self.encoder1 = ResEncoder(32, 64)
156 | self.encoder2 = ResEncoder(64, 128)
157 | self.encoder3 = ResEncoder(128, 256)
158 | self.encoder4 = ResEncoder(256, 512)
159 | self.downsample = downsample()
160 | self.affinity_attention = AffinityAttention(512)
161 | self.attention_fuse = nn.Conv2d(512 * 2, 512, kernel_size=1)
162 | self.decoder4 = Decoder(512, 256)
163 | self.decoder3 = Decoder(256, 128)
164 | self.decoder2 = Decoder(128, 64)
165 | self.decoder1 = Decoder(64, 32)
166 | self.deconv4 = deconv(512, 256)
167 | self.deconv3 = deconv(256, 128)
168 | self.deconv2 = deconv(128, 64)
169 | self.deconv1 = deconv(64, 32)
170 | self.final = nn.Conv2d(32, classes, kernel_size=1)
171 | initialize_weights(self)
172 |
173 | def forward(self, x):
174 | enc_input = self.enc_input(x)
175 | down1 = self.downsample(enc_input)
176 |
177 | enc1 = self.encoder1(down1)
178 | down2 = self.downsample(enc1)
179 |
180 | enc2 = self.encoder2(down2)
181 | down3 = self.downsample(enc2)
182 |
183 | enc3 = self.encoder3(down3)
184 | down4 = self.downsample(enc3)
185 |
186 | input_feature = self.encoder4(down4)
187 |
188 | # Do Attenttion operations here
189 | attention = self.affinity_attention(input_feature)
190 |
191 | # attention_fuse = self.attention_fuse(torch.cat((input_feature, attention), dim=1))
192 | attention_fuse = input_feature + attention
193 |
194 | # Do decoder operations here
195 | up4 = self.deconv4(attention_fuse)
196 | up4 = torch.cat((enc3, up4), dim=1)
197 | dec4 = self.decoder4(up4)
198 |
199 | up3 = self.deconv3(dec4)
200 | up3 = torch.cat((enc2, up3), dim=1)
201 | dec3 = self.decoder3(up3)
202 |
203 | up2 = self.deconv2(dec3)
204 | up2 = torch.cat((enc1, up2), dim=1)
205 | dec2 = self.decoder2(up2)
206 |
207 | up1 = self.deconv1(dec2)
208 | up1 = torch.cat((enc_input, up1), dim=1)
209 | dec1 = self.decoder1(up1)
210 |
211 | final = self.final(dec1)
212 | final = F.sigmoid(final)
213 | return final
214 |
--------------------------------------------------------------------------------
/model/csnet_3d.py:
--------------------------------------------------------------------------------
1 | """
2 | 3D Channel and Spatial Attention Network (CSA-Net 3D).
3 | """
4 | from __future__ import division
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 |
10 | def downsample():
11 | return nn.MaxPool3d(kernel_size=2, stride=2)
12 |
13 |
14 | def deconv(in_channels, out_channels):
15 | return nn.ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=2)
16 |
17 |
18 | def initialize_weights(*models):
19 | for model in models:
20 | for m in model.modules():
21 | if isinstance(m, nn.Conv3d) or isinstance(m, nn.Linear):
22 | nn.init.kaiming_normal(m.weight)
23 | if m.bias is not None:
24 | m.bias.data.zero_()
25 | elif isinstance(m, nn.BatchNorm3d):
26 | m.weight.data.fill_(1)
27 | m.bias.data.zero_()
28 |
29 |
30 | class ResEncoder3d(nn.Module):
31 | def __init__(self, in_channels, out_channels):
32 | super(ResEncoder3d, self).__init__()
33 | self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)
34 | self.bn1 = nn.BatchNorm3d(out_channels)
35 | self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1)
36 | self.bn2 = nn.BatchNorm3d(out_channels)
37 | self.relu = nn.ReLU(inplace=False)
38 | self.conv1x1 = nn.Conv3d(in_channels, out_channels, kernel_size=1)
39 |
40 | def forward(self, x):
41 | residual = self.conv1x1(x)
42 | out = self.relu(self.bn1(self.conv1(x)))
43 | out = self.relu(self.bn2(self.conv2(out)))
44 | out += residual
45 | out = self.relu(out)
46 | return out
47 |
48 |
49 | class Decoder3d(nn.Module):
50 | def __init__(self, in_channels, out_channels):
51 | super(Decoder3d, self).__init__()
52 | self.conv = nn.Sequential(
53 | nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
54 | nn.BatchNorm3d(out_channels),
55 | nn.ReLU(inplace=False),
56 | nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
57 | nn.BatchNorm3d(out_channels),
58 | nn.ReLU(inplace=False)
59 | )
60 |
61 | def forward(self, x):
62 | out = self.conv(x)
63 | return out
64 |
65 |
66 | class SpatialAttentionBlock3d(nn.Module):
67 | def __init__(self, in_channels):
68 | super(SpatialAttentionBlock3d, self).__init__()
69 | self.query = nn.Conv3d(in_channels, in_channels // 8, kernel_size=(1, 3, 1), padding=(0, 1, 0))
70 | self.key = nn.Conv3d(in_channels, in_channels // 8, kernel_size=(3, 1, 1), padding=(1, 0, 0))
71 | self.judge = nn.Conv3d(in_channels, in_channels // 8, kernel_size=(1, 1, 3), padding=(0, 0, 1))
72 | self.value = nn.Conv3d(in_channels, in_channels, kernel_size=1)
73 | self.gamma = nn.Parameter(torch.zeros(1))
74 | self.softmax = nn.Softmax(dim=-1)
75 |
76 | def forward(self, x):
77 | """
78 | :param x: input( BxCxHxWxZ )
79 | :return: affinity value + x
80 | B: batch size
81 | C: channels
82 | H: height
83 | W: width
84 | D: slice number (depth)
85 | """
86 | B, C, H, W, D = x.size()
87 | # compress x: [B,C,H,W,Z]-->[B,H*W*Z,C], make a matrix transpose
88 | proj_query = self.query(x).view(B, -1, W * H * D).permute(0, 2, 1) # -> [B,W*H*D,C]
89 | proj_key = self.key(x).view(B, -1, W * H * D) # -> [B,H*W*D,C]
90 | proj_judge = self.judge(x).view(B, -1, W * H * D).permute(0, 2, 1) # -> [B,C,H*W*D]
91 |
92 | affinity1 = torch.matmul(proj_query, proj_key)
93 | affinity2 = torch.matmul(proj_judge, proj_key)
94 | affinity = torch.matmul(affinity1, affinity2)
95 | affinity = self.softmax(affinity)
96 |
97 | proj_value = self.value(x).view(B, -1, H * W * D) # -> C*N
98 | weights = torch.matmul(proj_value, affinity)
99 | weights = weights.view(B, C, H, W, D)
100 | out = self.gamma * weights + x
101 | return out
102 |
103 |
104 | class ChannelAttentionBlock3d(nn.Module):
105 | def __init__(self, in_channels):
106 | super(ChannelAttentionBlock3d, self).__init__()
107 | self.gamma = nn.Parameter(torch.zeros(1))
108 | self.softmax = nn.Softmax(dim=-1)
109 |
110 | def forward(self, x):
111 | """
112 | :param x: input( BxCxHxWxD )
113 | :return: affinity value + x
114 | """
115 | B, C, H, W, D = x.size()
116 | proj_query = x.view(B, C, -1).permute(0, 2, 1)
117 | proj_key = x.view(B, C, -1)
118 | proj_judge = x.view(B, C, -1).permute(0, 2, 1)
119 | affinity1 = torch.matmul(proj_key, proj_query)
120 | affinity2 = torch.matmul(proj_key, proj_judge)
121 | affinity = torch.matmul(affinity1, affinity2)
122 | affinity_new = torch.max(affinity, -1, keepdim=True)[0].expand_as(affinity) - affinity
123 | affinity_new = self.softmax(affinity_new)
124 | proj_value = x.view(B, C, -1)
125 | weights = torch.matmul(affinity_new, proj_value)
126 | weights = weights.view(B, C, H, W, D)
127 | out = self.gamma * weights + x
128 | return out
129 |
130 |
131 | class AffinityAttention3d(nn.Module):
132 | """ Affinity attention module """
133 |
134 | def __init__(self, in_channels):
135 | super(AffinityAttention3d, self).__init__()
136 | self.sab = SpatialAttentionBlock3d(in_channels)
137 | self.cab = ChannelAttentionBlock3d(in_channels)
138 | # self.conv1x1 = nn.Conv2d(in_channels * 2, in_channels, kernel_size=1)
139 |
140 | def forward(self, x):
141 | """
142 | sab: spatial attention block
143 | cab: channel attention block
144 | :param x: input tensor
145 | :return: sab + cab
146 | """
147 | sab = self.sab(x)
148 | cab = self.cab(x)
149 | out = sab + cab + x
150 | return out
151 |
152 |
153 | class CSNet3D(nn.Module):
154 | def __init__(self, classes, channels):
155 | """
156 | :param classes: the object classes number.
157 | :param channels: the channels of the input image.
158 | """
159 | super(CSNet3D, self).__init__()
160 | self.enc_input = ResEncoder3d(channels, 16)
161 | self.encoder1 = ResEncoder3d(16, 32)
162 | self.encoder2 = ResEncoder3d(32, 64)
163 | self.encoder3 = ResEncoder3d(64, 128)
164 | self.encoder4 = ResEncoder3d(128, 256)
165 | self.downsample = downsample()
166 | self.affinity_attention = AffinityAttention3d(256)
167 | self.attention_fuse = nn.Conv3d(256 * 2, 256, kernel_size=1)
168 | self.decoder4 = Decoder3d(256, 128)
169 | self.decoder3 = Decoder3d(128, 64)
170 | self.decoder2 = Decoder3d(64, 32)
171 | self.decoder1 = Decoder3d(32, 16)
172 | self.deconv4 = deconv(256, 128)
173 | self.deconv3 = deconv(128, 64)
174 | self.deconv2 = deconv(64, 32)
175 | self.deconv1 = deconv(32, 16)
176 | self.final = nn.Conv3d(16, classes, kernel_size=1)
177 | initialize_weights(self)
178 |
179 | def forward(self, x):
180 | enc_input = self.enc_input(x)
181 | down1 = self.downsample(enc_input)
182 |
183 | enc1 = self.encoder1(down1)
184 | down2 = self.downsample(enc1)
185 |
186 | enc2 = self.encoder2(down2)
187 | down3 = self.downsample(enc2)
188 |
189 | enc3 = self.encoder3(down3)
190 | down4 = self.downsample(enc3)
191 |
192 | input_feature = self.encoder4(down4)
193 |
194 | # Do Attenttion operations here
195 | attention = self.affinity_attention(input_feature)
196 | attention_fuse = input_feature + attention
197 |
198 | # Do decoder operations here
199 | up4 = self.deconv4(attention_fuse)
200 | up4 = torch.cat((enc3, up4), dim=1)
201 | dec4 = self.decoder4(up4)
202 |
203 | up3 = self.deconv3(dec4)
204 | up3 = torch.cat((enc2, up3), dim=1)
205 | dec3 = self.decoder3(up3)
206 |
207 | up2 = self.deconv2(dec3)
208 | up2 = torch.cat((enc1, up2), dim=1)
209 | dec2 = self.decoder2(up2)
210 |
211 | up1 = self.deconv1(dec2)
212 | up1 = torch.cat((enc_input, up1), dim=1)
213 | dec1 = self.decoder1(up1)
214 |
215 | final = self.final(dec1)
216 | final = F.sigmoid(final)
217 | return final
218 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision import transforms
3 | from PIL import Image, ImageOps
4 |
5 | import numpy as np
6 | import scipy.misc as misc
7 | import os
8 | import glob
9 |
10 | from utils.misc import thresh_OTSU, ReScaleSize, Crop
11 | from utils.model_eval import eval
12 |
13 | DATABASE = './DRIVE/'
14 | #
15 | args = {
16 | 'root' : './dataset/' + DATABASE,
17 | 'test_path': './dataset/' + DATABASE + 'test/',
18 | 'pred_path': 'assets/' + 'DRIVE/',
19 | 'img_size' : 512
20 | }
21 |
22 | if not os.path.exists(args['pred_path']):
23 | os.makedirs(args['pred_path'])
24 |
25 |
26 | def rescale(img):
27 | w, h = img.size
28 | min_len = min(w, h)
29 | new_w, new_h = min_len, min_len
30 | scale_w = (w - new_w) // 2
31 | scale_h = (h - new_h) // 2
32 | box = (scale_w, scale_h, scale_w + new_w, scale_h + new_h)
33 | img = img.crop(box)
34 | return img
35 |
36 |
37 | def ReScaleSize_DRIVE(image, re_size=512):
38 | w, h = image.size
39 | min_len = min(w, h)
40 | new_w, new_h = min_len, min_len
41 | scale_w = (w - new_w) // 2
42 | scale_h = (h - new_h) // 2
43 | box = (scale_w, scale_h, scale_w + new_w, scale_h + new_h)
44 | image = image.crop(box)
45 | image = image.resize((re_size, re_size))
46 | return image # , origin_w, origin_h
47 |
48 |
49 | def ReScaleSize_STARE(image, re_size=512):
50 | w, h = image.size
51 | max_len = max(w, h)
52 | new_w, new_h = max_len, max_len
53 | delta_w = new_w - w
54 | delta_h = new_h - h
55 | padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
56 | image = ImageOps.expand(image, padding, fill=0)
57 | # origin_w, origin_h = w, h
58 | image = image.resize((re_size, re_size))
59 | return image # , origin_w, origin_h
60 |
61 |
62 | def load_nerve():
63 | test_images = []
64 | test_labels = []
65 | for file in glob.glob(os.path.join(args['test_path'], 'orig', '*.tif')):
66 | basename = os.path.basename(file)
67 | file_name = basename[:-4]
68 | image_name = os.path.join(args['test_path'], 'orig', basename)
69 | label_name = os.path.join(args['test_path'], 'mask2', file_name + '_centerline_overlay.tif')
70 | test_images.append(image_name)
71 | test_labels.append(label_name)
72 | return test_images, test_labels
73 |
74 |
75 | def load_drive():
76 | test_images = []
77 | test_labels = []
78 | for file in glob.glob(os.path.join(args['test_path'], 'images', '*.tif')):
79 | basename = os.path.basename(file)
80 | file_name = basename[:3]
81 | image_name = os.path.join(args['test_path'], 'images', basename)
82 | label_name = os.path.join(args['test_path'], '1st_manual', file_name + 'manual1.gif')
83 | test_images.append(image_name)
84 | test_labels.append(label_name)
85 | return test_images, test_labels
86 |
87 |
88 | def load_stare():
89 | test_images = []
90 | test_labels = []
91 | for file in glob.glob(os.path.join(args['test_path'], 'images', '*.ppm')):
92 | basename = os.path.basename(file)
93 | file_name = basename[:-4]
94 | image_name = os.path.join(args['test_path'], 'images', basename)
95 | label_name = os.path.join(args['test_path'], 'labels-ah', file_name + '.ah.ppm')
96 | test_images.append(image_name)
97 | test_labels.append(label_name)
98 | return test_images, test_labels
99 |
100 |
101 | def load_padova1():
102 | test_images = []
103 | test_labels = []
104 | for file in glob.glob(os.path.join(args['test_path'], 'images', '*.tif')):
105 | basename = os.path.basename(file)
106 | file_name = basename[:-4]
107 | image_name = os.path.join(args['test_path'], 'images', basename)
108 | label_name = os.path.join(args['test_path'], 'label2', file_name + '_centerline_overlay.tif')
109 | test_images.append(image_name)
110 | test_labels.append(label_name)
111 | return test_images, test_labels
112 |
113 |
114 | def load_octa():
115 | test_images = []
116 | test_labels = []
117 | for file in glob.glob(os.path.join(args['test_path'], 'images', '*.png')):
118 | basename = os.path.basename(file)
119 | file_name = basename[:-4]
120 | image_name = os.path.join(args['test_path'], 'images', basename)
121 | label_name = os.path.join(args['test_path'], 'label', file_name + '_nerve_ann.tif')
122 | test_images.append(image_name)
123 | test_labels.append(label_name)
124 | return test_images, test_labels
125 |
126 |
127 | def load_net():
128 | net = torch.load('./checkpoint/xxxx.pkl')
129 | return net
130 |
131 |
132 | def save_prediction(pred, filename=''):
133 | save_path = args['pred_path'] + 'pred/'
134 | if not os.path.exists(save_path):
135 | os.makedirs(save_path)
136 | print("Make dirs success!")
137 | mask = pred.data.cpu().numpy() * 255
138 | mask = np.transpose(np.squeeze(mask, axis=0), [1, 2, 0])
139 | mask = np.squeeze(mask, axis=-1)
140 | misc.imsave(save_path + filename + '.png', mask)
141 |
142 |
143 | def predict():
144 | net = load_net()
145 | # images, labels = load_nerve()
146 | images, labels = load_drive()
147 | # images, labels = load_stare()
148 | # images, labels = load_padova1()
149 | # images, labels = load_octa()
150 |
151 | transform = transforms.Compose([
152 | transforms.ToTensor()
153 | ])
154 |
155 | with torch.no_grad():
156 | net.eval()
157 | for i in range(len(images)):
158 | print(images[i])
159 | name_list = images[i].split('/')
160 | index = name_list[-1][:-4]
161 | image = Image.open(images[i])
162 | # image=image.convert("RGB")
163 | label = Image.open(labels[i])
164 | image, label = center_crop(image, label)
165 |
166 | # for other retinal vessel
167 | # image = rescale(image)
168 | # label = rescale(label)
169 | # image = ReScaleSize_STARE(image, re_size=args['img_size'])
170 | # label = ReScaleSize_DRIVE(label, re_size=args['img_size'])
171 |
172 | # for OCTA
173 | # image = Crop(image)
174 | # image = ReScaleSize(image)
175 | # label = Crop(label)
176 | # label = ReScaleSize(label)
177 |
178 | # label = label.resize((args['img_size'], args['img_size']))
179 | # if cuda
180 | image = transform(image).cuda()
181 | # image = transform(image)
182 | image = image.unsqueeze(0)
183 | output = net(image)
184 |
185 | save_prediction(output, filename=index + '_pred')
186 | print("output saving successfully")
187 |
188 |
189 | if __name__ == '__main__':
190 | predict()
191 | thresh_OTSU(args['pred_path'] + 'pred/')
192 |
--------------------------------------------------------------------------------
/predict3d.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | import os
5 | import glob
6 | from tqdm import tqdm
7 | import SimpleITK as sitk
8 | from utils.misc import get_spacing
9 |
10 | os.environ['CUDA_VISIBLE_DEVICES'] = "1"
11 |
12 | DATABASE = 'VascuSynth3/'
13 | #
14 | args = {
15 | 'root' : './dataset/' + DATABASE,
16 | 'test_path': './dataset/' + DATABASE + 'test/',
17 | 'pred_path': 'assets/' + 'VascuSynth3/',
18 | 'img_size' : 512
19 | }
20 |
21 | if not os.path.exists(args['pred_path']):
22 | os.makedirs(args['pred_path'])
23 |
24 |
25 | def rescale(img):
26 | w, h = img.size
27 | min_len = min(w, h)
28 | new_w, new_h = min_len, min_len
29 | scale_w = (w - new_w) // 2
30 | scale_h = (h - new_h) // 2
31 | box = (scale_w, scale_h, scale_w + new_w, scale_h + new_h)
32 | img = img.crop(box)
33 | return img
34 |
35 |
36 | def load_3d():
37 | test_images = []
38 | test_labels = []
39 | for file in glob.glob(os.path.join(args['test_path'], 'images', '*.mha')):
40 | basename = os.path.basename(file)
41 | file_name = basename[:-8]
42 | image_name = os.path.join(args['test_path'], 'images', basename)
43 | label_name = os.path.join(args['test_path'], 'label', file_name + 'gt.mha')
44 | test_images.append(image_name)
45 | test_labels.append(label_name)
46 | return test_images, test_labels
47 |
48 |
49 | def load_net():
50 | net = torch.load('/home/imed/Research/Attention/checkpoint/model.pkl')
51 | return net
52 |
53 |
54 | def save_prediction(pred, filename='', spacing=None):
55 | pred = torch.argmax(pred, dim=1)
56 | save_path = args['pred_path'] + 'pred/'
57 | if not os.path.exists(save_path):
58 | os.makedirs(save_path)
59 | print("Make dirs success!")
60 | # for MSELoss()
61 | mask = (pred.data.cpu().numpy() * 255).astype(np.uint8)
62 |
63 | # thresholding
64 | # mask[mask >= 100] = 255
65 | # mask[mask < 100] = 0
66 |
67 | # mask = (mask.squeeze(0)).squeeze(0) # 3D numpy array
68 | mask = mask.squeeze(0) # for CE Loss
69 | # image = nib.Nifti1Image(np.int32(mask), affine)
70 | # nib.save(image, save_path + filename + ".nii.gz")
71 | mask = sitk.GetImageFromArray(mask)
72 | # if spacing is not None:
73 | # mask.SetSpacing(spacing)
74 | sitk.WriteImage(mask, os.path.join(save_path + filename + ".mha"))
75 |
76 |
77 | def save_probability(pred, label, filename=""):
78 | save_path = args['pred_path'] + 'pred/'
79 | if not os.path.exists(save_path):
80 | os.makedirs(save_path)
81 | print("Make dirs success!")
82 | # # for MSELoss()
83 | # mask = (pred.data.cpu().numpy() * 255) # .astype(np.uint8)
84 | #
85 | # mask = mask.squeeze(0)
86 | # class0 = mask[0, :, :, :]
87 | # class1 = mask[1, :, :, :]
88 | # label = label / 255
89 | # class0 = class0 * label
90 | # class1 = class1 * label
91 | #
92 | # probability = class0 + class1
93 |
94 | probability = F.softmax(pred, dim=1)
95 | probability.squeeze_(0)
96 | class0 = probability[0, :, :, :]
97 | class1 = probability[1, :, :, :]
98 | class0 = sitk.GetImageFromArray(class0)
99 | class1 = sitk.GetImageFromArray(class1)
100 | sitk.WriteImage(class1, os.path.join(save_path + filename + "class1.mha"))
101 |
102 |
103 | def save_label(label, index, spacing=None):
104 | label_path = args['pred_path'] + 'label/'
105 | if not os.path.exists(label_path):
106 | os.makedirs(label_path)
107 | label = sitk.GetImageFromArray(label)
108 | if spacing is not None:
109 | label.SetSpacing(spacing)
110 | sitk.WriteImage(label, os.path.join(label_path, index + ".mha"))
111 |
112 |
113 | def predict():
114 | net = load_net()
115 | images, labels = load_3d()
116 | with torch.no_grad():
117 | net.eval()
118 | for i in tqdm(range(len(images))):
119 | name_list = images[i].split('/')
120 | index = name_list[-1][:-4]
121 | image = sitk.ReadImage(images[i])
122 | image = sitk.GetArrayFromImage(image).astype(np.float32)
123 | image = image / 255
124 | label = sitk.ReadImage(labels[i])
125 | label = sitk.GetArrayFromImage(label).astype(np.int64)
126 | # label = label / 255
127 | # VascuSynth
128 | # image = image[2:98, 2:98, 2:98]
129 | # label = label[2:98, 2:98, 2:98]
130 | save_label(label, index)
131 | # if cuda
132 | image = torch.from_numpy(np.ascontiguousarray(image)).unsqueeze(0).unsqueeze(0)
133 | image = image.cuda()
134 | output = net(image)
135 | save_prediction(output, filename=index + '_pred', spacing=None)
136 |
137 |
138 | if __name__ == '__main__':
139 | predict()
140 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """
2 | Training script for CS-Net
3 | """
4 | import os
5 | import torch
6 | import torch.nn as nn
7 | from torch import optim
8 | from torch.utils.data import DataLoader
9 | import visdom
10 | import numpy as np
11 | from model.csnet import CSNet
12 | from dataloader.drive import Data
13 | from utils.train_metrics import metrics
14 | from utils.visualize import init_visdom_line, update_lines
15 | from utils.dice_loss_single_class import dice_coeff_loss
16 |
17 | # os.environ["CUDA_VISIBLE_DEVICES"] = "1"
18 |
19 | args = {
20 | 'root' : '',
21 | 'data_path' : 'dataset/DRIVE/',
22 | 'epochs' : 1000,
23 | 'lr' : 0.0001,
24 | 'snapshot' : 100,
25 | 'test_step' : 1,
26 | 'ckpt_path' : 'checkpoint/',
27 | 'batch_size': 8,
28 | }
29 |
30 | # # Visdom---------------------------------------------------------
31 | X, Y = 0, 0.5 # for visdom
32 | x_acc, y_acc = 0, 0
33 | x_sen, y_sen = 0, 0
34 | env, panel = init_visdom_line(X, Y, title='Train Loss', xlabel="iters", ylabel="loss")
35 | env1, panel1 = init_visdom_line(x_acc, y_acc, title="Accuracy", xlabel="iters", ylabel="accuracy")
36 | env2, panel2 = init_visdom_line(x_sen, y_sen, title="Sensitivity", xlabel="iters", ylabel="sensitivity")
37 | # # ---------------------------------------------------------------
38 |
39 | def save_ckpt(net, iter):
40 | if not os.path.exists(args['ckpt_path']):
41 | os.makedirs(args['ckpt_path'])
42 | torch.save(net, args['ckpt_path'] + 'CS_Net_DRIVE_' + str(iter) + '.pkl')
43 | print('--->saved model:{}<--- '.format(args['root'] + args['ckpt_path']))
44 |
45 |
46 | # adjust learning rate (poly)
47 | def adjust_lr(optimizer, base_lr, iter, max_iter, power=0.9):
48 | lr = base_lr * (1 - float(iter) / max_iter) ** power
49 | for param_group in optimizer.param_groups:
50 | param_group['lr'] = lr
51 |
52 |
53 | def train():
54 | # set the channels to 3 when the format is RGB, otherwise 1.
55 | net = CSNet(classes=1, channels=3).cuda()
56 | net = nn.DataParallel(net, device_ids=[0, 1]).cuda()
57 | optimizer = optim.Adam(net.parameters(), lr=args['lr'], weight_decay=0.0005)
58 | critrion = nn.MSELoss().cuda()
59 | # critrion = nn.CrossEntropyLoss().cuda()
60 | print("---------------start training------------------")
61 | # load train dataset
62 | train_data = Data(args['data_path'], train=True)
63 | batchs_data = DataLoader(train_data, batch_size=args['batch_size'], num_workers=2, shuffle=True)
64 |
65 | iters = 1
66 | accuracy = 0.
67 | sensitivty = 0.
68 | for epoch in range(args['epochs']):
69 | net.train()
70 | for idx, batch in enumerate(batchs_data):
71 | image = batch[0].cuda()
72 | label = batch[1].cuda()
73 | optimizer.zero_grad()
74 | pred = net(image)
75 | # pred = pred.squeeze_(1)
76 | loss1 = critrion(pred, label)
77 | loss2 = dice_coeff_loss(pred, label)
78 | loss = loss1 + loss2
79 | loss.backward()
80 | optimizer.step()
81 | acc, sen = metrics(pred, label, pred.shape[0])
82 | print('[{0:d}:{1:d}] --- loss:{2:.10f}\tacc:{3:.4f}\tsen:{4:.4f}'.format(epoch + 1,
83 | iters, loss.item(),
84 | acc / pred.shape[0],
85 | sen / pred.shape[0]))
86 | iters += 1
87 | # # ---------------------------------- visdom --------------------------------------------------
88 | X, x_acc, x_sen = iters, iters, iters
89 | Y, y_acc, y_sen = loss.item(), acc / pred.shape[0], sen / pred.shape[0]
90 | update_lines(env, panel, X, Y)
91 | update_lines(env1, panel1, x_acc, y_acc)
92 | update_lines(env2, panel2, x_sen, y_sen)
93 | # # --------------------------------------------------------------------------------------------
94 |
95 | adjust_lr(optimizer, base_lr=args['lr'], iter=epoch, max_iter=args['epochs'], power=0.9)
96 | if (epoch + 1) % args['snapshot'] == 0:
97 | save_ckpt(net, epoch + 1)
98 |
99 | # model eval
100 | if (epoch + 1) % args['test_step'] == 0:
101 | test_acc, test_sen = model_eval(net)
102 | print("Average acc:{0:.4f}, average sen:{1:.4f}".format(test_acc, test_sen))
103 |
104 | if (accuracy > test_acc) & (sensitivty > test_sen):
105 | save_ckpt(net, epoch + 1 + 8888888)
106 | accuracy = test_acc
107 | sensitivty = test_sen
108 |
109 |
110 | def model_eval(net):
111 | print("Start testing model...")
112 | test_data = Data(args['data_path'], train=False)
113 | batchs_data = DataLoader(test_data, batch_size=1)
114 |
115 | net.eval()
116 | Acc, Sen = [], []
117 | file_num = 0
118 | for idx, batch in enumerate(batchs_data):
119 | image = batch[0].float().cuda()
120 | label = batch[1].float().cuda()
121 | pred_val = net(image)
122 | acc, sen = metrics(pred_val, label, pred_val.shape[0])
123 | print("\t---\t test acc:{0:.4f} test sen:{1:.4f}".format(acc, sen))
124 | Acc.append(acc)
125 | Sen.append(sen)
126 | file_num += 1
127 | # for better view, add testing visdom here.
128 | return np.mean(Acc), np.mean(Sen)
129 |
130 |
131 | if __name__ == '__main__':
132 | train()
133 |
--------------------------------------------------------------------------------
/train3d.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Author : Lei Mou
4 | # @File : train3d.py
5 | """
6 | Training script for CS-Net 3D
7 | """
8 | import os
9 | import torch
10 | import torch.nn as nn
11 | from torch import optim
12 | from torch.utils.data import DataLoader
13 | import datetime
14 | import numpy as np
15 |
16 | from model.csnet_3d import CSNet3D
17 | from dataloader.MRABrainLoader import Data
18 |
19 | from utils.train_metrics import metrics3d
20 | from utils.losses import WeightedCrossEntropyLoss, DiceLoss
21 | from utils.visualize import init_visdom_line, update_lines
22 |
23 | args = {
24 | 'root' : '/home/user/name/Projects/',
25 | 'data_path' : 'dataset/data dir(your own data path)/',
26 | 'epochs' : 200,
27 | 'lr' : 0.0001,
28 | 'snapshot' : 100,
29 | 'test_step' : 1,
30 | 'ckpt_path' : './checkpoint/',
31 | 'batch_size': 2,
32 | }
33 |
34 | # # Visdom---------------------------------------------------------
35 | # The initial values are defined by myself
36 | X, Y = 0, 1.0 # for visdom
37 | x_tp, y_tp = 0, 0
38 | x_fn, y_fn = 0.4, 0.4
39 | x_fp, y_fp = 0.4, 0.4
40 | x_testtp, y_testtp = 0.0, 0.0
41 | x_testdc, y_testdc = 0.0, 0.0
42 | env, panel = init_visdom_line(X, Y, title='Train Loss', xlabel="iters", ylabel="loss", env="wce")
43 | env1, panel1 = init_visdom_line(x_tp, y_tp, title="TPR", xlabel="iters", ylabel="TPR", env="wce")
44 | env2, panel2 = init_visdom_line(x_fn, y_fn, title="FNR", xlabel="iters", ylabel="FNR", env="wce")
45 | env3, panel3 = init_visdom_line(x_fp, y_fp, title="FPR", xlabel="iters", ylabel="FPR", env="wce")
46 | env6, panel6 = init_visdom_line(x_testtp, y_testtp, title="DSC", xlabel="iters", ylabel="DSC", env="wce")
47 | env4, panel4 = init_visdom_line(x_testtp, y_testtp, title="Test Loss", xlabel="iters", ylabel="Test Loss", env="wce")
48 | env5, panel5 = init_visdom_line(x_testdc, y_testdc, title="Test TP", xlabel="iters", ylabel="Test TP", env="wce")
49 | env7, panel7 = init_visdom_line(x_testdc, y_testdc, title="Test IoU", xlabel="iters", ylabel="Test IoU", env="wce")
50 |
51 |
52 | def save_ckpt(net, iter):
53 | if not os.path.exists(args['ckpt_path']):
54 | os.makedirs(args['ckpt_path'])
55 | date = datetime.datetime.now().strftime("%Y-%m-%d-")
56 | torch.save(net, args['ckpt_path'] + 'CSNet3D_' + date + iter + '.pkl')
57 | print("{} Saved model to:{}".format("\u2714", args['ckpt_path']))
58 |
59 |
60 | # adjust learning rate (poly)
61 | def adjust_lr(optimizer, base_lr, iter, max_iter, power=0.9):
62 | lr = base_lr * (1 - float(iter) / max_iter) ** power
63 | for param_group in optimizer.param_groups:
64 | param_group['lr'] = lr
65 |
66 |
67 | def train():
68 | net = CSNet3D(classes=2, channels=1).cuda()
69 | net = nn.DataParallel(net, device_ids=[0, 1]).cuda()
70 | optimizer = optim.Adam(net.parameters(), lr=args['lr'], weight_decay=0.0005)
71 |
72 | # load train dataset
73 | train_data = Data(args['data_path'], train=True)
74 | batchs_data = DataLoader(train_data, batch_size=args['batch_size'], num_workers=4, shuffle=True)
75 |
76 | critrion2 = WeightedCrossEntropyLoss().cuda()
77 | critrion = nn.CrossEntropyLoss().cuda()
78 | critrion3 = DiceLoss().cuda()
79 | # Start training
80 | print("\033[1;30;44m {} Start training ... {}\033[0m".format("*" * 8, "*" * 8))
81 |
82 | iters = 1
83 | for epoch in range(args['epochs']):
84 | net.train()
85 | for idx, batch in enumerate(batchs_data):
86 | image = batch[0].cuda()
87 | label = batch[1].cuda()
88 | optimizer.zero_grad()
89 | pred = net(image)
90 | loss_dice = critrion3(pred, label)
91 | label = label.squeeze(1)
92 | loss_ce = critrion(pred, label)
93 | loss_wce = critrion2(pred, label)
94 | loss = (loss_ce + 0.6 * loss_wce + 0.4 * loss_dice) / 3
95 | loss.backward()
96 | optimizer.step()
97 | tp, fn, fp, iou = metrics3d(pred, label, pred.shape[0])
98 | if (epoch % 2) == 0:
99 | print(
100 | '\033[1;36m [{0:d}:{1:d}] \u2501\u2501\u2501 loss:{2:.10f}\tTP:{3:.4f}\tFN:{4:.4f}\tFP:{5:.4f}\tIoU:{6:.4f} '.format(
101 | epoch + 1, iters, loss.item(), tp / pred.shape[0], fn / pred.shape[0], fp / pred.shape[0],
102 | iou / pred.shape[0]))
103 | else:
104 | print(
105 | '\033[1;32m [{0:d}:{1:d}] \u2501\u2501\u2501 loss:{2:.10f}\tTP:{3:.4f}\tFN:{4:.4f}\tFP:{5:.4f}\tIoU:{6:.4f} '.format(
106 | epoch + 1, iters, loss.item(), tp / pred.shape[0], fn / pred.shape[0], fp / pred.shape[0],
107 | iou / pred.shape[0]))
108 |
109 | iters += 1
110 | # # ---------------------------------- visdom --------------------------------------------------
111 | X, x_tp, x_fn, x_fp, x_dc = iters, iters, iters, iters, iters
112 | Y, y_tp, y_fn, y_fp, y_dc = loss.item(), tp / pred.shape[0], fn / pred.shape[0], fp / pred.shape[0], iou / \
113 | pred.shape[0]
114 |
115 | update_lines(env, panel, X, Y)
116 | update_lines(env1, panel1, x_tp, y_tp)
117 | update_lines(env2, panel2, x_fn, y_fn)
118 | update_lines(env3, panel3, x_fp, y_fp)
119 | update_lines(env6, panel6, x_dc, y_dc)
120 |
121 | # # --------------------------------------------------------------------------------------------
122 |
123 | adjust_lr(optimizer, base_lr=args['lr'], iter=epoch, max_iter=args['epochs'], power=0.9)
124 |
125 | if (epoch + 1) % args['snapshot'] == 0:
126 | save_ckpt(net, str(epoch + 1))
127 |
128 | # model eval
129 | if (epoch + 1) % args['test_step'] == 0:
130 | test_tp, test_fn, test_fp, test_dc = model_eval(net, critrion, iters)
131 | print("Average TP:{0:.4f}, average FN:{1:.4f}, average FP:{2:.4f}".format(test_tp, test_fn, test_fp))
132 |
133 |
134 | def model_eval(net, critrion, iters):
135 | print("\033[1;30;43m {} Start training ... {}\033[0m".format("*" * 8, "*" * 8))
136 | test_data = Data(args['data_path'], train=False)
137 | batchs_data = DataLoader(test_data, batch_size=1)
138 |
139 | net.eval()
140 | TP, FN, FP, IoU = [], [], [], []
141 | file_num = 0
142 | with torch.no_grad():
143 | for idx, batch in enumerate(batchs_data):
144 | image = batch[0].float().cuda()
145 | label = batch[1].cuda()
146 | pred_val = net(image)
147 | label = label.squeeze(1)
148 | loss = critrion(pred_val, label)
149 | tp, fn, fp, iou = metrics3d(pred_val, label, pred_val.shape[0])
150 | print(
151 | "--- test TP:{0:.4f} test FN:{1:.4f} test FP:{2:.4f} test IoU:{3:.4f}".format(tp, fn, fp, iou))
152 | TP.append(tp)
153 | FN.append(fn)
154 | FP.append(fp)
155 | IoU.append(iou)
156 | file_num += 1
157 | # # start visdom images
158 | X, x_testtp, x_testdc = iters, iters, iters
159 | Y, y_testtp, y_testdc = loss.item(), tp / pred_val.shape[0], iou / pred_val.shape[0]
160 | update_lines(env4, panel4, X, Y)
161 | update_lines(env5, panel5, x_testtp, y_testtp)
162 | update_lines(env7, panel7, x_testdc, y_testdc)
163 | return np.mean(TP), np.mean(FN), np.mean(FP), np.mean(IoU)
164 |
165 |
166 | if __name__ == '__main__':
167 | train()
168 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iMED-Lab/CS-Net/25079c377f8db4b57f25c0adc7b70d1a02a3ee62/utils/__init__.py
--------------------------------------------------------------------------------
/utils/dice_loss_single_class.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Function, Variable
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import numpy as np
6 |
7 |
8 | class DiceCoeff(Function):
9 | """Dice coeff for individual examples"""
10 |
11 | def forward(self, input, target):
12 | # target = _make_one_hot(target, 2)
13 | self.save_for_backward(input, target)
14 | eps = 0.0001
15 | # dot是返回两个矩阵的点集
16 | # inter,uniun:两个值的大小分别是10506.6,164867.2
17 | self.inter = torch.dot(input.view(-1), target.view(-1))
18 | self.union = torch.sum(input) + torch.sum(target) + eps
19 | # print("inter,uniun:",self.inter,self.union)
20 |
21 | t = (2 * self.inter.float() + eps) / self.union.float()
22 | return t
23 |
24 | # This function has only a single output, so it gets only one gradient
25 | def backward(self, grad_output):
26 |
27 | input, target = self.saved_variables
28 | grad_input = grad_target = None
29 |
30 | if self.needs_input_grad[0]:
31 | grad_input = grad_output * 2 * (target * self.union - self.inter) \
32 | / (self.union * self.union)
33 | if self.needs_input_grad[1]:
34 | grad_target = None
35 |
36 | # 这里没有打印出来,难道没有执行到这里吗
37 | # print("grad_input, grad_target:",grad_input, grad_target)
38 |
39 | return grad_input, grad_target
40 |
41 |
42 | def dice_coeff(input, target):
43 | """Dice coeff for batches"""
44 | if input.is_cuda:
45 | s = torch.FloatTensor(1).cuda().zero_()
46 | else:
47 | s = torch.FloatTensor(1).zero_()
48 |
49 | # print("size of input, target:", input.shape, target.shape)
50 |
51 | for i, c in enumerate(zip(input, target)):
52 | # c[0],c[1]的大小都是原图大小torch.Size([1, 576, 544])
53 | # print("size of c0 c1:", c[0].shape,c[1].shape)
54 | s = s + DiceCoeff().forward(c[0], c[1])
55 |
56 | return s / (i + 1)
57 |
58 |
59 | def dice_coeff_loss(input, target):
60 | return 1 - dice_coeff(input, target)
61 |
--------------------------------------------------------------------------------
/utils/evaluation_metrics.py:
--------------------------------------------------------------------------------
1 | """
2 | Evaluation metrics
3 | """
4 |
5 | import numpy as np
6 | import sklearn.metrics as metrics
7 | import os
8 | import glob
9 | import cv2
10 | from PIL import Image
11 |
12 |
13 | def numeric_score(pred, gt):
14 | FP = np.float(np.sum((pred == 1) & (gt == 0)))
15 | FN = np.float(np.sum((pred == 0) & (gt == 1)))
16 | TP = np.float(np.sum((pred == 1) & (gt == 1)))
17 | TN = np.float(np.sum((pred == 0) & (gt == 0)))
18 | return FP, FN, TP, TN
19 |
20 |
21 | def numeric_score_fov(pred, gt, mask):
22 | FP = np.float(np.sum((pred == 1) & (gt == 0) & (mask == 1)))
23 | FN = np.float(np.sum((pred == 0) & (gt == 1) & (mask == 1)))
24 | TP = np.float(np.sum((pred == 1) & (gt == 1) & (mask == 1)))
25 | TN = np.float(np.sum((pred == 0) & (gt == 0) & (mask == 1)))
26 | return FP, FN, TP, TN
27 |
28 |
29 | def AUC(path):
30 | all_auc = 0.
31 | file_num = 0
32 | for file in glob.glob(os.path.join(path, 'pred', '*pred.png')):
33 | base_name = os.path.basename(file)
34 | label_name = base_name[:-9] + '.png'
35 | label_path = os.path.join(path, 'label', label_name)
36 |
37 | mask_path = '/path/to/FOV/mask/'
38 |
39 | pred_image = cv2.imread(file, flags=-1)
40 | label = cv2.imread(label_path, flags=-1)
41 | mask = cv2.imread(mask_path, flags=-1)
42 |
43 | # with FOV
44 | label_fov = []
45 | pred_fov = []
46 | w, h = pred_image.shape
47 | for i in range(w):
48 | for j in range(h):
49 | if mask[i, j] == 255:
50 | label_fov.append(label[i, j])
51 | pred_fov.append(pred_image[i, j])
52 | pred_image = (np.asarray(pred_fov)) / 255
53 | label = np.uint8((np.asarray(label_fov)) / 255)
54 |
55 | # pred_image = pred_image.flatten() / 255
56 | # label = np.uint8(label.flatten() / 255)
57 |
58 | auc_score = metrics.roc_auc_score(label, pred_image)
59 | all_auc += auc_score
60 | file_num += 1
61 | avg_auc = all_auc / file_num
62 | return avg_auc
63 |
64 |
65 | def DSC(path):
66 | all_dsc = 0.
67 | file_num = 0
68 | for file in glob.glob(os.path.join(path, 'pred', '*otsu.png')):
69 | base_name = os.path.basename(file)
70 | label_name = base_name[:-14] + '.png'
71 | label_path = os.path.join(path, 'label', label_name)
72 |
73 | pred = cv2.imread(file, flags=-1)
74 | label = cv2.imread(label_path, flags=-1)
75 |
76 | pred = pred // 255
77 | label = label // 255
78 |
79 | FP, FN, TP, TN = numeric_score(pred, label)
80 | dsc = 2 * TP / (FP + 2 * TP + FN)
81 | all_dsc += dsc
82 | file_num += 1
83 | avg_dsc = all_dsc / file_num
84 | return avg_dsc
85 |
86 |
87 | def AccSenSpe(path):
88 | all_sen = []
89 | all_acc = []
90 | all_spe = []
91 | for file in glob.glob(os.path.join(path, 'pred', '*otsu.png')):
92 | base_name = os.path.basename(file)
93 | label_name = base_name[:-14] + '.png'
94 | label_path = os.path.join(path, 'label', label_name)
95 |
96 | mask_path = '/path/to/FOV/mask/'
97 |
98 | pred = cv2.imread(file, flags=-1)
99 | label = cv2.imread(label_path, flags=-1)
100 | mask = cv2.imread(mask_path, flags=-1)
101 |
102 | pred = pred // 255
103 | label = label // 255
104 | mask = mask // 255
105 |
106 | FP, FN, TP, TN = numeric_score(pred, label)
107 | acc = (TP + TN) / (TP + FP + TN + FN)
108 | sen = TP / (TP + FN)
109 | spe = TN / (TN + FP)
110 | all_acc.append(acc)
111 | all_sen.append(sen)
112 | all_spe.append(spe)
113 | avg_acc, avg_sen, avg_spe = np.mean(all_acc), np.mean(all_sen), np.mean(all_spe)
114 | var_acc, var_sen, var_spe = np.var(all_acc), np.var(all_sen), np.var(all_spe)
115 | return avg_acc, var_acc, avg_sen, var_sen, avg_spe, var_spe
116 |
117 |
118 | def FDR(path):
119 | all_fdr = []
120 | for file in glob.glob(os.path.join(path, 'pred', '*otsu.png')):
121 | base_name = os.path.basename(file)
122 | label_name = base_name[:-14] + '.png'
123 | label_path = os.path.join(path, 'label', label_name)
124 |
125 | pred = cv2.imread(file, flags=-1)
126 | label = cv2.imread(label_path, flags=-1)
127 |
128 | pred = pred // 255
129 | label = label // 255
130 |
131 | FP, FN, TP, TN = numeric_score(pred, label)
132 | fdr = FP / (FP + TP)
133 | all_fdr.append(fdr)
134 | return np.mean(all_fdr), np.var(all_fdr)
135 |
136 |
137 | if __name__ == '__main__':
138 | # predicted root path
139 | path = './assets/Padova1/'
140 | # auc = AUC(path)
141 | acc, var_acc, sen, var_sen, spe, var_spe = AccSenSpe(path)
142 | fdr, var_fdr = FDR(path)
143 | print("sen:{0:.4f} +- {1:.4f}".format(sen, var_sen))
144 | print("fdr:{0:.4f} +- {1:.4f}".format(fdr, var_fdr))
145 | # print("acc:{0:.4f}".format(acc))
146 | # print("sen:{0:.4f}".format(sen))
147 | # print("spe:{0:.4f}".format(spe))
148 | # print("auc:{0:.4f}".format(auc))
149 |
--------------------------------------------------------------------------------
/utils/evaluation_metrics3D.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # ╔═════════════════════════════════════════════════════════════════════════════════════════════════════════════╗
4 | # ║ ║
5 | # ║ __ __ ____ __ ║
6 | # ║ /\ \/\ \ /\ _`\ /\ \ __ ║
7 | # ║ \ \ \_\ \ __ _____ _____ __ __ \ \ \/\_\ ___ \_\ \/\_\ ___ __ ║
8 | # ║ \ \ _ \ /'__`\ /\ '__`\/\ '__`\/\ \/\ \ \ \ \/_/_ / __`\ /'_` \/\ \ /' _ `\ /'_ `\ ║
9 | # ║ \ \ \ \ \/\ \L\.\_\ \ \L\ \ \ \L\ \ \ \_\ \ \ \ \L\ \/\ \L\ \/\ \L\ \ \ \/\ \/\ \/\ \L\ \ ║
10 | # ║ \ \_\ \_\ \__/.\_\\ \ ,__/\ \ ,__/\/`____ \ \ \____/\ \____/\ \___,_\ \_\ \_\ \_\ \____ \ ║
11 | # ║ \/_/\/_/\/__/\/_/ \ \ \/ \ \ \/ `/___/> \ \/___/ \/___/ \/__,_ /\/_/\/_/\/_/\/___L\ \ ║
12 | # ║ \ \_\ \ \_\ /\___/ /\____/ ║
13 | # ║ \/_/ \/_/ \/__/ \_/__/ ║
14 | # ║ ║
15 | # ║ 49 4C 6F 76 65 59 6F 75 2C 42 75 74 59 6F 75 4B 6E 6F 77 4E 6F 74 68 69 6E 67 2E ║
16 | # ║ ║
17 | # ╚═════════════════════════════════════════════════════════════════════════════════════════════════════════════╝
18 | # @Author : Lei Mou
19 | # @File : evaluation_metrics3D.py
20 | import numpy as np
21 | import SimpleITK as sitk
22 | import glob
23 | import os
24 | from scipy.spatial import distance
25 | from sklearn.metrics import f1_score
26 |
27 |
28 | def numeric_score(pred, gt):
29 | FP = np.float(np.sum((pred == 255) & (gt == 0)))
30 | FN = np.float(np.sum((pred == 0) & (gt == 255)))
31 | TP = np.float(np.sum((pred == 255) & (gt == 255)))
32 | TN = np.float(np.sum((pred == 0) & (gt == 0)))
33 | return FP, FN, TP, TN
34 |
35 |
36 | def Dice(pred, gt):
37 | pred = np.int64(pred / 255)
38 | gt = np.int64(gt / 255)
39 | dice = np.sum(pred[gt == 1]) * 2.0 / (np.sum(pred) + np.sum(gt))
40 | return dice
41 |
42 |
43 | def IoU(pred, gt):
44 | pred = np.int64(pred / 255)
45 | gt = np.int64(gt / 255)
46 | m1 = np.sum(pred[gt == 1])
47 | m2 = np.sum(pred == 1) + np.sum(gt == 1) - m1
48 | iou = m1 / m2
49 | return iou
50 |
51 |
52 | def metrics_3d(pred, gt):
53 | FP, FN, TP, TN = numeric_score(pred, gt)
54 | tpr = TP / (TP + FN + 1e-10)
55 | fnr = FN / (FN + TP + 1e-10)
56 | fpr = FN / (FP + TN + 1e-10)
57 | iou = TP / (TP + FN + FP + 1e-10)
58 | return tpr, fnr, fpr, iou
59 |
60 |
61 | def over_rate(pred, gt):
62 | # pred = np.int64(pred / 255)
63 | # gt = np.int64(gt / 255)
64 | Rs = np.float(np.sum(gt == 255))
65 | Os = np.float(np.sum((pred == 255) & (gt == 0)))
66 | OR = Os / (Rs + Os)
67 | return OR
68 |
69 |
70 | def under_rate(pred, gt):
71 | # pred = np.int64(pred / 255)
72 | # gt = np.int64(gt / 255)
73 | Rs = np.float(np.sum(gt == 255))
74 | Us = np.float(np.sum((pred == 0) & (gt == 255)))
75 | Os = np.float(np.sum((pred == 255) & (gt == 0)))
76 | UR = Us / (Rs + Os)
77 | return UR
78 |
--------------------------------------------------------------------------------
/utils/losses.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import torch
4 | import torch.nn.functional as F
5 | from torch import nn as nn
6 | from torch.autograd import Variable, Function
7 | from torch.nn import MSELoss, SmoothL1Loss, L1Loss
8 | import numpy as np
9 |
10 |
11 | def make_one_hot(input, num_classes):
12 | """Convert class index tensor to one hot encoding tensor.
13 | Args:
14 | input: A tensor of shape [N, 1, *]
15 | num_classes: An int of number of class
16 | Returns:
17 | A tensor of shape [N, num_classes, *]
18 | """
19 | shape = np.array(input.shape)
20 | shape[1] = num_classes
21 | shape = tuple(shape)
22 | result = torch.zeros(shape)
23 | result = result.scatter_(1, input.cpu(), 1)
24 |
25 | return result
26 |
27 |
28 | class BinaryDiceLoss(nn.Module):
29 | """Dice loss of binary class
30 | Args:
31 | smooth: A float number to smooth loss, and avoid NaN error, default: 1
32 | p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2
33 | predict: A tensor of shape [N, *]
34 | target: A tensor of shape same with predict
35 | Returns:
36 | Loss tensor according to arg reduction
37 | Raise:
38 | Exception if unexpected reduction
39 | """
40 |
41 | def __init__(self, smooth=1, p=2):
42 | super(BinaryDiceLoss, self).__init__()
43 | self.smooth = smooth
44 | self.p = p
45 |
46 | def forward(self, predict, target):
47 | assert predict.shape[0] == target.shape[0], "predict & target batch size don't match"
48 | predict = predict.contiguous().view(predict.shape[0], -1)
49 | target = target.contiguous().view(target.shape[0], -1)
50 |
51 | num = torch.sum(torch.mul(predict, target)) * 2 + self.smooth
52 | den = torch.sum(predict.pow(self.p) + target.pow(self.p)) + self.smooth
53 |
54 | dice = num / den
55 | loss = 1 - dice
56 | return loss
57 |
58 |
59 | class DiceLoss(nn.Module):
60 | """Dice loss, need one hot encode input
61 | Args:
62 | weight: An array of shape [num_classes,]
63 | ignore_index: class index to ignore
64 | predict: A tensor of shape [N, C, *]
65 | target: A tensor of same shape with predict
66 | other args pass to BinaryDiceLoss
67 | Return:
68 | same as BinaryDiceLoss
69 | """
70 |
71 | def __init__(self, weight=None, ignore_index=None, **kwargs):
72 | super(DiceLoss, self).__init__()
73 | self.kwargs = kwargs
74 | self.weight = weight
75 | self.ignore_index = ignore_index
76 |
77 | def forward(self, predict, target):
78 | target = make_one_hot(target, num_classes=predict.shape[1])
79 | target = target.cuda()
80 | assert predict.shape == target.shape, 'predict & target shape do not match'
81 | dice = BinaryDiceLoss(**self.kwargs)
82 | total_loss = 0
83 | predict = F.softmax(predict, dim=1)
84 |
85 | for i in range(target.shape[1]):
86 | if i != self.ignore_index:
87 | dice_loss = dice(predict[:, i], target[:, i])
88 | if self.weight is not None:
89 | assert self.weight.shape[0] == target.shape[1], \
90 | 'Expect weight shape [{}], get[{}]'.format(target.shape[1], self.weight.shape[0])
91 | dice_loss *= self.weights[i]
92 | total_loss += dice_loss
93 |
94 | return total_loss / target.shape[1]
95 |
96 |
97 | # ---------------------------------------------------------------------------------------------------------
98 |
99 |
100 | def flatten(tensor):
101 | """Flattens a given tensor such that the channel axis is first.
102 | The shapes are transformed as follows:
103 | (N, C, D, H, W) -> (C, N * D * H * W)
104 | """
105 | C = tensor.size(1)
106 | # new axis order
107 | axis_order = (1, 0) + tuple(range(2, tensor.dim()))
108 | # Transpose: (N, C, D, H, W) -> (C, N, D, H, W)
109 | transposed = tensor.permute(axis_order)
110 | # Flatten: (C, N, D, H, W) -> (C, N * D * H * W)
111 | return transposed.contiguous().view(C, -1)
112 |
113 |
114 | class WeightedCrossEntropyLoss(nn.Module):
115 | """WeightedCrossEntropyLoss (WCE) as described in https://arxiv.org/pdf/1707.03237.pdf
116 | """
117 |
118 | def __init__(self, weight=None, ignore_index=-1):
119 | super(WeightedCrossEntropyLoss, self).__init__()
120 | self.register_buffer('weight', weight)
121 | self.ignore_index = ignore_index
122 |
123 | def forward(self, input, target):
124 | class_weights = self._class_weights(input)
125 | if self.weight is not None:
126 | weight = Variable(self.weight, requires_grad=False)
127 | class_weights = class_weights * weight
128 | return F.cross_entropy(input, target, weight=class_weights, ignore_index=self.ignore_index)
129 |
130 | @staticmethod
131 | def _class_weights(input):
132 | # normalize the input first
133 | input = F.softmax(input)
134 | flattened = flatten(input)
135 | nominator = (1. - flattened).sum(-1)
136 | denominator = flattened.sum(-1)
137 | class_weights = Variable(nominator / denominator, requires_grad=False)
138 | return class_weights
139 |
140 | # ---------------------------------------------------------------------------------------------
141 |
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import glob
4 | import cv2
5 | import torch.nn as nn
6 | import torch
7 | from PIL import ImageOps, Image
8 | from sklearn.metrics import confusion_matrix
9 | import SimpleITK as sitk
10 | import tqdm
11 | import vtk
12 |
13 |
14 | def ReScaleSize(image, re_size=512):
15 | w, h = image.size
16 | max_len = max(w, h)
17 | new_w, new_h = max_len, max_len
18 | delta_w = new_w - w
19 | delta_h = new_h - h
20 | padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
21 | image = ImageOps.expand(image, padding, fill=0)
22 | # origin_w, origin_h = w, h
23 | image = image.resize((re_size, re_size))
24 | return image # , origin_w, origin_h
25 |
26 |
27 | def Crop(image):
28 | left = 261
29 | top = 1
30 | right = 1110
31 | bottom = 850
32 | image = image.crop((left, top, right, bottom))
33 | return image
34 |
35 |
36 | def thresh_OTSU(path):
37 | for file in glob.glob(os.path.join(path, '*pred.png')):
38 | index = os.path.basename(file)[:-4]
39 | image = cv2.imread(file)
40 | gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
41 | thresh, img = cv2.threshold(gray, 0, 255, cv2.THRESH_OTSU)
42 | cv2.imwrite(os.path.join(path, index + '_otsu.png'), img)
43 | #cv2.imwrite(file, img)
44 | print(file, '\tdone!')
--------------------------------------------------------------------------------
/utils/train_metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import torch.nn as nn
4 | import torch
5 | from PIL import ImageOps, Image
6 | from sklearn.metrics import confusion_matrix
7 | from skimage import filters
8 |
9 | from utils.evaluation_metrics3D import metrics_3d, Dice
10 |
11 |
12 | def threshold(image):
13 | # t = filters.threshold_otsu(image, nbins=256)
14 | image[image >= 100] = 255
15 | image[image < 100] = 0
16 | return image
17 |
18 |
19 | def numeric_score(pred, gt):
20 | FP = np.float(np.sum((pred == 255) & (gt == 0)))
21 | FN = np.float(np.sum((pred == 0) & (gt == 255)))
22 | TP = np.float(np.sum((pred == 255) & (gt == 255)))
23 | TN = np.float(np.sum((pred == 0) & (gt == 0)))
24 | return FP, FN, TP, TN
25 |
26 |
27 | def metrics(pred, label, batch_size):
28 | # pred = torch.argmax(pred, dim=1) # for CE Loss series
29 | outputs = (pred.data.cpu().numpy() * 255).astype(np.uint8)
30 | labels = (label.data.cpu().numpy() * 255).astype(np.uint8)
31 | outputs = outputs.squeeze(1) # for MSELoss()
32 | labels = labels.squeeze(1) # for MSELoss()
33 | outputs = threshold(outputs) # for MSELoss()
34 |
35 | Acc, SEn = 0., 0.
36 | for i in range(batch_size):
37 | img = outputs[i, :, :]
38 | gt = labels[i, :, :]
39 | acc, sen = get_acc(img, gt)
40 | Acc += acc
41 | SEn += sen
42 | return Acc, SEn
43 |
44 |
45 | def metrics3dmse(pred, label, batch_size):
46 | outputs = (pred.data.cpu().numpy() * 255).astype(np.uint8)
47 | labels = (label.data.cpu().numpy() * 255).astype(np.uint8)
48 | outputs = outputs.squeeze(1) # for MSELoss()
49 | labels = labels.squeeze(1) # for MSELoss()
50 | outputs = threshold(outputs) # for MSELoss()
51 |
52 | tp, fn, fp, IoU = 0, 0, 0, 0
53 | for i in range(batch_size):
54 | img = outputs[i, :, :, :]
55 | gt = labels[i, :, :, :]
56 | tpr, fnr, fpr, iou = metrics_3d(img, gt)
57 | # dcr = Dice(img, gt)
58 | tp += tpr
59 | fn += fnr
60 | fp += fpr
61 | IoU += iou
62 | return tp, fn, fp, IoU
63 |
64 |
65 | def metrics3d(pred, label, batch_size):
66 | pred = torch.argmax(pred, dim=1) # for CE loss series
67 | outputs = (pred.data.cpu().numpy() * 255).astype(np.uint8)
68 | labels = (label.data.cpu().numpy() * 255).astype(np.uint8)
69 | # outputs = outputs.squeeze(1) # for MSELoss()
70 | # labels = labels.squeeze(1) # for MSELoss()
71 | # outputs = threshold(outputs) # for MSELoss()
72 |
73 | tp, fn, fp, IoU = 0, 0, 0, 0
74 | for i in range(batch_size):
75 | img = outputs[i, :, :, :]
76 | gt = labels[i, :, :, :]
77 | tpr, fnr, fpr, iou = metrics_3d(img, gt)
78 | # dcr = Dice(img, gt)
79 | tp += tpr
80 | fn += fnr
81 | fp += fpr
82 | IoU += iou
83 | return tp, fn, fp, IoU
84 |
85 |
86 | def get_acc(image, label):
87 | image = threshold(image)
88 |
89 | FP, FN, TP, TN = numeric_score(image, label)
90 | acc = (TP + TN) / (TP + FN + TN + FP + 1e-10)
91 | sen = (TP) / (TP + FN + 1e-10)
92 | return acc, sen
93 |
--------------------------------------------------------------------------------
/utils/visualize.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import visdom
3 |
4 |
5 | def init_visdom_line(x, y, title, xlabel, ylabel, env="default"):
6 | env = visdom.Visdom(env=env)
7 | panel = env.line(
8 | X=np.array([x]),
9 | Y=np.array([y]),
10 | opts=dict(title=title, showlegend=True, xlabel=xlabel, ylabel=ylabel)
11 | )
12 | return env, panel
13 |
14 |
15 | def update_lines(env, panel, x, y, update_type='append'):
16 | env.line(
17 | X=np.array([x]),
18 | Y=np.array([y]),
19 | win=panel,
20 | update=update_type
21 | )
22 |
--------------------------------------------------------------------------------