├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── cs2net.md ├── dataloader ├── MRABrainLoader.py ├── __init__.py ├── drive.py ├── octa.py ├── padova1.py ├── padova2.py └── stare.py ├── model ├── __init__.py ├── csnet.py └── csnet_3d.py ├── predict.py ├── predict3d.py ├── train.py ├── train3d.py └── utils ├── __init__.py ├── dice_loss_single_class.py ├── evaluation_metrics.py ├── evaluation_metrics3D.py ├── losses.py ├── misc.py ├── train_metrics.py └── visualize.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | .DS_Store 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 ineedzx 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CS-Net: Channel and Spatial Attention Network for Curvilinear Structure Segmentation 2 | 3 | Implementation of [CS-Net: Channel and Spatial Attention Network for Curvilinear Structure Segmentation](https://link.springer.com/chapter/10.1007/978-3-030-32239-7_80) 4 | 5 | For the details of 3D extended version of CS-Net, please refer to [CS2-Net: Deep Learning Segmentation of Curvilinear Structures in Medical Imaging](cs2net.md) 6 | 7 | --- 8 | 9 | ## Overview 10 | 11 |
12 | 14 | 15 | The main contribution of this work is the publication of two scarce datasets in the medical image field. Plesae click the link below to access the details and source data. [![](https://img.shields.io/badge/Download-CORN--1-green)](http://www.imed-lab.com/?p=16073) 16 | 17 | ## Requirements 18 | 19 | ![](https://img.shields.io/badge/PyTorch-%3E%3D0.4.1-orange) ![](https://img.shields.io/badge/tqdm-latest-orange) ![](https://img.shields.io/badge/cv2-latest-orange) ![](https://img.shields.io/badge/visdom-%3E%3D0.2.0-orange) ![](https://img.shields.io/badge/sklearn-latest-orange) 20 | 21 | The attention module was implemented based on [DANet](https://github.com/junfu1115/DANet). The difference between the proposed module and the original block is that we added a new 1x3 and 3x1 kernel convolution layer into spatial attention module. Plese refer to the paper for details. 22 | 23 | ## Get Started 24 | 25 | Using the ```train.py``` and ```predict.py``` to train and test the model on your own dataset, respectively. 26 | 27 | ## Examples 28 | 29 | - Vessel segmentation on Fundus 30 | 31 |
32 | 33 |
34 | 35 | - Vessel segmentation on OCT-A images 36 | 37 |
38 | 39 |
40 | 41 | - Nerve fiber tracing on CCM 42 | 43 |
44 | 45 |
46 | 47 | ## Citation 48 | 49 | ``` 50 | @inproceedings{mou2019cs, 51 | title={CS-Net: channel and spatial attention network for curvilinear structure segmentation}, 52 | author={Mou, Lei and Zhao, Yitian and Chen, Li and Cheng, Jun and Gu, Zaiwang and Hao, Huaying and Qi, Hong and Zheng, Yalin and Frangi, Alejandro and Liu, Jiang}, 53 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention}, 54 | pages={721--730}, 55 | year={2019}, 56 | organization={Springer} 57 | } 58 | ``` 59 | 60 | 61 | 62 | ## Useful Links 63 | 64 | | DRIVE | http://www.isi.uu.nl/Research/Databases/DRIVE/ | 65 | | :------------- | :---------------------------------------------------------- | 66 | | **STARE** | **http://www.ces.clemson.edu/ahoover/stare/** | 67 | | **IOSTAR** | **http://www.retinacheck.org/** | 68 | | **ToF MIDAS** | **http://insight-journal.org/midas/community/view/21** | 69 | | **Synthetic** | **https://github.com/giesekow/deepvesselnet/wiki/Datasets** | 70 | | **VascuSynth** | **http://vascusynth.cs.sfu.ca/Data.html** | 71 | -------------------------------------------------------------------------------- /cs2net.md: -------------------------------------------------------------------------------- 1 | 2 | # CS2-Net: Deep Learning Segmentation of Curvilinear Structures in Medical Imaging 3 | 4 | Implementation of [CS2-Net MedIA 2020](https://www.sciencedirect.com/science/article/pii/S1361841520302383) 5 | 6 | --- 7 | 8 | ## Overview 9 | 10 |
11 | 13 | 14 | ## Requirements 15 | 16 | ![](https://img.shields.io/badge/PyTorch-0.4.1-orange) ![](https://img.shields.io/badge/visdom-0.2.0-orange) ![](https://img.shields.io/badge/SimpleITK-latest-orange) 17 | 18 | ## Get Started 19 | 20 | - ```train3d.py``` is used to train the 3D segmentation network. 21 | 22 | - ```predict3d.py``` is used to test the trained model. 23 | 24 | - Please note that you should change the dataloader definition in ```train3d.py```. 25 | 26 | ## Examples 27 | 28 | - MRA brain vessel segmentation 29 | 30 |
31 | 32 |
33 | 34 | - Synthetic & VascuSynth 35 | 36 |
37 | 38 |
39 | 40 | ## Citation 41 | 42 | ``` 43 | @article{mou2020cs2, 44 | title={CS2-Net: Deep Learning Segmentation of Curvilinear Structures in Medical Imaging}, 45 | author={Mou, Lei and Zhao, Yitian and Fu, Huazhu and Liux, Yonghuai and Cheng, Jun and Zheng, Yalin and Su, Pan and Yang, Jianlong and Chen, Li and Frangi, Alejandro F and others}, 46 | journal={Medical Image Analysis}, 47 | pages={101874}, 48 | year={2020}, 49 | publisher={Elsevier} 50 | } 51 | ``` 52 | 53 | 54 | 55 | #### Corrections to: CS2-Net- Deep learning segmentation of curvilinear structures in medical imaging 56 | 57 | The original comparison results in Table 8 on page 14 are: 58 | 59 | 60 | 61 | The corrected comparison results are: 62 | 63 | 64 | 65 | ## Useful Links 66 | 67 | | DRIVE | http://www.isi.uu.nl/Research/Databases/DRIVE/ | 68 | | :------------- | :---------------------------------------------------------- | 69 | | **STARE** | **http://www.ces.clemson.edu/ahoover/stare/** | 70 | | **IOSTAR** | **http://www.retinacheck.org/** | 71 | | **ToF MIDAS** | **http://insight-journal.org/midas/community/view/21** | 72 | | **Synthetic** | **https://github.com/giesekow/deepvesselnet/wiki/Datasets** | 73 | | **VascuSynth** | **http://vascusynth.cs.sfu.ca/Data.html** | 74 | 75 | -------------------------------------------------------------------------------- /dataloader/MRABrainLoader.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import glob 4 | import torch 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | import random 8 | import warnings 9 | import SimpleITK as sitk 10 | import numpy as np 11 | from scipy.ndimage import rotate, map_coordinates, gaussian_filter 12 | 13 | warnings.filterwarnings('ignore') 14 | 15 | 16 | def load_dataset(root_dir, train=True): 17 | images = [] 18 | groundtruth = [] 19 | if train: 20 | sub_dir = 'training' 21 | else: 22 | sub_dir = 'test' 23 | images_path = os.path.join(root_dir, sub_dir, 'images') 24 | groundtruth_path = os.path.join(root_dir, sub_dir, 'mesh_label') 25 | 26 | for file in glob.glob(os.path.join(images_path, '*.mha')): 27 | image_name = os.path.basename(file)[:-8] 28 | groundtruth_name = image_name + '.mha' 29 | 30 | images.append(file) 31 | groundtruth.append(os.path.join(groundtruth_path, groundtruth_name)) 32 | 33 | return images, groundtruth 34 | 35 | 36 | class Data(Dataset): 37 | def __init__(self, 38 | root_dir, 39 | train=True, 40 | rotate=40, 41 | flip=True, 42 | random_crop=True, 43 | scale1=512): 44 | 45 | self.root_dir = root_dir 46 | self.train = train 47 | self.rotate = rotate 48 | self.flip = flip 49 | self.random_crop = random_crop 50 | self.transform = transforms.ToTensor() 51 | self.resize = scale1 52 | self.images, self.groundtruth = load_dataset(self.root_dir, self.train) 53 | 54 | def __len__(self): 55 | return len(self.images) 56 | 57 | def RandomCrop(self, image, label, crop_factor=(0, 0, 0)): 58 | """ 59 | Make a random crop of the whole volume 60 | :param image: 61 | :param label: 62 | :param crop_factor: The crop size that you want to crop 63 | :return: 64 | """ 65 | w, h, d = image.shape 66 | z = random.randint(0, w - crop_factor[0]) 67 | y = random.randint(0, h - crop_factor[1]) 68 | x = random.randint(0, d - crop_factor[2]) 69 | 70 | image = image[z:z + crop_factor[0], y:y + crop_factor[1], x:x + crop_factor[2]] 71 | label = label[z:z + crop_factor[0], y:y + crop_factor[1], x:x + crop_factor[2]] 72 | return image, label 73 | 74 | def __getitem__(self, idx): 75 | img_path = self.images[idx] 76 | gt_path = self.groundtruth[idx] 77 | 78 | image = sitk.ReadImage(img_path) 79 | image = sitk.GetArrayFromImage(image).astype(np.float32) # [x,y,z] -> [z,y,x] 80 | 81 | label = sitk.ReadImage(gt_path) 82 | # if use CE loss, type: astype(np.int64), or use MSE type: astype(np.float32) 83 | label = sitk.GetArrayFromImage(label).astype(np.int64) # [x,y,z] -> [z,y,x] 84 | 85 | image, label = self.RandomCrop(image, label, crop_factor=(64, 104, 112)) # [z,y,x] 86 | 87 | if self.train: 88 | image = torch.from_numpy(np.ascontiguousarray(image)).unsqueeze(0) 89 | label = torch.from_numpy(np.ascontiguousarray(label)).unsqueeze(0) 90 | 91 | else: 92 | image = torch.from_numpy(np.ascontiguousarray(image)).unsqueeze(0) 93 | label = torch.from_numpy(np.ascontiguousarray(label)).unsqueeze(0) 94 | 95 | image = image / 255 96 | label = label // 255 97 | 98 | return image, label 99 | -------------------------------------------------------------------------------- /dataloader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMED-Lab/CS-Net/25079c377f8db4b57f25c0adc7b70d1a02a3ee62/dataloader/__init__.py -------------------------------------------------------------------------------- /dataloader/drive.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import glob 4 | from torch.utils.data import Dataset 5 | from torchvision import transforms 6 | from PIL import Image, ImageEnhance 7 | from utils.misc import ReScaleSize 8 | import random 9 | import warnings 10 | import numpy as np 11 | import scipy.misc as misc 12 | 13 | warnings.filterwarnings('ignore') 14 | 15 | 16 | def load_dataset(root_dir, train=True): 17 | images = [] 18 | groundtruth = [] 19 | if train: 20 | sub_dir = 'training' 21 | else: 22 | sub_dir = 'test' 23 | images_path = os.path.join(root_dir, sub_dir, 'images') 24 | groundtruth_path = os.path.join(root_dir, sub_dir, '1st_manual') 25 | 26 | for file in glob.glob(os.path.join(images_path, '*.tif')): 27 | image_name = os.path.basename(file) 28 | groundtruth_name = image_name[:3] + 'manual1.gif' 29 | 30 | images.append(os.path.join(images_path, image_name)) 31 | groundtruth.append(os.path.join(groundtruth_path, groundtruth_name)) 32 | 33 | return images, groundtruth 34 | 35 | 36 | class Data(Dataset): 37 | def __init__(self, 38 | root_dir, 39 | train=True, 40 | rotate=40, 41 | flip=True, 42 | random_crop=True, 43 | scale1=512): 44 | 45 | self.root_dir = root_dir 46 | self.train = train 47 | self.rotate = rotate 48 | self.flip = flip 49 | self.random_crop = random_crop 50 | self.transform = transforms.ToTensor() 51 | self.resize = scale1 52 | self.images, self.groundtruth = load_dataset(self.root_dir, self.train) 53 | 54 | def __len__(self): 55 | return len(self.images) 56 | 57 | def RandomCrop(self, image, label, crop_size): 58 | crop_width, crop_height = crop_size 59 | w, h = image.size 60 | left = random.randint(0, w - crop_width) 61 | top = random.randint(0, h - crop_height) 62 | right = left + crop_width 63 | bottom = top + crop_height 64 | new_image = image.crop((left, top, right, bottom)) 65 | new_label = label.crop((left, top, right, bottom)) 66 | return new_image, new_label 67 | 68 | def RandomEnhance(self, image): 69 | value = random.uniform(-2, 2) 70 | random_seed = random.randint(1, 4) 71 | if random_seed == 1: 72 | img_enhanceed = ImageEnhance.Brightness(image) 73 | elif random_seed == 2: 74 | img_enhanceed = ImageEnhance.Color(image) 75 | elif random_seed == 3: 76 | img_enhanceed = ImageEnhance.Contrast(image) 77 | else: 78 | img_enhanceed = ImageEnhance.Sharpness(image) 79 | image = img_enhanceed.enhance(value) 80 | return image 81 | 82 | def rescale(self, img, re_size): 83 | w, h = img.size 84 | min_len = min(w, h) 85 | new_w, new_h = min_len, min_len 86 | scale_w = (w - new_w) // 2 87 | scale_h = (h - new_h) // 2 88 | box = (scale_w, scale_h, scale_w + new_w, scale_h + new_h) 89 | img = img.crop(box) 90 | img = img.resize((re_size, re_size)) 91 | return img 92 | 93 | def __getitem__(self, idx): 94 | img_path = self.images[idx] 95 | gt_path = self.groundtruth[idx] 96 | image = Image.open(img_path) 97 | label = Image.open(gt_path) 98 | 99 | image = self.rescale(image, self.resize) 100 | label = self.rescale(label, self.resize) 101 | 102 | if self.train: 103 | # augumentation 104 | angel = random.randint(-self.rotate, self.rotate) 105 | image = image.rotate(angel) 106 | label = label.rotate(angel) 107 | 108 | if random.random() > 0.5: 109 | image = self.RandomEnhance(image) 110 | 111 | image, label = self.RandomCrop(image, label, crop_size=[self.resize, self.resize]) 112 | 113 | # flip 114 | if self.flip and random.random() > 0.5: 115 | image = image.transpose(Image.FLIP_LEFT_RIGHT) 116 | label = label.transpose(Image.FLIP_LEFT_RIGHT) 117 | 118 | # img_size = image.size 119 | # if img_size[0] != self.resize: 120 | # image = image.resize((self.resize, self.resize)) 121 | # label = label.resize((self.resize, self.resize)) 122 | else: 123 | img_size = image.size 124 | if img_size[0] != self.resize: 125 | image = image.resize((self.resize, self.resize)) 126 | label = label.resize((self.resize, self.resize)) 127 | 128 | image = self.transform(image) 129 | label = self.transform(label) 130 | 131 | return image, label 132 | -------------------------------------------------------------------------------- /dataloader/octa.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import glob 4 | from torch.utils.data import Dataset 5 | from torchvision import transforms 6 | from PIL import Image, ImageEnhance, ImageOps 7 | import random 8 | import warnings 9 | 10 | warnings.filterwarnings('ignore') 11 | 12 | 13 | def load_dataset(root_dir, train=True): 14 | labels = [] 15 | images = [] 16 | if train: 17 | sub_dir = 'training' 18 | else: 19 | sub_dir = 'test' 20 | label_path = os.path.join(root_dir, sub_dir, 'label') 21 | image_path = os.path.join(root_dir, sub_dir, 'images') 22 | 23 | for file in glob.glob(os.path.join(image_path, '*.tif')): 24 | image_name = os.path.basename(file) 25 | label_name = image_name[:-4] + '_nerve_ann.tif' 26 | labels.append(os.path.join(label_path, label_name)) 27 | images.append(os.path.join(image_path, image_name)) 28 | return images, labels 29 | 30 | 31 | class Data(Dataset): 32 | def __init__(self, 33 | root_dir, 34 | train=True, 35 | rotate=45, 36 | flip=True, 37 | random_crop=True, 38 | scale1=512): 39 | 40 | self.root_dir = root_dir 41 | self.train = train 42 | self.rotate = rotate 43 | self.flip = flip 44 | self.random_crop = random_crop 45 | self.transform = transforms.ToTensor() 46 | self.resize = scale1 47 | self.images, self.groundtruth = load_dataset(self.root_dir, self.train) 48 | 49 | def __len__(self): 50 | return len(self.images) 51 | 52 | def RandomCrop(self, image, label, crop_size): 53 | crop_width, crop_height = crop_size 54 | w, h = image.size 55 | left = random.randint(0, w - crop_width) 56 | top = random.randint(0, h - crop_height) 57 | right = left + crop_width 58 | bottom = top + crop_height 59 | new_image = image.crop((left, top, right, bottom)) 60 | new_label = label.crop((left, top, right, bottom)) 61 | return new_image, new_label 62 | 63 | def RandomEnhance(self, image): 64 | value = random.uniform(-2, 2) 65 | random_seed = random.randint(1, 4) 66 | if random_seed == 1: 67 | img_enhanceed = ImageEnhance.Brightness(image) 68 | elif random_seed == 2: 69 | img_enhanceed = ImageEnhance.Color(image) 70 | elif random_seed == 3: 71 | img_enhanceed = ImageEnhance.Contrast(image) 72 | else: 73 | img_enhanceed = ImageEnhance.Sharpness(image) 74 | image = img_enhanceed.enhance(value) 75 | return image 76 | 77 | def Crop(self, image): 78 | left = 261 79 | top = 1 80 | right = 1110 81 | bottom = 850 82 | image = image.crop((left, top, right, bottom)) 83 | return image 84 | 85 | def ReScaleSize(self, image, re_size=512): 86 | w, h = image.size 87 | max_len = max(w, h) 88 | new_w, new_h = max_len, max_len 89 | delta_w = new_w - w 90 | delta_h = new_h - h 91 | padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2)) 92 | image = ImageOps.expand(image, padding, fill=0) 93 | # origin_w, origin_h = w, h 94 | image = image.resize((re_size, re_size)) 95 | return image # , origin_w, origin_h 96 | 97 | def __getitem__(self, idx): 98 | img_path = self.images[idx] 99 | gt_path = self.groundtruth[idx] 100 | 101 | image = Image.open(img_path) 102 | label = Image.open(gt_path) 103 | image = self.Crop(image) 104 | label = self.Crop(label) 105 | image = self.ReScaleSize(image, self.resize) 106 | label = self.ReScaleSize(label, self.resize) 107 | 108 | if self.train: 109 | # augumentation 110 | angel = random.randint(-self.rotate, self.rotate) 111 | image = image.rotate(angel) 112 | label = label.rotate(angel) 113 | 114 | if random.random() > 0.5: 115 | image = self.RandomEnhance(image) 116 | 117 | image, label = self.RandomCrop(image, label, crop_size=[self.resize, self.resize]) 118 | 119 | # flip 120 | if self.flip and random.random() > 0.5: 121 | image = image.transpose(Image.FLIP_LEFT_RIGHT) 122 | label = label.transpose(Image.FLIP_LEFT_RIGHT) 123 | 124 | else: 125 | img_size = image.size 126 | if img_size[0] != self.resize: 127 | image = image.resize((self.resize, self.resize)) 128 | label = label.resize((self.resize, self.resize)) 129 | 130 | image = self.transform(image) 131 | label = self.transform(label) 132 | 133 | return image, label 134 | -------------------------------------------------------------------------------- /dataloader/padova1.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import glob 4 | from torch.utils.data import Dataset 5 | from torchvision import transforms 6 | from PIL import Image, ImageEnhance 7 | import random 8 | import warnings 9 | 10 | warnings.filterwarnings('ignore') 11 | 12 | 13 | def load_dataset(root_dir, train=True): 14 | labels = [] 15 | images = [] 16 | if train: 17 | sub_dir = 'training' 18 | else: 19 | sub_dir = 'test' 20 | label_path = os.path.join(root_dir, sub_dir, 'label2') 21 | image_path = os.path.join(root_dir, sub_dir, 'images') 22 | 23 | for file in glob.glob(os.path.join(image_path, '*.tif')): 24 | image_name = os.path.basename(file) 25 | label_name = image_name[:-4] + '_centerline_overlay.tif' 26 | labels.append(os.path.join(label_path, label_name)) 27 | images.append(os.path.join(image_path, image_name)) 28 | return images, labels 29 | 30 | 31 | class Data(Dataset): 32 | def __init__(self, 33 | root_dir, 34 | train=True, 35 | rotate=45, 36 | flip=True, 37 | random_crop=True, 38 | scale1=384): 39 | 40 | self.root_dir = root_dir 41 | self.train = train 42 | self.rotate = rotate 43 | self.flip = flip 44 | self.random_crop = random_crop 45 | self.transform = transforms.ToTensor() 46 | self.resize = scale1 47 | self.images, self.groundtruth = load_dataset(self.root_dir, self.train) 48 | 49 | def __len__(self): 50 | return len(self.images) 51 | 52 | def RandomCrop(self, image, label, crop_size): 53 | crop_width, crop_height = crop_size 54 | w, h = image.size 55 | left = random.randint(0, w - crop_width) 56 | top = random.randint(0, h - crop_height) 57 | right = left + crop_width 58 | bottom = top + crop_height 59 | new_image = image.crop((left, top, right, bottom)) 60 | new_label = label.crop((left, top, right, bottom)) 61 | return new_image, new_label 62 | 63 | def RandomEnhance(self, image): 64 | value = random.uniform(-2, 2) 65 | random_seed = random.randint(1, 4) 66 | if random_seed == 1: 67 | img_enhanceed = ImageEnhance.Brightness(image) 68 | elif random_seed == 2: 69 | img_enhanceed = ImageEnhance.Color(image) 70 | elif random_seed == 3: 71 | img_enhanceed = ImageEnhance.Contrast(image) 72 | else: 73 | img_enhanceed = ImageEnhance.Sharpness(image) 74 | image = img_enhanceed.enhance(value) 75 | return image 76 | 77 | def __getitem__(self, idx): 78 | img_path = self.images[idx] 79 | gt_path = self.groundtruth[idx] 80 | 81 | image = Image.open(img_path) 82 | label = Image.open(gt_path) 83 | 84 | if self.train: 85 | # augumentation 86 | angel = random.randint(-self.rotate, self.rotate) 87 | image = image.rotate(angel) 88 | label = label.rotate(angel) 89 | 90 | if random.random() > 0.5: 91 | image = self.RandomEnhance(image) 92 | 93 | image, label = self.RandomCrop(image, label, crop_size=[self.resize, self.resize]) 94 | 95 | # flip 96 | if self.flip and random.random() > 0.5: 97 | image = image.transpose(Image.FLIP_LEFT_RIGHT) 98 | label = label.transpose(Image.FLIP_LEFT_RIGHT) 99 | 100 | else: 101 | img_size = image.size 102 | if img_size[0] != self.resize: 103 | image = image.resize((self.resize, self.resize)) 104 | label = label.resize((self.resize, self.resize)) 105 | 106 | image = self.transform(image) 107 | label = self.transform(label) 108 | 109 | return image, label 110 | -------------------------------------------------------------------------------- /dataloader/padova2.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import glob 4 | from torch.utils.data import Dataset 5 | from torchvision import transforms 6 | from PIL import Image, ImageEnhance 7 | import random 8 | import warnings 9 | 10 | warnings.filterwarnings('ignore') 11 | 12 | 13 | def load_dataset(root_dir, train=True): 14 | labels = [] 15 | images = [] 16 | if train: 17 | sub_dir = 'training' 18 | else: 19 | sub_dir = 'test' 20 | label_path = os.path.join(root_dir, sub_dir, 'label2') 21 | image_path = os.path.join(root_dir, sub_dir, 'images') 22 | 23 | for file in glob.glob(os.path.join(image_path, '*.tif')): 24 | image_name = os.path.basename(file) 25 | label_name = image_name[:-4] + '_centerline_overlay.tif' 26 | labels.append(os.path.join(label_path, label_name)) 27 | images.append(os.path.join(image_path, image_name)) 28 | return images, labels 29 | 30 | 31 | class Data(Dataset): 32 | def __init__(self, 33 | root_dir, 34 | train=True, 35 | rotate=45, 36 | flip=True, 37 | random_crop=True, 38 | scale1=384): 39 | 40 | self.root_dir = root_dir 41 | self.train = train 42 | self.rotate = rotate 43 | self.flip = flip 44 | self.random_crop = random_crop 45 | self.transform = transforms.ToTensor() 46 | self.resize = scale1 47 | self.images, self.groundtruth = load_dataset(self.root_dir, self.train) 48 | 49 | def __len__(self): 50 | return len(self.images) 51 | 52 | def RandomCrop(self, image, label, crop_size): 53 | crop_width, crop_height = crop_size 54 | w, h = image.size 55 | left = random.randint(0, w - crop_width) 56 | top = random.randint(0, h - crop_height) 57 | right = left + crop_width 58 | bottom = top + crop_height 59 | new_image = image.crop((left, top, right, bottom)) 60 | new_label = label.crop((left, top, right, bottom)) 61 | return new_image, new_label 62 | 63 | def RandomEnhance(self, image): 64 | value = random.uniform(-2, 2) 65 | random_seed = random.randint(1, 4) 66 | if random_seed == 1: 67 | img_enhanceed = ImageEnhance.Brightness(image) 68 | elif random_seed == 2: 69 | img_enhanceed = ImageEnhance.Color(image) 70 | elif random_seed == 3: 71 | img_enhanceed = ImageEnhance.Contrast(image) 72 | else: 73 | img_enhanceed = ImageEnhance.Sharpness(image) 74 | image = img_enhanceed.enhance(value) 75 | return image 76 | 77 | def __getitem__(self, idx): 78 | img_path = self.images[idx] 79 | gt_path = self.groundtruth[idx] 80 | 81 | image = Image.open(img_path) 82 | label = Image.open(gt_path) 83 | 84 | # image = ReScaleSize(image, self.resize) 85 | # label = ReScaleSize(label, self.resize) 86 | 87 | if self.train: 88 | # augumentation 89 | angel = random.randint(-self.rotate, self.rotate) 90 | image = image.rotate(angel) 91 | label = label.rotate(angel) 92 | 93 | if random.random() > 0.5: 94 | image = self.RandomEnhance(image) 95 | 96 | image, label = self.RandomCrop(image, label, crop_size=[self.resize, self.resize]) 97 | 98 | # flip 99 | if self.flip and random.random() > 0.5: 100 | image = image.transpose(Image.FLIP_LEFT_RIGHT) 101 | label = label.transpose(Image.FLIP_LEFT_RIGHT) 102 | 103 | else: 104 | img_size = image.size 105 | if img_size[0] != self.resize: 106 | image = image.resize((self.resize, self.resize)) 107 | label = label.resize((self.resize, self.resize)) 108 | 109 | image = self.transform(image) 110 | label = self.transform(label) 111 | 112 | return image, label 113 | -------------------------------------------------------------------------------- /dataloader/stare.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import glob 4 | from torch.utils.data import Dataset 5 | from torchvision import transforms 6 | from PIL import Image, ImageEnhance 7 | from utils.misc import ReScaleSize 8 | import random 9 | import warnings 10 | 11 | warnings.filterwarnings('ignore') 12 | 13 | 14 | def load_dataset(root_dir, train=True): 15 | images = [] 16 | groundtruth = [] 17 | if train: 18 | sub_dir = 'training' 19 | else: 20 | sub_dir = 'test' 21 | images_path = os.path.join(root_dir, sub_dir, 'images') 22 | groundtruth_path = os.path.join(root_dir, sub_dir, 'labels-ah') 23 | 24 | for file in glob.glob(os.path.join(images_path, '*.ppm')): 25 | image_name = os.path.basename(file) 26 | groundtruth_name = image_name[:-4] + '.ah.ppm' 27 | images.append(os.path.join(images_path, image_name)) 28 | groundtruth.append(os.path.join(groundtruth_path, groundtruth_name)) 29 | 30 | return images, groundtruth 31 | 32 | 33 | class Data(Dataset): 34 | def __init__(self, 35 | root_dir, 36 | train=True, 37 | rotate=40, 38 | flip=True, 39 | random_crop=True, 40 | scale1=688): 41 | 42 | self.root_dir = root_dir 43 | self.train = train 44 | self.rotate = rotate 45 | self.flip = flip 46 | self.random_crop = random_crop 47 | self.transform = transforms.ToTensor() 48 | self.resize = scale1 49 | self.images, self.groundtruth = load_dataset(self.root_dir, self.train) 50 | 51 | def __len__(self): 52 | return len(self.images) 53 | 54 | def RandomCrop(self, image, label, crop_size): 55 | crop_width, crop_height = crop_size 56 | w, h = image.size 57 | left = random.randint(0, w - crop_width) 58 | top = random.randint(0, h - crop_height) 59 | right = left + crop_width 60 | bottom = top + crop_height 61 | new_image = image.crop((left, top, right, bottom)) 62 | new_label = label.crop((left, top, right, bottom)) 63 | return new_image, new_label 64 | 65 | def RandomEnhance(self, image): 66 | value = random.uniform(-2, 2) 67 | random_seed = random.randint(1, 4) 68 | if random_seed == 1: 69 | img_enhanceed = ImageEnhance.Brightness(image) 70 | elif random_seed == 2: 71 | img_enhanceed = ImageEnhance.Color(image) 72 | elif random_seed == 3: 73 | img_enhanceed = ImageEnhance.Contrast(image) 74 | else: 75 | img_enhanceed = ImageEnhance.Sharpness(image) 76 | image = img_enhanceed.enhance(value) 77 | return image 78 | 79 | def __getitem__(self, idx): 80 | img_path = self.images[idx] 81 | gt_path = self.groundtruth[idx] 82 | image = Image.open(img_path) 83 | label = Image.open(gt_path) 84 | image = ReScaleSize(image, self.resize) 85 | label = ReScaleSize(label, self.resize) 86 | 87 | if self.train: 88 | # augumentation 89 | angel = random.randint(-self.rotate, self.rotate) 90 | image = image.rotate(angel) 91 | label = label.rotate(angel) 92 | 93 | if random.random() > 0.5: 94 | image = self.RandomEnhance(image) 95 | 96 | image, label = self.RandomCrop(image, label, crop_size=[self.resize, self.resize]) 97 | 98 | # flip 99 | if self.flip and random.random() > 0.5: 100 | image = image.transpose(Image.FLIP_LEFT_RIGHT) 101 | label = label.transpose(Image.FLIP_LEFT_RIGHT) 102 | 103 | else: 104 | img_size = image.size 105 | if img_size[0] != self.resize: 106 | image = image.resize((self.resize, self.resize)) 107 | label = label.resize((self.resize, self.resize)) 108 | 109 | image = self.transform(image) 110 | label = self.transform(label) 111 | 112 | return image, label 113 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMED-Lab/CS-Net/25079c377f8db4b57f25c0adc7b70d1a02a3ee62/model/__init__.py -------------------------------------------------------------------------------- /model/csnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Channel and Spatial CSNet Network (CS-Net). 3 | """ 4 | from __future__ import division 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | def downsample(): 11 | return nn.MaxPool2d(kernel_size=2, stride=2) 12 | 13 | 14 | def deconv(in_channels, out_channels): 15 | return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2) 16 | 17 | 18 | def initialize_weights(*models): 19 | for model in models: 20 | for m in model.modules(): 21 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 22 | nn.init.kaiming_normal(m.weight) 23 | if m.bias is not None: 24 | m.bias.data.zero_() 25 | elif isinstance(m, nn.BatchNorm2d): 26 | m.weight.data.fill_(1) 27 | m.bias.data.zero_() 28 | 29 | 30 | class ResEncoder(nn.Module): 31 | def __init__(self, in_channels, out_channels): 32 | super(ResEncoder, self).__init__() 33 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) 34 | self.bn1 = nn.BatchNorm2d(out_channels) 35 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) 36 | self.bn2 = nn.BatchNorm2d(out_channels) 37 | self.relu = nn.ReLU(inplace=False) 38 | self.conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1) 39 | 40 | def forward(self, x): 41 | residual = self.conv1x1(x) 42 | out = self.relu(self.bn1(self.conv1(x))) 43 | out = self.relu(self.bn2(self.conv2(out))) 44 | out += residual 45 | out = self.relu(out) 46 | return out 47 | 48 | 49 | class Decoder(nn.Module): 50 | def __init__(self, in_channels, out_channels): 51 | super(Decoder, self).__init__() 52 | self.conv = nn.Sequential( 53 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 54 | nn.BatchNorm2d(out_channels), 55 | nn.ReLU(inplace=True), 56 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 57 | nn.BatchNorm2d(out_channels), 58 | nn.ReLU(inplace=True) 59 | ) 60 | 61 | def forward(self, x): 62 | out = self.conv(x) 63 | return out 64 | 65 | 66 | class SpatialAttentionBlock(nn.Module): 67 | def __init__(self, in_channels): 68 | super(SpatialAttentionBlock, self).__init__() 69 | self.query = nn.Sequential( 70 | nn.Conv2d(in_channels,in_channels//8,kernel_size=(1,3), padding=(0,1)), 71 | nn.BatchNorm2d(in_channels//8), 72 | nn.ReLU(inplace=True) 73 | ) 74 | self.key = nn.Sequential( 75 | nn.Conv2d(in_channels, in_channels//8, kernel_size=(3,1), padding=(1,0)), 76 | nn.BatchNorm2d(in_channels//8), 77 | nn.ReLU(inplace=True) 78 | ) 79 | self.value = nn.Conv2d(in_channels, in_channels, kernel_size=1) 80 | self.gamma = nn.Parameter(torch.zeros(1)) 81 | self.softmax = nn.Softmax(dim=-1) 82 | 83 | def forward(self, x): 84 | """ 85 | :param x: input( BxCxHxW ) 86 | :return: affinity value + x 87 | """ 88 | B, C, H, W = x.size() 89 | # compress x: [B,C,H,W]-->[B,H*W,C], make a matrix transpose 90 | proj_query = self.query(x).view(B, -1, W * H).permute(0, 2, 1) 91 | proj_key = self.key(x).view(B, -1, W * H) 92 | affinity = torch.matmul(proj_query, proj_key) 93 | affinity = self.softmax(affinity) 94 | proj_value = self.value(x).view(B, -1, H * W) 95 | weights = torch.matmul(proj_value, affinity.permute(0, 2, 1)) 96 | weights = weights.view(B, C, H, W) 97 | out = self.gamma * weights + x 98 | return out 99 | 100 | 101 | class ChannelAttentionBlock(nn.Module): 102 | def __init__(self, in_channels): 103 | super(ChannelAttentionBlock, self).__init__() 104 | self.gamma = nn.Parameter(torch.zeros(1)) 105 | self.softmax = nn.Softmax(dim=-1) 106 | 107 | def forward(self, x): 108 | """ 109 | :param x: input( BxCxHxW ) 110 | :return: affinity value + x 111 | """ 112 | B, C, H, W = x.size() 113 | proj_query = x.view(B, C, -1) 114 | proj_key = x.view(B, C, -1).permute(0, 2, 1) 115 | affinity = torch.matmul(proj_query, proj_key) 116 | affinity_new = torch.max(affinity, -1, keepdim=True)[0].expand_as(affinity) - affinity 117 | affinity_new = self.softmax(affinity_new) 118 | proj_value = x.view(B, C, -1) 119 | weights = torch.matmul(affinity_new, proj_value) 120 | weights = weights.view(B, C, H, W) 121 | out = self.gamma * weights + x 122 | return out 123 | 124 | 125 | class AffinityAttention(nn.Module): 126 | """ Affinity attention module """ 127 | 128 | def __init__(self, in_channels): 129 | super(AffinityAttention, self).__init__() 130 | self.sab = SpatialAttentionBlock(in_channels) 131 | self.cab = ChannelAttentionBlock(in_channels) 132 | # self.conv1x1 = nn.Conv2d(in_channels * 2, in_channels, kernel_size=1) 133 | 134 | def forward(self, x): 135 | """ 136 | sab: spatial attention block 137 | cab: channel attention block 138 | :param x: input tensor 139 | :return: sab + cab 140 | """ 141 | sab = self.sab(x) 142 | cab = self.cab(x) 143 | out = sab + cab 144 | return out 145 | 146 | 147 | class CSNet(nn.Module): 148 | def __init__(self, classes, channels): 149 | """ 150 | :param classes: the object classes number. 151 | :param channels: the channels of the input image. 152 | """ 153 | super(CSNet, self).__init__() 154 | self.enc_input = ResEncoder(channels, 32) 155 | self.encoder1 = ResEncoder(32, 64) 156 | self.encoder2 = ResEncoder(64, 128) 157 | self.encoder3 = ResEncoder(128, 256) 158 | self.encoder4 = ResEncoder(256, 512) 159 | self.downsample = downsample() 160 | self.affinity_attention = AffinityAttention(512) 161 | self.attention_fuse = nn.Conv2d(512 * 2, 512, kernel_size=1) 162 | self.decoder4 = Decoder(512, 256) 163 | self.decoder3 = Decoder(256, 128) 164 | self.decoder2 = Decoder(128, 64) 165 | self.decoder1 = Decoder(64, 32) 166 | self.deconv4 = deconv(512, 256) 167 | self.deconv3 = deconv(256, 128) 168 | self.deconv2 = deconv(128, 64) 169 | self.deconv1 = deconv(64, 32) 170 | self.final = nn.Conv2d(32, classes, kernel_size=1) 171 | initialize_weights(self) 172 | 173 | def forward(self, x): 174 | enc_input = self.enc_input(x) 175 | down1 = self.downsample(enc_input) 176 | 177 | enc1 = self.encoder1(down1) 178 | down2 = self.downsample(enc1) 179 | 180 | enc2 = self.encoder2(down2) 181 | down3 = self.downsample(enc2) 182 | 183 | enc3 = self.encoder3(down3) 184 | down4 = self.downsample(enc3) 185 | 186 | input_feature = self.encoder4(down4) 187 | 188 | # Do Attenttion operations here 189 | attention = self.affinity_attention(input_feature) 190 | 191 | # attention_fuse = self.attention_fuse(torch.cat((input_feature, attention), dim=1)) 192 | attention_fuse = input_feature + attention 193 | 194 | # Do decoder operations here 195 | up4 = self.deconv4(attention_fuse) 196 | up4 = torch.cat((enc3, up4), dim=1) 197 | dec4 = self.decoder4(up4) 198 | 199 | up3 = self.deconv3(dec4) 200 | up3 = torch.cat((enc2, up3), dim=1) 201 | dec3 = self.decoder3(up3) 202 | 203 | up2 = self.deconv2(dec3) 204 | up2 = torch.cat((enc1, up2), dim=1) 205 | dec2 = self.decoder2(up2) 206 | 207 | up1 = self.deconv1(dec2) 208 | up1 = torch.cat((enc_input, up1), dim=1) 209 | dec1 = self.decoder1(up1) 210 | 211 | final = self.final(dec1) 212 | final = F.sigmoid(final) 213 | return final 214 | -------------------------------------------------------------------------------- /model/csnet_3d.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3D Channel and Spatial Attention Network (CSA-Net 3D). 3 | """ 4 | from __future__ import division 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | def downsample(): 11 | return nn.MaxPool3d(kernel_size=2, stride=2) 12 | 13 | 14 | def deconv(in_channels, out_channels): 15 | return nn.ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=2) 16 | 17 | 18 | def initialize_weights(*models): 19 | for model in models: 20 | for m in model.modules(): 21 | if isinstance(m, nn.Conv3d) or isinstance(m, nn.Linear): 22 | nn.init.kaiming_normal(m.weight) 23 | if m.bias is not None: 24 | m.bias.data.zero_() 25 | elif isinstance(m, nn.BatchNorm3d): 26 | m.weight.data.fill_(1) 27 | m.bias.data.zero_() 28 | 29 | 30 | class ResEncoder3d(nn.Module): 31 | def __init__(self, in_channels, out_channels): 32 | super(ResEncoder3d, self).__init__() 33 | self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1) 34 | self.bn1 = nn.BatchNorm3d(out_channels) 35 | self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1) 36 | self.bn2 = nn.BatchNorm3d(out_channels) 37 | self.relu = nn.ReLU(inplace=False) 38 | self.conv1x1 = nn.Conv3d(in_channels, out_channels, kernel_size=1) 39 | 40 | def forward(self, x): 41 | residual = self.conv1x1(x) 42 | out = self.relu(self.bn1(self.conv1(x))) 43 | out = self.relu(self.bn2(self.conv2(out))) 44 | out += residual 45 | out = self.relu(out) 46 | return out 47 | 48 | 49 | class Decoder3d(nn.Module): 50 | def __init__(self, in_channels, out_channels): 51 | super(Decoder3d, self).__init__() 52 | self.conv = nn.Sequential( 53 | nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1), 54 | nn.BatchNorm3d(out_channels), 55 | nn.ReLU(inplace=False), 56 | nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1), 57 | nn.BatchNorm3d(out_channels), 58 | nn.ReLU(inplace=False) 59 | ) 60 | 61 | def forward(self, x): 62 | out = self.conv(x) 63 | return out 64 | 65 | 66 | class SpatialAttentionBlock3d(nn.Module): 67 | def __init__(self, in_channels): 68 | super(SpatialAttentionBlock3d, self).__init__() 69 | self.query = nn.Conv3d(in_channels, in_channels // 8, kernel_size=(1, 3, 1), padding=(0, 1, 0)) 70 | self.key = nn.Conv3d(in_channels, in_channels // 8, kernel_size=(3, 1, 1), padding=(1, 0, 0)) 71 | self.judge = nn.Conv3d(in_channels, in_channels // 8, kernel_size=(1, 1, 3), padding=(0, 0, 1)) 72 | self.value = nn.Conv3d(in_channels, in_channels, kernel_size=1) 73 | self.gamma = nn.Parameter(torch.zeros(1)) 74 | self.softmax = nn.Softmax(dim=-1) 75 | 76 | def forward(self, x): 77 | """ 78 | :param x: input( BxCxHxWxZ ) 79 | :return: affinity value + x 80 | B: batch size 81 | C: channels 82 | H: height 83 | W: width 84 | D: slice number (depth) 85 | """ 86 | B, C, H, W, D = x.size() 87 | # compress x: [B,C,H,W,Z]-->[B,H*W*Z,C], make a matrix transpose 88 | proj_query = self.query(x).view(B, -1, W * H * D).permute(0, 2, 1) # -> [B,W*H*D,C] 89 | proj_key = self.key(x).view(B, -1, W * H * D) # -> [B,H*W*D,C] 90 | proj_judge = self.judge(x).view(B, -1, W * H * D).permute(0, 2, 1) # -> [B,C,H*W*D] 91 | 92 | affinity1 = torch.matmul(proj_query, proj_key) 93 | affinity2 = torch.matmul(proj_judge, proj_key) 94 | affinity = torch.matmul(affinity1, affinity2) 95 | affinity = self.softmax(affinity) 96 | 97 | proj_value = self.value(x).view(B, -1, H * W * D) # -> C*N 98 | weights = torch.matmul(proj_value, affinity) 99 | weights = weights.view(B, C, H, W, D) 100 | out = self.gamma * weights + x 101 | return out 102 | 103 | 104 | class ChannelAttentionBlock3d(nn.Module): 105 | def __init__(self, in_channels): 106 | super(ChannelAttentionBlock3d, self).__init__() 107 | self.gamma = nn.Parameter(torch.zeros(1)) 108 | self.softmax = nn.Softmax(dim=-1) 109 | 110 | def forward(self, x): 111 | """ 112 | :param x: input( BxCxHxWxD ) 113 | :return: affinity value + x 114 | """ 115 | B, C, H, W, D = x.size() 116 | proj_query = x.view(B, C, -1).permute(0, 2, 1) 117 | proj_key = x.view(B, C, -1) 118 | proj_judge = x.view(B, C, -1).permute(0, 2, 1) 119 | affinity1 = torch.matmul(proj_key, proj_query) 120 | affinity2 = torch.matmul(proj_key, proj_judge) 121 | affinity = torch.matmul(affinity1, affinity2) 122 | affinity_new = torch.max(affinity, -1, keepdim=True)[0].expand_as(affinity) - affinity 123 | affinity_new = self.softmax(affinity_new) 124 | proj_value = x.view(B, C, -1) 125 | weights = torch.matmul(affinity_new, proj_value) 126 | weights = weights.view(B, C, H, W, D) 127 | out = self.gamma * weights + x 128 | return out 129 | 130 | 131 | class AffinityAttention3d(nn.Module): 132 | """ Affinity attention module """ 133 | 134 | def __init__(self, in_channels): 135 | super(AffinityAttention3d, self).__init__() 136 | self.sab = SpatialAttentionBlock3d(in_channels) 137 | self.cab = ChannelAttentionBlock3d(in_channels) 138 | # self.conv1x1 = nn.Conv2d(in_channels * 2, in_channels, kernel_size=1) 139 | 140 | def forward(self, x): 141 | """ 142 | sab: spatial attention block 143 | cab: channel attention block 144 | :param x: input tensor 145 | :return: sab + cab 146 | """ 147 | sab = self.sab(x) 148 | cab = self.cab(x) 149 | out = sab + cab + x 150 | return out 151 | 152 | 153 | class CSNet3D(nn.Module): 154 | def __init__(self, classes, channels): 155 | """ 156 | :param classes: the object classes number. 157 | :param channels: the channels of the input image. 158 | """ 159 | super(CSNet3D, self).__init__() 160 | self.enc_input = ResEncoder3d(channels, 16) 161 | self.encoder1 = ResEncoder3d(16, 32) 162 | self.encoder2 = ResEncoder3d(32, 64) 163 | self.encoder3 = ResEncoder3d(64, 128) 164 | self.encoder4 = ResEncoder3d(128, 256) 165 | self.downsample = downsample() 166 | self.affinity_attention = AffinityAttention3d(256) 167 | self.attention_fuse = nn.Conv3d(256 * 2, 256, kernel_size=1) 168 | self.decoder4 = Decoder3d(256, 128) 169 | self.decoder3 = Decoder3d(128, 64) 170 | self.decoder2 = Decoder3d(64, 32) 171 | self.decoder1 = Decoder3d(32, 16) 172 | self.deconv4 = deconv(256, 128) 173 | self.deconv3 = deconv(128, 64) 174 | self.deconv2 = deconv(64, 32) 175 | self.deconv1 = deconv(32, 16) 176 | self.final = nn.Conv3d(16, classes, kernel_size=1) 177 | initialize_weights(self) 178 | 179 | def forward(self, x): 180 | enc_input = self.enc_input(x) 181 | down1 = self.downsample(enc_input) 182 | 183 | enc1 = self.encoder1(down1) 184 | down2 = self.downsample(enc1) 185 | 186 | enc2 = self.encoder2(down2) 187 | down3 = self.downsample(enc2) 188 | 189 | enc3 = self.encoder3(down3) 190 | down4 = self.downsample(enc3) 191 | 192 | input_feature = self.encoder4(down4) 193 | 194 | # Do Attenttion operations here 195 | attention = self.affinity_attention(input_feature) 196 | attention_fuse = input_feature + attention 197 | 198 | # Do decoder operations here 199 | up4 = self.deconv4(attention_fuse) 200 | up4 = torch.cat((enc3, up4), dim=1) 201 | dec4 = self.decoder4(up4) 202 | 203 | up3 = self.deconv3(dec4) 204 | up3 = torch.cat((enc2, up3), dim=1) 205 | dec3 = self.decoder3(up3) 206 | 207 | up2 = self.deconv2(dec3) 208 | up2 = torch.cat((enc1, up2), dim=1) 209 | dec2 = self.decoder2(up2) 210 | 211 | up1 = self.deconv1(dec2) 212 | up1 = torch.cat((enc_input, up1), dim=1) 213 | dec1 = self.decoder1(up1) 214 | 215 | final = self.final(dec1) 216 | final = F.sigmoid(final) 217 | return final 218 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms 3 | from PIL import Image, ImageOps 4 | 5 | import numpy as np 6 | import scipy.misc as misc 7 | import os 8 | import glob 9 | 10 | from utils.misc import thresh_OTSU, ReScaleSize, Crop 11 | from utils.model_eval import eval 12 | 13 | DATABASE = './DRIVE/' 14 | # 15 | args = { 16 | 'root' : './dataset/' + DATABASE, 17 | 'test_path': './dataset/' + DATABASE + 'test/', 18 | 'pred_path': 'assets/' + 'DRIVE/', 19 | 'img_size' : 512 20 | } 21 | 22 | if not os.path.exists(args['pred_path']): 23 | os.makedirs(args['pred_path']) 24 | 25 | 26 | def rescale(img): 27 | w, h = img.size 28 | min_len = min(w, h) 29 | new_w, new_h = min_len, min_len 30 | scale_w = (w - new_w) // 2 31 | scale_h = (h - new_h) // 2 32 | box = (scale_w, scale_h, scale_w + new_w, scale_h + new_h) 33 | img = img.crop(box) 34 | return img 35 | 36 | 37 | def ReScaleSize_DRIVE(image, re_size=512): 38 | w, h = image.size 39 | min_len = min(w, h) 40 | new_w, new_h = min_len, min_len 41 | scale_w = (w - new_w) // 2 42 | scale_h = (h - new_h) // 2 43 | box = (scale_w, scale_h, scale_w + new_w, scale_h + new_h) 44 | image = image.crop(box) 45 | image = image.resize((re_size, re_size)) 46 | return image # , origin_w, origin_h 47 | 48 | 49 | def ReScaleSize_STARE(image, re_size=512): 50 | w, h = image.size 51 | max_len = max(w, h) 52 | new_w, new_h = max_len, max_len 53 | delta_w = new_w - w 54 | delta_h = new_h - h 55 | padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2)) 56 | image = ImageOps.expand(image, padding, fill=0) 57 | # origin_w, origin_h = w, h 58 | image = image.resize((re_size, re_size)) 59 | return image # , origin_w, origin_h 60 | 61 | 62 | def load_nerve(): 63 | test_images = [] 64 | test_labels = [] 65 | for file in glob.glob(os.path.join(args['test_path'], 'orig', '*.tif')): 66 | basename = os.path.basename(file) 67 | file_name = basename[:-4] 68 | image_name = os.path.join(args['test_path'], 'orig', basename) 69 | label_name = os.path.join(args['test_path'], 'mask2', file_name + '_centerline_overlay.tif') 70 | test_images.append(image_name) 71 | test_labels.append(label_name) 72 | return test_images, test_labels 73 | 74 | 75 | def load_drive(): 76 | test_images = [] 77 | test_labels = [] 78 | for file in glob.glob(os.path.join(args['test_path'], 'images', '*.tif')): 79 | basename = os.path.basename(file) 80 | file_name = basename[:3] 81 | image_name = os.path.join(args['test_path'], 'images', basename) 82 | label_name = os.path.join(args['test_path'], '1st_manual', file_name + 'manual1.gif') 83 | test_images.append(image_name) 84 | test_labels.append(label_name) 85 | return test_images, test_labels 86 | 87 | 88 | def load_stare(): 89 | test_images = [] 90 | test_labels = [] 91 | for file in glob.glob(os.path.join(args['test_path'], 'images', '*.ppm')): 92 | basename = os.path.basename(file) 93 | file_name = basename[:-4] 94 | image_name = os.path.join(args['test_path'], 'images', basename) 95 | label_name = os.path.join(args['test_path'], 'labels-ah', file_name + '.ah.ppm') 96 | test_images.append(image_name) 97 | test_labels.append(label_name) 98 | return test_images, test_labels 99 | 100 | 101 | def load_padova1(): 102 | test_images = [] 103 | test_labels = [] 104 | for file in glob.glob(os.path.join(args['test_path'], 'images', '*.tif')): 105 | basename = os.path.basename(file) 106 | file_name = basename[:-4] 107 | image_name = os.path.join(args['test_path'], 'images', basename) 108 | label_name = os.path.join(args['test_path'], 'label2', file_name + '_centerline_overlay.tif') 109 | test_images.append(image_name) 110 | test_labels.append(label_name) 111 | return test_images, test_labels 112 | 113 | 114 | def load_octa(): 115 | test_images = [] 116 | test_labels = [] 117 | for file in glob.glob(os.path.join(args['test_path'], 'images', '*.png')): 118 | basename = os.path.basename(file) 119 | file_name = basename[:-4] 120 | image_name = os.path.join(args['test_path'], 'images', basename) 121 | label_name = os.path.join(args['test_path'], 'label', file_name + '_nerve_ann.tif') 122 | test_images.append(image_name) 123 | test_labels.append(label_name) 124 | return test_images, test_labels 125 | 126 | 127 | def load_net(): 128 | net = torch.load('./checkpoint/xxxx.pkl') 129 | return net 130 | 131 | 132 | def save_prediction(pred, filename=''): 133 | save_path = args['pred_path'] + 'pred/' 134 | if not os.path.exists(save_path): 135 | os.makedirs(save_path) 136 | print("Make dirs success!") 137 | mask = pred.data.cpu().numpy() * 255 138 | mask = np.transpose(np.squeeze(mask, axis=0), [1, 2, 0]) 139 | mask = np.squeeze(mask, axis=-1) 140 | misc.imsave(save_path + filename + '.png', mask) 141 | 142 | 143 | def predict(): 144 | net = load_net() 145 | # images, labels = load_nerve() 146 | images, labels = load_drive() 147 | # images, labels = load_stare() 148 | # images, labels = load_padova1() 149 | # images, labels = load_octa() 150 | 151 | transform = transforms.Compose([ 152 | transforms.ToTensor() 153 | ]) 154 | 155 | with torch.no_grad(): 156 | net.eval() 157 | for i in range(len(images)): 158 | print(images[i]) 159 | name_list = images[i].split('/') 160 | index = name_list[-1][:-4] 161 | image = Image.open(images[i]) 162 | # image=image.convert("RGB") 163 | label = Image.open(labels[i]) 164 | image, label = center_crop(image, label) 165 | 166 | # for other retinal vessel 167 | # image = rescale(image) 168 | # label = rescale(label) 169 | # image = ReScaleSize_STARE(image, re_size=args['img_size']) 170 | # label = ReScaleSize_DRIVE(label, re_size=args['img_size']) 171 | 172 | # for OCTA 173 | # image = Crop(image) 174 | # image = ReScaleSize(image) 175 | # label = Crop(label) 176 | # label = ReScaleSize(label) 177 | 178 | # label = label.resize((args['img_size'], args['img_size'])) 179 | # if cuda 180 | image = transform(image).cuda() 181 | # image = transform(image) 182 | image = image.unsqueeze(0) 183 | output = net(image) 184 | 185 | save_prediction(output, filename=index + '_pred') 186 | print("output saving successfully") 187 | 188 | 189 | if __name__ == '__main__': 190 | predict() 191 | thresh_OTSU(args['pred_path'] + 'pred/') 192 | -------------------------------------------------------------------------------- /predict3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | import os 5 | import glob 6 | from tqdm import tqdm 7 | import SimpleITK as sitk 8 | from utils.misc import get_spacing 9 | 10 | os.environ['CUDA_VISIBLE_DEVICES'] = "1" 11 | 12 | DATABASE = 'VascuSynth3/' 13 | # 14 | args = { 15 | 'root' : './dataset/' + DATABASE, 16 | 'test_path': './dataset/' + DATABASE + 'test/', 17 | 'pred_path': 'assets/' + 'VascuSynth3/', 18 | 'img_size' : 512 19 | } 20 | 21 | if not os.path.exists(args['pred_path']): 22 | os.makedirs(args['pred_path']) 23 | 24 | 25 | def rescale(img): 26 | w, h = img.size 27 | min_len = min(w, h) 28 | new_w, new_h = min_len, min_len 29 | scale_w = (w - new_w) // 2 30 | scale_h = (h - new_h) // 2 31 | box = (scale_w, scale_h, scale_w + new_w, scale_h + new_h) 32 | img = img.crop(box) 33 | return img 34 | 35 | 36 | def load_3d(): 37 | test_images = [] 38 | test_labels = [] 39 | for file in glob.glob(os.path.join(args['test_path'], 'images', '*.mha')): 40 | basename = os.path.basename(file) 41 | file_name = basename[:-8] 42 | image_name = os.path.join(args['test_path'], 'images', basename) 43 | label_name = os.path.join(args['test_path'], 'label', file_name + 'gt.mha') 44 | test_images.append(image_name) 45 | test_labels.append(label_name) 46 | return test_images, test_labels 47 | 48 | 49 | def load_net(): 50 | net = torch.load('/home/imed/Research/Attention/checkpoint/model.pkl') 51 | return net 52 | 53 | 54 | def save_prediction(pred, filename='', spacing=None): 55 | pred = torch.argmax(pred, dim=1) 56 | save_path = args['pred_path'] + 'pred/' 57 | if not os.path.exists(save_path): 58 | os.makedirs(save_path) 59 | print("Make dirs success!") 60 | # for MSELoss() 61 | mask = (pred.data.cpu().numpy() * 255).astype(np.uint8) 62 | 63 | # thresholding 64 | # mask[mask >= 100] = 255 65 | # mask[mask < 100] = 0 66 | 67 | # mask = (mask.squeeze(0)).squeeze(0) # 3D numpy array 68 | mask = mask.squeeze(0) # for CE Loss 69 | # image = nib.Nifti1Image(np.int32(mask), affine) 70 | # nib.save(image, save_path + filename + ".nii.gz") 71 | mask = sitk.GetImageFromArray(mask) 72 | # if spacing is not None: 73 | # mask.SetSpacing(spacing) 74 | sitk.WriteImage(mask, os.path.join(save_path + filename + ".mha")) 75 | 76 | 77 | def save_probability(pred, label, filename=""): 78 | save_path = args['pred_path'] + 'pred/' 79 | if not os.path.exists(save_path): 80 | os.makedirs(save_path) 81 | print("Make dirs success!") 82 | # # for MSELoss() 83 | # mask = (pred.data.cpu().numpy() * 255) # .astype(np.uint8) 84 | # 85 | # mask = mask.squeeze(0) 86 | # class0 = mask[0, :, :, :] 87 | # class1 = mask[1, :, :, :] 88 | # label = label / 255 89 | # class0 = class0 * label 90 | # class1 = class1 * label 91 | # 92 | # probability = class0 + class1 93 | 94 | probability = F.softmax(pred, dim=1) 95 | probability.squeeze_(0) 96 | class0 = probability[0, :, :, :] 97 | class1 = probability[1, :, :, :] 98 | class0 = sitk.GetImageFromArray(class0) 99 | class1 = sitk.GetImageFromArray(class1) 100 | sitk.WriteImage(class1, os.path.join(save_path + filename + "class1.mha")) 101 | 102 | 103 | def save_label(label, index, spacing=None): 104 | label_path = args['pred_path'] + 'label/' 105 | if not os.path.exists(label_path): 106 | os.makedirs(label_path) 107 | label = sitk.GetImageFromArray(label) 108 | if spacing is not None: 109 | label.SetSpacing(spacing) 110 | sitk.WriteImage(label, os.path.join(label_path, index + ".mha")) 111 | 112 | 113 | def predict(): 114 | net = load_net() 115 | images, labels = load_3d() 116 | with torch.no_grad(): 117 | net.eval() 118 | for i in tqdm(range(len(images))): 119 | name_list = images[i].split('/') 120 | index = name_list[-1][:-4] 121 | image = sitk.ReadImage(images[i]) 122 | image = sitk.GetArrayFromImage(image).astype(np.float32) 123 | image = image / 255 124 | label = sitk.ReadImage(labels[i]) 125 | label = sitk.GetArrayFromImage(label).astype(np.int64) 126 | # label = label / 255 127 | # VascuSynth 128 | # image = image[2:98, 2:98, 2:98] 129 | # label = label[2:98, 2:98, 2:98] 130 | save_label(label, index) 131 | # if cuda 132 | image = torch.from_numpy(np.ascontiguousarray(image)).unsqueeze(0).unsqueeze(0) 133 | image = image.cuda() 134 | output = net(image) 135 | save_prediction(output, filename=index + '_pred', spacing=None) 136 | 137 | 138 | if __name__ == '__main__': 139 | predict() 140 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training script for CS-Net 3 | """ 4 | import os 5 | import torch 6 | import torch.nn as nn 7 | from torch import optim 8 | from torch.utils.data import DataLoader 9 | import visdom 10 | import numpy as np 11 | from model.csnet import CSNet 12 | from dataloader.drive import Data 13 | from utils.train_metrics import metrics 14 | from utils.visualize import init_visdom_line, update_lines 15 | from utils.dice_loss_single_class import dice_coeff_loss 16 | 17 | # os.environ["CUDA_VISIBLE_DEVICES"] = "1" 18 | 19 | args = { 20 | 'root' : '', 21 | 'data_path' : 'dataset/DRIVE/', 22 | 'epochs' : 1000, 23 | 'lr' : 0.0001, 24 | 'snapshot' : 100, 25 | 'test_step' : 1, 26 | 'ckpt_path' : 'checkpoint/', 27 | 'batch_size': 8, 28 | } 29 | 30 | # # Visdom--------------------------------------------------------- 31 | X, Y = 0, 0.5 # for visdom 32 | x_acc, y_acc = 0, 0 33 | x_sen, y_sen = 0, 0 34 | env, panel = init_visdom_line(X, Y, title='Train Loss', xlabel="iters", ylabel="loss") 35 | env1, panel1 = init_visdom_line(x_acc, y_acc, title="Accuracy", xlabel="iters", ylabel="accuracy") 36 | env2, panel2 = init_visdom_line(x_sen, y_sen, title="Sensitivity", xlabel="iters", ylabel="sensitivity") 37 | # # --------------------------------------------------------------- 38 | 39 | def save_ckpt(net, iter): 40 | if not os.path.exists(args['ckpt_path']): 41 | os.makedirs(args['ckpt_path']) 42 | torch.save(net, args['ckpt_path'] + 'CS_Net_DRIVE_' + str(iter) + '.pkl') 43 | print('--->saved model:{}<--- '.format(args['root'] + args['ckpt_path'])) 44 | 45 | 46 | # adjust learning rate (poly) 47 | def adjust_lr(optimizer, base_lr, iter, max_iter, power=0.9): 48 | lr = base_lr * (1 - float(iter) / max_iter) ** power 49 | for param_group in optimizer.param_groups: 50 | param_group['lr'] = lr 51 | 52 | 53 | def train(): 54 | # set the channels to 3 when the format is RGB, otherwise 1. 55 | net = CSNet(classes=1, channels=3).cuda() 56 | net = nn.DataParallel(net, device_ids=[0, 1]).cuda() 57 | optimizer = optim.Adam(net.parameters(), lr=args['lr'], weight_decay=0.0005) 58 | critrion = nn.MSELoss().cuda() 59 | # critrion = nn.CrossEntropyLoss().cuda() 60 | print("---------------start training------------------") 61 | # load train dataset 62 | train_data = Data(args['data_path'], train=True) 63 | batchs_data = DataLoader(train_data, batch_size=args['batch_size'], num_workers=2, shuffle=True) 64 | 65 | iters = 1 66 | accuracy = 0. 67 | sensitivty = 0. 68 | for epoch in range(args['epochs']): 69 | net.train() 70 | for idx, batch in enumerate(batchs_data): 71 | image = batch[0].cuda() 72 | label = batch[1].cuda() 73 | optimizer.zero_grad() 74 | pred = net(image) 75 | # pred = pred.squeeze_(1) 76 | loss1 = critrion(pred, label) 77 | loss2 = dice_coeff_loss(pred, label) 78 | loss = loss1 + loss2 79 | loss.backward() 80 | optimizer.step() 81 | acc, sen = metrics(pred, label, pred.shape[0]) 82 | print('[{0:d}:{1:d}] --- loss:{2:.10f}\tacc:{3:.4f}\tsen:{4:.4f}'.format(epoch + 1, 83 | iters, loss.item(), 84 | acc / pred.shape[0], 85 | sen / pred.shape[0])) 86 | iters += 1 87 | # # ---------------------------------- visdom -------------------------------------------------- 88 | X, x_acc, x_sen = iters, iters, iters 89 | Y, y_acc, y_sen = loss.item(), acc / pred.shape[0], sen / pred.shape[0] 90 | update_lines(env, panel, X, Y) 91 | update_lines(env1, panel1, x_acc, y_acc) 92 | update_lines(env2, panel2, x_sen, y_sen) 93 | # # -------------------------------------------------------------------------------------------- 94 | 95 | adjust_lr(optimizer, base_lr=args['lr'], iter=epoch, max_iter=args['epochs'], power=0.9) 96 | if (epoch + 1) % args['snapshot'] == 0: 97 | save_ckpt(net, epoch + 1) 98 | 99 | # model eval 100 | if (epoch + 1) % args['test_step'] == 0: 101 | test_acc, test_sen = model_eval(net) 102 | print("Average acc:{0:.4f}, average sen:{1:.4f}".format(test_acc, test_sen)) 103 | 104 | if (accuracy > test_acc) & (sensitivty > test_sen): 105 | save_ckpt(net, epoch + 1 + 8888888) 106 | accuracy = test_acc 107 | sensitivty = test_sen 108 | 109 | 110 | def model_eval(net): 111 | print("Start testing model...") 112 | test_data = Data(args['data_path'], train=False) 113 | batchs_data = DataLoader(test_data, batch_size=1) 114 | 115 | net.eval() 116 | Acc, Sen = [], [] 117 | file_num = 0 118 | for idx, batch in enumerate(batchs_data): 119 | image = batch[0].float().cuda() 120 | label = batch[1].float().cuda() 121 | pred_val = net(image) 122 | acc, sen = metrics(pred_val, label, pred_val.shape[0]) 123 | print("\t---\t test acc:{0:.4f} test sen:{1:.4f}".format(acc, sen)) 124 | Acc.append(acc) 125 | Sen.append(sen) 126 | file_num += 1 127 | # for better view, add testing visdom here. 128 | return np.mean(Acc), np.mean(Sen) 129 | 130 | 131 | if __name__ == '__main__': 132 | train() 133 | -------------------------------------------------------------------------------- /train3d.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Author : Lei Mou 4 | # @File : train3d.py 5 | """ 6 | Training script for CS-Net 3D 7 | """ 8 | import os 9 | import torch 10 | import torch.nn as nn 11 | from torch import optim 12 | from torch.utils.data import DataLoader 13 | import datetime 14 | import numpy as np 15 | 16 | from model.csnet_3d import CSNet3D 17 | from dataloader.MRABrainLoader import Data 18 | 19 | from utils.train_metrics import metrics3d 20 | from utils.losses import WeightedCrossEntropyLoss, DiceLoss 21 | from utils.visualize import init_visdom_line, update_lines 22 | 23 | args = { 24 | 'root' : '/home/user/name/Projects/', 25 | 'data_path' : 'dataset/data dir(your own data path)/', 26 | 'epochs' : 200, 27 | 'lr' : 0.0001, 28 | 'snapshot' : 100, 29 | 'test_step' : 1, 30 | 'ckpt_path' : './checkpoint/', 31 | 'batch_size': 2, 32 | } 33 | 34 | # # Visdom--------------------------------------------------------- 35 | # The initial values are defined by myself 36 | X, Y = 0, 1.0 # for visdom 37 | x_tp, y_tp = 0, 0 38 | x_fn, y_fn = 0.4, 0.4 39 | x_fp, y_fp = 0.4, 0.4 40 | x_testtp, y_testtp = 0.0, 0.0 41 | x_testdc, y_testdc = 0.0, 0.0 42 | env, panel = init_visdom_line(X, Y, title='Train Loss', xlabel="iters", ylabel="loss", env="wce") 43 | env1, panel1 = init_visdom_line(x_tp, y_tp, title="TPR", xlabel="iters", ylabel="TPR", env="wce") 44 | env2, panel2 = init_visdom_line(x_fn, y_fn, title="FNR", xlabel="iters", ylabel="FNR", env="wce") 45 | env3, panel3 = init_visdom_line(x_fp, y_fp, title="FPR", xlabel="iters", ylabel="FPR", env="wce") 46 | env6, panel6 = init_visdom_line(x_testtp, y_testtp, title="DSC", xlabel="iters", ylabel="DSC", env="wce") 47 | env4, panel4 = init_visdom_line(x_testtp, y_testtp, title="Test Loss", xlabel="iters", ylabel="Test Loss", env="wce") 48 | env5, panel5 = init_visdom_line(x_testdc, y_testdc, title="Test TP", xlabel="iters", ylabel="Test TP", env="wce") 49 | env7, panel7 = init_visdom_line(x_testdc, y_testdc, title="Test IoU", xlabel="iters", ylabel="Test IoU", env="wce") 50 | 51 | 52 | def save_ckpt(net, iter): 53 | if not os.path.exists(args['ckpt_path']): 54 | os.makedirs(args['ckpt_path']) 55 | date = datetime.datetime.now().strftime("%Y-%m-%d-") 56 | torch.save(net, args['ckpt_path'] + 'CSNet3D_' + date + iter + '.pkl') 57 | print("{} Saved model to:{}".format("\u2714", args['ckpt_path'])) 58 | 59 | 60 | # adjust learning rate (poly) 61 | def adjust_lr(optimizer, base_lr, iter, max_iter, power=0.9): 62 | lr = base_lr * (1 - float(iter) / max_iter) ** power 63 | for param_group in optimizer.param_groups: 64 | param_group['lr'] = lr 65 | 66 | 67 | def train(): 68 | net = CSNet3D(classes=2, channels=1).cuda() 69 | net = nn.DataParallel(net, device_ids=[0, 1]).cuda() 70 | optimizer = optim.Adam(net.parameters(), lr=args['lr'], weight_decay=0.0005) 71 | 72 | # load train dataset 73 | train_data = Data(args['data_path'], train=True) 74 | batchs_data = DataLoader(train_data, batch_size=args['batch_size'], num_workers=4, shuffle=True) 75 | 76 | critrion2 = WeightedCrossEntropyLoss().cuda() 77 | critrion = nn.CrossEntropyLoss().cuda() 78 | critrion3 = DiceLoss().cuda() 79 | # Start training 80 | print("\033[1;30;44m {} Start training ... {}\033[0m".format("*" * 8, "*" * 8)) 81 | 82 | iters = 1 83 | for epoch in range(args['epochs']): 84 | net.train() 85 | for idx, batch in enumerate(batchs_data): 86 | image = batch[0].cuda() 87 | label = batch[1].cuda() 88 | optimizer.zero_grad() 89 | pred = net(image) 90 | loss_dice = critrion3(pred, label) 91 | label = label.squeeze(1) 92 | loss_ce = critrion(pred, label) 93 | loss_wce = critrion2(pred, label) 94 | loss = (loss_ce + 0.6 * loss_wce + 0.4 * loss_dice) / 3 95 | loss.backward() 96 | optimizer.step() 97 | tp, fn, fp, iou = metrics3d(pred, label, pred.shape[0]) 98 | if (epoch % 2) == 0: 99 | print( 100 | '\033[1;36m [{0:d}:{1:d}] \u2501\u2501\u2501 loss:{2:.10f}\tTP:{3:.4f}\tFN:{4:.4f}\tFP:{5:.4f}\tIoU:{6:.4f} '.format( 101 | epoch + 1, iters, loss.item(), tp / pred.shape[0], fn / pred.shape[0], fp / pred.shape[0], 102 | iou / pred.shape[0])) 103 | else: 104 | print( 105 | '\033[1;32m [{0:d}:{1:d}] \u2501\u2501\u2501 loss:{2:.10f}\tTP:{3:.4f}\tFN:{4:.4f}\tFP:{5:.4f}\tIoU:{6:.4f} '.format( 106 | epoch + 1, iters, loss.item(), tp / pred.shape[0], fn / pred.shape[0], fp / pred.shape[0], 107 | iou / pred.shape[0])) 108 | 109 | iters += 1 110 | # # ---------------------------------- visdom -------------------------------------------------- 111 | X, x_tp, x_fn, x_fp, x_dc = iters, iters, iters, iters, iters 112 | Y, y_tp, y_fn, y_fp, y_dc = loss.item(), tp / pred.shape[0], fn / pred.shape[0], fp / pred.shape[0], iou / \ 113 | pred.shape[0] 114 | 115 | update_lines(env, panel, X, Y) 116 | update_lines(env1, panel1, x_tp, y_tp) 117 | update_lines(env2, panel2, x_fn, y_fn) 118 | update_lines(env3, panel3, x_fp, y_fp) 119 | update_lines(env6, panel6, x_dc, y_dc) 120 | 121 | # # -------------------------------------------------------------------------------------------- 122 | 123 | adjust_lr(optimizer, base_lr=args['lr'], iter=epoch, max_iter=args['epochs'], power=0.9) 124 | 125 | if (epoch + 1) % args['snapshot'] == 0: 126 | save_ckpt(net, str(epoch + 1)) 127 | 128 | # model eval 129 | if (epoch + 1) % args['test_step'] == 0: 130 | test_tp, test_fn, test_fp, test_dc = model_eval(net, critrion, iters) 131 | print("Average TP:{0:.4f}, average FN:{1:.4f}, average FP:{2:.4f}".format(test_tp, test_fn, test_fp)) 132 | 133 | 134 | def model_eval(net, critrion, iters): 135 | print("\033[1;30;43m {} Start training ... {}\033[0m".format("*" * 8, "*" * 8)) 136 | test_data = Data(args['data_path'], train=False) 137 | batchs_data = DataLoader(test_data, batch_size=1) 138 | 139 | net.eval() 140 | TP, FN, FP, IoU = [], [], [], [] 141 | file_num = 0 142 | with torch.no_grad(): 143 | for idx, batch in enumerate(batchs_data): 144 | image = batch[0].float().cuda() 145 | label = batch[1].cuda() 146 | pred_val = net(image) 147 | label = label.squeeze(1) 148 | loss = critrion(pred_val, label) 149 | tp, fn, fp, iou = metrics3d(pred_val, label, pred_val.shape[0]) 150 | print( 151 | "--- test TP:{0:.4f} test FN:{1:.4f} test FP:{2:.4f} test IoU:{3:.4f}".format(tp, fn, fp, iou)) 152 | TP.append(tp) 153 | FN.append(fn) 154 | FP.append(fp) 155 | IoU.append(iou) 156 | file_num += 1 157 | # # start visdom images 158 | X, x_testtp, x_testdc = iters, iters, iters 159 | Y, y_testtp, y_testdc = loss.item(), tp / pred_val.shape[0], iou / pred_val.shape[0] 160 | update_lines(env4, panel4, X, Y) 161 | update_lines(env5, panel5, x_testtp, y_testtp) 162 | update_lines(env7, panel7, x_testdc, y_testdc) 163 | return np.mean(TP), np.mean(FN), np.mean(FP), np.mean(IoU) 164 | 165 | 166 | if __name__ == '__main__': 167 | train() 168 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMED-Lab/CS-Net/25079c377f8db4b57f25c0adc7b70d1a02a3ee62/utils/__init__.py -------------------------------------------------------------------------------- /utils/dice_loss_single_class.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function, Variable 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | 8 | class DiceCoeff(Function): 9 | """Dice coeff for individual examples""" 10 | 11 | def forward(self, input, target): 12 | # target = _make_one_hot(target, 2) 13 | self.save_for_backward(input, target) 14 | eps = 0.0001 15 | # dot是返回两个矩阵的点集 16 | # inter,uniun:两个值的大小分别是10506.6,164867.2 17 | self.inter = torch.dot(input.view(-1), target.view(-1)) 18 | self.union = torch.sum(input) + torch.sum(target) + eps 19 | # print("inter,uniun:",self.inter,self.union) 20 | 21 | t = (2 * self.inter.float() + eps) / self.union.float() 22 | return t 23 | 24 | # This function has only a single output, so it gets only one gradient 25 | def backward(self, grad_output): 26 | 27 | input, target = self.saved_variables 28 | grad_input = grad_target = None 29 | 30 | if self.needs_input_grad[0]: 31 | grad_input = grad_output * 2 * (target * self.union - self.inter) \ 32 | / (self.union * self.union) 33 | if self.needs_input_grad[1]: 34 | grad_target = None 35 | 36 | # 这里没有打印出来,难道没有执行到这里吗 37 | # print("grad_input, grad_target:",grad_input, grad_target) 38 | 39 | return grad_input, grad_target 40 | 41 | 42 | def dice_coeff(input, target): 43 | """Dice coeff for batches""" 44 | if input.is_cuda: 45 | s = torch.FloatTensor(1).cuda().zero_() 46 | else: 47 | s = torch.FloatTensor(1).zero_() 48 | 49 | # print("size of input, target:", input.shape, target.shape) 50 | 51 | for i, c in enumerate(zip(input, target)): 52 | # c[0],c[1]的大小都是原图大小torch.Size([1, 576, 544]) 53 | # print("size of c0 c1:", c[0].shape,c[1].shape) 54 | s = s + DiceCoeff().forward(c[0], c[1]) 55 | 56 | return s / (i + 1) 57 | 58 | 59 | def dice_coeff_loss(input, target): 60 | return 1 - dice_coeff(input, target) 61 | -------------------------------------------------------------------------------- /utils/evaluation_metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluation metrics 3 | """ 4 | 5 | import numpy as np 6 | import sklearn.metrics as metrics 7 | import os 8 | import glob 9 | import cv2 10 | from PIL import Image 11 | 12 | 13 | def numeric_score(pred, gt): 14 | FP = np.float(np.sum((pred == 1) & (gt == 0))) 15 | FN = np.float(np.sum((pred == 0) & (gt == 1))) 16 | TP = np.float(np.sum((pred == 1) & (gt == 1))) 17 | TN = np.float(np.sum((pred == 0) & (gt == 0))) 18 | return FP, FN, TP, TN 19 | 20 | 21 | def numeric_score_fov(pred, gt, mask): 22 | FP = np.float(np.sum((pred == 1) & (gt == 0) & (mask == 1))) 23 | FN = np.float(np.sum((pred == 0) & (gt == 1) & (mask == 1))) 24 | TP = np.float(np.sum((pred == 1) & (gt == 1) & (mask == 1))) 25 | TN = np.float(np.sum((pred == 0) & (gt == 0) & (mask == 1))) 26 | return FP, FN, TP, TN 27 | 28 | 29 | def AUC(path): 30 | all_auc = 0. 31 | file_num = 0 32 | for file in glob.glob(os.path.join(path, 'pred', '*pred.png')): 33 | base_name = os.path.basename(file) 34 | label_name = base_name[:-9] + '.png' 35 | label_path = os.path.join(path, 'label', label_name) 36 | 37 | mask_path = '/path/to/FOV/mask/' 38 | 39 | pred_image = cv2.imread(file, flags=-1) 40 | label = cv2.imread(label_path, flags=-1) 41 | mask = cv2.imread(mask_path, flags=-1) 42 | 43 | # with FOV 44 | label_fov = [] 45 | pred_fov = [] 46 | w, h = pred_image.shape 47 | for i in range(w): 48 | for j in range(h): 49 | if mask[i, j] == 255: 50 | label_fov.append(label[i, j]) 51 | pred_fov.append(pred_image[i, j]) 52 | pred_image = (np.asarray(pred_fov)) / 255 53 | label = np.uint8((np.asarray(label_fov)) / 255) 54 | 55 | # pred_image = pred_image.flatten() / 255 56 | # label = np.uint8(label.flatten() / 255) 57 | 58 | auc_score = metrics.roc_auc_score(label, pred_image) 59 | all_auc += auc_score 60 | file_num += 1 61 | avg_auc = all_auc / file_num 62 | return avg_auc 63 | 64 | 65 | def DSC(path): 66 | all_dsc = 0. 67 | file_num = 0 68 | for file in glob.glob(os.path.join(path, 'pred', '*otsu.png')): 69 | base_name = os.path.basename(file) 70 | label_name = base_name[:-14] + '.png' 71 | label_path = os.path.join(path, 'label', label_name) 72 | 73 | pred = cv2.imread(file, flags=-1) 74 | label = cv2.imread(label_path, flags=-1) 75 | 76 | pred = pred // 255 77 | label = label // 255 78 | 79 | FP, FN, TP, TN = numeric_score(pred, label) 80 | dsc = 2 * TP / (FP + 2 * TP + FN) 81 | all_dsc += dsc 82 | file_num += 1 83 | avg_dsc = all_dsc / file_num 84 | return avg_dsc 85 | 86 | 87 | def AccSenSpe(path): 88 | all_sen = [] 89 | all_acc = [] 90 | all_spe = [] 91 | for file in glob.glob(os.path.join(path, 'pred', '*otsu.png')): 92 | base_name = os.path.basename(file) 93 | label_name = base_name[:-14] + '.png' 94 | label_path = os.path.join(path, 'label', label_name) 95 | 96 | mask_path = '/path/to/FOV/mask/' 97 | 98 | pred = cv2.imread(file, flags=-1) 99 | label = cv2.imread(label_path, flags=-1) 100 | mask = cv2.imread(mask_path, flags=-1) 101 | 102 | pred = pred // 255 103 | label = label // 255 104 | mask = mask // 255 105 | 106 | FP, FN, TP, TN = numeric_score(pred, label) 107 | acc = (TP + TN) / (TP + FP + TN + FN) 108 | sen = TP / (TP + FN) 109 | spe = TN / (TN + FP) 110 | all_acc.append(acc) 111 | all_sen.append(sen) 112 | all_spe.append(spe) 113 | avg_acc, avg_sen, avg_spe = np.mean(all_acc), np.mean(all_sen), np.mean(all_spe) 114 | var_acc, var_sen, var_spe = np.var(all_acc), np.var(all_sen), np.var(all_spe) 115 | return avg_acc, var_acc, avg_sen, var_sen, avg_spe, var_spe 116 | 117 | 118 | def FDR(path): 119 | all_fdr = [] 120 | for file in glob.glob(os.path.join(path, 'pred', '*otsu.png')): 121 | base_name = os.path.basename(file) 122 | label_name = base_name[:-14] + '.png' 123 | label_path = os.path.join(path, 'label', label_name) 124 | 125 | pred = cv2.imread(file, flags=-1) 126 | label = cv2.imread(label_path, flags=-1) 127 | 128 | pred = pred // 255 129 | label = label // 255 130 | 131 | FP, FN, TP, TN = numeric_score(pred, label) 132 | fdr = FP / (FP + TP) 133 | all_fdr.append(fdr) 134 | return np.mean(all_fdr), np.var(all_fdr) 135 | 136 | 137 | if __name__ == '__main__': 138 | # predicted root path 139 | path = './assets/Padova1/' 140 | # auc = AUC(path) 141 | acc, var_acc, sen, var_sen, spe, var_spe = AccSenSpe(path) 142 | fdr, var_fdr = FDR(path) 143 | print("sen:{0:.4f} +- {1:.4f}".format(sen, var_sen)) 144 | print("fdr:{0:.4f} +- {1:.4f}".format(fdr, var_fdr)) 145 | # print("acc:{0:.4f}".format(acc)) 146 | # print("sen:{0:.4f}".format(sen)) 147 | # print("spe:{0:.4f}".format(spe)) 148 | # print("auc:{0:.4f}".format(auc)) 149 | -------------------------------------------------------------------------------- /utils/evaluation_metrics3D.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # ╔═════════════════════════════════════════════════════════════════════════════════════════════════════════════╗ 4 | # ║ ║ 5 | # ║ __ __ ____ __ ║ 6 | # ║ /\ \/\ \ /\ _`\ /\ \ __ ║ 7 | # ║ \ \ \_\ \ __ _____ _____ __ __ \ \ \/\_\ ___ \_\ \/\_\ ___ __ ║ 8 | # ║ \ \ _ \ /'__`\ /\ '__`\/\ '__`\/\ \/\ \ \ \ \/_/_ / __`\ /'_` \/\ \ /' _ `\ /'_ `\ ║ 9 | # ║ \ \ \ \ \/\ \L\.\_\ \ \L\ \ \ \L\ \ \ \_\ \ \ \ \L\ \/\ \L\ \/\ \L\ \ \ \/\ \/\ \/\ \L\ \ ║ 10 | # ║ \ \_\ \_\ \__/.\_\\ \ ,__/\ \ ,__/\/`____ \ \ \____/\ \____/\ \___,_\ \_\ \_\ \_\ \____ \ ║ 11 | # ║ \/_/\/_/\/__/\/_/ \ \ \/ \ \ \/ `/___/> \ \/___/ \/___/ \/__,_ /\/_/\/_/\/_/\/___L\ \ ║ 12 | # ║ \ \_\ \ \_\ /\___/ /\____/ ║ 13 | # ║ \/_/ \/_/ \/__/ \_/__/ ║ 14 | # ║ ║ 15 | # ║ 49 4C 6F 76 65 59 6F 75 2C 42 75 74 59 6F 75 4B 6E 6F 77 4E 6F 74 68 69 6E 67 2E ║ 16 | # ║ ║ 17 | # ╚═════════════════════════════════════════════════════════════════════════════════════════════════════════════╝ 18 | # @Author : Lei Mou 19 | # @File : evaluation_metrics3D.py 20 | import numpy as np 21 | import SimpleITK as sitk 22 | import glob 23 | import os 24 | from scipy.spatial import distance 25 | from sklearn.metrics import f1_score 26 | 27 | 28 | def numeric_score(pred, gt): 29 | FP = np.float(np.sum((pred == 255) & (gt == 0))) 30 | FN = np.float(np.sum((pred == 0) & (gt == 255))) 31 | TP = np.float(np.sum((pred == 255) & (gt == 255))) 32 | TN = np.float(np.sum((pred == 0) & (gt == 0))) 33 | return FP, FN, TP, TN 34 | 35 | 36 | def Dice(pred, gt): 37 | pred = np.int64(pred / 255) 38 | gt = np.int64(gt / 255) 39 | dice = np.sum(pred[gt == 1]) * 2.0 / (np.sum(pred) + np.sum(gt)) 40 | return dice 41 | 42 | 43 | def IoU(pred, gt): 44 | pred = np.int64(pred / 255) 45 | gt = np.int64(gt / 255) 46 | m1 = np.sum(pred[gt == 1]) 47 | m2 = np.sum(pred == 1) + np.sum(gt == 1) - m1 48 | iou = m1 / m2 49 | return iou 50 | 51 | 52 | def metrics_3d(pred, gt): 53 | FP, FN, TP, TN = numeric_score(pred, gt) 54 | tpr = TP / (TP + FN + 1e-10) 55 | fnr = FN / (FN + TP + 1e-10) 56 | fpr = FN / (FP + TN + 1e-10) 57 | iou = TP / (TP + FN + FP + 1e-10) 58 | return tpr, fnr, fpr, iou 59 | 60 | 61 | def over_rate(pred, gt): 62 | # pred = np.int64(pred / 255) 63 | # gt = np.int64(gt / 255) 64 | Rs = np.float(np.sum(gt == 255)) 65 | Os = np.float(np.sum((pred == 255) & (gt == 0))) 66 | OR = Os / (Rs + Os) 67 | return OR 68 | 69 | 70 | def under_rate(pred, gt): 71 | # pred = np.int64(pred / 255) 72 | # gt = np.int64(gt / 255) 73 | Rs = np.float(np.sum(gt == 255)) 74 | Us = np.float(np.sum((pred == 0) & (gt == 255))) 75 | Os = np.float(np.sum((pred == 255) & (gt == 0))) 76 | UR = Us / (Rs + Os) 77 | return UR 78 | -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn as nn 6 | from torch.autograd import Variable, Function 7 | from torch.nn import MSELoss, SmoothL1Loss, L1Loss 8 | import numpy as np 9 | 10 | 11 | def make_one_hot(input, num_classes): 12 | """Convert class index tensor to one hot encoding tensor. 13 | Args: 14 | input: A tensor of shape [N, 1, *] 15 | num_classes: An int of number of class 16 | Returns: 17 | A tensor of shape [N, num_classes, *] 18 | """ 19 | shape = np.array(input.shape) 20 | shape[1] = num_classes 21 | shape = tuple(shape) 22 | result = torch.zeros(shape) 23 | result = result.scatter_(1, input.cpu(), 1) 24 | 25 | return result 26 | 27 | 28 | class BinaryDiceLoss(nn.Module): 29 | """Dice loss of binary class 30 | Args: 31 | smooth: A float number to smooth loss, and avoid NaN error, default: 1 32 | p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2 33 | predict: A tensor of shape [N, *] 34 | target: A tensor of shape same with predict 35 | Returns: 36 | Loss tensor according to arg reduction 37 | Raise: 38 | Exception if unexpected reduction 39 | """ 40 | 41 | def __init__(self, smooth=1, p=2): 42 | super(BinaryDiceLoss, self).__init__() 43 | self.smooth = smooth 44 | self.p = p 45 | 46 | def forward(self, predict, target): 47 | assert predict.shape[0] == target.shape[0], "predict & target batch size don't match" 48 | predict = predict.contiguous().view(predict.shape[0], -1) 49 | target = target.contiguous().view(target.shape[0], -1) 50 | 51 | num = torch.sum(torch.mul(predict, target)) * 2 + self.smooth 52 | den = torch.sum(predict.pow(self.p) + target.pow(self.p)) + self.smooth 53 | 54 | dice = num / den 55 | loss = 1 - dice 56 | return loss 57 | 58 | 59 | class DiceLoss(nn.Module): 60 | """Dice loss, need one hot encode input 61 | Args: 62 | weight: An array of shape [num_classes,] 63 | ignore_index: class index to ignore 64 | predict: A tensor of shape [N, C, *] 65 | target: A tensor of same shape with predict 66 | other args pass to BinaryDiceLoss 67 | Return: 68 | same as BinaryDiceLoss 69 | """ 70 | 71 | def __init__(self, weight=None, ignore_index=None, **kwargs): 72 | super(DiceLoss, self).__init__() 73 | self.kwargs = kwargs 74 | self.weight = weight 75 | self.ignore_index = ignore_index 76 | 77 | def forward(self, predict, target): 78 | target = make_one_hot(target, num_classes=predict.shape[1]) 79 | target = target.cuda() 80 | assert predict.shape == target.shape, 'predict & target shape do not match' 81 | dice = BinaryDiceLoss(**self.kwargs) 82 | total_loss = 0 83 | predict = F.softmax(predict, dim=1) 84 | 85 | for i in range(target.shape[1]): 86 | if i != self.ignore_index: 87 | dice_loss = dice(predict[:, i], target[:, i]) 88 | if self.weight is not None: 89 | assert self.weight.shape[0] == target.shape[1], \ 90 | 'Expect weight shape [{}], get[{}]'.format(target.shape[1], self.weight.shape[0]) 91 | dice_loss *= self.weights[i] 92 | total_loss += dice_loss 93 | 94 | return total_loss / target.shape[1] 95 | 96 | 97 | # --------------------------------------------------------------------------------------------------------- 98 | 99 | 100 | def flatten(tensor): 101 | """Flattens a given tensor such that the channel axis is first. 102 | The shapes are transformed as follows: 103 | (N, C, D, H, W) -> (C, N * D * H * W) 104 | """ 105 | C = tensor.size(1) 106 | # new axis order 107 | axis_order = (1, 0) + tuple(range(2, tensor.dim())) 108 | # Transpose: (N, C, D, H, W) -> (C, N, D, H, W) 109 | transposed = tensor.permute(axis_order) 110 | # Flatten: (C, N, D, H, W) -> (C, N * D * H * W) 111 | return transposed.contiguous().view(C, -1) 112 | 113 | 114 | class WeightedCrossEntropyLoss(nn.Module): 115 | """WeightedCrossEntropyLoss (WCE) as described in https://arxiv.org/pdf/1707.03237.pdf 116 | """ 117 | 118 | def __init__(self, weight=None, ignore_index=-1): 119 | super(WeightedCrossEntropyLoss, self).__init__() 120 | self.register_buffer('weight', weight) 121 | self.ignore_index = ignore_index 122 | 123 | def forward(self, input, target): 124 | class_weights = self._class_weights(input) 125 | if self.weight is not None: 126 | weight = Variable(self.weight, requires_grad=False) 127 | class_weights = class_weights * weight 128 | return F.cross_entropy(input, target, weight=class_weights, ignore_index=self.ignore_index) 129 | 130 | @staticmethod 131 | def _class_weights(input): 132 | # normalize the input first 133 | input = F.softmax(input) 134 | flattened = flatten(input) 135 | nominator = (1. - flattened).sum(-1) 136 | denominator = flattened.sum(-1) 137 | class_weights = Variable(nominator / denominator, requires_grad=False) 138 | return class_weights 139 | 140 | # --------------------------------------------------------------------------------------------- 141 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import glob 4 | import cv2 5 | import torch.nn as nn 6 | import torch 7 | from PIL import ImageOps, Image 8 | from sklearn.metrics import confusion_matrix 9 | import SimpleITK as sitk 10 | import tqdm 11 | import vtk 12 | 13 | 14 | def ReScaleSize(image, re_size=512): 15 | w, h = image.size 16 | max_len = max(w, h) 17 | new_w, new_h = max_len, max_len 18 | delta_w = new_w - w 19 | delta_h = new_h - h 20 | padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2)) 21 | image = ImageOps.expand(image, padding, fill=0) 22 | # origin_w, origin_h = w, h 23 | image = image.resize((re_size, re_size)) 24 | return image # , origin_w, origin_h 25 | 26 | 27 | def Crop(image): 28 | left = 261 29 | top = 1 30 | right = 1110 31 | bottom = 850 32 | image = image.crop((left, top, right, bottom)) 33 | return image 34 | 35 | 36 | def thresh_OTSU(path): 37 | for file in glob.glob(os.path.join(path, '*pred.png')): 38 | index = os.path.basename(file)[:-4] 39 | image = cv2.imread(file) 40 | gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) 41 | thresh, img = cv2.threshold(gray, 0, 255, cv2.THRESH_OTSU) 42 | cv2.imwrite(os.path.join(path, index + '_otsu.png'), img) 43 | #cv2.imwrite(file, img) 44 | print(file, '\tdone!') -------------------------------------------------------------------------------- /utils/train_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch.nn as nn 4 | import torch 5 | from PIL import ImageOps, Image 6 | from sklearn.metrics import confusion_matrix 7 | from skimage import filters 8 | 9 | from utils.evaluation_metrics3D import metrics_3d, Dice 10 | 11 | 12 | def threshold(image): 13 | # t = filters.threshold_otsu(image, nbins=256) 14 | image[image >= 100] = 255 15 | image[image < 100] = 0 16 | return image 17 | 18 | 19 | def numeric_score(pred, gt): 20 | FP = np.float(np.sum((pred == 255) & (gt == 0))) 21 | FN = np.float(np.sum((pred == 0) & (gt == 255))) 22 | TP = np.float(np.sum((pred == 255) & (gt == 255))) 23 | TN = np.float(np.sum((pred == 0) & (gt == 0))) 24 | return FP, FN, TP, TN 25 | 26 | 27 | def metrics(pred, label, batch_size): 28 | # pred = torch.argmax(pred, dim=1) # for CE Loss series 29 | outputs = (pred.data.cpu().numpy() * 255).astype(np.uint8) 30 | labels = (label.data.cpu().numpy() * 255).astype(np.uint8) 31 | outputs = outputs.squeeze(1) # for MSELoss() 32 | labels = labels.squeeze(1) # for MSELoss() 33 | outputs = threshold(outputs) # for MSELoss() 34 | 35 | Acc, SEn = 0., 0. 36 | for i in range(batch_size): 37 | img = outputs[i, :, :] 38 | gt = labels[i, :, :] 39 | acc, sen = get_acc(img, gt) 40 | Acc += acc 41 | SEn += sen 42 | return Acc, SEn 43 | 44 | 45 | def metrics3dmse(pred, label, batch_size): 46 | outputs = (pred.data.cpu().numpy() * 255).astype(np.uint8) 47 | labels = (label.data.cpu().numpy() * 255).astype(np.uint8) 48 | outputs = outputs.squeeze(1) # for MSELoss() 49 | labels = labels.squeeze(1) # for MSELoss() 50 | outputs = threshold(outputs) # for MSELoss() 51 | 52 | tp, fn, fp, IoU = 0, 0, 0, 0 53 | for i in range(batch_size): 54 | img = outputs[i, :, :, :] 55 | gt = labels[i, :, :, :] 56 | tpr, fnr, fpr, iou = metrics_3d(img, gt) 57 | # dcr = Dice(img, gt) 58 | tp += tpr 59 | fn += fnr 60 | fp += fpr 61 | IoU += iou 62 | return tp, fn, fp, IoU 63 | 64 | 65 | def metrics3d(pred, label, batch_size): 66 | pred = torch.argmax(pred, dim=1) # for CE loss series 67 | outputs = (pred.data.cpu().numpy() * 255).astype(np.uint8) 68 | labels = (label.data.cpu().numpy() * 255).astype(np.uint8) 69 | # outputs = outputs.squeeze(1) # for MSELoss() 70 | # labels = labels.squeeze(1) # for MSELoss() 71 | # outputs = threshold(outputs) # for MSELoss() 72 | 73 | tp, fn, fp, IoU = 0, 0, 0, 0 74 | for i in range(batch_size): 75 | img = outputs[i, :, :, :] 76 | gt = labels[i, :, :, :] 77 | tpr, fnr, fpr, iou = metrics_3d(img, gt) 78 | # dcr = Dice(img, gt) 79 | tp += tpr 80 | fn += fnr 81 | fp += fpr 82 | IoU += iou 83 | return tp, fn, fp, IoU 84 | 85 | 86 | def get_acc(image, label): 87 | image = threshold(image) 88 | 89 | FP, FN, TP, TN = numeric_score(image, label) 90 | acc = (TP + TN) / (TP + FN + TN + FP + 1e-10) 91 | sen = (TP) / (TP + FN + 1e-10) 92 | return acc, sen 93 | -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import visdom 3 | 4 | 5 | def init_visdom_line(x, y, title, xlabel, ylabel, env="default"): 6 | env = visdom.Visdom(env=env) 7 | panel = env.line( 8 | X=np.array([x]), 9 | Y=np.array([y]), 10 | opts=dict(title=title, showlegend=True, xlabel=xlabel, ylabel=ylabel) 11 | ) 12 | return env, panel 13 | 14 | 15 | def update_lines(env, panel, x, y, update_type='append'): 16 | env.line( 17 | X=np.array([x]), 18 | Y=np.array([y]), 19 | win=panel, 20 | update=update_type 21 | ) 22 | --------------------------------------------------------------------------------