├── README.md ├── data_preprocess ├── __init__.py ├── data_augmentation.py └── pre_process.py ├── datasets ├── BettiMatching.py ├── ER_dataset.py ├── MITO_dataset.py ├── NUCLEUS_dataset.py ├── ROAD_dataset.py ├── ROSE_dataset.py ├── STARE_dataset.py ├── dataset.py └── metric.py ├── evaluation.py ├── figures ├── FFM.png └── MDNet.png ├── fractal_analysis.py ├── get_edge_skeleton.py ├── inference.py ├── models ├── hrnet.py ├── md_net.py ├── optimize.py ├── unet.py └── utils.py ├── requirements.txt ├── train_hrnet.py ├── train_mdnet.py ├── train_mdnet_weighted.py ├── train_tdnet.py ├── train_tdnet_weighted.py └── train_unet.py /README.md: -------------------------------------------------------------------------------- 1 | # FFM-Multi-Decoder-Network 2 | The official code repository for the ECCV 2024 accepted paper "Representing Topological Self-Similarity Using Fractal Feature Maps for Accurate Segmentation of Tubular Structures". 3 | 4 | # Representing Topological Self-Similarity Using Fractal Feature Maps for Accurate Segmentation of Tubular Structures 5 | 6 | ## Overview 7 | Workflow of computing FFM of an image. 8 | ![FFM](figures/FFM.png) 9 | Overview and details of our proposed model multi-decoder network (MD-Net). 10 | ![MDNet](figures/MDNet.png) 11 | 12 | 13 | ## Preparing Dataset 14 | ### Files of training and testing 15 | Please generate ".txt" files for train data and test data separately. In the ".txt" file, each line consists of the path of image and mask corresponding to the image. 16 | 17 | For example: 18 | 19 | ``` 20 | /datasets/directory/train_mito.txt 21 | └── .../MITO/train/images/20_h384_w384.tif .../MITO/train/masks/20_h384_w384.tif 22 | /datasets/directory/test_mito.txt 23 | └── .../MITO/test/images/20_h384_w384.tif .../MITO/test/masks/20_h384_w384.tif 24 | ``` 25 | 26 | ### Computation of FFMs 27 | 28 | Please utilize the functions in [fractak_analysis.py](./fractal_analysis.py) to compute FFMs of images. 29 | 30 | The first function compute_FFM is used to compute a Fractal Feature Map for an image. 31 | ``` 32 | Input: 33 | image: A 2D array containing a grayscale image;; 34 | window_size: the size of sliding window; 35 | step_size: the size of sliding step; 36 | Output: 37 | FFM: the fractal feature map of image. 38 | 39 | def compute_FFM(image, step_size, window_size): 40 | ``` 41 | The second function compute_FMM_Pool is used to compute and save Fractal Feature Maps for a list of images. 42 | ``` 43 | Input: 44 | file_path: the root path of images; 45 | window_size: the size of sliding window; 46 | step_size: the size of sliding step; 47 | 48 | def compute_FMM_Pool(file_path, window_size, step_size): 49 | ``` 50 | 51 | ### Extraction of edge and skeleton 52 | 53 | To extract the edges and skeletons of segmentation objects, please utilize the functions in [get_edge_skeleton.py](./get_edge_skeleton.py). 54 | ``` 55 | def edge_extract(root): 56 | 57 | def skeleton_extract(root): 58 | ``` 59 | Provide the file path of labels to these functions, you will get the edges and skeletons of labels. 60 | ## Setup 61 | 62 | Setting up for this project. 63 | 64 | ### Installing dependencies 65 | 66 | To install all the dependencies, please run the following: 67 | 68 | ``` 69 | pip install -r requirements.txt or conda install --yes --file requirements.txt 70 | ``` 71 | 72 | ## Running 73 | 74 | ### Training 75 | 76 | In this project, we used models U-Net, HR-Net and MD-Net in the experimental phase. To facilitate the training of the different models, we created five different training files. 77 | 78 | Below lines will run the training code with default setting in the file. The value of warmup_step can be adjusted according to the size of the dataset. 79 | 80 | ``` 81 | python train_unet.py 82 | python train_hrnet.py 83 | python train_mdnet.py 84 | python train_mdnet_weighted.py 85 | python train_tdet.py 86 | python train_tdnet_weighted.py 87 | ``` 88 | 89 | Before training, you need to assign the file path of FFMs, edges and skeletons in the dataset files. 90 | 91 | For example, in [ER_dataset.py](./datasets/ER_dataset.py): 92 | ``` 93 | # the path of FFM_image 94 | npy_path = img_path.replace("images", self.fractal_dir) 95 | # the path of FFM_label 96 | weight_path = mask_path.replace('masks', 'masks' + self.weight_dir) 97 | 98 | edge_path = mask_path.replace('masks', self.edge_dir) 99 | skeleton_path = mask_path.replace('masks', 'self.skeleton_dir) 100 | ``` 101 | ### Inference and Evaluation 102 | 103 | In order to obtain segmentation results and evaluate model's performance under different thresholds, you can run the following lines: 104 | 105 | ``` 106 | Set up 107 | model_choice = ['unet', 'hrnet', 'Two_decoder_Net', 'Multi_decoder_Net'] 108 | dataset_list = ['er', 'er_fractal', 'er_fractal_two_decoder', 'nucleus_fractal_two_decoder','nucleus_fractal_two_decoder_weighted'] 109 | txt_choice = ['test_mito.txt', 'test_er.txt', 'test_stare.txt'] 110 | Run 111 | python inference.py 112 | python evaluation.py 113 | ``` 114 | -------------------------------------------------------------------------------- /data_preprocess/__init__.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | 5 | def generate_train_txt(): 6 | root_path = '/data/ldap_shared/home/' 7 | txt_save_path = os.path.join('../datasets', "train_mito.txt") 8 | file_path = '../mito_dataset/train/' 9 | masks_dir = file_path + 'masks_aug/' 10 | images_dir = file_path + 'images_aug/' 11 | img_list = glob.glob(os.path.join(images_dir, "*.tif")) 12 | img_list.sort() 13 | with open(txt_save_path, 'w') as f: 14 | for i, p in enumerate(img_list): 15 | img_name = os.path.split(p)[-1] 16 | print("==> Process image: %s." % (img_name)) 17 | f.writelines(root_path + images_dir[2:] + img_name + " " + root_path + masks_dir[2:] + img_name + "\n") 18 | 19 | 20 | def generate_test_txt(): 21 | root_path = '/data/ldap_shared/home/' 22 | txt_save_path = os.path.join('../datasets', "test_mito.txt") 23 | file_path = '../mito_dataset/test/' 24 | masks_dir = file_path + 'masks/' 25 | images_dir = file_path + 'images/' 26 | 27 | img_list = glob.glob(os.path.join(images_dir, "*.tif")) 28 | img_list.sort() 29 | with open(txt_save_path, 'w') as f: 30 | for i, p in enumerate(img_list): 31 | img_name = os.path.split(p)[-1] 32 | print("==> Process image: %s." % (img_name)) 33 | f.writelines( 34 | root_path + images_dir[2:] + img_name + " " + root_path + masks_dir[2:] + img_name + "\n") 35 | 36 | 37 | if __name__ == '__main__': 38 | generate_train_txt() 39 | -------------------------------------------------------------------------------- /data_preprocess/data_augmentation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import cv2 3 | import numpy as np 4 | from random import randint, random, randrange 5 | import glob 6 | import os 7 | 8 | 9 | def random_flip(image, mask): 10 | flip_seed = randint(-1,2) 11 | print(flip_seed) 12 | if flip_seed != 2: 13 | image = cv2.flip(image, flip_seed) 14 | mask = cv2.flip(mask, flip_seed) 15 | return image, mask 16 | 17 | def flip(image, mask): 18 | flip_seed = 1 19 | 20 | image = cv2.flip(image, flip_seed) 21 | mask = cv2.flip(mask, flip_seed) 22 | return image, mask 23 | 24 | 25 | def rotation_image(image, mask, angle=0, scale=1): 26 | height, width = image.shape 27 | M = cv2.getRotationMatrix2D((height/2,width/2), angle, scale) 28 | out_image = cv2.warpAffine(image, M, (height,width)) 29 | out_mask = cv2.warpAffine(mask, M, (height,width), flags=cv2.INTER_NEAREST, borderValue=(255,255,255)) 30 | return out_image, out_mask 31 | 32 | 33 | def random_rotation_scale(image, mask, angle_min=-45, angle_max=45, scale_var=True): 34 | if scale_var: 35 | scale = randint(80, 120) / 100. # scale from 0.8-1.2 36 | else: 37 | scale = 1 38 | if angle_min < angle_max: 39 | angle = randint(angle_min, angle_max) # rotation angle from [-45, 45] 40 | out_image, out_mask = rotation_image(image, mask, angle=angle, scale=scale) 41 | return out_image, out_mask 42 | 43 | 44 | def random_shift(image, mask, shift_range=20): 45 | rand_seed_1 = -1 if random() < 0.5 else 1 46 | trans_x = rand_seed_1 * random() * shift_range # shift direction, shift percentage based on shift range 47 | rand_seed_2 = -1 if random() < 0.5 else 1 48 | trans_y = rand_seed_2 * random() * shift_range 49 | 50 | M = np.array([[1,0,trans_x], [0,1,trans_y]], dtype=np.float32) 51 | out_image = cv2.warpAffine(image, M, image.shape) 52 | out_mask = cv2.warpAffine(mask, M, mask.shape, flags=cv2.INTER_NEAREST, borderValue=(255,255,255)) 53 | return out_image, out_mask 54 | 55 | 56 | def random_shear(image, mask, shear_range=0.2): 57 | rand_seed_1 = -1 if random() < 0.5 else 1 58 | shear_factor = rand_seed_1 * random() * shear_range 59 | 60 | w, h = image.shape[1], image.shape[0] 61 | if shear_factor < 0: 62 | image, mask = cv2.flip(image, 1), cv2.flip(mask, 1) 63 | M = np.array([[1, abs(shear_factor), 0], 64 | [0, 1, 0]]) 65 | nW = w + abs(shear_factor * h) 66 | image = cv2.warpAffine(image, M, (int(nW), h)) 67 | mask = cv2.warpAffine(mask, M, (int(nW), h), flags=cv2.INTER_NEAREST, borderValue=(255,255,255)) 68 | if shear_factor < 0: 69 | image, mask = cv2.flip(image, 1), cv2.flip(mask, 1) 70 | image_out, mask_out = cv2.resize(image, (w,h)), cv2.resize(mask, (w,h), interpolation=cv2.INTER_NEAREST) 71 | return image_out, mask_out 72 | 73 | 74 | def random_contrast(image): 75 | factor = randint(7,10) / 10 76 | mean = np.uint16(np.mean(image) + 0.5) 77 | mean_img = (np.ones(image.shape) * mean).astype(np.uint16) 78 | out_image = image.astype(np.uint16) * factor + mean_img * (1.0 - factor) 79 | if factor < 0 or factor > 1: 80 | out_image = clip_image(out_image.astype(np.float)) 81 | return out_image.astype(np.uint16) 82 | 83 | 84 | def random_brightness(image): 85 | noise_scale = randint(7,13) / 10. 86 | noise_img = image * noise_scale 87 | out_image = clip_image(noise_img) 88 | return out_image 89 | 90 | 91 | def random_noise(image): 92 | noise_seed = randint(0,1) 93 | if noise_seed == 0: 94 | noise_img = cv2.GaussianBlur(image, (5,5), 0) 95 | else: 96 | noise_img = image 97 | return noise_img 98 | 99 | 100 | def clip_image(image): 101 | image[image > 65535.] = 65535 102 | image[image < 0.] = 0 103 | image = image.astype(np.uint16) 104 | return image 105 | 106 | 107 | def convert_mask(mask): 108 | mask[mask >= 127.5] = 255 109 | mask[mask < 127.5] = 0 110 | mask = mask.astype(np.uint8) 111 | return mask 112 | 113 | 114 | if __name__ == '__main__': 115 | file_path = '../nucleus/' 116 | masks_dir = file_path + 'masks_train' 117 | images_dir = file_path + 'images_train' 118 | img_list = glob.glob(os.path.join(images_dir, "*.tif")) 119 | img_list.sort() 120 | 121 | for i, p in enumerate(img_list): 122 | img_name = os.path.split(p)[-1] 123 | 124 | print("==> Process image: %s." % (img_name)) 125 | 126 | now_image = cv2.imread(p, -1) 127 | now_mask = cv2.imread(os.path.join(masks_dir, img_name), -1) 128 | flip_image, flip_mask = flip(now_image, now_mask) 129 | 130 | now_image_90, now_mask_90 = rotation_image(now_image, now_mask, angle=90, scale=1) 131 | now_image_180, now_mask_180 = rotation_image(now_image, now_mask, angle=180, scale=1) 132 | now_image_270, now_mask_270 = rotation_image(now_image, now_mask, angle=270, scale=1) 133 | 134 | flip_image_90, flip_mask_90 = rotation_image(flip_image, flip_mask, angle=90, scale=1) 135 | flip_image_180, flip_mask_180 = rotation_image(flip_image, flip_mask, angle=180, scale=1) 136 | flip_image_270, flip_mask_270 = rotation_image(flip_image, flip_mask, angle=270, scale=1) 137 | 138 | cv2.imwrite(images_dir + '_aug/' + img_name[:-4] + '.tif', now_image) 139 | cv2.imwrite(images_dir + '_aug/' + img_name[:-4] + '_90.tif', now_image_90) 140 | cv2.imwrite(images_dir + '_aug/' + img_name[:-4] + '_180.tif', now_image_180) 141 | cv2.imwrite(images_dir + '_aug/' + img_name[:-4] + '_270.tif', now_image_270) 142 | 143 | cv2.imwrite(images_dir + '_aug/flip_' + img_name[:-4] + '.tif', flip_image) 144 | cv2.imwrite(images_dir + '_aug/flip_' + img_name[:-4] + '_90.tif', flip_image_90) 145 | cv2.imwrite(images_dir + '_aug/flip_' + img_name[:-4] + '_180.tif', flip_image_180) 146 | cv2.imwrite(images_dir + '_aug/flip_' + img_name[:-4] + '_270.tif', flip_image_270) 147 | 148 | cv2.imwrite(masks_dir + '_aug/' + img_name[:-4] + '.tif', now_mask) 149 | cv2.imwrite(masks_dir + '_aug/' + img_name[:-4] + '_90.tif', now_mask_90) 150 | cv2.imwrite(masks_dir + '_aug/' + img_name[:-4] + '_180.tif', now_mask_180) 151 | cv2.imwrite(masks_dir + '_aug/' + img_name[:-4] + '_270.tif', now_mask_270) 152 | 153 | cv2.imwrite(masks_dir + '_aug/flip_' + img_name[:-4] + '.tif', flip_mask) 154 | cv2.imwrite(masks_dir + '_aug/flip_' + img_name[:-4] + '_90.tif', flip_mask_90) 155 | cv2.imwrite(masks_dir + '_aug/flip_' + img_name[:-4] + '_180.tif', flip_mask_180) 156 | cv2.imwrite(masks_dir + '_aug/flip_' + img_name[:-4] + '_270.tif', flip_mask_270) 157 | -------------------------------------------------------------------------------- /data_preprocess/pre_process.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | def img_clahe(img): 5 | clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8)) 6 | img = clahe.apply(img) 7 | return img 8 | 9 | def img_clahe_cm(img): 10 | b,g,r = cv2.split(img) 11 | clahe = cv2.createCLAHE(clipLimit=1.0, tileGridSize=(8,8)) 12 | b = clahe.apply(b) 13 | g = clahe.apply(g) 14 | r = clahe.apply(r) 15 | output = cv2.merge((b,g,r)) 16 | return output 17 | 18 | def img_normalized(img): 19 | std = np.std(img) 20 | mean = np.mean(img) 21 | img_normalized = (img - mean) / (std + 1e-10) 22 | return img_normalized 23 | 24 | 25 | def convert_16to8(img): 26 | img = (img - np.mean(img)) / np.std(img) 27 | img = (img - np.min(img)) / (np.max(img) - np.min(img)) 28 | img = (img * 255).astype(np.uint8) 29 | return img 30 | 31 | def convert_8to16(img): 32 | img = (img - np.mean(img)) / np.std(img) 33 | img = (img - np.min(img)) / (np.max(img) - np.min(img)) 34 | img = (img * 65535).astype(np.uint16) 35 | return img 36 | 37 | def sober_filter(img): 38 | if img.dtype == "uint16": 39 | dx = np.array(cv2.Sobel(img, cv2.CV_32F, 1, 0)) 40 | dy = np.array(cv2.Sobel(img, cv2.CV_32F, 0, 1)) 41 | elif img.dtype == "uint8": 42 | dx = np.array(cv2.Sobel(img, cv2.CV_16S, 1, 0)) 43 | dy = np.array(cv2.Sobel(img, cv2.CV_16S, 0, 1)) 44 | dx = np.abs(dx) 45 | dy = np.abs(dy) 46 | edge = cv2.addWeighted(dx, 0.5, dy, 0.5, 0) 47 | return edge 48 | 49 | 50 | def standardization(data): 51 | mu = np.mean(data) 52 | sigma = np.std(data) 53 | return (data - mu) / sigma 54 | 55 | def npy_PreProc(npy): 56 | img_FD = npy[0] 57 | img_FL = npy[1] 58 | FD_min = np.min(img_FD) 59 | FD_max = np.max(img_FD) 60 | img_FD = (img_FD - FD_min) / (FD_max - FD_min) 61 | 62 | FL_min = np.min(img_FL) 63 | FL_max = np.max(img_FL) 64 | img_FL = (img_FL - FL_min) / (FL_max - FL_min) 65 | sd_FD = standardization(img_FD) 66 | sd_FL = standardization(img_FL) 67 | return sd_FD, sd_FL -------------------------------------------------------------------------------- /datasets/ER_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import Dataset, DataLoader 4 | from data_preprocess.pre_process import * 5 | 6 | 7 | def img_PreProc_er(img, pro_type): 8 | if pro_type == "clahe": 9 | img = img_clahe(img) 10 | img = img / 65535. 11 | sd_img = standardization(img) 12 | return sd_img 13 | 14 | elif pro_type == "invert": 15 | img = 65535 - img 16 | return img / 65535. 17 | 18 | elif pro_type == "edgeEnhance": 19 | edge = sober_filter(img) 20 | edge = edge / np.max(edge) 21 | return ((img / 65535.) + edge) * 0.5 22 | 23 | elif pro_type == "norm": 24 | img = img / 65535. 25 | img = (img - np.mean(img)) / (np.std(img) + 1e-8) 26 | return img 27 | 28 | elif pro_type == "clahe_norm": 29 | img = img_clahe(img) 30 | img = img / 65535. 31 | img = (img - np.mean(img)) / (np.std(img) + 1e-8) 32 | return img 33 | 34 | 35 | class ER_Dataset(Dataset): 36 | def __init__(self, txt, dataset_type, train, fractal_dir='', weight_dir='',edge_dir='', skeleton_dir='', decoder_type='', 37 | log_file='', epoch=0, update_d=5, img_size=256): 38 | 39 | self.img_size = img_size 40 | self.dataset_type = dataset_type 41 | self.train = train 42 | self.fractal_dir = fractal_dir 43 | self.weight_dir = weight_dir 44 | self.decoder_type = decoder_type 45 | self.log_file = log_file 46 | self.epoch = epoch 47 | self.update_d = update_d 48 | self.edge_dir = edge_dir 49 | self.skeleton_dir = skeleton_dir 50 | 51 | with open(txt, "r") as fid: 52 | lines = fid.readlines() 53 | 54 | img_mask_paths = [] 55 | for line in lines: 56 | line = line.strip('\n') 57 | line = line.rstrip() 58 | words = line.split(" ") 59 | img_mask_paths.append((words[0], words[1])) 60 | 61 | self.img_mask_paths = img_mask_paths 62 | 63 | def __getitem__(self, index): 64 | img_path, mask_path = self.img_mask_paths[index] 65 | 66 | # initialize input 67 | if self.dataset_type == 'er': 68 | img = cv2.imread(img_path, -1) 69 | mask = cv2.imread(mask_path, -1) 70 | img_ = img_PreProc_er(img, pro_type='clahe') 71 | 72 | img_ = torch.from_numpy(img_).unsqueeze(dim=0).float() 73 | mask_ = torch.from_numpy(mask / 255.).unsqueeze_(dim=0).float() 74 | 75 | sample = {"image": img_, 76 | "mask": mask_, 77 | "ID": os.path.split(img_path)[1]} 78 | 79 | elif self.dataset_type == 'er_fractal': 80 | npy_path = img_path.replace("images", self.fractal_dir) 81 | npy_path = npy_path.replace(".tif", ".npy") 82 | fractal_info = np.load(npy_path) 83 | img_FD, img_FL = npy_PreProc(fractal_info) 84 | img = cv2.imread(img_path, -1) 85 | mask = cv2.imread(mask_path, -1) 86 | img_ = img_PreProc_er(img, pro_type='clahe') 87 | 88 | fractal_img = np.stack((img_, img_FD, img_FL), axis=2) 89 | image_chw = np.transpose(fractal_img, (2, 0, 1)) 90 | image_chw = torch.from_numpy(image_chw).float() 91 | mask_ = torch.from_numpy(mask / 255.).unsqueeze_(dim=0).float() 92 | 93 | sample = {"image": image_chw, 94 | "mask": mask_, 95 | "ID": os.path.split(img_path)[1]} 96 | elif self.dataset_type == 'er_copy': 97 | img = cv2.imread(img_path, -1) 98 | mask = cv2.imread(mask_path, -1) 99 | img_ = img_PreProc_er(img, pro_type='clahe') 100 | 101 | fractal_img = np.stack((img_, img_, img_), axis=2) 102 | image_chw = np.transpose(fractal_img, (2, 0, 1)) 103 | image_chw = torch.from_numpy(image_chw).float() 104 | mask_ = torch.from_numpy(mask / 255.).unsqueeze_(dim=0).float() 105 | 106 | sample = {"image": image_chw, 107 | "mask": mask_, 108 | "ID": os.path.split(img_path)[1]} 109 | elif self.dataset_type == 'er_fractal_two_decoder': 110 | npy_path = img_path.replace("images", self.fractal_dir) 111 | npy_path = npy_path.replace(".tif", ".npy") 112 | 113 | fractal_info = np.load(npy_path) 114 | img_FD, img_FL = npy_PreProc(fractal_info) 115 | 116 | img = cv2.imread(img_path, -1) 117 | mask = cv2.imread(mask_path, -1) 118 | 119 | img_ = img_PreProc_er(img, pro_type='clahe') 120 | 121 | fractal_img = np.stack((img_, img_FD, img_FL), axis=2) 122 | image_chw = np.transpose(fractal_img, (2, 0, 1)) 123 | image_chw = torch.from_numpy(image_chw).float() 124 | mask_ = torch.from_numpy(mask / 255.).unsqueeze_(dim=0).float() 125 | 126 | if self.train: 127 | if self.decoder_type != '': 128 | if self.decoder_type == 'skeleton': 129 | skeleton_path = mask_path.replace('masks', self.skeleton_dir) 130 | skeleton = cv2.imread(skeleton_path, -1) 131 | skeleton_ = torch.from_numpy(skeleton / 255.).unsqueeze_(dim=0).float() 132 | 133 | sample = {"image": image_chw, 134 | "mask": mask_, 135 | "skeleton": skeleton_, 136 | "ID": os.path.split(img_path)[1]} 137 | elif self.decoder_type == 'edge': 138 | edge_path = mask_path.replace('masks', self.edge_dir) 139 | edge = cv2.imread(edge_path, -1) 140 | edge_ = torch.from_numpy(edge / 255.).unsqueeze_(dim=0).float() 141 | sample = {"image": image_chw, 142 | "mask": mask_, 143 | "edge": edge_, 144 | "ID": os.path.split(img_path)[1]} 145 | else: 146 | skeleton_path = mask_path.replace('masks', self.skeleton_dir) 147 | skeleton = cv2.imread(skeleton_path, -1) 148 | skeleton_ = torch.from_numpy(skeleton / 255.).unsqueeze_(dim=0).float() 149 | 150 | sample = {"image": image_chw, 151 | "mask": mask_, 152 | "skeleton": skeleton_, 153 | "ID": os.path.split(img_path)[1]} 154 | 155 | else: 156 | sample = {"image": image_chw, 157 | "mask": mask_, 158 | "ID": os.path.split(img_path)[1]} 159 | elif self.dataset_type == 'er_fractal_three_decoder': 160 | npy_path = img_path.replace("images", self.fractal_dir) 161 | npy_path = npy_path.replace(".tif", ".npy") 162 | 163 | fractal_info = np.load(npy_path) 164 | img_FD, img_FL = npy_PreProc(fractal_info) 165 | 166 | img = cv2.imread(img_path, -1) 167 | mask = cv2.imread(mask_path, -1) 168 | 169 | img_ = img_PreProc_er(img, pro_type='clahe') 170 | 171 | fractal_img = np.stack((img_, img_FD, img_FL), axis=2) 172 | image_chw = np.transpose(fractal_img, (2, 0, 1)) 173 | image_chw = torch.from_numpy(image_chw).float() 174 | mask_ = torch.from_numpy(mask / 255.).unsqueeze_(dim=0).float() 175 | 176 | if self.train: 177 | edge_path = mask_path.replace('masks', self.edge_dir) 178 | skeleton_path = mask_path.replace('masks', self.skeleton_dir) 179 | edge = cv2.imread(edge_path, -1) 180 | skeleton = cv2.imread(skeleton_path, -1) 181 | skeleton_ = torch.from_numpy(skeleton / 255.).unsqueeze_(dim=0).float() 182 | edge_ = torch.from_numpy(edge / 255.).unsqueeze_(dim=0).float() 183 | sample = {"image": image_chw, 184 | "mask": mask_, 185 | "skeleton": skeleton_, 186 | "edge": edge_, 187 | "ID": os.path.split(img_path)[1]} 188 | else: 189 | sample = {"image": image_chw, 190 | "mask": mask_, 191 | "ID": os.path.split(img_path)[1]} 192 | elif self.dataset_type == 'er_three_decoder': 193 | img = cv2.imread(img_path, -1) 194 | mask = cv2.imread(mask_path, -1) 195 | img_ = img_PreProc_er(img, pro_type='clahe') 196 | image_chw = torch.from_numpy(img_).unsqueeze_(dim=0).float() 197 | mask_ = torch.from_numpy(mask / 255.).unsqueeze_(dim=0).float() 198 | 199 | if self.train: 200 | edge_path = mask_path.replace('masks', self.edge_dir) 201 | skeleton_path = mask_path.replace('masks', self.skeleton_dir) 202 | edge = cv2.imread(edge_path, -1) 203 | skeleton = cv2.imread(skeleton_path, -1) 204 | skeleton_ = torch.from_numpy(skeleton / 255.).unsqueeze_(dim=0).float() 205 | edge_ = torch.from_numpy(edge / 255.).unsqueeze_(dim=0).float() 206 | sample = {"image": image_chw, 207 | "mask": mask_, 208 | "skeleton": skeleton_, 209 | "edge": edge_, 210 | "ID": os.path.split(img_path)[1]} 211 | else: 212 | sample = {"image": image_chw, 213 | "mask": mask_, 214 | "ID": os.path.split(img_path)[1]} 215 | elif self.dataset_type == 'er_fractal_three_decoder_weighted': 216 | npy_path = img_path.replace("images", self.fractal_dir) 217 | npy_path = npy_path.replace(".tif", ".npy") 218 | 219 | fractal_info = np.load(npy_path) 220 | img_FD, img_FL = npy_PreProc(fractal_info) 221 | 222 | img = cv2.imread(img_path, -1) 223 | mask = cv2.imread(mask_path, -1) 224 | 225 | weight_path = mask_path.replace('masks', 'masks' + self.weight_dir) 226 | weighted_npy = weight_path.replace(".tif", ".npy") 227 | weight1 = np.ones_like(img) 228 | weight2 = np.load(weighted_npy) 229 | weight = weight1 + (weight2 / np.max(weight2)) 230 | 231 | 232 | img_ = img_PreProc_er(img, pro_type='clahe') 233 | 234 | fractal_img = np.stack((img_, img_FD, img_FL), axis=2) 235 | image_chw = np.transpose(fractal_img, (2, 0, 1)) 236 | image_chw = torch.from_numpy(image_chw).float() 237 | mask_ = torch.from_numpy(mask / 255.).unsqueeze_(dim=0).float() 238 | weight_ = torch.from_numpy(weight / 1.0).unsqueeze_(dim=0).float() 239 | 240 | if self.train: 241 | edge_path = mask_path.replace('masks', self.edge_dir) 242 | skeleton_path = mask_path.replace('masks', self.skeleton_dir) 243 | edge = cv2.imread(edge_path, -1) 244 | skeleton = cv2.imread(skeleton_path, -1) 245 | skeleton_ = torch.from_numpy(skeleton / 255.).unsqueeze_(dim=0).float() 246 | edge_ = torch.from_numpy(edge / 255.).unsqueeze_(dim=0).float() 247 | sample = {"image": image_chw, 248 | "mask": mask_, 249 | "skeleton": skeleton_, 250 | "edge": edge_, 251 | "weight": weight_, 252 | "ID": os.path.split(img_path)[1]} 253 | else: 254 | sample = {"image": image_chw, 255 | "mask": mask_, 256 | "ID": os.path.split(img_path)[1]} 257 | return sample 258 | 259 | def __len__(self): 260 | return len(self.img_mask_paths) 261 | -------------------------------------------------------------------------------- /datasets/MITO_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import Dataset 4 | from data_preprocess.pre_process import * 5 | 6 | 7 | def img_PreProc_mito(img, pro_type): 8 | if pro_type == "clahe": 9 | img = img_clahe(img) 10 | return img / 65535. 11 | 12 | elif pro_type == "invert": 13 | img = 65535 - img 14 | return img / 65535. 15 | 16 | elif pro_type == "edgeEnhance": 17 | edge = sober_filter(img) 18 | edge = edge / np.max(edge) 19 | return ((img / 65535.) + edge) * 0.5 20 | 21 | elif pro_type == "norm": 22 | img = img / 65535. 23 | img = (img - np.mean(img)) / (np.std(img) + 1e-8) 24 | return img 25 | 26 | elif pro_type == "clahe_norm": 27 | img = img_clahe(img) 28 | img = img / 65535. 29 | img = (img - np.mean(img)) / (np.std(img) + 1e-8) 30 | return img 31 | 32 | 33 | class MITO_Dataset(Dataset): 34 | def __init__(self, txt, dataset_type, train, fractal_dir='',weight_dir='', edge_dir='', skeleton_dir='', decoder_type='', 35 | log_file='', epoch=0, update_d=5, img_size=256): 36 | 37 | self.img_size = img_size 38 | self.dataset_type = dataset_type 39 | self.train = train 40 | self.fractal_dir = fractal_dir 41 | self.weight_dir = weight_dir 42 | self.decoder_type = decoder_type 43 | self.log_file = log_file 44 | self.epoch = epoch 45 | self.update_d = update_d 46 | self.edge_dir = edge_dir 47 | self.skeleton_dir = skeleton_dir 48 | 49 | with open(txt, "r") as fid: 50 | lines = fid.readlines() 51 | 52 | img_mask_paths = [] 53 | for line in lines: 54 | line = line.strip('\n') 55 | line = line.rstrip() 56 | words = line.split(" ") 57 | img_mask_paths.append((words[0], words[1])) 58 | 59 | self.img_mask_paths = img_mask_paths 60 | 61 | def __getitem__(self, index): 62 | img_path, mask_path = self.img_mask_paths[index] 63 | 64 | # initialize input 65 | if self.dataset_type == 'mito': 66 | img = cv2.imread(img_path, -1) 67 | mask = cv2.imread(mask_path, -1) 68 | img_ = img_PreProc_mito(img, pro_type='clahe_norm') 69 | 70 | img_ = torch.from_numpy(img_).unsqueeze(dim=0).float() 71 | mask_ = torch.from_numpy(mask / 255.).unsqueeze_(dim=0).float() 72 | 73 | sample = {"image": img_, 74 | "mask": mask_, 75 | "ID": os.path.split(img_path)[1]} 76 | elif self.dataset_type == 'mito_copy': 77 | img = cv2.imread(img_path, -1) 78 | mask = cv2.imread(mask_path, -1) 79 | img_ = img_PreProc_mito(img, pro_type='clahe_norm') 80 | 81 | fractal_img = np.stack((img_, img_, img_), axis=2) 82 | image_chw = np.transpose(fractal_img, (2, 0, 1)) 83 | image_chw = torch.from_numpy(image_chw).float() 84 | mask_ = torch.from_numpy(mask / 255.).unsqueeze_(dim=0).float() 85 | 86 | sample = {"image": image_chw, 87 | "mask": mask_, 88 | "ID": os.path.split(img_path)[1]} 89 | elif self.dataset_type == 'mito_fractal': 90 | npy_path = img_path.replace("images", self.fractal_dir) 91 | npy_path = npy_path.replace(".tif", ".npy") 92 | fractal_info = np.load(npy_path) 93 | img_FD, img_FL = npy_PreProc(fractal_info) 94 | img = cv2.imread(img_path, -1) 95 | mask = cv2.imread(mask_path, -1) 96 | img_ = img_PreProc_mito(img, pro_type='clahe_norm') 97 | 98 | fractal_img = np.stack((img_, img_FD, img_FL), axis=2) 99 | image_chw = np.transpose(fractal_img, (2, 0, 1)) 100 | image_chw = torch.from_numpy(image_chw).float() 101 | mask_ = torch.from_numpy(mask / 255.).unsqueeze_(dim=0).float() 102 | 103 | sample = {"image": image_chw, 104 | "mask": mask_, 105 | "ID": os.path.split(img_path)[1]} 106 | elif self.dataset_type == 'mito_fractal_three_decoder': 107 | npy_path = img_path.replace("images", self.fractal_dir) 108 | npy_path = npy_path.replace(".tif", ".npy") 109 | fractal_info = np.load(npy_path) 110 | img_FD, img_FL = npy_PreProc(fractal_info) 111 | img = cv2.imread(img_path, -1) 112 | mask = cv2.imread(mask_path, -1) 113 | img_ = img_PreProc_mito(img, pro_type='clahe_norm') 114 | 115 | fractal_img = np.stack((img_, img_FD, img_FL), axis=2) 116 | image_chw = np.transpose(fractal_img, (2, 0, 1)) 117 | image_chw = torch.from_numpy(image_chw).float() 118 | mask_ = torch.from_numpy(mask / 255.).unsqueeze_(dim=0).float() 119 | 120 | if self.train: 121 | edge_path = mask_path.replace('masks_aug', self.edge_dir) 122 | skeleton_path = mask_path.replace('masks_aug', self.skeleton_dir) 123 | edge = cv2.imread(edge_path, -1) 124 | skeleton = cv2.imread(skeleton_path, -1) 125 | skeleton_ = torch.from_numpy(skeleton / 255.).unsqueeze_(dim=0).float() 126 | edge_ = torch.from_numpy(edge / 255.).unsqueeze_(dim=0).float() 127 | sample = {"image": image_chw, 128 | "mask": mask_, 129 | "skeleton": skeleton_, 130 | "edge": edge_, 131 | "ID": os.path.split(img_path)[1]} 132 | else: 133 | sample = {"image": image_chw, 134 | "mask": mask_, 135 | "ID": os.path.split(img_path)[1]} 136 | elif self.dataset_type == 'mito_fractal_three_decoder_weighted': 137 | npy_path = img_path.replace("images", self.fractal_dir) 138 | npy_path = npy_path.replace(".tif", ".npy") 139 | fractal_info = np.load(npy_path) 140 | img_FD, img_FL = npy_PreProc(fractal_info) 141 | 142 | img = cv2.imread(img_path, -1) 143 | mask = cv2.imread(mask_path, -1) 144 | 145 | weight_path = mask_path.replace('masks_aug', 'masks_aug' + self.weight_dir) 146 | weighted_npy = weight_path.replace(".tif", ".npy") 147 | weight1 = np.ones_like(img) 148 | weight2 = np.load(weighted_npy) 149 | weight = weight1 + (weight2 / np.max(weight2)) 150 | 151 | 152 | img_ = img_PreProc_mito(img, pro_type='clahe_norm') 153 | fractal_img = np.stack((img_, img_FD, img_FL), axis=2) 154 | image_chw = np.transpose(fractal_img, (2, 0, 1)) 155 | image_chw = torch.from_numpy(image_chw).float() 156 | mask_ = torch.from_numpy(mask / 255.).unsqueeze_(dim=0).float() 157 | weight_ = torch.from_numpy(weight / 1.0).unsqueeze_(dim=0).float() 158 | 159 | if self.train: 160 | edge_path = mask_path.replace('masks_aug', self.edge_dir) 161 | skeleton_path = mask_path.replace('masks_aug', self.skeleton_dir) 162 | edge = cv2.imread(edge_path, -1) 163 | skeleton = cv2.imread(skeleton_path, -1) 164 | skeleton_ = torch.from_numpy(skeleton / 255.).unsqueeze_(dim=0).float() 165 | edge_ = torch.from_numpy(edge / 255.).unsqueeze_(dim=0).float() 166 | sample = {"image": image_chw, 167 | "mask": mask_, 168 | "skeleton": skeleton_, 169 | "edge": edge_, 170 | "weight": weight_, 171 | "ID": os.path.split(img_path)[1]} 172 | else: 173 | sample = {"image": image_chw, 174 | "mask": mask_, 175 | "ID": os.path.split(img_path)[1]} 176 | 177 | return sample 178 | 179 | def __len__(self): 180 | return len(self.img_mask_paths) 181 | -------------------------------------------------------------------------------- /datasets/NUCLEUS_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import Dataset 4 | from data_preprocess.pre_process import * 5 | 6 | 7 | def img_PreProc_nucleus(img, pro_type): 8 | if pro_type == "clahe": 9 | img = img_clahe(img) 10 | img = img / 255. 11 | sd_img = standardization(img) 12 | return sd_img 13 | 14 | elif pro_type == "clahe_new": 15 | img = img_clahe(img) 16 | return img / 255. 17 | 18 | 19 | elif pro_type == "invert": 20 | img = 255 - img 21 | return img / 255. 22 | 23 | elif pro_type == "edgeEnhance": 24 | edge = sober_filter(img) 25 | edge = edge / np.max(edge) 26 | return ((img / 255.) + edge) * 0.5 27 | 28 | elif pro_type == "norm_single": 29 | img = img / 255. 30 | img = (img - np.mean(img)) / (np.std(img) + 1e-8) 31 | return img 32 | elif pro_type == "norm_dataset": 33 | img = (img - (1, 1, 1)) / (1, 1, 1) 34 | return img 35 | elif pro_type == "clahe_norm_single": 36 | img = img_clahe(img) 37 | img = img / 255. 38 | img = (img - np.mean(img)) / (np.std(img) + 1e-8) 39 | return img 40 | elif pro_type == "clahe_norm": 41 | img = img_clahe(img) 42 | img = img / 255. 43 | img = (img - np.mean(img)) / (np.std(img) + 1e-8) 44 | return img 45 | 46 | 47 | class NUCLEUS_Dataset(Dataset): 48 | def __init__(self, txt, dataset_type, train, fractal_dir='', weight_dir='', edge_dir='',skeleton_dir='',decoder_type='', log_file='', epoch=0, update_d=5, 49 | img_size=256): 50 | self.img_size = img_size 51 | self.dataset_type = dataset_type 52 | self.train = train 53 | self.fractal_dir = fractal_dir 54 | self.weight_dir = weight_dir 55 | self.decoder_type = decoder_type 56 | self.log_file = log_file 57 | self.epoch = epoch 58 | self.update_d = update_d 59 | self.edge_dir =edge_dir 60 | self.skeleton_dir = skeleton_dir 61 | 62 | with open(txt, "r") as fid: 63 | lines = fid.readlines() 64 | 65 | img_mask_paths = [] 66 | for line in lines: 67 | line = line.strip('\n') 68 | line = line.rstrip() 69 | words = line.split(" ") 70 | img_mask_paths.append((words[0], words[1])) 71 | 72 | self.img_mask_paths = img_mask_paths 73 | 74 | def __getitem__(self, index): 75 | img_path, mask_path = self.img_mask_paths[index] 76 | 77 | # initialize input 78 | if self.dataset_type == 'nucleus': 79 | img = cv2.imread(img_path, 0) 80 | mask = cv2.imread(mask_path, 0) 81 | img_ = img_PreProc_nucleus(img, pro_type='clahe_norm') 82 | 83 | image_chw = torch.from_numpy(img_).unsqueeze(dim=0).float() 84 | mask_ = torch.from_numpy(mask / 255.).unsqueeze_(dim=0).float() 85 | 86 | sample = {"image": image_chw, 87 | "mask": mask_, 88 | "ID": os.path.split(img_path)[1]} 89 | elif self.dataset_type == 'nucleus_fractal': 90 | npy_path = img_path.replace("images", self.fractal_dir) 91 | npy_path = npy_path[:-4] + ".npy" 92 | fractal_info = np.load(npy_path) 93 | img_FD, img_FL = npy_PreProc(fractal_info) 94 | img = cv2.imread(img_path, 0) 95 | mask = cv2.imread(mask_path, 0) 96 | img_ = img_PreProc_nucleus(img, pro_type='clahe_norm') 97 | 98 | fractal_img = np.stack((img_, img_FD, img_FL), axis=2) 99 | image_chw = np.transpose(fractal_img, (2, 0, 1)) 100 | image_chw = torch.from_numpy(image_chw).float() 101 | mask_ = torch.from_numpy(mask / 255.).unsqueeze_(dim=0).float() 102 | 103 | sample = {"image": image_chw, 104 | "mask": mask_, 105 | "ID": os.path.split(img_path)[1]} 106 | elif self.dataset_type == 'nucleus_fractal_two_decoder': 107 | npy_path = img_path.replace("images", self.fractal_dir) 108 | npy_path = npy_path[:-4] + ".npy" 109 | fractal_info = np.load(npy_path) 110 | img_FD, img_FL = npy_PreProc(fractal_info) 111 | img = cv2.imread(img_path, 0) 112 | mask = cv2.imread(mask_path, 0) 113 | img_ = img_PreProc_nucleus(img, pro_type='clahe_norm') 114 | 115 | fractal_img = np.stack((img_, img_FD, img_FL), axis=2) 116 | image_chw = np.transpose(fractal_img, (2, 0, 1)) 117 | image_chw = torch.from_numpy(image_chw).float() 118 | mask_ = torch.from_numpy(mask / 255.).unsqueeze_(dim=0).float() 119 | 120 | if self.train: 121 | edge_path = mask_path.replace('masks_train_aug', self.edge_dir) 122 | edge = cv2.imread(edge_path, -1) 123 | edge_ = torch.from_numpy(edge / 255.).unsqueeze_(dim=0).float() 124 | sample = {"image": image_chw, 125 | "mask": mask_, 126 | "edge": edge_, 127 | "ID": os.path.split(img_path)[1]} 128 | else: 129 | sample = {"image": image_chw, 130 | "mask": mask_, 131 | "ID": os.path.split(img_path)[1]} 132 | 133 | elif self.dataset_type == 'nucleus_fractal_two_decoder_weighted': 134 | train_imgpth_list = img_path.split('/') 135 | image_name = train_imgpth_list[-1] 136 | train_imgpth_list[-2] = train_imgpth_list[-2] + '_' + self.fractal_dir 137 | npy_path = '/'.join(train_imgpth_list) 138 | npy_path = npy_path[:-4] + ".npy" 139 | 140 | fractal_info = np.load(npy_path) 141 | img_FD, img_FL = npy_PreProc(fractal_info) 142 | 143 | img = cv2.imread(img_path, 0) 144 | mask = cv2.imread(mask_path, 0) 145 | 146 | weight_path = mask_path.replace('masks_train_aug', 'masks_train_aug' + self.weight_dir) 147 | weighted_npy = weight_path.replace(".tif", ".npy") 148 | weight1 = np.ones_like(img) 149 | weight2 = np.load(weighted_npy) 150 | weight = weight1 + (weight2 / np.max(weight2)) 151 | 152 | img_ = img_PreProc_nucleus(img, pro_type='clahe_norm') 153 | 154 | fractal_img = np.stack((img_, img_FD, img_FL), axis=2) 155 | image_chw = np.transpose(fractal_img, (2, 0, 1)) 156 | image_chw = torch.from_numpy(image_chw).float() 157 | mask_ = torch.from_numpy(mask / 255.).unsqueeze_(dim=0).float() 158 | weight_ = torch.from_numpy(weight / 1.0).unsqueeze_(dim=0).float() 159 | 160 | if self.train: 161 | edge_path = mask_path.replace('masks_train_aug', self.edge_dir) 162 | edge = cv2.imread(edge_path, -1) 163 | edge_ = torch.from_numpy(edge / 255.).unsqueeze_(dim=0).float() 164 | sample = {"image": image_chw, 165 | "mask": mask_, 166 | "edge": edge_, 167 | "weight": weight_, 168 | "ID": os.path.split(img_path)[1]} 169 | else: 170 | sample = {"image": image_chw, 171 | "mask": mask_, 172 | "ID": os.path.split(img_path)[1]} 173 | 174 | return sample 175 | 176 | def __len__(self): 177 | return len(self.img_mask_paths) 178 | -------------------------------------------------------------------------------- /datasets/ROAD_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import Dataset, DataLoader 4 | from data_preprocess.pre_process import * 5 | import torchvision.transforms as transforms 6 | import torchvision.transforms.functional as F 7 | 8 | 9 | def img_PreProc_road(img, pro_type): 10 | if pro_type == "clahe": 11 | img = img_clahe(img) 12 | img = img / 255. 13 | sd_img = standardization(img) 14 | return sd_img 15 | 16 | elif pro_type == "clahe_new": 17 | img = img_clahe(img) 18 | return img / 255. 19 | 20 | 21 | elif pro_type == "invert": 22 | img = 255 - img 23 | return img / 255. 24 | 25 | elif pro_type == "edgeEnhance": 26 | edge = sober_filter(img) 27 | edge = edge / np.max(edge) 28 | return ((img / 255.) + edge) * 0.5 29 | 30 | elif pro_type == "norm_single": 31 | img = img / 255. 32 | img = (img - np.mean(img)) / (np.std(img) + 1e-8) 33 | return img 34 | elif pro_type == "norm_dataset": 35 | img = (img - (109.244987, 110.007784, 100.735999)) / (74.424917, 72.786659, 75.802716) 36 | # img = (img - (109.139851, 109.901530, 100.629362)) / (73.187326, 71.655098, 75.030130) 37 | # img = (img - (109.135833, 109.898700, 100.626847)) / (73.189280, 71.656865, 75.031641) 38 | return img 39 | elif pro_type == "clahe_norm_single": 40 | img = img_clahe(img) 41 | img = img / 255. 42 | img = (img - np.mean(img)) / (np.std(img) + 1e-8) 43 | return img 44 | elif pro_type == "clahe_norm": 45 | img = img_clahe(img) 46 | img = img / 255. 47 | img = (img - np.mean(img)) / (np.std(img) + 1e-8) 48 | return img 49 | 50 | 51 | def random_crop(images, labels, aim_size): 52 | trans = transforms.Compose([transforms.RandomCrop(aim_size)]) 53 | seed = torch.random.seed() 54 | torch.random.manual_seed(seed) 55 | cropped_img = trans(images) 56 | torch.random.manual_seed(seed) 57 | cropped_label = trans(labels) 58 | return cropped_img, cropped_label 59 | 60 | 61 | def random_crop_with_edge_skeleton_weight(images, labels, edge, skeleton, weight, aim_size): 62 | trans = transforms.Compose([transforms.RandomCrop(aim_size)]) 63 | seed = torch.random.seed() 64 | 65 | torch.random.manual_seed(seed) 66 | cropped_img = trans(images) 67 | torch.random.manual_seed(seed) 68 | cropped_label = trans(labels) 69 | torch.random.manual_seed(seed) 70 | cropped_edge = trans(edge) 71 | torch.random.manual_seed(seed) 72 | cropped_skeleton = trans(skeleton) 73 | torch.random.manual_seed(seed) 74 | cropped_weight = trans(weight) 75 | return cropped_img, cropped_label, cropped_edge, cropped_skeleton, cropped_weight 76 | 77 | 78 | def random_crop_with_edge_skeleton(images, labels, edge, skeleton, aim_size): 79 | trans = transforms.Compose([transforms.RandomCrop(aim_size)]) 80 | seed = torch.random.seed() 81 | 82 | torch.random.manual_seed(seed) 83 | cropped_img = trans(images) 84 | torch.random.manual_seed(seed) 85 | cropped_label = trans(labels) 86 | torch.random.manual_seed(seed) 87 | cropped_edge = trans(edge) 88 | torch.random.manual_seed(seed) 89 | cropped_skeleton = trans(skeleton) 90 | 91 | return cropped_img, cropped_label, cropped_edge, cropped_skeleton 92 | 93 | 94 | def regular_crop(images, labels, aim_size): 95 | cropped_image = F.crop(images, 100, 100, 256, 256) 96 | cropped_label = F.crop(labels, 100, 100, 256, 256) 97 | return cropped_image, cropped_label 98 | 99 | 100 | class ROAD_Dataset(Dataset): 101 | def __init__(self, txt, dataset_type, train, fractal_dir='', weight_dir='', edge_dir='', skeleton_dir='', decoder_type='',log_file='', epoch=0, update_d=5,img_size=256): 102 | 103 | self.img_size = img_size 104 | self.dataset_type = dataset_type 105 | self.train = train 106 | self.fractal_dir = fractal_dir 107 | self.weight_dir = weight_dir 108 | self.decoder_type = decoder_type 109 | self.log_file = log_file 110 | self.epoch = epoch 111 | self.update_d = update_d 112 | self.edge_dir = edge_dir 113 | self.skeleton_dir = skeleton_dir 114 | 115 | with open(txt, "r") as fid: 116 | lines = fid.readlines() 117 | 118 | img_mask_paths = [] 119 | for line in lines: 120 | line = line.strip('\n') 121 | line = line.rstrip() 122 | words = line.split(" ") 123 | img_mask_paths.append((words[0], words[1])) 124 | 125 | self.img_mask_paths = img_mask_paths 126 | 127 | def __getitem__(self, index): 128 | img_path, mask_path = self.img_mask_paths[index] 129 | 130 | # initialize input 131 | if self.dataset_type == 'road': 132 | if self.train: 133 | img = cv2.imread(img_path, -1) 134 | img_ = img_PreProc_road(img, pro_type='norm_dataset') 135 | image_chw = np.transpose(img_, (2, 0, 1)) 136 | else: 137 | test_path = img_path.replace('cropped_test_input', 'cropped_test_pre') 138 | image_chw = np.load(test_path).astype(np.float32) 139 | 140 | mask = cv2.imread(mask_path, 0) 141 | mask[mask > 3] = 255 142 | mask[mask <= 3] = 0 143 | 144 | image_chw_ = torch.from_numpy(image_chw).float() 145 | mask_ = torch.from_numpy(mask / 255.).unsqueeze_(dim=0).float() 146 | 147 | if self.train: 148 | cropped_img, cropped_mask = random_crop(image_chw_, mask_, 256) 149 | else: 150 | cropped_img = image_chw_ 151 | cropped_mask = mask_ 152 | sample = {"image": cropped_img, 153 | "mask": cropped_mask, 154 | "ID": os.path.split(img_path)[1]} 155 | elif self.dataset_type == 'road_copy': 156 | if self.train: 157 | npy_path = img_path.replace('train_val_aug', 'pre_train_val_aug') 158 | image_chw = np.load(npy_path[:-3] + 'npy').astype(np.float32) 159 | else: 160 | image_chw = np.load(img_path).astype(np.float32) 161 | 162 | mask = cv2.imread(mask_path, 0) 163 | mask[mask > 3] = 255 164 | mask[mask <= 3] = 0 165 | 166 | fractal_img = np.stack((image_chw[0], image_chw[1], image_chw[2], image_chw[0], image_chw[1], image_chw[2]), 167 | axis=2) 168 | image_chw2 = np.transpose(fractal_img, (2, 0, 1)) 169 | 170 | image_chw_ = torch.from_numpy(image_chw2).float() 171 | mask_ = torch.from_numpy(mask / 255.).unsqueeze_(dim=0).float() 172 | if self.train: 173 | cropped_img, cropped_mask = random_crop(image_chw_, mask_, 256) 174 | else: 175 | cropped_img = image_chw_ 176 | cropped_mask = mask_ 177 | 178 | sample = {"image": cropped_img, 179 | "mask": cropped_mask, 180 | "ID": os.path.split(img_path)[1]} 181 | elif self.dataset_type == 'road_fractal': 182 | if self.train: 183 | npy_path = img_path.replace('train_val_aug', 'pre_train_val_aug') 184 | image_chw = np.load(npy_path[:-3] + 'npy').astype(np.float32) 185 | 186 | else: 187 | image_chw = np.load(img_path).astype(np.float32) 188 | 189 | mask = cv2.imread(mask_path, 0) 190 | mask[mask > 3] = 255 191 | mask[mask <= 3] = 0 192 | 193 | image_chw_ = torch.from_numpy(image_chw).float() 194 | mask_ = torch.from_numpy(mask / 255.).unsqueeze_(dim=0).float() 195 | if self.train: 196 | cropped_img, cropped_mask = random_crop(image_chw_, mask_, 256) 197 | else: 198 | cropped_img = image_chw 199 | cropped_mask = mask_ 200 | 201 | sample = {"image": cropped_img, 202 | "mask": cropped_mask, 203 | "ID": os.path.split(img_path)[1]} 204 | elif self.dataset_type == 'road_fractal_three_decoder': 205 | if self.train: 206 | npy_path = img_path.replace('train_val_aug', 'pre_train_val_aug') 207 | image_chw = np.load(npy_path[:-3] + 'npy').astype(np.float32) 208 | else: 209 | image_chw = np.load(img_path).astype(np.float32) 210 | mask = cv2.imread(mask_path, 0) 211 | mask[mask > 3] = 255 212 | mask[mask <= 3] = 0 213 | image_chw_ = torch.from_numpy(image_chw).float() 214 | mask_ = torch.from_numpy(mask / 255.).unsqueeze_(dim=0).float() 215 | if self.train: 216 | edge_path = mask_path.replace('train_val_labels_aug', self.edge_dir) 217 | skeleton_path = mask_path.replace('train_val_labels_aug', self.skeleton_dir) 218 | edge = cv2.imread(edge_path, -1) 219 | skeleton = cv2.imread(skeleton_path, -1) 220 | skeleton_ = torch.from_numpy(skeleton / 255.).unsqueeze_(dim=0).float() 221 | edge_ = torch.from_numpy(edge / 255.).unsqueeze_(dim=0).float() 222 | cropped_img, cropped_mask, cropped_edge, cropped_skeleton = random_crop_with_edge_skeleton(image_chw_, 223 | mask_, 224 | edge_, 225 | skeleton_, 226 | 256) 227 | sample = {"image": cropped_img, 228 | "mask": cropped_mask, 229 | "skeleton": cropped_skeleton, 230 | "edge": cropped_edge, 231 | "ID": os.path.split(img_path)[1]} 232 | else: 233 | cropped_img = image_chw 234 | cropped_mask = mask_ 235 | sample = {"image": cropped_img, 236 | "mask": cropped_mask, 237 | "ID": os.path.split(img_path)[1]} 238 | elif self.dataset_type == 'road_fractal_three_decoder_weighted': 239 | train_imgpth_list = img_path.split('/') 240 | image_name = train_imgpth_list[-1] 241 | if self.train: 242 | npy_path = img_path.replace('train_val_aug', 'pre_train_val_aug') 243 | image_chw = np.load(npy_path[:-3] + 'npy').astype(np.float32) 244 | else: 245 | image_chw = np.load(img_path).astype(np.float32) 246 | mask = cv2.imread(mask_path, 0) 247 | mask[mask > 3] = 255 248 | mask[mask <= 3] = 0 249 | image_chw_ = torch.from_numpy(image_chw).float() 250 | mask_ = torch.from_numpy(mask / 255.).unsqueeze_(dim=0).float() 251 | 252 | weight_path = mask_path.replace('train_val_labels_aug', 'train_val_labels_aug' + self.weight_dir) 253 | weighted_npy = weight_path.replace(".tif", ".npy") 254 | weight1 = np.ones_like(img) 255 | weight2 = np.load(weighted_npy) 256 | weight = weight1 + (weight2 / np.max(weight2)) 257 | 258 | 259 | if self.train: 260 | edge_path = mask_path.replace('train_val_labels_aug', self.edge_dir) 261 | skeleton_path = mask_path.replace('train_val_labels_aug', self.skeleton_dir) 262 | edge = cv2.imread(edge_path, -1) 263 | skeleton = cv2.imread(skeleton_path, -1) 264 | skeleton_ = torch.from_numpy(skeleton / 255.).unsqueeze_(dim=0).float() 265 | edge_ = torch.from_numpy(edge / 255.).unsqueeze_(dim=0).float() 266 | weight_ = torch.from_numpy(weight / 1.0).unsqueeze_(dim=0).float() 267 | cropped_img, cropped_mask, cropped_edge, cropped_skeleton, cropped_weight = random_crop_with_edge_skeleton_weight( 268 | image_chw_, mask_, 269 | edge_, skeleton_, weight_, 256) 270 | 271 | sample = {"image": cropped_img, 272 | "mask": cropped_mask, 273 | "skeleton": cropped_skeleton, 274 | "edge": cropped_edge, 275 | "weight": weight_, 276 | "ID": os.path.split(img_path)[1]} 277 | else: 278 | cropped_img = image_chw 279 | cropped_mask = mask_ 280 | sample = {"image": cropped_img, 281 | "mask": cropped_mask, 282 | "ID": os.path.split(img_path)[1]} 283 | 284 | return sample 285 | 286 | def __len__(self): 287 | return len(self.img_mask_paths) 288 | -------------------------------------------------------------------------------- /datasets/ROSE_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import Dataset, DataLoader 4 | from data_preprocess.pre_process import * 5 | 6 | 7 | def img_PreProc_rose(img, pro_type): 8 | if pro_type == "clahe": 9 | img = img_clahe(img) 10 | img = img / 255. 11 | sd_img = standardization(img) 12 | return sd_img 13 | 14 | elif pro_type == "invert": 15 | img = 255 - img 16 | return img / 255. 17 | 18 | elif pro_type == "edgeEnhance": 19 | edge = sober_filter(img) 20 | edge = edge / np.max(edge) 21 | return ((img / 255.) + edge) * 0.5 22 | 23 | elif pro_type == "norm": 24 | img = img / 255. 25 | img = (img - np.mean(img)) / (np.std(img) + 1e-8) 26 | return img 27 | 28 | elif pro_type == "clahe_norm": 29 | img = img_clahe(img) 30 | img = img / 255. 31 | img = (img - np.mean(img)) / (np.std(img) + 1e-8) 32 | return img 33 | 34 | 35 | class ROSE_Dataset(Dataset): 36 | def __init__(self, txt, dataset_type, train, fractal_dir='', weight_dir='', edge_dir='',skeleton_dir='',decoder_type='', log_file='', epoch=0, update_d=5, 37 | img_size=256): 38 | self.img_size = img_size 39 | self.dataset_type = dataset_type 40 | self.train = train 41 | self.fractal_dir = fractal_dir 42 | self.weight_dir = weight_dir 43 | self.decoder_type = decoder_type 44 | self.log_file = log_file 45 | self.epoch = epoch 46 | self.update_d = update_d 47 | self.edge_dir = edge_dir 48 | self.skeleton_dir = skeleton_dir 49 | 50 | with open(txt, "r") as fid: 51 | lines = fid.readlines() 52 | 53 | img_mask_paths = [] 54 | for line in lines: 55 | line = line.strip('\n') 56 | line = line.rstrip() 57 | words = line.split(" ") 58 | img_mask_paths.append((words[0], words[1])) 59 | 60 | self.img_mask_paths = img_mask_paths 61 | 62 | def __getitem__(self, index): 63 | 64 | img_path, mask_path = self.img_mask_paths[index] 65 | 66 | # initialize input 67 | if self.dataset_type == 'rose': 68 | img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) 69 | mask = cv2.imread(mask_path, -1) 70 | img_ = img_PreProc_rose(img, pro_type='norm') 71 | 72 | img_ = torch.from_numpy(img_).unsqueeze(dim=0).float() 73 | mask_ = torch.from_numpy(mask / 255.).unsqueeze_(dim=0).float() 74 | 75 | sample = {"image": img_, 76 | "mask": mask_, 77 | "ID": os.path.split(img_path)[1]} 78 | elif self.dataset_type == 'rose_copy': 79 | img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) 80 | mask = cv2.imread(mask_path, -1) 81 | img_ = img_PreProc_rose(img, pro_type='norm') 82 | 83 | fractal_img = np.stack((img_, img_, img_), axis=2) 84 | image_chw = np.transpose(fractal_img, (2, 0, 1)) 85 | image_chw = torch.from_numpy(image_chw).float() 86 | mask_ = torch.from_numpy(mask / 255.).unsqueeze_(dim=0).float() 87 | 88 | sample = {"image": image_chw, 89 | "mask": mask_, 90 | "ID": os.path.split(img_path)[1]} 91 | elif self.dataset_type == 'rose_fractal': 92 | train_imgpth_list = img_path.split('/') 93 | train_imgpth_list[-2] = self.fractal_dir 94 | npy_path = '/'.join(train_imgpth_list) 95 | npy_path = npy_path[:-4] + ".npy" 96 | fractal_info = np.load(npy_path) 97 | img_FD, img_FL = npy_PreProc(fractal_info) 98 | img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) 99 | mask = cv2.imread(mask_path, -1) 100 | img_ = img_PreProc_rose(img, pro_type='norm') 101 | 102 | fractal_img = np.stack((img_, img_FD, img_FL), axis=2) 103 | image_chw = np.transpose(fractal_img, (2, 0, 1)) 104 | image_chw = torch.from_numpy(image_chw).float() 105 | mask_ = torch.from_numpy(mask / 255.).unsqueeze_(dim=0).float() 106 | 107 | sample = {"image": image_chw, 108 | "mask": mask_, 109 | "ID": os.path.split(img_path)[1]} 110 | elif self.dataset_type == 'rose_fractal_three_decoder': 111 | train_imgpth_list = img_path.split('/') 112 | if self.train: 113 | train_imgpth_list[-2] = 'aug_' + self.fractal_dir 114 | else: 115 | train_imgpth_list[-2] = self.fractal_dir 116 | npy_path = '/'.join(train_imgpth_list) 117 | npy_path = npy_path[:-4] + ".npy" 118 | fractal_info = np.load(npy_path) 119 | img_FD, img_FL = npy_PreProc(fractal_info) 120 | img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) 121 | mask = cv2.imread(mask_path, -1) 122 | img_ = img_PreProc_rose(img, pro_type='norm') 123 | 124 | fractal_img = np.stack((img_, img_FD, img_FL), axis=2) 125 | image_chw = np.transpose(fractal_img, (2, 0, 1)) 126 | image_chw = torch.from_numpy(image_chw).float() 127 | mask_ = torch.from_numpy(mask / 255.).unsqueeze_(dim=0).float() 128 | 129 | if self.train: 130 | edge_path = mask_path.replace('aug_gt', self.edge_dir) 131 | skeleton_path = mask_path.replace('aug_gt', self.skeleton_dir) 132 | edge = cv2.imread(edge_path, -1) 133 | skeleton = cv2.imread(skeleton_path, -1) 134 | skeleton_ = torch.from_numpy(skeleton / 255.).unsqueeze_(dim=0).float() 135 | edge_ = torch.from_numpy(edge / 255.).unsqueeze_(dim=0).float() 136 | sample = {"image": image_chw, 137 | "mask": mask_, 138 | "skeleton": skeleton_, 139 | "edge": edge_, 140 | "ID": os.path.split(img_path)[1]} 141 | else: 142 | sample = {"image": image_chw, 143 | "mask": mask_, 144 | "ID": os.path.split(img_path)[1]} 145 | elif self.dataset_type == 'rose_fractal_three_decoder_weighted': 146 | train_imgpth_list[-2] = self.fractal_dir 147 | npy_path = '/'.join(train_imgpth_list) 148 | npy_path = npy_path[:-4] + ".npy" 149 | 150 | fractal_info = np.load(npy_path) 151 | img_FD, img_FL = npy_PreProc(fractal_info) 152 | 153 | img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) 154 | mask = cv2.imread(mask_path, -1) 155 | 156 | weight_path = mask_path.replace('aug_gt', 'aug_gt' + self.weight_dir) 157 | weighted_npy = weight_path.replace(".tif", ".npy") 158 | weight1 = np.ones_like(img) 159 | weight2 = np.load(weighted_npy) 160 | weight = weight1 + (weight2 / np.max(weight2)) 161 | 162 | img_ = img_PreProc_rose(img, pro_type='norm') 163 | 164 | fractal_img = np.stack((img_, img_FD, img_FL), axis=2) 165 | image_chw = np.transpose(fractal_img, (2, 0, 1)) 166 | image_chw = torch.from_numpy(image_chw).float() 167 | mask_ = torch.from_numpy(mask / 255.).unsqueeze_(dim=0).float() 168 | weight_ = torch.from_numpy(weight / 1.0).unsqueeze_(dim=0).float() 169 | 170 | if self.train: 171 | edge_path = mask_path.replace('aug_gt', self.edge_dir) 172 | skeleton_path = mask_path.replace('aug_gt', self.skeleton_dir) 173 | edge = cv2.imread(edge_path, -1) 174 | skeleton = cv2.imread(skeleton_path, -1) 175 | skeleton_ = torch.from_numpy(skeleton / 255.).unsqueeze_(dim=0).float() 176 | edge_ = torch.from_numpy(edge / 255.).unsqueeze_(dim=0).float() 177 | sample = {"image": image_chw, 178 | "mask": mask_, 179 | "skeleton": skeleton_, 180 | "edge": edge_, 181 | "weight": weight_, 182 | "ID": os.path.split(img_path)[1]} 183 | else: 184 | sample = {"image": image_chw, 185 | "mask": mask_, 186 | "ID": os.path.split(img_path)[1]} 187 | 188 | return sample 189 | 190 | def __len__(self): 191 | return len(self.img_mask_paths) 192 | -------------------------------------------------------------------------------- /datasets/STARE_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import Dataset, DataLoader 4 | from data_preprocess.pre_process import * 5 | 6 | 7 | def img_PreProc_retina(img, pro_type): 8 | if pro_type == "clahe": 9 | img = img_clahe(img) 10 | img = img / 255. 11 | sd_img = standardization(img) 12 | return sd_img 13 | 14 | elif pro_type == "clahe_new": 15 | img = img_clahe(img) 16 | return img / 255. 17 | 18 | 19 | elif pro_type == "invert": 20 | img = 255 - img 21 | return img / 255. 22 | 23 | elif pro_type == "edgeEnhance": 24 | edge = sober_filter(img) 25 | edge = edge / np.max(edge) 26 | return ((img / 255.) + edge) * 0.5 27 | 28 | elif pro_type == "norm": 29 | img = img / 255. 30 | img = (img - np.mean(img)) / (np.std(img) + 1e-8) 31 | return img 32 | 33 | elif pro_type == "clahe_norm": 34 | img = img_clahe(img) 35 | img = img / 255. 36 | img = (img - np.mean(img)) / (np.std(img) + 1e-8) 37 | return img 38 | 39 | 40 | class STARE_Dataset(Dataset): 41 | def __init__(self, txt, dataset_type, train, fractal_dir='', weight_dir='', edge_dir='',skeleton_dir='', decoder_type='', log_file='', epoch=0, update_d=5, 42 | img_size=256): 43 | self.img_size = img_size 44 | self.dataset_type = dataset_type 45 | self.train = train 46 | self.fractal_dir = fractal_dir 47 | self.weight_dir = weight_dir 48 | self.decoder_type = decoder_type 49 | self.log_file = log_file 50 | self.epoch = epoch 51 | self.update_d = update_d 52 | self.edge_dir = edge_dir 53 | self.skeleton_dir = skeleton_dir 54 | 55 | with open(txt, "r") as fid: 56 | lines = fid.readlines() 57 | 58 | img_mask_paths = [] 59 | for line in lines: 60 | line = line.strip('\n') 61 | line = line.rstrip() 62 | words = line.split(" ") 63 | img_mask_paths.append((words[0], words[1])) 64 | 65 | self.img_mask_paths = img_mask_paths 66 | 67 | def __getitem__(self, index): 68 | img_path, mask_path = self.img_mask_paths[index] 69 | 70 | # initialize input 71 | if self.dataset_type == 'stare': 72 | img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) 73 | mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) 74 | img_ = img_PreProc_retina(img, pro_type='clahe') 75 | 76 | img_ = torch.from_numpy(img_).unsqueeze(dim=0).float() 77 | mask_ = torch.from_numpy(mask / 255.).unsqueeze_(dim=0).float() 78 | 79 | sample = {"image": img_, 80 | "mask": mask_, 81 | "ID": os.path.split(img_path)[1]} 82 | elif self.dataset_type == 'stare_fractal': 83 | npy_path = img_path.replace("images", self.fractal_dir) 84 | npy_path = npy_path[:-4] + ".npy" 85 | fractal_info = np.load(npy_path) 86 | img_FD, img_FL = npy_PreProc(fractal_info) 87 | img = cv2.imread(img_path, 0) 88 | mask = cv2.imread(mask_path, 0) 89 | img_ = img_PreProc_retina(img, pro_type='clahe') 90 | 91 | fractal_img = np.stack((img_, img_FD, img_FL), axis=2) 92 | image_chw = np.transpose(fractal_img, (2, 0, 1)) 93 | image_chw = torch.from_numpy(image_chw).float() 94 | mask_ = torch.from_numpy(mask / 255.).unsqueeze_(dim=0).float() 95 | 96 | sample = {"image": image_chw, 97 | "mask": mask_, 98 | "ID": os.path.split(img_path)[1]} 99 | elif self.dataset_type == 'stare_copy': 100 | img = cv2.imread(img_path, 0) 101 | mask = cv2.imread(mask_path, 0) 102 | img_ = img_PreProc_retina(img, pro_type='clahe') 103 | 104 | fractal_img = np.stack((img_, img_, img_), axis=2) 105 | image_chw = np.transpose(fractal_img, (2, 0, 1)) 106 | image_chw = torch.from_numpy(image_chw).float() 107 | mask_ = torch.from_numpy(mask / 255.).unsqueeze_(dim=0).float() 108 | 109 | sample = {"image": image_chw, 110 | "mask": mask_, 111 | "ID": os.path.split(img_path)[1]} 112 | elif self.dataset_type == 'stare_fractal_three_decoder': 113 | npy_path = img_path.replace("images", self.fractal_dir) 114 | npy_path = npy_path[:-4] + ".npy" 115 | fractal_info = np.load(npy_path) 116 | img_FD, img_FL = npy_PreProc(fractal_info) 117 | img = cv2.imread(img_path, 0) 118 | mask = cv2.imread(mask_path, 0) 119 | img_ = img_PreProc_retina(img, pro_type='clahe') 120 | 121 | fractal_img = np.stack((img_, img_FD, img_FL), axis=2) 122 | image_chw = np.transpose(fractal_img, (2, 0, 1)) 123 | image_chw = torch.from_numpy(image_chw).float() 124 | mask_ = torch.from_numpy(mask / 255.).unsqueeze_(dim=0).float() 125 | 126 | if self.train: 127 | edge_path = mask_path.replace('masks_aug', self.edge_dir) 128 | skeleton_path = mask_path.replace('masks_aug', self.skeleton_dir) 129 | edge = cv2.imread(edge_path, -1) 130 | skeleton = cv2.imread(skeleton_path, -1) 131 | skeleton_ = torch.from_numpy(skeleton / 255.).unsqueeze_(dim=0).float() 132 | edge_ = torch.from_numpy(edge / 255.).unsqueeze_(dim=0).float() 133 | sample = {"image": image_chw, 134 | "mask": mask_, 135 | "skeleton": skeleton_, 136 | "edge": edge_, 137 | "ID": os.path.split(img_path)[1]} 138 | else: 139 | sample = {"image": image_chw, 140 | "mask": mask_, 141 | "ID": os.path.split(img_path)[1]} 142 | elif self.dataset_type == 'stare_fractal_two_decoder': 143 | npy_path = img_path.replace("images", self.fractal_dir) 144 | npy_path = npy_path[:-4] + ".npy" 145 | fractal_info = np.load(npy_path) 146 | img_FD, img_FL = npy_PreProc(fractal_info) 147 | img = cv2.imread(img_path, 0) 148 | mask = cv2.imread(mask_path, 0) 149 | img_ = img_PreProc_retina(img, pro_type='clahe') 150 | 151 | fractal_img = np.stack((img_, img_FD, img_FL), axis=2) 152 | image_chw = np.transpose(fractal_img, (2, 0, 1)) 153 | image_chw = torch.from_numpy(image_chw).float() 154 | mask_ = torch.from_numpy(mask / 255.).unsqueeze_(dim=0).float() 155 | 156 | if self.train: 157 | edge_path = mask_path.replace('masks_aug', self.edge_dir) 158 | skeleton_path = mask_path.replace('masks_aug', self.skeleton_dir) 159 | edge = cv2.imread(edge_path, -1) 160 | skeleton = cv2.imread(skeleton_path, -1) 161 | skeleton_ = torch.from_numpy(skeleton / 255.).unsqueeze_(dim=0).float() 162 | edge_ = torch.from_numpy(edge / 255.).unsqueeze_(dim=0).float() 163 | if self.decoder_type=='edge': 164 | sample = {"image": image_chw, 165 | "mask": mask_, 166 | "edge": edge_, 167 | "ID": os.path.split(img_path)[1]} 168 | else: 169 | sample = {"image": image_chw, 170 | "mask": mask_, 171 | "skeleton": skeleton_, 172 | "ID": os.path.split(img_path)[1]} 173 | else: 174 | sample = {"image": image_chw, 175 | "mask": mask_, 176 | "ID": os.path.split(img_path)[1]} 177 | elif self.dataset_type == 'stare_three_decoder': 178 | img = cv2.imread(img_path, 0) 179 | mask = cv2.imread(mask_path, 0) 180 | img_ = img_PreProc_retina(img, pro_type='clahe') 181 | image_chw = torch.from_numpy(img_).unsqueeze_(dim=0).float() 182 | mask_ = torch.from_numpy(mask / 255.).unsqueeze_(dim=0).float() 183 | 184 | if self.train: 185 | edge_path = mask_path.replace('masks_aug', self.edge_dir) 186 | skeleton_path = mask_path.replace('masks_aug', self.skeleton_dir) 187 | edge = cv2.imread(edge_path, -1) 188 | skeleton = cv2.imread(skeleton_path, -1) 189 | skeleton_ = torch.from_numpy(skeleton / 255.).unsqueeze_(dim=0).float() 190 | edge_ = torch.from_numpy(edge / 255.).unsqueeze_(dim=0).float() 191 | sample = {"image": image_chw, 192 | "mask": mask_, 193 | "skeleton": skeleton_, 194 | "edge": edge_, 195 | "ID": os.path.split(img_path)[1]} 196 | else: 197 | sample = {"image": image_chw, 198 | "mask": mask_, 199 | "ID": os.path.split(img_path)[1]} 200 | elif self.dataset_type == 'stare_fractal_three_decoder_weighted': 201 | train_imgpth_list = img_path.split('/') 202 | image_name = train_imgpth_list[-1] 203 | npy_path = img_path.replace("images", self.fractal_dir) 204 | npy_path = npy_path[:-4] + ".npy" 205 | 206 | fractal_info = np.load(npy_path) 207 | img_FD, img_FL = npy_PreProc(fractal_info) 208 | 209 | img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) 210 | mask = cv2.imread(mask_path, -1) 211 | 212 | weight_path = mask_path.replace('masks_aug', 'masks_aug' + self.weight_dir) 213 | weighted_npy = weight_path.replace(".tif", ".npy") 214 | weight1 = np.ones_like(img) 215 | weight2 = np.load(weighted_npy) 216 | weight = weight1 + (weight2 / np.max(weight2)) 217 | 218 | img_ = img_PreProc_retina(img, pro_type='clahe') 219 | 220 | fractal_img = np.stack((img_, img_FD, img_FL), axis=2) 221 | image_chw = np.transpose(fractal_img, (2, 0, 1)) 222 | image_chw = torch.from_numpy(image_chw).float() 223 | mask_ = torch.from_numpy(mask / 255.).unsqueeze_(dim=0).float() 224 | weight_ = torch.from_numpy(weight / 1.0).unsqueeze_(dim=0).float() 225 | 226 | if self.train: 227 | edge_path = mask_path.replace('masks_aug', self.edge_dir) 228 | skeleton_path = mask_path.replace('masks_aug', self.skeleton_dir) 229 | edge = cv2.imread(edge_path, -1) 230 | skeleton = cv2.imread(skeleton_path, -1) 231 | skeleton_ = torch.from_numpy(skeleton / 255.).unsqueeze_(dim=0).float() 232 | edge_ = torch.from_numpy(edge / 255.).unsqueeze_(dim=0).float() 233 | sample = {"image": image_chw, 234 | "mask": mask_, 235 | "skeleton": skeleton_, 236 | "edge": edge_, 237 | "weight": weight_, 238 | "ID": os.path.split(img_path)[1]} 239 | else: 240 | sample = {"image": image_chw, 241 | "mask": mask_, 242 | "ID": os.path.split(img_path)[1]} 243 | 244 | return sample 245 | 246 | def __len__(self): 247 | return len(self.img_mask_paths) 248 | -------------------------------------------------------------------------------- /datasets/dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | 3 | from datasets.ER_dataset import ER_Dataset 4 | from datasets.MITO_dataset import MITO_Dataset 5 | from datasets.ROSE_dataset import ROSE_Dataset 6 | from datasets.STARE_dataset import STARE_Dataset 7 | from datasets.ROAD_dataset import ROAD_Dataset 8 | from datasets.NUCLEUS_dataset import NUCLEUS_Dataset 9 | 10 | DATASETs = {'MITO': MITO_Dataset, 'ER': ER_Dataset, 'ROSE': ROSE_Dataset, 11 | 'SATRE': STARE_Dataset, 'ROAD': ROAD_Dataset, 'NUCLEUS': NUCLEUS_Dataset} 12 | 13 | 14 | def build_data_loader(data_name, data_list, batch_size, dataset_type, is_train=True, fractal_dir='', weight_dir='', 15 | edge_dir='', skeleton_dir='', decoder_type='', 16 | log_file='', epoch=0, update_d=0): 17 | train_data = DATASETs[data_name](txt=data_list, dataset_type=dataset_type, train=is_train, fractal_dir=fractal_dir, 18 | weight_dir=weight_dir, edge_dir=edge_dir, skeleton_dir=skeleton_dir, 19 | decoder_type=decoder_type, log_file=log_file, epoch=epoch, update_d=update_d) 20 | 21 | data_loader = DataLoader(train_data, batch_size=batch_size, shuffle=is_train, num_workers=8) 22 | 23 | return data_loader 24 | -------------------------------------------------------------------------------- /datasets/metric.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import roc_curve, auc, f1_score, accuracy_score,roc_auc_score 2 | from scipy.spatial.distance import directed_hausdorff 3 | from skimage.morphology import skeletonize, skeletonize_3d 4 | from datasets.BettiMatching import * 5 | import cv2 6 | import math 7 | 8 | def extract_mask(pred_arr, gt_arr, mask_arr=None): 9 | # we want to make them into vectors 10 | pred_vec = pred_arr.flatten() 11 | gt_vec = gt_arr.flatten() 12 | 13 | if mask_arr is not None: 14 | mask_vec = mask_arr.flatten() 15 | idx = list(np.where(mask_vec == 0)[0]) 16 | 17 | pred_vec = np.delete(pred_vec, idx) 18 | gt_vec = np.delete(gt_vec, idx) 19 | 20 | return pred_vec, gt_vec 21 | 22 | 23 | def calc_auc(pred_arr, gt_arr, mask_arr=None): 24 | pred_vec, gt_vec = extract_mask(pred_arr, gt_arr, mask_arr=mask_arr) 25 | roc_auc = roc_auc_score(gt_vec, pred_vec) 26 | 27 | return roc_auc 28 | 29 | 30 | def numeric_score(pred_arr, gt_arr, kernel_size=(1, 1)): # DCC & ROSE-2: kernel_size=(3, 3) 31 | """Computation of statistical numerical scores: 32 | 33 | * FP = False Positives 34 | * FN = False Negatives 35 | * TP = True Positives 36 | * TN = True Negatives 37 | 38 | return: tuple (FP, FN, TP, TN) 39 | """ 40 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, kernel_size) 41 | dilated_gt_arr = cv2.dilate(gt_arr, kernel, iterations=1) 42 | 43 | FP = np.float(np.sum(np.logical_and(pred_arr == 1, dilated_gt_arr == 0))) 44 | FN = np.float(np.sum(np.logical_and(pred_arr == 0, gt_arr == 1))) 45 | TP = np.float(np.sum(np.logical_and(pred_arr == 1, dilated_gt_arr == 1))) 46 | TN = np.float(np.sum(np.logical_and(pred_arr == 0, gt_arr == 0))) 47 | 48 | return FP, FN, TP, TN 49 | 50 | 51 | def calc_acc(pred_arr, gt_arr, kernel_size=(1, 1)): # DCC & ROSE-2: kernel_size=(3, 3) 52 | FP, FN, TP, TN = numeric_score(pred_arr, gt_arr, kernel_size) 53 | acc = (TP + TN) / (FP + FN + TP + TN) 54 | 55 | return acc 56 | 57 | 58 | def calc_sen(pred_arr, gt_arr, kernel_size=(1, 1)): # DCC & ROSE-2: kernel_size=(3, 3) 59 | FP, FN, TP, TN = numeric_score(pred_arr, gt_arr, kernel_size) 60 | sen = TP / (FN + TP + 1e-12) 61 | 62 | return sen 63 | 64 | 65 | def calc_fdr(pred_arr, gt_arr, kernel_size=(1, 1)): # DCC & ROSE-2: kernel_size=(3, 3) 66 | FP, FN, TP, TN = numeric_score(pred_arr, gt_arr, kernel_size) 67 | fdr = FP / (FP + TP + 1e-12) 68 | 69 | return fdr 70 | 71 | 72 | def calc_spe(pred_arr, gt_arr, kernel_size=(1, 1)): # DCC & ROSE-2: kernel_size=(3, 3) 73 | FP, FN, TP, TN = numeric_score(pred_arr, gt_arr, kernel_size) 74 | spe = TN / (FP + TN + 1e-12) 75 | 76 | return spe 77 | 78 | 79 | def calc_gmean(pred_arr, gt_arr, kernel_size=(1, 1)): # DCC & ROSE-2: kernel_size=(3, 3) 80 | sen = calc_sen(pred_arr, gt_arr, kernel_size=kernel_size) 81 | spe = calc_spe(pred_arr, gt_arr, kernel_size=kernel_size) 82 | 83 | return math.sqrt(sen * spe) 84 | 85 | 86 | def calc_kappa(pred_arr, gt_arr, kernel_size=(1, 1)): # DCC & ROSE-2: kernel_size=(3, 3) 87 | FP, FN, TP, TN = numeric_score(pred_arr, gt_arr, kernel_size=kernel_size) 88 | matrix = np.array([[TP, FP], 89 | [FN, TN]]) 90 | n = np.sum(matrix) 91 | 92 | sum_po = 0 93 | sum_pe = 0 94 | for i in range(len(matrix[0])): 95 | sum_po += matrix[i][i] 96 | row = np.sum(matrix[i, :]) 97 | col = np.sum(matrix[:, i]) 98 | sum_pe += row * col 99 | 100 | po = sum_po / n 101 | pe = sum_pe / (n * n) 102 | # print(po, pe) 103 | 104 | return (po - pe) / (1 - pe) 105 | 106 | 107 | def calc_iou(pred_arr, gt_arr, kernel_size=(1, 1)): # DCC & ROSE-2: kernel_size=(3, 3) 108 | FP, FN, TP, TN = numeric_score(pred_arr, gt_arr, kernel_size) 109 | iou = TP / (FP + FN + TP + 1e-12) 110 | 111 | return iou 112 | 113 | 114 | def calc_dice(pred_arr, gt_arr, kernel_size=(1, 1)): # DCC & ROSE-2: kernel_size=(3, 3) 115 | FP, FN, TP, TN = numeric_score(pred_arr, gt_arr, kernel_size) 116 | dice = 2.0 * TP / (FP + FN + 2.0 * TP + 1e-12) 117 | 118 | return dice 119 | 120 | 121 | def hausdorff_distance_single(seg, label): 122 | # segmentation = seg.squeeze(1) 123 | # mask = label.squeeze(1) 124 | segmentation = seg 125 | mask = label 126 | 127 | non_zero_seg = np.transpose(np.nonzero(segmentation)) 128 | non_zero_mask = np.transpose(np.nonzero(mask)) 129 | h_dist = max(directed_hausdorff(non_zero_seg, non_zero_mask)[0], 130 | directed_hausdorff(non_zero_mask, non_zero_seg)[0]) 131 | 132 | return h_dist 133 | 134 | 135 | def cl_score(v, s): 136 | """[this function computes the skeleton volume overlap] 137 | 138 | Args: 139 | v ([bool]): [image] 140 | s ([bool]): [skeleton] 141 | 142 | Returns: 143 | [float]: [computed skeleton volume intersection] 144 | """ 145 | return np.sum(v * s) / np.sum(s) 146 | 147 | 148 | def clDice(v_p, v_l): 149 | """[this function computes the cldice metric] 150 | 151 | Args: 152 | v_p ([bool]): [predicted image] 153 | v_l ([bool]): [ground truth image] 154 | 155 | Returns: 156 | [float]: [cldice metric] 157 | """ 158 | if len(v_p.shape) == 2: 159 | tprec = cl_score(v_p, skeletonize(v_l)) 160 | tsens = cl_score(v_l, skeletonize(v_p)) 161 | elif len(v_p.shape) == 3: 162 | tprec = cl_score(v_p, skeletonize_3d(v_l)) 163 | tsens = cl_score(v_l, skeletonize_3d(v_p)) 164 | return 2 * tprec * tsens / (tprec + tsens) 165 | 166 | 167 | def compute_metrics(t, relative=False, comparison='union', filtration='superlevel', construction='V'): 168 | BM = BettiMatching(t[0], t[1], relative=relative, comparison=comparison, filtration=filtration, 169 | construction=construction) 170 | return BM.loss(dimensions=[0, 1]), BM.loss(dimensions=[0]), BM.loss(dimensions=[1]), BM.Betti_number_error( 171 | threshold=0.5, dimensions=[0, 1]), BM.Betti_number_error(threshold=0.5, dimensions=[0]), BM.Betti_number_error( 172 | threshold=0.5, dimensions=[1]) 173 | 174 | 175 | def acc(seg, label): 176 | now_num = seg.shape[0] 177 | # seg, label = np.array(seg), np.array(label) 178 | seg_one = seg.reshape(-1) 179 | label_one = label.reshape(-1) 180 | 181 | label_T = label_one > 0 182 | corrects = torch.eq(seg_one, label_T).sum() 183 | all_num = seg_one.numel() 184 | 185 | # corrects = (seg.int() == label.int()) 186 | acc = corrects / all_num 187 | return acc * now_num 188 | 189 | 190 | def roc(pred, label): 191 | pred, label = np.array(pred), np.array(label) 192 | preds_roc = np.reshape(pred, -1) 193 | labels_roc = np.reshape(label, -1) 194 | fpr, tpr, thresholds = roc_curve(labels_roc, preds_roc) 195 | roc_auc = auc(fpr, tpr) 196 | return fpr, tpr, roc_auc 197 | 198 | 199 | def dice_cof(pred, label, reduce=False): 200 | matrix_sum = pred.int() + label.int() 201 | i = torch.sum(matrix_sum == 2, dim=(1, 2, 3)) 202 | x1 = torch.sum(pred == 1, dim=(1, 2, 3)) 203 | x2 = torch.sum(label == 1, dim=(1, 2, 3)) 204 | dice_score = 2. * i.float() / (x1.float() + x2.float()) 205 | if reduce: 206 | return torch.mean(dice_score) 207 | else: 208 | return torch.sum(dice_score) 209 | 210 | 211 | def IoU(preds, labels, reduce=False): 212 | matrix_sum = preds.int() + labels.int() 213 | i = torch.sum(matrix_sum == 2, dim=(1, 2, 3)) 214 | u = torch.sum(matrix_sum == 1, dim=(1, 2, 3)) 215 | iou = i.float() / (i.float() + u.float() + 1e-9) 216 | if reduce: 217 | iou = torch.mean(iou) 218 | else: 219 | iou = torch.sum(iou) 220 | return iou 221 | 222 | 223 | def IoU_r(preds, labels, reduce=False): 224 | matrix_sum = preds.int() + labels.int() 225 | i = torch.sum(matrix_sum == 2, dim=(1, 2)) 226 | u = torch.sum(matrix_sum == 1, dim=(1, 2)) 227 | iou = i.float() / (i.float() + u.float() + 1e-9) 228 | if reduce: 229 | iou = torch.mean(iou) 230 | else: 231 | iou = torch.sum(iou) 232 | return iou 233 | 234 | 235 | def acc_list(seg, label): 236 | total_acc = 0.0 237 | img_num = len(seg) 238 | for auc_index in range(img_num): 239 | now_pred = seg[auc_index] 240 | now_labels = label[auc_index] 241 | val_acc = acc(now_pred[0], now_labels[0]) 242 | total_acc += val_acc 243 | 244 | return total_acc 245 | 246 | 247 | def dIoU(preds, labels, reduce=False): 248 | matrix_sum = preds.int() + labels.int() 249 | i = torch.sum(matrix_sum == 2) 250 | u = torch.sum(matrix_sum == 1) 251 | iou = i.float() / (i.float() + u.float() + 1e-9) 252 | if reduce: 253 | iou = torch.mean(iou) 254 | else: 255 | iou = iou 256 | return iou 257 | 258 | 259 | def mIoU(preds, labels, reduce=False): 260 | matrix_sum = preds.int() + labels.int() 261 | f_i = torch.sum(matrix_sum == 2, dim=(1, 2, 3)) 262 | u = torch.sum(matrix_sum == 1, dim=(1, 2, 3)) 263 | b_i = torch.sum(matrix_sum == 0, dim=(1, 2, 3)) 264 | f_iou = f_i.float() / (f_i.float() + u.float() + 1e-9) 265 | b_iou = b_i.float() / (b_i.float() + u.float() + 1e-9) 266 | miou = 0.5 * (f_iou + b_iou) 267 | if reduce: 268 | miou = torch.mean(miou) 269 | else: 270 | miou = torch.sum(miou) 271 | return miou 272 | 273 | 274 | def dmIoU(preds, labels, reduce=False): 275 | matrix_sum = preds.int() + labels.int() 276 | f_i = torch.sum(matrix_sum == 2) 277 | u = torch.sum(matrix_sum == 1) 278 | b_i = torch.sum(matrix_sum == 0) 279 | f_iou = f_i.float() / (f_i.float() + u.float() + 1e-9) 280 | b_iou = b_i.float() / (b_i.float() + u.float() + 1e-9) 281 | miou = 0.5 * (f_iou + b_iou) 282 | if reduce: 283 | miou = torch.mean(miou) 284 | else: 285 | miou = miou 286 | return miou 287 | 288 | 289 | def F1_score(pred, label, reduce=False): 290 | pred, label = pred.int(), label.int() 291 | p = torch.sum((label == 1).int(), dim=(1, 2, 3)) 292 | tp = torch.sum((pred == 1).int() & (label == 1).int(), dim=(1, 2, 3)) 293 | fp = torch.sum((pred == 1).int() & (label == 0).int(), dim=(1, 2, 3)) 294 | recall = tp.float() / (p.float() + 1e-9) 295 | precision = tp.float() / (tp.float() + fp.float() + 1e-9) 296 | f1 = (2 * recall * precision) / (recall + precision + 1e-9) 297 | if reduce: 298 | f1 = torch.mean(f1) 299 | else: 300 | f1 = torch.sum(f1) 301 | return f1 302 | 303 | 304 | def dF1_score(pred, label, reduce=False): 305 | pred, label = pred.int(), label.int() 306 | p = torch.sum((label == 1).int()) 307 | tp = torch.sum((pred == 1).int() & (label == 1).int()) 308 | fp = torch.sum((pred == 1).int() & (label == 0).int()) 309 | recall = tp.float() / (p.float() + 1e-9) 310 | precision = tp.float() / (tp.float() + fp.float() + 1e-9) 311 | f1 = (2 * recall * precision) / (recall + precision + 1e-9) 312 | if reduce: 313 | f1 = torch.mean(f1) 314 | else: 315 | f1 = f1 316 | return f1 317 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 5 | from datasets.metric import * 6 | import cv2 7 | from sklearn.metrics import roc_auc_score, confusion_matrix 8 | from hausdorff import hausdorff_distance 9 | from multiprocessing import Pool 10 | 11 | print("PyTorch Version: ", torch.__version__) 12 | 13 | ''' 14 | evaluation 15 | ''' 16 | def compute_metrics(y_scores, y_true, relative=True, comparison='union', filtration='superlevel', construction='V'): 17 | BM = BettiMatching(y_scores, y_true, relative=relative, comparison=comparison, filtration=filtration, 18 | construction=construction) 19 | 20 | return [BM.loss(dimensions=[0, 1]), BM.loss(dimensions=[0]), BM.loss(dimensions=[1]), BM.Betti_number_error( 21 | threshold=0.5, dimensions=[0, 1]), BM.Betti_number_error(threshold=0.5, dimensions=[0]), BM.Betti_number_error( 22 | threshold=0.5, dimensions=[1])] 23 | 24 | 25 | def infer_metric_threshold(infer_path, mask_path, low, end, size, no_betti): 26 | filenames = os.listdir(mask_path) 27 | img_num = len(filenames) 28 | 29 | thresholds = np.arange(low, end, size) 30 | for threshold in thresholds: 31 | total_img = 0 32 | total_iou = 0.0 33 | total_f1 = 0.0 34 | total_acc = 0.0 35 | total_sen = 0.0 36 | total_auc = 0.0 37 | total_spec = 0.0 38 | cldices = [] 39 | hds = [] 40 | betti_losses = [] 41 | pool = Pool(8) 42 | for i in range(img_num): 43 | now_img = cv2.imread(os.path.join(infer_path, filenames[i][:-3] + 'tif'), -1) 44 | if now_img is None: 45 | # print('not exist') 46 | continue 47 | 48 | # now_img = now_img/255.0 49 | now_mask = cv2.imread(os.path.join(mask_path, filenames[i]), 0) 50 | gt_arr = now_mask // 255 51 | 52 | best_iou = 0.00 53 | # enable evaluation mode 54 | y_scores = np.zeros_like(now_img) 55 | y_true = np.zeros_like(now_mask) 56 | y_true[now_mask > 0.01] = 1 57 | y_scores[now_img > threshold] = 1 58 | hd = hausdorff_distance(y_scores, y_true) 59 | if 'nucleus' in mask_path: 60 | cldice = 0 61 | else: 62 | cldice = clDice(y_scores, y_true) 63 | if no_betti: 64 | loss = loss_0 = loss_1 = betti_err = betti_0_err = betti_1_err = 0 65 | else: 66 | betti_losses.append(pool.apply_async(compute_metrics, args=(y_scores, y_true,))) 67 | cldices.append(cldice) 68 | y_scores1 = y_scores.flatten() 69 | # y_pred = y_scores > threshold 70 | y_true1 = y_true.flatten() 71 | 72 | hds.append(hd) 73 | 74 | confusion = confusion_matrix(y_true1, y_scores1) 75 | tp = float(confusion[1, 1]) 76 | fn = float(confusion[1, 0]) 77 | fp = float(confusion[0, 1]) 78 | tn = float(confusion[0, 0]) 79 | 80 | val_acc = (tp + tn) / (tp + fn + fp + tn) 81 | sensitivity = tp / (tp + fn) 82 | specificity = tn / (tn + fp) 83 | precision = tp / (tp + fp) 84 | f1 = 2 * sensitivity * precision / (sensitivity + precision + 1e-9) 85 | iou = tp / (tp + fn + fp + 1e-9) 86 | auc = calc_auc(now_img, gt_arr) 87 | total_iou += iou 88 | total_acc += val_acc 89 | total_f1 += f1 90 | total_auc += auc 91 | total_sen += sensitivity 92 | total_spec += specificity 93 | total_img += 1 94 | 95 | epoch_iou = (total_iou) / total_img 96 | if epoch_iou > best_iou: 97 | best_iou = epoch_iou 98 | epoch_f1 = total_f1 / total_img 99 | epoch_acc = total_acc / total_img 100 | epoch_auc = total_auc / total_img 101 | epoch_sen = total_sen / total_img 102 | epoch_spec = total_spec / total_img 103 | epoch_clDice = np.mean(cldices) 104 | epoch_hd = np.mean(hds) 105 | message = "inference =====>threshold: {:.4f}: Evaluation ACC: {:.4f}; IOU: {:.4f}; F1_score: {:.4f}; Auc: {:.4f} ;Sen: {:.4f}; Spec: {:.4f}; clDice: {:.4f}; hausdorff_distance: {:.4f};".format( 106 | threshold, 107 | epoch_acc, 108 | epoch_iou, 109 | epoch_f1, epoch_auc, epoch_sen, epoch_spec, epoch_clDice, epoch_hd) 110 | 111 | print("==> %s" % (message)) 112 | 113 | pool.close() 114 | pool.join() 115 | if no_betti: 116 | Betti_error = Betti_error_std = Betti_0_error = Betti_0_error_std = Betti_1_error = Betti_1_error_std = 0 117 | else: 118 | betti_results = [] 119 | for if_index in range(total_img): 120 | betti_result_now = betti_losses[if_index].get() 121 | betti_results.append(betti_result_now) 122 | 123 | betti_losses_array2 = np.array(betti_results) 124 | betti_mean = np.mean(betti_losses_array2, axis=0) 125 | Betti_error = betti_mean[3] 126 | Betti_error_std = betti_mean[3] 127 | Betti_0_error = betti_mean[4] 128 | Betti_0_error_std = betti_mean[4] 129 | Betti_1_error = betti_mean[5] 130 | Betti_1_error_std = betti_mean[5] 131 | 132 | print("Betti number error", Betti_error) 133 | # print("Betti number error std", Betti_error_std) 134 | print("Betti number error dim 0", Betti_0_error) 135 | # print("Betti number error dim 0 std", Betti_0_error_std) 136 | print("Betti number error dim 1", Betti_1_error) 137 | # print("Betti number error dim 1 std", Betti_1_error_std) 138 | 139 | 140 | if __name__ == "__main__": 141 | er_end_path = '/predict_score/ER_best/' 142 | er_mask_dir = '/mnt/data1/ER/test/masks/' 143 | 144 | 145 | model_dir = 'er_fractal_HRNet_iou_32_0.05_50_0.3_1000_20240312_warmup' 146 | infer_path = './train_logs/' + model_dir + er_end_path 147 | 148 | infer_metric_threshold(infer_path, er_mask_dir, 0.3, 0.31, 0.01, no_betti=False) 149 | -------------------------------------------------------------------------------- /figures/FFM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cbmi-group/FFM-Multi-Decoder-Network/3c534eba174a702c5e609fb8192f062f98063ce0/figures/FFM.png -------------------------------------------------------------------------------- /figures/MDNet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cbmi-group/FFM-Multi-Decoder-Network/3c534eba174a702c5e609fb8192f062f98063ce0/figures/MDNet.png -------------------------------------------------------------------------------- /fractal_analysis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import cv2 4 | import csv 5 | import math 6 | import os 7 | from PIL import Image, ImageDraw, ImageOps 8 | import matplotlib.pyplot as plt 9 | from multiprocessing import Pool 10 | from functools import partial 11 | import torch 12 | 13 | ''' 14 | The function is a modified box-counting algorithm to compute Fractal Dimension for image, as described by Wen-Li Lee and Kai-Sheng Hsieh. 15 | 16 | Input: 17 | image: A 2D array containing a grayscale image; 18 | Output: 19 | D: fractal dimension of image, as estimated by the modified box-counting algorithm; 20 | ''' 21 | 22 | 23 | def Box_counting_modified(image): 24 | M = image.shape[0] # image shape 25 | G_min = image.min() # lowest gray level (0=white) 26 | G_max = image.max() # highest gray level (255=black) 27 | G = G_max - G_min + 1 # number of gray levels, typically 256 28 | prev = -1 # used to check for plateaus 29 | r_Nr = [] 30 | 31 | for L in range(2, (M // 2) + 1): 32 | h = max(1, G // (M // L)) # minimum box height is 1 33 | N_r = 0 34 | r = L / M 35 | for i in range(0, M, L): 36 | boxes = [[]] * ((G + h - 1) // h) # create enough boxes with height h to fill the fractal space 37 | for row in image[i:i + L]: # boxes that exceed bounds are shrunk to fit 38 | for pixel in row[i:i + L]: 39 | height = (pixel - G_min) // h # lowest box is at G_min and each is h gray levels tall 40 | boxes[height].append(pixel) # assign the pixel intensity to the correct box 41 | stddev = np.sqrt(np.var(boxes, axis=1)) # calculate the standard deviation of each box 42 | stddev = stddev[~np.isnan(stddev)] # remove boxes with NaN standard deviations (empty) 43 | nBox_r = 2 * (stddev // h) + 1 44 | N_r += sum(nBox_r) 45 | if N_r != prev: # check for plateauing 46 | r_Nr.append([r, N_r]) 47 | prev = N_r 48 | x = np.array([np.log(1 / point[0]) for point in r_Nr]) # log(1/r) 49 | y = np.array([np.log(point[1]) for point in r_Nr]) # log(Nr) 50 | D = np.polyfit(x, y, 1) # D = lim r -> 0 log(Nr)/log(1/r) 51 | return D 52 | 53 | 54 | ''' 55 | The function to compute Fractal Feature Map for image. 56 | 57 | Input: 58 | image: A 2D array containing a grayscale image;; 59 | window_size: the size of sliding window; 60 | step_size: the size of sliding step; 61 | Output: 62 | FFM: the fractal feature map of image. 63 | 64 | ''' 65 | 66 | 67 | def compute_FFM(image, step_size, window_size): 68 | img_shape = np.shape(image) 69 | result_x = math.ceil(img_shape[0] / step_size) 70 | result_y = math.ceil(img_shape[1] / step_size) 71 | FD = np.zeros((result_x, result_y)) 72 | Length = np.zeros((result_x, result_y)) 73 | H = img_shape[0] 74 | pad_size = math.floor(window_size / 2) 75 | padded_img = np.pad(image, ((pad_size, pad_size)), 'linear_ramp') 76 | for i in range(0, H, step_size): 77 | for j in range(0, H, step_size): 78 | selected_img = padded_img[i:i + window_size, j:j + window_size] 79 | selected_img_info = Box_counting_modified(selected_img) 80 | save_coor_x = int(i / step_size) 81 | save_coor_y = int(j / step_size) 82 | FD[save_coor_x, save_coor_y] = selected_img_info[0] 83 | Length[save_coor_x, save_coor_y] = selected_img_info[1] 84 | FFM = np.zeros((2, result_x, result_y)) 85 | FFM[0] = FD 86 | FFM[1] = Length 87 | return FFM 88 | 89 | ''' 90 | The function to compute Fractal Feature Map for images using Pool 91 | 92 | Input: 93 | file_path: the root path of images; 94 | window_size: the size of sliding window; 95 | step_size: the size of sliding step; 96 | ''' 97 | 98 | 99 | def compute_FMM_Pool(file_path, window_size, step_size): 100 | save_path = file_path[:-1] + '_Fractal_info_' + str(window_size) + '_' + str(step_size) + '/' 101 | if step_size > 1: 102 | up_save_path = file_path[:-1] + '_Fractal_info_' + str(window_size) + '_' + str(step_size) + '_up/' 103 | os.makedirs(up_save_path, exist_ok=True) 104 | os.makedirs(save_path, exist_ok=True) 105 | filenames = os.listdir(file_path) 106 | img_num = len(filenames) 107 | images_fractal = [] 108 | pool = Pool(16) 109 | for i in range(img_num): 110 | now_img = (cv2.imread(os.path.join(file_path, filenames[i]), 0)).astype(np.uint8) 111 | images_fractal.append( 112 | pool.apply_async(compute_FFM, args=(now_img, window_size, step_size,))) 113 | pool.close() 114 | pool.join() 115 | now_img = (cv2.imread(os.path.join(file_path, filenames[i]), 0)).astype(np.uint8) 116 | img_shape = np.shape(now_img) 117 | m = torch.nn.Upsample(size=img_shape, mode='bilinear') 118 | for if_index in range(img_num): 119 | image_weight_now = images_fractal[if_index].get() 120 | img_ = torch.from_numpy(image_weight_now).unsqueeze(dim=0).float() 121 | up_now_file = m(img_) 122 | up_now_npy = up_now_file.squeeze().numpy() 123 | if step_size > 1: 124 | np.save(os.path.join(up_save_path, filenames[if_index].split('.')[0] + '.npy'), up_now_npy) 125 | np.save(os.path.join(save_path, filenames[if_index].split('.')[0] + '.npy'), image_weight_now) 126 | -------------------------------------------------------------------------------- /get_edge_skeleton.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import itertools 4 | import matplotlib.pyplot as plt 5 | import os 6 | from skimage import morphology 7 | 8 | def edge_extract(root): 9 | img_root = os.path.join(root, 'masks') 10 | edge_root = os.path.join(root, 'masks_edges') 11 | 12 | if not os.path.exists(edge_root): 13 | os.mkdir(edge_root) 14 | 15 | file_names = os.listdir(img_root) 16 | 17 | index = 0 18 | for name in file_names: 19 | img = cv2.imread(os.path.join(img_root, name), 0) 20 | edge, _ = cv2.findContours(img, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 21 | contour_img = np.zeros_like(img) 22 | cv2.drawContours(contour_img, edge, -1, (255), 1) 23 | cv2.imwrite(os.path.join(edge_root, name), contour_img) 24 | index += 1 25 | return 0 26 | 27 | 28 | 29 | def skeleton_extract(root): 30 | img_root = os.path.join(root, 'train_val_labels_aug') 31 | skeleton_root = os.path.join(root, 'train_val_labels_aug_bone') 32 | if not os.path.exists(skeleton_root): 33 | os.mkdir(skeleton_root) 34 | 35 | file_names = os.listdir(img_root) 36 | for name in file_names: 37 | img = cv2.imread(os.path.join(img_root, name), -1) 38 | img[img <= 100] = 0 39 | img[img > 100] = 1 40 | skeleton0 = morphology.skeletonize(img) 41 | skeleton = skeleton0.astype(np.uint8) * 255 42 | cv2.imwrite(os.path.join(skeleton_root, name), skeleton) 43 | 44 | return 0 45 | 46 | 47 | 48 | if __name__ == '__main__': 49 | train_er = "./data/ER/train/" 50 | edge_extract(train_er) 51 | skeleton_extract(train_er) 52 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | 4 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 5 | import numpy as np 6 | import torch 7 | from datasets.dataset import build_data_loader 8 | from models.unet import UNet as u_net 9 | from models.hrnet import HRNetV2 10 | from datasets.metric import * 11 | from models.md_net import Multi_decoder_Net, Two_decoder_Net 12 | import cv2 13 | import torch.nn.functional as F 14 | import torchvision.transforms as transforms 15 | from sklearn.metrics import roc_auc_score, confusion_matrix 16 | 17 | print("PyTorch Version: ", torch.__version__) 18 | 19 | ''' 20 | inference 21 | ''' 22 | 23 | 24 | def infer_model(opts): 25 | val_batch_size = opts["eval_batch_size"] 26 | dataset_type = opts['dataset_type'] 27 | load_epoch = opts['load_epoch'] 28 | gpus = opts["gpu_list"].split(',') 29 | gpu_list = [] 30 | for str_id in gpus: 31 | id = int(str_id) 32 | gpu_list.append(id) 33 | os.environ['CUDA_VISIBLE_DEVICE'] = opts["gpu_list"] 34 | 35 | eval_data_dir = opts["eval_data_dir"] 36 | 37 | train_dir = opts["train_dir"] 38 | model_type = opts['model_type'] 39 | fractal_dir = opts['fractal_dir'] 40 | dataset_name = opts["dataset_name"] 41 | 42 | model_score_dir = os.path.join(str(os.path.split(train_dir)[0]), 43 | 'predict_score/' + dataset_name + '_' + str(load_epoch)) 44 | if not os.path.exists(model_score_dir): os.makedirs(model_score_dir) 45 | 46 | # dataloader 47 | print("==> Create dataloader") 48 | dataloader = build_data_loader(dataset_name, eval_data_dir, val_batch_size, dataset_type, is_train=False, 49 | fractal_dir=fractal_dir) 50 | 51 | # define network 52 | print("==> Create network") 53 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 54 | model = None 55 | if 'fractal' in opts["dataset_type"]: 56 | if 'road' in opts["dataset_type"]: 57 | if 'RGB' in fractal_dir: 58 | num_channels = 6 59 | else: 60 | num_channels = 5 61 | else: 62 | num_channels = 3 63 | else: 64 | if 'road' in opts["dataset_type"] or 'copy' in opts["dataset_type"]: 65 | num_channels = 3 66 | else: 67 | num_channels = 1 68 | 69 | num_classes = 1 70 | if model_type == 'unet': 71 | model = u_net(num_channels, num_classes) 72 | elif model_type == 'hrnet': 73 | model = HRNetV2(n_channels=num_channels, n_class=num_classes) 74 | elif model_type == 'Two_decoder_Net': 75 | model = Two_decoder_Net(num_channels, num_classes) 76 | elif model_type == 'Multi_decoder_Net': 77 | model = Multi_decoder_Net(num_channels, num_classes) 78 | 79 | 80 | # load trained model 81 | pretrain_model = os.path.join(train_dir, str(load_epoch) + ".pth") 82 | # print(pretrain_model) 83 | # pretrain_model = os.path.join(train_dir, "checkpoints_" + str(load_epoch) + ".pth") 84 | 85 | if os.path.isfile(pretrain_model): 86 | c_checkpoint = torch.load(pretrain_model) 87 | model.load_state_dict(c_checkpoint["model_state_dict"]) 88 | print("==> Loaded pretrianed model checkpoint '{}'.".format(pretrain_model)) 89 | else: 90 | print("==> No trained model.") 91 | return 0 92 | 93 | # set model to gpu mode 94 | print("==> Set to GPU mode") 95 | 96 | model.cuda() 97 | model = torch.nn.DataParallel(model) 98 | 99 | # enable evaluation mode 100 | with torch.no_grad(): 101 | model.eval() 102 | total_img = 0 103 | for inputs in dataloader: 104 | images = inputs["image"].cuda() 105 | img_name = inputs['ID'] 106 | # print('now process image is %s' % (img_name)) 107 | total_img += len(images) 108 | # unet 109 | if model_type == 'unet': 110 | p_seg = model(images) 111 | elif model_type == 'hrnet': 112 | outputs_list = model(images) 113 | p_seg = outputs_list[0] 114 | elif model_type == 'Two_decoder_Net': 115 | p_seg, pred_bone = model(images) 116 | elif model_type == 'Multi_decoder_Net': 117 | p_seg, pred_bone, pred_edge = model(images) 118 | 119 | 120 | for i in range(len(images)): 121 | # print('predict image: {}'.format(img_name[i])) 122 | now_dir = model_score_dir 123 | os.makedirs(now_dir, exist_ok=True) 124 | np.save(os.path.join(now_dir, img_name[i].split('.')[0] + '.npy'), 125 | p_seg[i][0].cpu().numpy().astype(np.float32)) 126 | cv2.imwrite(os.path.join(now_dir, img_name[i].split('.')[0] + '.tif'), 127 | p_seg[i][0].cpu().numpy().astype(np.float32)) 128 | 129 | 130 | 131 | 132 | if __name__ == "__main__": 133 | model_choice = ['unet', 'hrnet', 'Two_decoder_Net', 'Multi_decoder_Net'] 134 | dataset_list = ['er', 'er_fractal', 'er_fractal_two_decoder', 'nucleus_fractal_two_decoder','nucleus_fractal_two_decoder_weighted'] 135 | txt_choice = ['train_mito.txt', 'test_mito.txt', 'train_er.txt', 'test_er.txt', 'test_stare.txt', 'train_stare.txt'] 136 | 137 | opts = dict() 138 | opts["dataset_name"] = "ER" 139 | opts['dataset_type'] = 'er_fractal' 140 | opts["eval_batch_size"] = 1 141 | opts["gpu_list"] = "0,1,2,3" 142 | opts["train_dir"] = "./train_logs/er_fractal_HRNet_iou_32_0.05_50_0.3_1000_20240312_warmup/checkpoints" 143 | opts["eval_data_dir"] = "./dataset_txts/test_er.txt" 144 | opts["decoder_type"] = "edge" 145 | opts['model_type'] = 'hrnet' 146 | opts["load_epoch"] = 'best' 147 | 148 | opts["fractal_dir"] = 'Fractal_info_5' 149 | 150 | best_iou = 0.0 151 | infer_model(opts) 152 | 153 | -------------------------------------------------------------------------------- /models/md_net.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch 3 | import torch.nn as nn 4 | 5 | class DoubleConv(nn.Module): 6 | """(convolution => [BN] => ReLU) * 2""" 7 | 8 | def __init__(self, in_channels, out_channels): 9 | super().__init__() 10 | self.double_conv = nn.Sequential( 11 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 12 | nn.BatchNorm2d(out_channels), 13 | nn.ReLU(inplace=True), 14 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 15 | nn.BatchNorm2d(out_channels), 16 | nn.ReLU(inplace=True) 17 | ) 18 | 19 | def forward(self, x): 20 | return self.double_conv(x) 21 | 22 | 23 | class Down(nn.Module): 24 | """Downscaling with maxpool then double conv""" 25 | 26 | def __init__(self, in_channels, out_channels): 27 | super().__init__() 28 | self.maxpool_conv = nn.Sequential( 29 | nn.MaxPool2d(2), 30 | DoubleConv(in_channels, out_channels) 31 | ) 32 | 33 | def forward(self, x): 34 | return self.maxpool_conv(x) 35 | 36 | 37 | class Up(nn.Module): 38 | """Upscaling then double conv""" 39 | 40 | def __init__(self, in_channels, out_channels, bilinear=False): 41 | super().__init__() 42 | 43 | # if bilinear, use the normal convolutions to reduce the number of channels 44 | if bilinear: 45 | self.up = nn.Sequential( 46 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 47 | nn.Conv2d(in_channels, in_channels // 2, kernel_size=(1, 1), stride=1) 48 | ) 49 | 50 | else: 51 | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) 52 | 53 | self.conv = DoubleConv(in_channels, out_channels) 54 | 55 | def forward(self, x1, x2): 56 | x1 = self.up(x1) 57 | 58 | cat_x = torch.cat((x1, x2), 1) 59 | output = self.conv(cat_x) 60 | return output 61 | 62 | 63 | class up_conv(nn.Module): 64 | """ 65 | Up Convolution Block 66 | """ 67 | 68 | def __init__(self, in_ch, out_ch): 69 | super(up_conv, self).__init__() 70 | self.up = nn.Sequential( 71 | nn.Upsample(scale_factor=2), 72 | nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True), 73 | nn.BatchNorm2d(out_ch), 74 | nn.ReLU(inplace=True) 75 | ) 76 | 77 | def forward(self, x): 78 | x = self.up(x) 79 | return x 80 | 81 | 82 | class OutConv(nn.Module): 83 | def __init__(self, in_channels, out_channels): 84 | super(OutConv, self).__init__() 85 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 86 | 87 | def forward(self, x): 88 | return self.conv(x) 89 | 90 | 91 | ''' 92 | model Two_decoder_Net 93 | ''' 94 | 95 | 96 | class Two_decoder_Net(nn.Module): 97 | def __init__(self, n_channels, n_classes=1, bilinear=False): 98 | super(Two_decoder_Net, self).__init__() 99 | self.n_channels = n_channels 100 | self.n_classes = n_classes 101 | self.bilinear = bilinear 102 | 103 | # encoder 104 | self.inc = DoubleConv(n_channels, 64) 105 | self.down1 = Down(64, 128) 106 | self.down2 = Down(128, 256) 107 | self.down3 = Down(256, 512) 108 | self.down4 = Down(512, 1024) 109 | 110 | # decoder 111 | self.up1_1 = Up(1024, 512, bilinear) 112 | self.up1_2 = Up(1024, 512, bilinear) 113 | self.up2_1 = Up(512, 256, bilinear) 114 | self.up2_2 = Up(512, 256, bilinear) 115 | self.up3_1 = Up(256, 128, bilinear) 116 | self.up3_2 = Up(256, 128, bilinear) 117 | self.up4_1 = Up(128, 64, bilinear) 118 | self.up4_2 = Up(128, 64, bilinear) 119 | self.out_1 = OutConv(64, n_classes) 120 | self.out_2 = OutConv(64, n_classes) 121 | 122 | def forward(self, x): 123 | # encoder 124 | x1 = self.inc(x) 125 | 126 | x2 = self.down1(x1) 127 | x3 = self.down2(x2) 128 | x4 = self.down3(x3) 129 | x5 = self.down4(x4) 130 | 131 | # decoder 132 | o_4_1 = self.up1_1(x5, x4) 133 | o_4_2 = self.up1_2(x5, x4) 134 | 135 | o_3_1 = self.up2_1(o_4_1, x3) 136 | o_3_2 = self.up2_2(o_4_2, x3) 137 | 138 | o_2_1 = self.up3_1(o_3_1, x2) 139 | o_2_2 = self.up3_2(o_3_2, x2) 140 | 141 | o_1_1 = self.up4_1(o_2_1, x1) 142 | o_1_2 = self.up4_2(o_2_2, x1) 143 | 144 | o_seg1 = self.out_1(o_1_1) 145 | o_seg2 = self.out_2(o_1_2) 146 | 147 | if self.n_classes > 1: 148 | seg1 = F.softmax(o_seg1, dim=1) 149 | seg2 = F.softmax(o_seg2, dim=1) 150 | return seg1, seg2 151 | elif self.n_classes == 1: 152 | seg1 = torch.sigmoid(o_seg1) 153 | seg2 = torch.sigmoid(o_seg2) 154 | return seg1, seg2 155 | 156 | 157 | ''' 158 | model Multi-decoder-Net 159 | ''' 160 | 161 | 162 | class Multi_decoder_Net(nn.Module): 163 | def __init__(self, n_channels, n_classes=1, bilinear=False): 164 | super(Multi_decoder_Net, self).__init__() 165 | self.n_channels = n_channels 166 | self.n_classes = n_classes 167 | self.bilinear = bilinear 168 | 169 | # encoder 170 | self.inc = DoubleConv(n_channels, 64) 171 | self.down1 = Down(64, 128) 172 | self.down2 = Down(128, 256) 173 | self.down3 = Down(256, 512) 174 | self.down4 = Down(512, 1024) 175 | 176 | # decoder 177 | self.up1_1 = Up(1024, 512, bilinear) 178 | self.up1_2 = Up(1024, 512, bilinear) 179 | self.up1_3 = Up(1024, 512, bilinear) 180 | 181 | self.up2_1 = Up(512, 256, bilinear) 182 | self.up2_2 = Up(512, 256, bilinear) 183 | self.up2_3 = Up(512, 256, bilinear) 184 | 185 | self.up3_1 = Up(256, 128, bilinear) 186 | self.up3_2 = Up(256, 128, bilinear) 187 | self.up3_3 = Up(256, 128, bilinear) 188 | 189 | self.up4_1 = Up(128, 64, bilinear) 190 | self.up4_2 = Up(128, 64, bilinear) 191 | self.up4_3 = Up(128, 64, bilinear) 192 | 193 | self.out_1 = OutConv(64, n_classes) 194 | self.out_2 = OutConv(64, n_classes) 195 | self.out_3 = OutConv(64, n_classes) 196 | 197 | def forward(self, x): 198 | # encoder 199 | x1 = self.inc(x) 200 | 201 | x2 = self.down1(x1) 202 | x3 = self.down2(x2) 203 | x4 = self.down3(x3) 204 | x5 = self.down4(x4) 205 | 206 | # decoder 207 | o_4_1 = self.up1_1(x5, x4) 208 | o_4_2 = self.up1_2(x5, x4) 209 | o_4_3 = self.up1_3(x5, x4) 210 | 211 | o_3_1 = self.up2_1(o_4_1, x3) 212 | o_3_2 = self.up2_2(o_4_2, x3) 213 | o_3_3 = self.up2_3(o_4_3, x3) 214 | 215 | o_2_1 = self.up3_1(o_3_1, x2) 216 | o_2_2 = self.up3_2(o_3_2, x2) 217 | o_2_3 = self.up3_3(o_3_3, x2) 218 | 219 | o_1_1 = self.up4_1(o_2_1, x1) 220 | o_1_2 = self.up4_2(o_2_2, x1) 221 | o_1_3 = self.up4_3(o_2_3, x1) 222 | 223 | o_seg1 = self.out_1(o_1_1) 224 | o_seg2 = self.out_2(o_1_2) 225 | o_seg3 = self.out_3(o_1_3) 226 | 227 | if self.n_classes > 1: 228 | seg1 = F.softmax(o_seg1, dim=1) 229 | seg2 = F.softmax(o_seg2, dim=1) 230 | seg3 = F.softmax(o_seg3, dim=1) 231 | return seg1, seg2, seg3 232 | elif self.n_classes == 1: 233 | seg1 = torch.sigmoid(o_seg1) 234 | seg2 = torch.sigmoid(o_seg2) 235 | seg3 = torch.sigmoid(o_seg3) 236 | return seg1, seg2, seg3 237 | 238 | 239 | -------------------------------------------------------------------------------- /models/optimize.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.optim as optim 3 | from torch.autograd import Variable 4 | import torch 5 | import torch.nn.functional as F 6 | import numpy as np 7 | import random 8 | import math 9 | import concurrent.futures 10 | 11 | 12 | def create_criterion(criterion="crossentropy"): 13 | if criterion == "crossentropy": 14 | return nn.CrossEntropyLoss() 15 | elif criterion == "bce": 16 | return nn.BCELoss() 17 | elif criterion == "bcelog": 18 | return nn.BCEWithLogitsLoss() 19 | elif criterion == "L1": 20 | return nn.L1Loss() 21 | elif criterion == "MSE": 22 | return nn.MSELoss() 23 | elif criterion == "focal": 24 | return FocalLoss2d() 25 | elif criterion == "wbce": 26 | return weighted_edge_loss() 27 | elif criterion == "iou": 28 | return soft_iou_loss() 29 | elif criterion == 'dice': 30 | return DiceLoss() 31 | elif criterion == 'weighted_soft_iou_loss': 32 | return weighted_soft_iou_loss() 33 | 34 | 35 | def create_optimizer(parameters, mode="SGD", lr=0.001, momentum=0.9, wd=0.0005, beta1=0.5, beta2=0.999): 36 | if mode == "SGD": 37 | return optim.SGD(parameters, lr=lr, momentum=momentum, weight_decay=wd) 38 | elif mode == "Adam": 39 | return optim.Adam(parameters, lr=lr, betas=(beta1, beta2), weight_decay=wd) 40 | 41 | 42 | def update_learning_rate(optimizer, epoch, lr, step=30, gamma=0.1): 43 | lr = lr * (gamma ** (epoch // step)) 44 | for param_group in optimizer.param_groups: 45 | param_group['lr'] = lr 46 | return lr 47 | 48 | 49 | def warmup_learning_rate(optimizer, train_steps, warmup_steps, lr, method): 50 | # gradual warmup_lr 51 | if warmup_steps and train_steps < warmup_steps: 52 | warmup_percent_done = train_steps / warmup_steps 53 | warmup_learning_rate = lr * warmup_percent_done 54 | learning_rate = warmup_learning_rate 55 | else: 56 | # after warm up, decay lr 57 | for param_group in optimizer.param_groups: 58 | now_lr = param_group['lr'] 59 | if method == 'sin': 60 | learning_rate = np.sin(now_lr) 61 | elif method == 'exp': 62 | learning_rate = now_lr ** 1.001 63 | if (train_steps + 1) % 100 == 0: 64 | print("train_steps:%.3f--warmup_steps:%.3f--learning_rate:%.3f" % ( 65 | train_steps + 1, warmup_steps, learning_rate)) 66 | for param_group in optimizer.param_groups: 67 | param_group['lr'] = learning_rate 68 | return learning_rate 69 | 70 | 71 | def get_lr(optimizer): 72 | for param_group in optimizer.param_groups: 73 | return param_group['lr'] 74 | 75 | 76 | def weight_softmax(x): 77 | e_x = np.exp(x - np.max(x)) 78 | return e_x / e_x.sum() 79 | 80 | 81 | class soft_iou_loss(nn.Module): 82 | def __init__(self): 83 | super(soft_iou_loss, self).__init__() 84 | 85 | def forward(self, pred, label): 86 | b = pred.size()[0] 87 | pred = pred.view(b, -1) 88 | label = label.view(b, -1) 89 | inter = torch.sum(torch.mul(pred, label), dim=-1, keepdim=False) 90 | unit = torch.sum(torch.mul(pred, pred) + label, dim=-1, keepdim=False) - inter 91 | return torch.mean(1 - inter / (unit + 1e-10)) 92 | 93 | 94 | class weighted_soft_iou_loss(nn.Module): 95 | def __init__(self): 96 | super(weighted_soft_iou_loss, self).__init__() 97 | 98 | def forward(self, pred, label, weit): 99 | # pred = torch.sigmoid(pred) 100 | b = pred.size()[0] 101 | pred = pred.view(b, -1) 102 | label = label.view(b, -1) 103 | weit = weit.view(b, -1) 104 | inter_ = torch.mul(pred, label) 105 | inter = torch.sum(torch.mul(inter_, weit), dim=-1, keepdim=False) 106 | union_ = torch.mul(torch.mul(pred, pred) + label, weit) 107 | unit = torch.sum(union_, dim=-1, keepdim=False) - inter 108 | return torch.mean(1 - inter / (unit + 1e-10)) 109 | 110 | 111 | 112 | class IoU_loss(torch.nn.Module): 113 | def __init__(self): 114 | super(IoU_loss, self).__init__() 115 | 116 | def forward(self, pred, target): 117 | b = pred.shape[0] 118 | IoU = 0.0 119 | for i in range(0, b): 120 | # compute the IoU of the foreground 121 | Iand1 = torch.sum(target[i, :, :, :] * pred[i, :, :, :]) 122 | Ior1 = torch.sum(target[i, :, :, :]) + torch.sum(pred[i, :, :, :]) - Iand1 123 | IoU1 = Iand1 / (Ior1 + 1e-5) 124 | # IoU loss is (1-IoU1) 125 | IoU = IoU + (1 - IoU1) 126 | return IoU / b 127 | 128 | 129 | class FocalLoss2d(nn.Module): 130 | 131 | def __init__(self, alpha=0.25, gamma=2, ignore_index=None, reduction='mean', **kwargs): 132 | super(FocalLoss2d, self).__init__() 133 | self.alpha = alpha 134 | self.gamma = gamma 135 | self.smooth = 1e-6 # set '1e-4' when train with FP16 136 | self.ignore_index = ignore_index 137 | self.reduction = reduction 138 | 139 | assert self.reduction in ['none', 'mean', 'sum'] 140 | 141 | def forward(self, prob, target): 142 | prob = torch.clamp(prob, self.smooth, 1.0 - self.smooth) 143 | 144 | valid_mask = None 145 | if self.ignore_index is not None: 146 | valid_mask = (target != self.ignore_index).float() 147 | 148 | pos_mask = (target == 1).float() 149 | neg_mask = (target == 0).float() 150 | if valid_mask is not None: 151 | pos_mask = pos_mask * valid_mask 152 | neg_mask = neg_mask * valid_mask 153 | 154 | pos_weight = (pos_mask * torch.pow(1 - prob, self.gamma)).detach() 155 | pos_loss = -self.alpha * (pos_weight * torch.log(prob)) 156 | 157 | neg_weight = (neg_mask * torch.pow(prob, self.gamma)).detach() 158 | neg_loss = -(1 - self.alpha) * (neg_weight * torch.log(1 - prob)) 159 | 160 | loss = (pos_loss + neg_loss) 161 | 162 | return loss.mean() 163 | 164 | 165 | class weighted_edge_loss(nn.Module): 166 | def __init__(self, beta_1=1, beta_2=1): 167 | super(weighted_edge_loss, self).__init__() 168 | self.beta_1 = beta_1 169 | self.beta_2 = beta_2 170 | 171 | def forward(self, pred, label): 172 | label = label.long() 173 | mask = label.float() 174 | num_positive = torch.sum((mask == 1).float()).float() 175 | num_negative = torch.sum((mask == 0).float()).float() 176 | 177 | mask[mask == 1] = self.beta_1 * num_negative / (num_positive + num_negative) 178 | mask[mask == 0] = self.beta_2 * num_positive / (num_positive + num_negative) 179 | 180 | cost = nn.functional.binary_cross_entropy(pred.float(), label.float(), weight=mask, reduction='sum') / ( 181 | num_negative + num_positive) 182 | return cost 183 | 184 | 185 | class DiceLoss(nn.Module): 186 | def __init__(self): 187 | super().__init__() 188 | 189 | def forward(self, input, target): 190 | smooth = 1e-5 191 | num = target.size(0) 192 | input = input.view(num, -1) 193 | target = target.view(num, -1) 194 | intersection = (input * target) 195 | dice = (2. * intersection.sum(1) + smooth) / (input.sum(1) + target.sum(1) + smooth) 196 | dice = 1 - dice.sum() / num 197 | return dice 198 | -------------------------------------------------------------------------------- /models/unet.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch 3 | import torch.nn as nn 4 | 5 | class DoubleConv(nn.Module): 6 | """(convolution => [BN] => ReLU) * 2""" 7 | 8 | def __init__(self, in_channels, out_channels): 9 | super().__init__() 10 | self.double_conv = nn.Sequential( 11 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 12 | nn.BatchNorm2d(out_channels), 13 | nn.ReLU(inplace=True), 14 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 15 | nn.BatchNorm2d(out_channels), 16 | nn.ReLU(inplace=True) 17 | ) 18 | 19 | def forward(self, x): 20 | return self.double_conv(x) 21 | 22 | 23 | class Down(nn.Module): 24 | """Downscaling with maxpool then double conv""" 25 | 26 | def __init__(self, in_channels, out_channels): 27 | super().__init__() 28 | self.maxpool_conv = nn.Sequential( 29 | nn.MaxPool2d(2), 30 | DoubleConv(in_channels, out_channels) 31 | ) 32 | 33 | def forward(self, x): 34 | return self.maxpool_conv(x) 35 | 36 | 37 | class Up(nn.Module): 38 | """Upscaling then double conv""" 39 | 40 | def __init__(self, in_channels, out_channels, bilinear=False): 41 | super().__init__() 42 | 43 | # if bilinear, use the normal convolutions to reduce the number of channels 44 | if bilinear: 45 | self.up = nn.Sequential( 46 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 47 | nn.Conv2d(in_channels, in_channels // 2, kernel_size=(1, 1), stride=1) 48 | ) 49 | 50 | else: 51 | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) 52 | 53 | self.conv = DoubleConv(in_channels, out_channels) 54 | 55 | def forward(self, x1, x2): 56 | x1 = self.up(x1) 57 | 58 | cat_x = torch.cat((x1, x2), 1) 59 | output = self.conv(cat_x) 60 | return output 61 | 62 | 63 | class OutConv(nn.Module): 64 | def __init__(self, in_channels, out_channels): 65 | super(OutConv, self).__init__() 66 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 67 | 68 | def forward(self, x): 69 | return self.conv(x) 70 | 71 | 72 | ''' 73 | model 74 | ''' 75 | 76 | 77 | class UNet(nn.Module): 78 | def __init__(self, n_channels, n_classes=1, bilinear=False): 79 | super(UNet, self).__init__() 80 | self.n_channels = n_channels 81 | self.n_classes = n_classes 82 | self.bilinear = bilinear 83 | 84 | # encoder 85 | self.inc = DoubleConv(n_channels, 64) 86 | self.down1 = Down(64, 128) 87 | self.down2 = Down(128, 256) 88 | self.down3 = Down(256, 512) 89 | self.down4 = Down(512, 1024) 90 | 91 | # decoder 92 | self.up1 = Up(1024, 512, bilinear) 93 | self.up2 = Up(512, 256, bilinear) 94 | self.up3 = Up(256, 128, bilinear) 95 | self.up4 = Up(128, 64, bilinear) 96 | self.out = OutConv(64, n_classes) 97 | 98 | def forward(self, x): 99 | # encoder 100 | x1 = self.inc(x) 101 | 102 | x2 = self.down1(x1) 103 | x3 = self.down2(x2) 104 | x4 = self.down3(x3) 105 | x5 = self.down4(x4) 106 | 107 | # decoder 108 | o_4 = self.up1(x5, x4) 109 | o_3 = self.up2(o_4, x3) 110 | o_2 = self.up3(o_3, x2) 111 | o_1 = self.up4(o_2, x1) 112 | o_seg = self.out(o_1) 113 | 114 | if self.n_classes > 1: 115 | seg = F.softmax(o_seg, dim=1) 116 | return seg 117 | elif self.n_classes == 1: 118 | seg = torch.sigmoid(o_seg) 119 | return seg 120 | 121 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | def make_activation_layer(inplanes, relu_type="relu"): 7 | if relu_type == "relu": 8 | act_func = nn.ReLU(inplace=True) 9 | elif relu_type == "prelu": 10 | act_func = nn.PReLU() 11 | elif relu_type == "leaky": 12 | act_func = nn.LeakyReLU(0.2) 13 | else: 14 | print("Not support this type of acitvation function.") 15 | return 0 16 | return act_func 17 | 18 | 19 | def make_a_conv_layer(inplanes, outplanes, ksize=3, stride=1, pad=1, bn=True): 20 | if bn: 21 | return nn.Sequential(nn.Conv2d(inplanes, outplanes, kernel_size=ksize, stride=stride, padding=pad, bias=False), 22 | nn.BatchNorm2d(outplanes)) 23 | else: 24 | return nn.Conv2d(inplanes, outplanes, kernel_size=ksize, stride=stride, padding=pad, bias=True) 25 | 26 | 27 | def make_a_conv_relu_layer(inplanes, outplanes, ksize=3, stride=1, pad=1, bn=True, relu_type="relu"): 28 | return nn.Sequential(make_a_conv_layer(inplanes, outplanes, ksize=ksize, stride=stride, pad=pad, bn=bn), 29 | make_activation_layer(outplanes, relu_type=relu_type)) 30 | 31 | 32 | ''' 33 | Make a sequence of conv layers 34 | ''' 35 | def make_conv_layers(repeats, in_dim, out_dim, make_layer=make_a_conv_relu_layer, relu_type="relu", expansion=1): 36 | layers = [make_layer(in_dim, out_dim, relu_type=relu_type)] # default 3x3@s1p1 with bn and relu 37 | for _ in range(1, repeats): 38 | layers.append(make_layer(out_dim, out_dim, relu_type=relu_type)) 39 | return nn.Sequential(*layers) 40 | 41 | 42 | ''' 43 | Make a residual block 44 | ''' 45 | class make_a_res_block(nn.Module): 46 | def __init__(self, in_dim, out_dim, stride=1, expansion=1): 47 | super(make_a_res_block, self).__init__() 48 | self._expansion = expansion 49 | # (2a) increase #channel using 1x1 conv or halve feature map size 50 | self.branch2a = make_a_conv_relu_layer(in_dim, out_dim, ksize=1, stride=stride, pad=0) 51 | # (2b) regular conv layer 52 | self.branch2b = make_a_conv_relu_layer(out_dim, out_dim) 53 | self._out_dim = out_dim * expansion 54 | # (2c) decrease #channel using 1x1 conv 55 | self.branch2c = make_a_conv_layer(out_dim, self._out_dim, ksize=1, pad=0) 56 | # (1) increase #channel or halve feature map size 57 | self.branch1 = make_a_conv_relu_layer(in_dim, self._out_dim, ksize=1, stride=stride, pad=0) \ 58 | if stride != 1 or in_dim != self._out_dim else nn.Sequential() 59 | self.relu = nn.ReLU(inplace=True) 60 | 61 | def forward(self, x): 62 | branch2 = self.branch2a(x) 63 | branch2 = self.branch2b(branch2) 64 | branch2 = self.branch2c(branch2) 65 | branch1 = self.branch1(x) 66 | y = self.relu(branch1 + branch2) 67 | return y 68 | 69 | 70 | ''' 71 | Make a basic residual block 72 | ''' 73 | class make_a_res_block_basic(nn.Module): 74 | def __init__(self, in_dim, out_dim, stride=1): 75 | super(make_a_res_block_basic, self).__init__() 76 | self.branch2a = make_a_conv_relu_layer(in_dim, out_dim, ksize=3, stride=stride) 77 | self.branch2b = make_a_conv_relu_layer(out_dim, out_dim, ksize=3) 78 | self.branch1 = make_a_conv_layer(in_dim, out_dim, ksize=1, stride=stride) \ 79 | if stride != 1 or in_dim != out_dim else nn.Sequential() 80 | self.relu = nn.ReLU(inplace=True) 81 | 82 | def forward(self, x): 83 | branch2 = self.branch2a(x) 84 | branch2 = self.branch2b(branch2) 85 | branch1 = self.branch1(x) 86 | y = self.relu(branch1 + branch2) 87 | return y 88 | 89 | 90 | ''' 91 | Make a sequence of residual layers 92 | ''' 93 | def make_res_layers(repeats, in_dim, out_dim, make_layer=make_a_res_block, expansion=1, reverse=False, relu_type="relu"): 94 | if reverse: 95 | layers = [] 96 | for _ in range(1, repeats): 97 | layers.append(make_layer(in_dim, in_dim, expansion=expansion)) 98 | layers.append(make_layer(in_dim, out_dim, expansion=expansion)) 99 | else: 100 | layers = [make_layer(in_dim, out_dim, expansion=expansion)] 101 | for _ in range(1, repeats): 102 | layers.append(make_layer(out_dim, out_dim, expansion=expansion)) 103 | return nn.Sequential(*layers) 104 | 105 | 106 | 107 | ''' 108 | Make an upsampling layer 109 | ''' 110 | class make_upsample_layer(nn.Module): 111 | def __init__(self): 112 | super(make_upsample_layer, self).__init__() 113 | 114 | def forward(self, x): 115 | return nn.functional.interpolate(x, scale_factor=2) 116 | 117 | 118 | ''' 119 | Make a transposed conv layer 120 | ''' 121 | def make_up_layer(layer_type="upsample", in_dims=0, out_dims=0): 122 | if layer_type == "upsample": 123 | return make_upsample_layer() 124 | elif layer_type == "transconv": 125 | return nn.ConvTranspose2d(in_dims, out_dims, kernel_size=2, stride=2) 126 | 127 | 128 | ''' 129 | Make a pooling layer 130 | ''' 131 | def make_pool_layer(pool_type='avg'): 132 | if pool_type == 'agv': 133 | return nn.AvgPool2d(2, stride=2) 134 | elif pool_type == 'max': 135 | return nn.MaxPool2d(2, stride=2, ceil_mode=False) 136 | 137 | 138 | ''' 139 | Make a merge layer 140 | ''' 141 | class mergeup(nn.Module): 142 | def __init__(self, merge_type): 143 | super(mergeup, self).__init__() 144 | self.merge_type = merge_type 145 | 146 | def forward(self, up1, up2): 147 | if self.merge_type == "add": 148 | return up1 + up2 149 | elif self.merge_type == "prod": 150 | return up1 * up2 151 | elif self.merge_type == "concat": 152 | return torch.cat((up1, up2), dim=1) 153 | 154 | 155 | def init_weights(net, init_type='normal', init_gain=0.02): 156 | """Initialize network weights. 157 | Parameters: 158 | net (network) -- network to be initialized 159 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 160 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 161 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might 162 | work better for some applications. Feel free to try yourself. 163 | """ 164 | def init_func(m): # define the initialization function 165 | classname = m.__class__.__name__ 166 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 167 | if init_type == 'normal': 168 | nn.init.normal_(m.weight.data, 0.0, init_gain) 169 | elif init_type == 'xavier': 170 | nn.init.xavier_normal_(m.weight.data, gain=init_gain) 171 | elif init_type == 'kaiming': 172 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 173 | elif init_type == 'orthogonal': 174 | nn.init.orthogonal_(m.weight.data, gain=init_gain) 175 | else: 176 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 177 | if hasattr(m, 'bias') and m.bias is not None: 178 | nn.init.constant_(m.bias.data, 0.0) 179 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 180 | nn.init.normal_(m.weight.data, 1.0, init_gain) 181 | nn.init.constant_(m.bias.data, 0.0) 182 | 183 | print('initialize network with %s' % init_type) 184 | net.apply(init_func) # apply the initialization function 185 | 186 | 187 | def set_requires_grad(nets, requires_grad=False): 188 | if not isinstance(nets, list): 189 | nets = [nets] 190 | for net in nets: 191 | if net is not None: 192 | for param in net.parameters(): 193 | param.requires_grad = requires_grad -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.12.0 2 | numpy>=1.21.5 3 | opencv-python>=4.5.1.48 4 | scikit-learn>=1.0.2 5 | scipy>=1.7.3 6 | scikit-image>=0.19.3 7 | Pillow>=9.2.0 8 | torchvision>=0.13.0 9 | matplotlib>=3.5.3 10 | panel>=0.14.4 11 | gudhi>=3.8.0 12 | hausdorff>=0.2.6 13 | requests>=2.28.1 -------------------------------------------------------------------------------- /train_hrnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | 5 | os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" 6 | import sys 7 | import shutil 8 | import numpy as np 9 | import time 10 | import torch 11 | import torch.nn as nn 12 | import torchvision 13 | 14 | root_dir = os.path.abspath(os.path.dirname(__file__)) 15 | sys.path.append(root_dir) 16 | sys.path.append(os.path.join(root_dir, "datasets")) 17 | sys.path.append(os.path.join(root_dir, "models")) 18 | 19 | from datasets.dataset import build_data_loader 20 | from models.hrnet import HRNetV2 21 | from models.optimize import create_criterion, create_optimizer, update_learning_rate, warmup_learning_rate 22 | from datasets.metric import * 23 | from hausdorff import hausdorff_distance 24 | 25 | print("PyTorch Version: ", torch.__version__) 26 | print("Torchvision Version: ", torchvision.__version__) 27 | 28 | 29 | def print_table(data): 30 | col_width = [max(len(item) for item in col) for col in data] 31 | for row_idx in range(len(data[0])): 32 | for col_idx, col in enumerate(data): 33 | item = col[row_idx] 34 | align = '<' if not col_idx == 0 else '>' 35 | print(('{:' + align + str(col_width[col_idx]) + '}').format(item), end=" ") 36 | print() 37 | 38 | 39 | def train_one_epoch(epoch, total_steps, dataloader, model, 40 | device, criterion, optimizer, lr, 41 | display_iter, log_file, warmup_step, warmup_method): 42 | model.train() 43 | smooth_loss = 0.0 44 | current_step = 0 45 | t0 = 0.0 46 | 47 | for inputs in dataloader: 48 | t1 = time.time() 49 | 50 | images = inputs['image'].to(device) 51 | labels = inputs['mask'].to(device) 52 | 53 | # forward pass 54 | seg = model(images) 55 | 56 | # compute loss 57 | loss = criterion(seg[0], labels) 58 | 59 | # predictions 60 | t0 += (time.time() - t1) 61 | 62 | total_steps += 1 63 | current_step += 1 64 | smooth_loss += loss.item() 65 | 66 | # backpropagate when training 67 | optimizer.zero_grad() 68 | # lr_update = update_learning_rate(optimizer, epoch, lr, step=lr_decay) 69 | lr_update = warmup_learning_rate(optimizer, total_steps, warmup_step, lr, warmup_method) 70 | loss.backward() 71 | optimizer.step() 72 | 73 | # torch.cuda.empty_cache() 74 | 75 | # log_loss 76 | if total_steps % display_iter == 0: 77 | smooth_loss = smooth_loss / current_step 78 | message = "Epoch: %d Step: %d LR: %.6f Loss: %.4f Runtime: %.2fs/%diters." % ( 79 | epoch + 1, total_steps, lr_update, smooth_loss, t0, display_iter) 80 | print("==> %s" % (message)) 81 | with open(log_file, "a+") as fid: 82 | fid.write('%s\n' % message) 83 | t0 = 0.0 84 | current_step = 0 85 | smooth_loss = 0.0 86 | return total_steps 87 | 88 | 89 | def eval_one_epoch(epoch, dataloader, model, device, log_file, threshold): 90 | with torch.no_grad(): 91 | model.eval() 92 | 93 | total_iou = 0.0 94 | total_f1 = 0.0 95 | total_acc = 0.0 96 | total_img = 0 97 | 98 | for inputs in dataloader: 99 | images = inputs['image'].to(device) 100 | labels = inputs['mask'] 101 | 102 | total_img += len(images) 103 | 104 | outputs_list = model(images) 105 | outputs = outputs_list[0] 106 | 107 | preds = outputs > threshold 108 | preds = preds.cpu() 109 | 110 | val_acc = acc(preds, labels) 111 | total_acc += val_acc 112 | 113 | val_iou = IoU(preds, labels) 114 | total_iou += val_iou 115 | 116 | val_f1 = F1_score(preds, labels) 117 | total_f1 += val_f1 118 | # iou 119 | epoch_iou = total_iou / total_img 120 | epoch_f1 = total_f1 / total_img 121 | epoch_acc = total_acc / total_img 122 | 123 | message = "total Threshold: {:.3f} =====> Evaluation IOU: {:.4f}; F1_score: {:.4f}; ACC: {:.4f}".format( 124 | threshold, epoch_iou, epoch_f1, epoch_acc) 125 | print("==> %s" % (message)) 126 | 127 | with open(log_file, "a+") as fid: 128 | fid.write('%s\n' % message) 129 | 130 | # torch.cuda.empty_cache() 131 | return epoch_acc, epoch_iou, epoch_f1 132 | 133 | 134 | def train_eval_model(opts): 135 | num_epochs = opts["num_epochs"] 136 | train_batch_size = opts["train_batch_size"] 137 | val_batch_size = opts["eval_batch_size"] 138 | dataset_type = opts['dataset_type'] 139 | dataset_name = opts['dataset_name'] 140 | 141 | opti_mode = opts["optimizer"] 142 | loss_criterion = opts["loss_criterion"] 143 | lr = opts["lr"] 144 | wd = opts["weight_decay"] 145 | warmup_step = opts["warmup_step"] 146 | warmup_method = opts["warmup_method"] 147 | 148 | gpus = opts["gpu_list"].split(',') 149 | os.environ['CUDA_VISIBLE_DEVICE'] = opts["gpu_list"] 150 | train_dir = opts["log_dir"] 151 | fractal_dir = opts["fractal_dir"] 152 | 153 | train_data_dir = opts["train_data_dir"] 154 | eval_data_dir = opts["eval_data_dir"] 155 | 156 | pretrained = opts["pretrained_model"] 157 | resume = opts["resume"] 158 | display_iter = opts["display_iter"] 159 | threshold = opts["threshold"] 160 | 161 | save_epoch = opts["save_every_epoch"] 162 | 163 | log_file = os.path.join(train_dir, "log_file.txt") 164 | os.makedirs(train_dir, exist_ok=True) 165 | model_dir = os.path.join(train_dir, "code_backup") 166 | os.makedirs(model_dir, exist_ok=True) 167 | if resume is None and os.path.exists(log_file): os.remove(log_file) 168 | shutil.copy("./models/hrnet.py", os.path.join(model_dir, "hrnet.py")) 169 | shutil.copy("./train_hrnet.py", os.path.join(model_dir, "train_hrnet.py")) 170 | shutil.copy("./datasets/dataset.py", os.path.join(model_dir, "dataset.py")) 171 | 172 | ckt_dir = os.path.join(train_dir, "checkpoints") 173 | os.makedirs(ckt_dir, exist_ok=True) 174 | 175 | # format printing configs 176 | print("*" * 50) 177 | table_key = [] 178 | table_value = [] 179 | n = 0 180 | for key, value in opts.items(): 181 | table_key.append(key) 182 | table_value.append(str(value)) 183 | n += 1 184 | print_table([table_key, ["="] * n, table_value]) 185 | 186 | # format gpu list 187 | gpu_list = [] 188 | for str_id in gpus: 189 | id = int(str_id) 190 | gpu_list.append(id) 191 | 192 | # dataloader 193 | print("==> Create dataloader") 194 | dataloaders_dict = { 195 | "train": build_data_loader(dataset_name, train_data_dir, train_batch_size, dataset_type, is_train=True, 196 | fractal_dir=fractal_dir), 197 | "eval": build_data_loader(dataset_name, eval_data_dir, val_batch_size, dataset_type, is_train=False, 198 | fractal_dir=fractal_dir)} 199 | 200 | # define parameters of two networks 201 | print("==> Create network") 202 | 203 | num_classes = 1 204 | 205 | if 'fractal' in opts["dataset_type"]: 206 | if 'road' in opts["dataset_type"]: 207 | num_channels = 5 208 | else: 209 | num_channels = 3 210 | elif 'copy' in opts["dataset_type"]: 211 | if 'road' in opts["dataset_type"]: 212 | num_channels = 6 213 | else: 214 | num_channels = 3 215 | elif 'road' in opts["dataset_type"]: 216 | num_channels = 3 217 | else: 218 | num_channels = 1 219 | 220 | model = HRNetV2(n_channels=num_channels, n_class=num_classes) 221 | 222 | # loss layer 223 | criterion = create_criterion(criterion=loss_criterion) 224 | 225 | best_acc = 0.0 226 | best_iou = 0.0 227 | start_epoch = 0 228 | 229 | # load pretrained model 230 | if pretrained is not None and os.path.isfile(pretrained): 231 | print("==> Train from model '{}'".format(pretrained)) 232 | checkpoint_gan = torch.load(pretrained) 233 | model.load_state_dict(checkpoint_gan['model_state_dict']) 234 | print("==> Loaded checkpoint '{}')".format(pretrained)) 235 | for param in model.parameters(): 236 | param.requires_grad = False 237 | 238 | # resume training 239 | elif resume is not None and os.path.isfile(resume): 240 | print("==> Resume from checkpoint '{}'".format(resume)) 241 | checkpoint = torch.load(resume) 242 | start_epoch = checkpoint['epoch'] + 1 243 | best_acc = checkpoint['best_acc'] 244 | model_dict = model.state_dict() 245 | pretrained_dict = {k: v for k, v in checkpoint['model_state_dict'].items() if 246 | k in model_dict and v.size() == model_dict[k].size()} 247 | model_dict.update(pretrained_dict) 248 | model.load_state_dict(pretrained_dict) 249 | print("==> Loaded checkpoint '{}' (epoch {})".format(resume, checkpoint['epoch'] + 1)) 250 | 251 | # train from scratch 252 | else: 253 | print("==> Train from initial or random state.") 254 | 255 | # define mutiple-gpu mode 256 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 257 | model.cuda() 258 | hrnet_model = nn.DataParallel(model) 259 | 260 | # print learnable parameters 261 | print("==> List learnable parameters") 262 | for name, param in model.named_parameters(): 263 | if param.requires_grad == True: 264 | print("\t{}, size {}".format(name, param.size())) 265 | params_to_update = [{'params': model.parameters()}] 266 | 267 | # define optimizer 268 | print("==> Create optimizer") 269 | optimizer = create_optimizer(params_to_update, opti_mode, lr=lr, momentum=0.9, wd=wd) 270 | if resume is not None and os.path.isfile(resume): optimizer.load_state_dict(checkpoint['optimizer']) 271 | 272 | # start training 273 | since = time.time() 274 | 275 | # Each epoch has a training and validation phase 276 | print("==> Start training") 277 | total_steps = 0 278 | 279 | for epoch in range(start_epoch, num_epochs): 280 | 281 | print('-' * 50) 282 | print("==> Epoch {}/{}".format(epoch + 1, num_epochs)) 283 | 284 | total_steps = train_one_epoch(epoch, total_steps, 285 | dataloaders_dict['train'], 286 | hrnet_model, device, 287 | criterion, optimizer, lr, 288 | display_iter, log_file, warmup_step, warmup_method) 289 | 290 | epoch_acc, epoch_iou, epoch_f1 = eval_one_epoch(epoch, dataloaders_dict['eval'], hrnet_model, 291 | device, log_file, threshold) 292 | 293 | if best_iou < epoch_iou and epoch >= 3: 294 | best_iou = epoch_iou 295 | torch.save({'epoch': epoch, 296 | 'model_state_dict': hrnet_model.module.state_dict(), 297 | 'optimizer': optimizer.state_dict(), 298 | 'best_acc': best_acc}, 299 | os.path.join(ckt_dir, "best.pth")) 300 | 301 | if (epoch + 1) % save_epoch == 0 and (epoch + 1) >= 25: 302 | torch.save({'epoch': epoch, 303 | 'model_state_dict': hrnet_model.module.state_dict(), 304 | 'optimizer': optimizer.state_dict(), 305 | 'best_acc': epoch_acc, 306 | 'best_f1': epoch_f1, 307 | 'best_iou': epoch_iou}, 308 | os.path.join(ckt_dir, "checkpoints_" + str(epoch + 1) + ".pth")) 309 | time_elapsed = time.time() - since 310 | time_message = 'Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60) 311 | print(time_message) 312 | with open(log_file, "a+") as fid: 313 | fid.write('%s\n' % time_message) 314 | print('==> Best val Iou: {:4f}'.format(best_iou)) 315 | with open(log_file, "a+") as fid: 316 | fid.write('==> Best val Iou: {:4f}'.format(best_iou)) 317 | 318 | 319 | if __name__ == '__main__': 320 | dataset_name = ['ER', 'MITO', 'ROSE', 'STARE', 'ROAD', 'NUCLEUS'] 321 | dataset_list = ['er', 'er_fractal'] 322 | date = '20240312' 323 | 324 | opts = dict() 325 | opts['dataset_type'] = 'er_fractal' 326 | opts["dataset_name"] = 'ER' 327 | opts["model_type"] = 'HRNet' 328 | opts["fractal_dir"] = 'Fractal_info_5' 329 | opts["num_epochs"] = 50 330 | opts["train_data_dir"] = "./dataset_txts/train_er.txt" 331 | opts["eval_data_dir"] = "./dataset_txts/test_er.txt" 332 | opts["train_batch_size"] = 32 333 | opts["eval_batch_size"] = 32 334 | opts["optimizer"] = "SGD" 335 | opts["loss_criterion"] = "iou" 336 | opts["threshold"] = 0.3 337 | opts["lr"] = 0.05 338 | opts["warmup_step"] = 1000 339 | opts["warmup_method"] = 'exp' 340 | opts["weight_decay"] = 0.0005 341 | opts["gpu_list"] = "0,1,2,3" 342 | 343 | log_dir = "./train_logs/" + str(opts["dataset_type"]) + "_" + opts["model_type"] + "_" + \ 344 | opts["loss_criterion"] + "_" + str(opts["train_batch_size"]) + '_' + str(opts["lr"]) + \ 345 | '_' + str(opts["num_epochs"]) + '_' + str(opts["threshold"]) + '_' + str( 346 | opts["warmup_step"]) + '_' + date + '_warmup_' + opts["fractal_dir"] 347 | 348 | opts["log_dir"] = log_dir 349 | opts["pretrained_model"] = None 350 | opts["resume"] = None 351 | opts["display_iter"] = 10 352 | opts["save_every_epoch"] = 5 353 | 354 | train_eval_model(opts) 355 | -------------------------------------------------------------------------------- /train_mdnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | 5 | os.environ["CUDA_VISIBLE_DEVICES"] = "2,3" 6 | 7 | import sys 8 | import shutil 9 | import time 10 | import torch.nn as nn 11 | import matplotlib.pyplot as plt 12 | 13 | plt.switch_backend('agg') 14 | root_dir = os.path.abspath(os.path.dirname(__file__)) 15 | sys.path.append(root_dir) 16 | sys.path.append(os.path.join(root_dir, "datasets")) 17 | sys.path.append(os.path.join(root_dir, "models")) 18 | sys.path.append(os.path.join(root_dir, "optim")) 19 | 20 | from datasets.dataset import build_data_loader 21 | from models.utils import init_weights 22 | from models.optimize import create_criterion, create_optimizer, update_learning_rate, warmup_learning_rate 23 | from models.md_net import Multi_decoder_Net, Two_decoder_Net 24 | from datasets.metric import * 25 | 26 | print("PyTorch Version: ", torch.__version__) 27 | 28 | 29 | def FrobeniusNorm(input): # [b,c,h,w] 30 | b, c, h, w = input.size() 31 | triu = torch.eye(h).cuda() 32 | triu = triu.unsqueeze(0).unsqueeze(0) 33 | triu = triu.repeat(b, c, 1, 1) 34 | 35 | x = torch.matmul(input, input.transpose(-2, -1)) 36 | tr = torch.mul(x, triu) 37 | y = torch.sum(tr) 38 | return y 39 | 40 | 41 | def print_table(data): 42 | col_width = [max(len(item) for item in col) for col in data] 43 | for row_idx in range(len(data[0])): 44 | for col_idx, col in enumerate(data): 45 | item = col[row_idx] 46 | align = '<' if not col_idx == 0 else '>' 47 | print(('{:' + align + str(col_width[col_idx]) + '}').format(item), end=" ") 48 | print() 49 | 50 | 51 | def gmm_loss(label, prd, mu_f, mu_b, std_f, std_b, f_k): 52 | b_k = 1 - f_k 53 | 54 | f_likelihood = - f_k * ( 55 | torch.log(np.sqrt(2 * 3.14) * std_f) + torch.pow((prd - mu_f), 2) / (2 * torch.pow(std_f, 2)) + 1e-10) 56 | b_likelihood = - b_k * ( 57 | torch.log(np.sqrt(2 * 3.14) * std_b) + torch.pow((prd - mu_b), 2) / (2 * torch.pow(std_b, 2)) + 1e-10) 58 | likelihood = f_likelihood + b_likelihood 59 | loss = torch.mean(torch.pow(label - torch.exp(likelihood), 2)) 60 | return loss 61 | 62 | 63 | def train_one_epoch(epoch, model_type, total_steps, dataloader, model, 64 | device, criterion_iou, criterion_bce, optimizer, lr, 65 | display_iter, log_file, warmup_step, warmup_method): 66 | model.train() 67 | 68 | smooth_loss = 0.0 69 | current_step = 0 70 | t0 = 0.0 71 | 72 | for inputs in dataloader: 73 | 74 | t1 = time.time() 75 | 76 | images = inputs['image'].to(device) 77 | labels = inputs['mask'].to(device) 78 | skeletons = inputs['skeleton'].to(device) 79 | edges = inputs['edge'].to(device) 80 | 81 | # forward pass 82 | pred, pred_skeleton, pred_edge = model(images) 83 | 84 | # compute loss 85 | loss1 = criterion_iou(pred, labels) 86 | loss2 = criterion_bce(pred_skeleton, skeletons) 87 | loss3 = criterion_bce(pred_edge, edges) 88 | loss = loss1 + 0.5 * loss2 + 0.5 * loss3 89 | 90 | # predictions 91 | t0 += (time.time() - t1) 92 | 93 | total_steps += 1 94 | current_step += 1 95 | smooth_loss += loss.item() 96 | 97 | # backpropagate when training 98 | optimizer.zero_grad() 99 | lr_update = warmup_learning_rate(optimizer, total_steps, warmup_step, lr, warmup_method) 100 | # lr_update = update_learning_rate(optimizer, epoch, lr, step=lr_decay) 101 | loss.backward() 102 | # loss.backward(retain_graph = True) 103 | optimizer.step() 104 | 105 | # log loss 106 | if total_steps % display_iter == 0: 107 | smooth_loss = smooth_loss / current_step 108 | message = "Epoch: %d Step: %d LR: %.6f Loss: %.4f Runtime: %.2fs/%diters." % ( 109 | epoch + 1, total_steps, lr_update, smooth_loss, t0, display_iter) 110 | print("==> %s" % (message)) 111 | with open(log_file, "a+") as fid: 112 | fid.write('%s\n' % message) 113 | 114 | t0 = 0.0 115 | current_step = 0 116 | smooth_loss = 0.0 117 | 118 | return total_steps 119 | 120 | 121 | def eval_one_epoch(epoch, model_type, threshold, dataloader, model, device, log_file): 122 | with torch.no_grad(): 123 | model.eval() 124 | 125 | total_iou = 0.0 126 | total_f1 = 0.0 127 | # total_distance = 0.0 128 | total_acc = 0.0 129 | total_img = 0 130 | 131 | for inputs in dataloader: 132 | images = inputs['image'].to(device) 133 | labels = inputs['mask'] 134 | 135 | total_img += len(images) 136 | outputs, pred_skeleton, pred_edge = model(images) 137 | preds = outputs > threshold 138 | 139 | preds = preds.cpu() 140 | 141 | # metric 142 | val_acc = acc(preds, labels) 143 | 144 | total_acc += val_acc 145 | 146 | val_iou = IoU(preds, labels) 147 | total_iou += val_iou 148 | 149 | val_f1 = F1_score(preds, labels) 150 | total_f1 += val_f1 151 | 152 | # iou 153 | epoch_iou = total_iou / total_img 154 | epoch_f1 = total_f1 / total_img 155 | epoch_acc = total_acc / total_img 156 | 157 | message = "total Threshold: {:.3f} =====> Evaluation IOU: {:.4f}; F1_score: {:.4f}; Acc: {:.4f}".format( 158 | threshold, epoch_iou, epoch_f1, epoch_acc) 159 | print("==> %s" % (message)) 160 | with open(log_file, "a+") as fid: 161 | fid.write('%s\n' % message) 162 | 163 | return epoch_acc, epoch_iou, epoch_f1 164 | 165 | 166 | def train_eval_model(opts): 167 | # parse model configuration 168 | num_epochs = opts["num_epochs"] 169 | train_batch_size = opts["train_batch_size"] 170 | val_batch_size = opts["eval_batch_size"] 171 | dataset_type = opts["dataset_type"] 172 | dataset_name = opts['dataset_name'] 173 | model_type = opts["model_type"] 174 | warmup_step = opts["warmup_step"] 175 | warmup_method = opts["warmup_method"] 176 | 177 | opti_mode = opts["optimizer"] 178 | loss_criterion = opts["loss_criterion"] 179 | lr = opts["lr"] 180 | wd = opts["weight_decay"] 181 | 182 | gpus = opts["gpu_list"].split(',') 183 | os.environ['CUDA_VISIBLE_DEVICE'] = opts["gpu_list"] 184 | train_dir = opts["log_dir"] 185 | 186 | train_data_dir = opts["train_data_dir"] 187 | eval_data_dir = opts["eval_data_dir"] 188 | 189 | pretrained = opts["pretrained_model"] 190 | resume = opts["resume"] 191 | display_iter = opts["display_iter"] 192 | save_epoch = opts["save_every_epoch"] 193 | fractal_dir = opts['fractal_dir'] 194 | edge_dir = opts["edge_dir"] 195 | skeleton_dir = opts["skeleton_dir"] 196 | 197 | # backup train configs 198 | log_file = os.path.join(train_dir, "log_file.txt") 199 | os.makedirs(train_dir, exist_ok=True) 200 | model_dir = os.path.join(train_dir, "code_backup") 201 | os.makedirs(model_dir, exist_ok=True) 202 | if resume is None and os.path.exists(log_file): os.remove(log_file) 203 | shutil.copy("models/md_net.py", os.path.join(model_dir, "md_net.py")) 204 | shutil.copy("train_mdnet.py", os.path.join(model_dir, "train_mdnet.py")) 205 | shutil.copy("./datasets/dataset.py", os.path.join(model_dir, "dataset.py")) 206 | 207 | ckt_dir = os.path.join(train_dir, "checkpoints") 208 | os.makedirs(ckt_dir, exist_ok=True) 209 | 210 | # format printing configs 211 | print("*" * 50) 212 | table_key = [] 213 | table_value = [] 214 | n = 0 215 | for key, value in opts.items(): 216 | table_key.append(key) 217 | table_value.append(str(value)) 218 | n += 1 219 | print_table([table_key, ["="] * n, table_value]) 220 | 221 | # format gpu list 222 | gpu_list = [] 223 | for str_id in gpus: 224 | id = int(str_id) 225 | gpu_list.append(id) 226 | 227 | # dataloader 228 | print("==> Create dataloader") 229 | dataloaders_dict = { 230 | "train": build_data_loader(dataset_name, train_data_dir, train_batch_size, dataset_type, is_train=True, 231 | fractal_dir=fractal_dir, edge_dir=edge_dir, skeleton_dir=skeleton_dir), 232 | "eval": build_data_loader(dataset_name, eval_data_dir, val_batch_size, dataset_type, is_train=False, 233 | fractal_dir=fractal_dir)} 234 | 235 | # define parameters of two networks 236 | print("==> Create network") 237 | if 'fractal' in opts["dataset_type"]: 238 | if 'road' in opts["dataset_type"]: 239 | num_channels = 5 240 | else: 241 | num_channels = 3 242 | elif 'copy' in opts["dataset_type"]: 243 | if 'road' in opts["dataset_type"]: 244 | num_channels = 6 245 | else: 246 | num_channels = 3 247 | elif 'road' in opts["dataset_type"]: 248 | num_channels = 3 249 | else: 250 | num_channels = 1 251 | 252 | num_classes = 1 253 | 254 | model = Multi_decoder_Net(num_channels, num_classes) 255 | 256 | init_weights(model) 257 | 258 | # loss layer 259 | criterion_iou = create_criterion(criterion=loss_criterion) 260 | criterion_bce = create_criterion(criterion="bce") 261 | 262 | best_acc = 0.0 263 | start_epoch = 0 264 | 265 | # load pretrained model 266 | if pretrained is not None and os.path.isfile(pretrained): 267 | print("==> Train from model '{}'".format(pretrained)) 268 | checkpoint_gan = torch.load(pretrained) 269 | model.load_state_dict(checkpoint_gan['model_state_dict']) 270 | print("==> Loaded checkpoint '{}')".format(pretrained)) 271 | for param in model.parameters(): 272 | param.requires_grad = False 273 | 274 | # resume training 275 | elif resume is not None and os.path.isfile(resume): 276 | print("==> Resume from checkpoint '{}'".format(resume)) 277 | checkpoint = torch.load(resume) 278 | start_epoch = checkpoint['epoch'] + 1 279 | best_acc = checkpoint['best_acc'] 280 | model_dict = model.state_dict() 281 | pretrained_dict = {k: v for k, v in checkpoint['model_state_dict'].items() if 282 | k in model_dict and v.size() == model_dict[k].size()} 283 | model_dict.update(pretrained_dict) 284 | model.load_state_dict(pretrained_dict) 285 | print("==> Loaded checkpoint '{}' (epoch {})".format(resume, checkpoint['epoch'] + 1)) 286 | 287 | # train from scratch 288 | else: 289 | print("==> Train from initial or random state.") 290 | 291 | # define mutiple-gpu mode 292 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 293 | model.cuda() 294 | model = nn.DataParallel(model) 295 | 296 | # print learnable parameters 297 | print("==> List learnable parameters") 298 | for name, param in model.named_parameters(): 299 | if param.requires_grad == True: 300 | print("\t{}, size {}".format(name, param.size())) 301 | params_to_update = [{'params': model.parameters()}] 302 | 303 | # define optimizer 304 | print("==> Create optimizer") 305 | optimizer = create_optimizer(params_to_update, opti_mode, lr=lr, momentum=0.9, wd=wd) 306 | if resume is not None and os.path.isfile(resume): optimizer.load_state_dict(checkpoint['optimizer']) 307 | 308 | # start training 309 | since = time.time() 310 | 311 | # Each epoch has a training and validation phase 312 | print("==> Start training") 313 | total_steps = 0 314 | threshold = opts["threshold"] 315 | epochs = [] 316 | ious = [] 317 | best_iou = 0.0 318 | for epoch in range(start_epoch, num_epochs): 319 | print('-' * 50) 320 | print("==> Epoch {}/{}".format(epoch + 1, num_epochs)) 321 | 322 | total_steps = train_one_epoch(epoch, model_type, total_steps, 323 | dataloaders_dict['train'], 324 | model, device, 325 | criterion_iou, criterion_bce, optimizer, lr, 326 | display_iter, log_file, warmup_step, warmup_method) 327 | 328 | epoch_acc, epoch_iou, epoch_f1 = eval_one_epoch(epoch, model_type, threshold, dataloaders_dict['eval'], 329 | model, device, log_file) 330 | epochs.append(epoch) 331 | ious.append(epoch_iou) 332 | 333 | if best_iou < epoch_iou and epoch >= 5: 334 | best_iou = epoch_iou 335 | best_acc = epoch_acc 336 | torch.save({'epoch': epoch, 337 | 'model_state_dict': model.module.state_dict(), 338 | 'optimizer': optimizer.state_dict(), 339 | 'best_acc': best_acc}, 340 | os.path.join(ckt_dir, "best.pth")) 341 | 342 | if (epoch + 1) % save_epoch == 0 and (epoch + 1) >= 30: 343 | torch.save({'epoch': epoch, 344 | 'model_state_dict': model.module.state_dict(), 345 | 'optimizer': optimizer.state_dict(), 346 | 'best_iou': epoch_iou}, 347 | os.path.join(ckt_dir, "checkpoints_" + str(epoch + 1) + ".pth")) 348 | 349 | time_elapsed = time.time() - since 350 | time_message = 'Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60) 351 | print(time_message) 352 | plt.figure(figsize=(10, 10)) 353 | plt.plot(epochs, ious) 354 | plt.ylim(0, 0.9) 355 | # set the label of x and y 356 | plt.xlabel("epoch") 357 | plt.ylabel("iou") 358 | plt.title("Train model= " + str(model_type) + "; lr=" + str(lr)) 359 | plt.legend() 360 | plt.savefig(os.path.join(train_dir, 'lr_' + str(lr) + '_train_iou.png')) 361 | with open(log_file, "a+") as fid: 362 | fid.write('%s\n' % time_message) 363 | fid.write('==> Best val Acc: {:4f}; Iou: {:4f}'.format(best_acc, best_iou)) 364 | print('==> Best val Acc: {:4f}; Iou: {:4f}'.format(best_acc, best_iou)) 365 | 366 | 367 | if __name__ == '__main__': 368 | dataset_names = ['ER', 'MITO', 'ROSE', 'STARE', 'ROAD', 'NUCLEUS'] 369 | dataset_list = ['er', 'er_fractal', 'er_fractal_three_decoder', 'er_fractal_three_decoder_weighted'] 370 | model_choice = ['Two_decoder_Net', 'Multi_decoder_Net'] 371 | date = '20240312' 372 | 373 | opts = dict() 374 | opts['dataset_type'] = 'er_fractal_three_decoder' 375 | opts["dataset_name"] = 'ER' 376 | opts["num_epochs"] = 50 377 | opts["fractal_dir"] = 'Fractal_info_5' 378 | opts["edge_dir"] = 'mask_edge' 379 | opts["skeleton_dir"] = 'mask_skeleton' 380 | opts["train_data_dir"] = "./dataset_txts/train_er.txt" 381 | opts["eval_data_dir"] = "./dataset_txts/test_er.txt" 382 | opts["train_batch_size"] = 32 383 | opts["eval_batch_size"] = 32 384 | opts["optimizer"] = "SGD" 385 | opts["model_type"] = "Multi_decoder_Net" 386 | opts["loss_criterion"] = "iou" 387 | opts["lr"] = 0.05 388 | opts["threshold"] = 0.3 389 | opts["warmup_step"] = 1000 390 | opts["warmup_method"] = 'exp' 391 | opts["weight_decay"] = 0.0005 392 | opts["gpu_list"] = "0,1,2,3" 393 | log_dir = "./train_logs/" + str(opts["dataset_type"]) + str(opts["model_type"]) + "_iou_bce_loss_one_half_half" + \ 394 | '_' + str(opts["train_batch_size"]) + '_' + str(opts["lr"]) + '_' + str(opts["num_epochs"]) + '_' + str( 395 | opts["threshold"]) + '_' + \ 396 | str(opts["warmup_step"]) + '_' + date + '_warmup_' + opts["fractal_dir"] 397 | opts["log_dir"] = log_dir 398 | opts["pretrained_model"] = None 399 | opts["resume"] = None 400 | opts["display_iter"] = 10 401 | opts["save_every_epoch"] = 5 402 | 403 | train_eval_model(opts) 404 | -------------------------------------------------------------------------------- /train_mdnet_weighted.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | 5 | os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" 6 | import sys 7 | import shutil 8 | import time 9 | import torch.nn as nn 10 | import matplotlib.pyplot as plt 11 | from multiprocessing import Pool 12 | 13 | plt.switch_backend('agg') 14 | root_dir = os.path.abspath(os.path.dirname(__file__)) 15 | sys.path.append(root_dir) 16 | sys.path.append(os.path.join(root_dir, "datasets")) 17 | sys.path.append(os.path.join(root_dir, "models")) 18 | sys.path.append(os.path.join(root_dir, "optim")) 19 | 20 | from datasets.dataset import build_data_loader 21 | from models.utils import init_weights 22 | from models.optimize import create_criterion, create_optimizer, update_learning_rate, warmup_learning_rate 23 | from models.md_net import Multi_decoder_Net, Two_decoder_Net 24 | 25 | from fractal_analysis import compute_FFM 26 | 27 | from datasets.metric import * 28 | 29 | print("PyTorch Version: ", torch.__version__) 30 | 31 | 32 | def FrobeniusNorm(input): # [b,c,h,w] 33 | b, c, h, w = input.size() 34 | triu = torch.eye(h).cuda() 35 | triu = triu.unsqueeze(0).unsqueeze(0) 36 | triu = triu.repeat(b, c, 1, 1) 37 | 38 | x = torch.matmul(input, input.transpose(-2, -1)) 39 | tr = torch.mul(x, triu) 40 | y = torch.sum(tr) 41 | return y 42 | 43 | 44 | def print_table(data): 45 | col_width = [max(len(item) for item in col) for col in data] 46 | for row_idx in range(len(data[0])): 47 | for col_idx, col in enumerate(data): 48 | item = col[row_idx] 49 | align = '<' if not col_idx == 0 else '>' 50 | print(('{:' + align + str(col_width[col_idx]) + '}').format(item), end=" ") 51 | print() 52 | 53 | 54 | def gmm_loss(label, prd, mu_f, mu_b, std_f, std_b, f_k): 55 | b_k = 1 - f_k 56 | 57 | f_likelihood = - f_k * ( 58 | torch.log(np.sqrt(2 * 3.14) * std_f) + torch.pow((prd - mu_f), 2) / (2 * torch.pow(std_f, 2)) + 1e-10) 59 | b_likelihood = - b_k * ( 60 | torch.log(np.sqrt(2 * 3.14) * std_b) + torch.pow((prd - mu_b), 2) / (2 * torch.pow(std_b, 2)) + 1e-10) 61 | likelihood = f_likelihood + b_likelihood 62 | loss = torch.mean(torch.pow(label - torch.exp(likelihood), 2)) 63 | return loss 64 | 65 | 66 | def train_one_epoch(epoch, model_type, total_steps, dataloader, model, 67 | device, criterion_weight_iou, criterion_soft_iou, criterion_bce, optimizer, lr, 68 | display_iter, log_file, warmup_step, warmup_method): 69 | model.train() 70 | 71 | smooth_loss = 0.0 72 | current_step = 0 73 | t0 = 0.0 74 | 75 | for inputs in dataloader: 76 | 77 | t1 = time.time() 78 | 79 | images = inputs['image'].to(device) 80 | # c_images = inputs['c_masks'].to(device) 81 | labels = inputs['mask'].to(device) 82 | skeletons = inputs['skeleton'].to(device) 83 | edges = inputs['edge'].to(device) 84 | 85 | weights = inputs['weight'].to(device) 86 | 87 | # forward pass 88 | pred, pred_skeleton, pred_edge = model(images) 89 | 90 | # compute loss 91 | loss1 = criterion_weight_iou(pred, labels, weights) 92 | loss2 = criterion_bce(pred_skeleton, skeletons) 93 | loss3 = criterion_bce(pred_edge, edges) 94 | loss = loss1 + 0.5 * loss2 + 0.5 * loss3 95 | 96 | # predictions 97 | t0 += (time.time() - t1) 98 | 99 | total_steps += 1 100 | current_step += 1 101 | smooth_loss += loss.item() 102 | 103 | # backpropagate when training 104 | optimizer.zero_grad() 105 | lr_update = warmup_learning_rate(optimizer, total_steps, warmup_step, lr, warmup_method) 106 | # lr_update = update_learning_rate(optimizer, epoch, lr, step=lr_decay) 107 | loss.backward() 108 | # loss.backward(retain_graph = True) 109 | optimizer.step() 110 | 111 | # log loss 112 | if total_steps % display_iter == 0: 113 | smooth_loss = smooth_loss / current_step 114 | message = "Epoch: %d Step: %d LR: %.6f Loss: %.4f Runtime: %.2fs/%diters." % ( 115 | epoch + 1, total_steps, lr_update, smooth_loss, t0, display_iter) 116 | print("==> %s" % (message)) 117 | with open(log_file, "a+") as fid: 118 | fid.write('%s\n' % message) 119 | 120 | t0 = 0.0 121 | current_step = 0 122 | smooth_loss = 0.0 123 | 124 | return total_steps 125 | 126 | 127 | def eval_one_epoch(epoch, model_type, threshold, dataloader, model, device, epoch_dir, log_file): 128 | with torch.no_grad(): 129 | model.eval() 130 | 131 | total_iou = 0.0 132 | total_f1 = 0.0 133 | # total_distance = 0.0 134 | total_acc = 0.0 135 | total_img = 0 136 | 137 | for inputs in dataloader: 138 | images = inputs['image'].to(device) 139 | labels = inputs['mask'] 140 | img_name = inputs['ID'] 141 | 142 | total_img += len(images) 143 | outputs, pred_skeleton, pred_edge = model(images) 144 | 145 | preds = outputs > threshold 146 | preds = preds.cpu() 147 | 148 | # metric 149 | val_acc = acc(preds, labels) 150 | total_acc += val_acc 151 | 152 | val_iou = IoU(preds, labels) 153 | total_iou += val_iou 154 | 155 | val_f1 = F1_score(preds, labels) 156 | total_f1 += val_f1 157 | 158 | # iou 159 | epoch_iou = total_iou / total_img 160 | epoch_f1 = total_f1 / total_img 161 | epoch_acc = total_acc / total_img 162 | 163 | message = "total Threshold: {:.3f} =====> Evaluation IOU: {:.4f}; F1_score: {:.4f}; Acc: {:.4f}".format( 164 | threshold, epoch_iou, epoch_f1, epoch_acc) 165 | print("==> %s" % (message)) 166 | with open(log_file, "a+") as fid: 167 | fid.write('%s\n' % message) 168 | 169 | return epoch_acc, epoch_iou, epoch_f1 170 | 171 | def train_eval_model(opts): 172 | # parse model configuration 173 | num_epochs = opts["num_epochs"] 174 | train_batch_size = opts["train_batch_size"] 175 | val_batch_size = opts["eval_batch_size"] 176 | dataset_type = opts["dataset_type"] 177 | dataset_name = opts['dataset_name'] 178 | model_type = opts["model_type"] 179 | warmup_step = opts["warmup_step"] 180 | warmup_method = opts["warmup_method"] 181 | 182 | opti_mode = opts["optimizer"] 183 | loss_criterion = opts["loss_criterion"] 184 | lr = opts["lr"] 185 | wd = opts["weight_decay"] 186 | step_size = opts["step_size"] 187 | window_size = opts["window_size"] 188 | update_d = opts["update_d"] 189 | 190 | gpus = opts["gpu_list"].split(',') 191 | os.environ['CUDA_VISIBLE_DEVICE'] = opts["gpu_list"] 192 | train_dir = opts["log_dir"] 193 | 194 | train_data_dir = opts["train_data_dir"] 195 | eval_data_dir = opts["eval_data_dir"] 196 | 197 | pretrained = opts["pretrained_model"] 198 | resume = opts["resume"] 199 | display_iter = opts["display_iter"] 200 | save_epoch = opts["save_every_epoch"] 201 | weight_dir = opts['weight_dir'] 202 | fractal_dir = opts['fractal_dir'] 203 | edge_dir = opts["edge_dir"] 204 | skeleton_dir= opts["skeleton_dir"] 205 | 206 | # backup train configs 207 | log_file = os.path.join(train_dir, "log_file.txt") 208 | os.makedirs(train_dir, exist_ok=True) 209 | model_dir = os.path.join(train_dir, "code_backup") 210 | os.makedirs(model_dir, exist_ok=True) 211 | infer_dir = os.path.join(train_dir, "inference") 212 | os.makedirs(infer_dir, exist_ok=True) 213 | if resume is None and os.path.exists(log_file): os.remove(log_file) 214 | shutil.copy("models/md_net.py", os.path.join(model_dir, "md_net.py")) 215 | shutil.copy("train_mdnet_weighted.py", os.path.join(model_dir, "train_mdnet_weighted.py")) 216 | shutil.copy("./datasets/dataset.py", os.path.join(model_dir, "dataset.py")) 217 | 218 | ckt_dir = os.path.join(train_dir, "checkpoints") 219 | os.makedirs(ckt_dir, exist_ok=True) 220 | 221 | # format printing configs 222 | print("*" * 50) 223 | table_key = [] 224 | table_value = [] 225 | n = 0 226 | for key, value in opts.items(): 227 | table_key.append(key) 228 | table_value.append(str(value)) 229 | n += 1 230 | print_table([table_key, ["="] * n, table_value]) 231 | 232 | # format gpu list 233 | gpu_list = [] 234 | for str_id in gpus: 235 | id = int(str_id) 236 | gpu_list.append(id) 237 | 238 | # dataloader 239 | print("==> Create dataloader") 240 | dataloaders_dict = { 241 | "train": build_data_loader(dataset_name, train_data_dir, train_batch_size, dataset_type, is_train=True, 242 | fractal_dir=fractal_dir,weight_dir=weight_dir,edge_dir=edge_dir,skeleton_dir=skeleton_dir, log_file=train_dir, update_d=update_d), 243 | "eval": build_data_loader(dataset_name, eval_data_dir, val_batch_size, dataset_type, is_train=False, 244 | fractal_dir=fractal_dir)} 245 | 246 | # define parameters of two networks 247 | print("==> Create network") 248 | 249 | if 'fractal' in opts["dataset_type"]: 250 | if 'road' in opts["dataset_type"]: 251 | num_channels = 5 252 | else: 253 | num_channels = 3 254 | elif 'copy' in opts["dataset_type"]: 255 | if 'road' in opts["dataset_type"]: 256 | num_channels = 6 257 | else: 258 | num_channels = 3 259 | elif 'road' in opts["dataset_type"]: 260 | num_channels = 3 261 | else: 262 | num_channels = 1 263 | 264 | num_classes = 1 265 | 266 | model = Multi_decoder_Net(num_channels, num_classes) 267 | init_weights(model) 268 | 269 | # loss layer 270 | criterion_weight_iou = create_criterion(criterion=loss_criterion) 271 | criterion_bce = create_criterion(criterion="bce") 272 | criterion_soft_iou = create_criterion(criterion="iou") 273 | 274 | best_acc = 0.0 275 | start_epoch = 0 276 | 277 | # load pretrained model 278 | if pretrained is not None and os.path.isfile(pretrained): 279 | print("==> Train from model '{}'".format(pretrained)) 280 | checkpoint_gan = torch.load(pretrained) 281 | model.load_state_dict(checkpoint_gan['model_state_dict']) 282 | print("==> Loaded checkpoint '{}')".format(pretrained)) 283 | for param in model.parameters(): 284 | param.requires_grad = False 285 | 286 | # resume training 287 | elif resume is not None and os.path.isfile(resume): 288 | print("==> Resume from checkpoint '{}'".format(resume)) 289 | checkpoint = torch.load(resume) 290 | start_epoch = checkpoint['epoch'] + 1 291 | best_acc = checkpoint['best_acc'] 292 | warmup_step = checkpoint['warmup_step'] 293 | lr = checkpoint['lr'] 294 | warmup_method = checkpoint['warmup_method'] 295 | total_steps = checkpoint['total_steps'] 296 | best_acc = checkpoint['best_acc'] 297 | model_dict = model.state_dict() 298 | pretrained_dict = {k: v for k, v in checkpoint['model_state_dict'].items() if 299 | k in model_dict and v.size() == model_dict[k].size()} 300 | model_dict.update(pretrained_dict) 301 | model.load_state_dict(pretrained_dict) 302 | print("==> Loaded checkpoint '{}' (epoch {})".format(resume, checkpoint['epoch'] + 1)) 303 | 304 | # train from scratch 305 | else: 306 | print("==> Train from initial or random state.") 307 | total_steps = 0 308 | 309 | # define mutiple-gpu mode 310 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 311 | model.cuda() 312 | model = nn.DataParallel(model) 313 | 314 | # print learnable parameters 315 | print("==> List learnable parameters") 316 | for name, param in model.named_parameters(): 317 | if param.requires_grad == True: 318 | print("\t{}, size {}".format(name, param.size())) 319 | params_to_update = [{'params': model.parameters()}] 320 | 321 | # define optimizer 322 | print("==> Create optimizer") 323 | optimizer = create_optimizer(params_to_update, opti_mode, lr=lr, momentum=0.9, wd=wd) 324 | if resume is not None and os.path.isfile(resume): 325 | optimizer.load_state_dict(checkpoint['optimizer']) 326 | 327 | # start training 328 | since = time.time() 329 | 330 | # Each epoch has a training and validation phase 331 | print("==> Start training") 332 | 333 | threshold = opts["threshold"] 334 | epochs = [] 335 | ious = [] 336 | best_iou = 0.0 337 | for epoch in range(start_epoch, num_epochs): 338 | dataloaders_dict['train'].dataset.epoch = epoch 339 | print('-' * 50) 340 | print("==> Epoch {}/{}".format(epoch + 1, num_epochs)) 341 | 342 | total_steps = train_one_epoch(epoch, model_type, total_steps, 343 | dataloaders_dict['train'], 344 | model, device, 345 | criterion_weight_iou, criterion_soft_iou, criterion_bce, optimizer, lr, 346 | display_iter, log_file, warmup_step, warmup_method) 347 | 348 | epoch_acc, epoch_iou, epoch_f1 = eval_one_epoch(epoch, model_type, threshold, dataloaders_dict['eval'], 349 | model, device, epoch_dir, log_file) 350 | epochs.append(epoch) 351 | ious.append(epoch_iou) 352 | 353 | if best_iou < epoch_iou and (epoch + 1) >= 5: 354 | best_iou = epoch_iou 355 | best_acc = epoch_acc 356 | torch.save({'epoch': epoch, 357 | 'model_state_dict': model.module.state_dict(), 358 | 'optimizer': optimizer.state_dict(), 359 | 'best_acc': best_acc}, 360 | os.path.join(ckt_dir, "best.pth")) 361 | 362 | if (epoch + 1) % save_epoch == 0 and (epoch + 1) >= 5: 363 | torch.save({'epoch': epoch, 364 | 'model_state_dict': model.module.state_dict(), 365 | 'optimizer': optimizer.state_dict(), 366 | 'best_iou': epoch_iou, 367 | 'warmup_step': warmup_step, 368 | 'warmup_method': warmup_method, 369 | 'lr': lr, 370 | 'total_steps': total_steps}, 371 | os.path.join(ckt_dir, "checkpoints_" + str(epoch + 1) + ".pth")) 372 | 373 | time_elapsed = time.time() - since 374 | time_message = 'Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60) 375 | print(time_message) 376 | plt.figure(figsize=(10, 10)) 377 | plt.plot(epochs, ious) 378 | plt.ylim(0, 0.9) 379 | # set the label of x and y 380 | plt.xlabel("epoch") 381 | plt.ylabel("iou") 382 | plt.title("Train model= " + str(model_type) + "; lr=" + str(lr)) 383 | plt.legend() 384 | plt.savefig(os.path.join(train_dir, 'lr_' + str(lr) + '_train_iou.png')) 385 | with open(log_file, "a+") as fid: 386 | fid.write('%s\n' % time_message) 387 | fid.write('==> Best val Acc: {:4f}; Iou: {:4f}'.format(best_acc, best_iou)) 388 | print('==> Best val Acc: {:4f}; Iou: {:4f}'.format(best_acc, best_iou)) 389 | 390 | 391 | if __name__ == '__main__': 392 | dataset_names = ['ER', 'MITO', 'ROSE', 'STARE', 'ROAD', 'NUCLEUS'] 393 | dataset_list = ['er', 'er_fractal', 'er_fractal_three_decoder', 'er_fractal_three_decoder_weighted'] 394 | model_choice = ['Multi_decoder_Net'] 395 | date = '20240312' 396 | 397 | opts = dict() 398 | opts['dataset_type'] = 'er_fractal_three_decoder_weighted' 399 | opts["dataset_name"] = 'ER' 400 | opts["num_epochs"] = 50 401 | opts["fractal_dir"] = 'Fractal_info_5' 402 | opts["weight_dir"] = 'Weight_5_1' 403 | opts["edge_dir"] = 'mask_edge' 404 | opts["skeleton_dir"] = 'mask_skeleton' 405 | opts["train_data_dir"] = "./dataset_txts/train_er.txt" 406 | opts["eval_data_dir"] = "./dataset_txts/test_er.txt" 407 | opts["train_batch_size"] = 32 408 | opts["eval_batch_size"] = 32 409 | opts["optimizer"] = "SGD" 410 | opts["model_type"] = "Multi_decoder_Net" 411 | opts["loss_criterion"] = "weighted_soft_iou_loss" 412 | opts["lr"] = 0.05 413 | opts["step_size"] = 3 414 | opts["window_size"] = 5 415 | opts["update_d"] = 5 416 | opts["threshold"] = 0.3 417 | opts["warmup_step"] = 1000 418 | opts["warmup_method"] = 'exp' 419 | opts["weight_decay"] = 0.0005 420 | opts["gpu_list"] = "0,1,2,3" 421 | log_dir = "./train_logg/" + str( 422 | opts["dataset_type"]) + "_train_two_iou_bce_" + opts["model_type"] + '_' + opts["loss_criterion"] + '_' + str( 423 | opts["train_batch_size"]) + '_' + str(opts["lr"]) + '_' + str(opts["num_epochs"]) + '_' + str( 424 | opts["threshold"]) + '_' + str(opts["warmup_step"]) + '_' + date + '_' + str( 425 | opts["weight_decay"]) + '_warmup_' + opts["fractal_dir"] 426 | opts["log_dir"] = log_dir 427 | opts["pretrained_model"] = None 428 | opts["resume"] = None 429 | opts["display_iter"] = 10 430 | opts["save_every_epoch"] = 5 431 | 432 | train_eval_model(opts) 433 | -------------------------------------------------------------------------------- /train_tdnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | 5 | os.environ["CUDA_VISIBLE_DEVICES"] = "2,3" 6 | import sys 7 | import importlib 8 | import shutil 9 | import json 10 | import numpy as np 11 | import time 12 | import requests 13 | import torch 14 | import torch.nn as nn 15 | import torchvision 16 | import torchvision.transforms as transforms 17 | from sklearn.metrics import roc_curve, auc 18 | import matplotlib.pyplot as plt 19 | 20 | plt.switch_backend('agg') 21 | root_dir = os.path.abspath(os.path.dirname(__file__)) 22 | sys.path.append(root_dir) 23 | sys.path.append(os.path.join(root_dir, "datasets")) 24 | sys.path.append(os.path.join(root_dir, "models")) 25 | sys.path.append(os.path.join(root_dir, "optim")) 26 | 27 | from datasets.dataset import build_data_loader 28 | from models.utils import init_weights 29 | from models.optimize import create_criterion, create_optimizer, update_learning_rate, warmup_learning_rate 30 | 31 | from models.md_net import Two_decoder_Net 32 | from datasets.metric import * 33 | 34 | print("PyTorch Version: ", torch.__version__) 35 | 36 | 37 | def FrobeniusNorm(input): # [b,c,h,w] 38 | b, c, h, w = input.size() 39 | triu = torch.eye(h).cuda() 40 | triu = triu.unsqueeze(0).unsqueeze(0) 41 | triu = triu.repeat(b, c, 1, 1) 42 | 43 | x = torch.matmul(input, input.transpose(-2, -1)) 44 | tr = torch.mul(x, triu) 45 | y = torch.sum(tr) 46 | return y 47 | 48 | 49 | def print_table(data): 50 | col_width = [max(len(item) for item in col) for col in data] 51 | for row_idx in range(len(data[0])): 52 | for col_idx, col in enumerate(data): 53 | item = col[row_idx] 54 | align = '<' if not col_idx == 0 else '>' 55 | print(('{:' + align + str(col_width[col_idx]) + '}').format(item), end=" ") 56 | print() 57 | 58 | 59 | def gmm_loss(label, prd, mu_f, mu_b, std_f, std_b, f_k): 60 | b_k = 1 - f_k 61 | 62 | f_likelihood = - f_k * ( 63 | torch.log(np.sqrt(2 * 3.14) * std_f) + torch.pow((prd - mu_f), 2) / (2 * torch.pow(std_f, 2)) + 1e-10) 64 | b_likelihood = - b_k * ( 65 | torch.log(np.sqrt(2 * 3.14) * std_b) + torch.pow((prd - mu_b), 2) / (2 * torch.pow(std_b, 2)) + 1e-10) 66 | likelihood = f_likelihood + b_likelihood 67 | loss = torch.mean(torch.pow(label - torch.exp(likelihood), 2)) 68 | return loss 69 | 70 | 71 | def train_one_epoch(epoch, model_type, total_steps, dataloader, model, 72 | device, criterion_iou, criterion_bce, optimizer, lr, 73 | display_iter, log_file, warmup_step, warmup_method, decoder_type=''): 74 | model.train() 75 | 76 | smooth_loss = 0.0 77 | current_step = 0 78 | t0 = 0.0 79 | 80 | for inputs in dataloader: 81 | 82 | t1 = time.time() 83 | 84 | images = inputs['image'].to(device) 85 | # c_images = inputs['c_masks'].to(device) 86 | labels = inputs['mask'].to(device) 87 | if decoder_type == '': 88 | skeletons = inputs['skeleton'].to(device) 89 | else: 90 | skeletons = inputs[decoder_type].to(device) 91 | 92 | # forward pass 93 | pred, pred_skeleton = model(images) 94 | 95 | # compute loss 96 | loss1 = criterion_iou(pred, labels) 97 | loss2 = criterion_bce(pred_skeleton, skeletons) 98 | 99 | loss = loss1 + loss2 100 | 101 | # predictions 102 | t0 += (time.time() - t1) 103 | 104 | total_steps += 1 105 | current_step += 1 106 | smooth_loss += loss.item() 107 | 108 | # backpropagate when training 109 | optimizer.zero_grad() 110 | lr_update = warmup_learning_rate(optimizer, total_steps, warmup_step, lr, warmup_method) 111 | # lr_update = update_learning_rate(optimizer, epoch, lr, step=lr_decay) 112 | loss.backward() 113 | # loss.backward(retain_graph = True) 114 | optimizer.step() 115 | 116 | # log loss 117 | if total_steps % display_iter == 0: 118 | smooth_loss = smooth_loss / current_step 119 | message = "Epoch: %d Step: %d LR: %.6f Loss: %.4f Runtime: %.2fs/%diters." % ( 120 | epoch + 1, total_steps, lr_update, smooth_loss, t0, display_iter) 121 | print("==> %s" % (message)) 122 | with open(log_file, "a+") as fid: 123 | fid.write('%s\n' % message) 124 | 125 | t0 = 0.0 126 | current_step = 0 127 | smooth_loss = 0.0 128 | 129 | return total_steps 130 | 131 | 132 | def eval_one_epoch(epoch, model_type, threshold, dataloader, model, device, log_file): 133 | with torch.no_grad(): 134 | model.eval() 135 | 136 | total_iou = 0.0 137 | total_f1 = 0.0 138 | # total_distance = 0.0 139 | total_acc = 0.0 140 | total_img = 0 141 | 142 | for inputs in dataloader: 143 | images = inputs['image'].to(device) 144 | labels = inputs['mask'] 145 | 146 | total_img += len(images) 147 | outputs, pred_skeleton = model(images) 148 | preds = outputs > threshold 149 | preds = preds.cpu() 150 | 151 | # metric 152 | val_acc = acc(preds, labels) 153 | 154 | total_acc += val_acc 155 | 156 | val_iou = IoU(preds, labels) 157 | total_iou += val_iou 158 | 159 | val_f1 = F1_score(preds, labels) 160 | total_f1 += val_f1 161 | 162 | # iou 163 | epoch_iou = total_iou / total_img 164 | epoch_f1 = total_f1 / total_img 165 | epoch_acc = total_acc / total_img 166 | 167 | message = "total Threshold: {:.3f} =====> Evaluation IOU: {:.4f}; F1_score: {:.4f}; Acc: {:.4f}".format( 168 | threshold, epoch_iou, epoch_f1, epoch_acc) 169 | print("==> %s" % (message)) 170 | with open(log_file, "a+") as fid: 171 | fid.write('%s\n' % message) 172 | 173 | return epoch_acc, epoch_iou, epoch_f1 174 | 175 | 176 | def train_eval_model(opts): 177 | # parse model configuration 178 | num_epochs = opts["num_epochs"] 179 | train_batch_size = opts["train_batch_size"] 180 | val_batch_size = opts["eval_batch_size"] 181 | dataset_type = opts["dataset_type"] 182 | model_type = opts["model_type"] 183 | warmup_step = opts["warmup_step"] 184 | warmup_method = opts["warmup_method"] 185 | decoder_type = opts["decoder_type"] 186 | dataset_name = opts["dataset_name"] 187 | 188 | opti_mode = opts["optimizer"] 189 | loss_criterion = opts["loss_criterion"] 190 | lr = opts["lr"] 191 | wd = opts["weight_decay"] 192 | 193 | gpus = opts["gpu_list"].split(',') 194 | os.environ['CUDA_VISIBLE_DEVICE'] = opts["gpu_list"] 195 | train_dir = opts["log_dir"] 196 | 197 | train_data_dir = opts["train_data_dir"] 198 | eval_data_dir = opts["eval_data_dir"] 199 | 200 | pretrained = opts["pretrained_model"] 201 | resume = opts["resume"] 202 | display_iter = opts["display_iter"] 203 | save_epoch = opts["save_every_epoch"] 204 | fractal_dir = opts['fractal_dir'] 205 | edge_dir = opts["edge_dir"] 206 | skeleton_dir = opts["skeleton_dir"] 207 | 208 | # backup train configs 209 | log_file = os.path.join(train_dir, "log_file.txt") 210 | os.makedirs(train_dir, exist_ok=True) 211 | model_dir = os.path.join(train_dir, "code_backup") 212 | os.makedirs(model_dir, exist_ok=True) 213 | if resume is None and os.path.exists(log_file): os.remove(log_file) 214 | shutil.copy("./models/md_unet.py", os.path.join(model_dir, "md_unet.py")) 215 | shutil.copy("./train_tdnet.py", os.path.join(model_dir, "train_tdnet.py")) 216 | shutil.copy("./datasets/dataset.py", os.path.join(model_dir, "dataset.py")) 217 | 218 | ckt_dir = os.path.join(train_dir, "checkpoints") 219 | os.makedirs(ckt_dir, exist_ok=True) 220 | 221 | # format printing configs 222 | print("*" * 50) 223 | table_key = [] 224 | table_value = [] 225 | n = 0 226 | for key, value in opts.items(): 227 | table_key.append(key) 228 | table_value.append(str(value)) 229 | n += 1 230 | print_table([table_key, ["="] * n, table_value]) 231 | 232 | # format gpu list 233 | gpu_list = [] 234 | for str_id in gpus: 235 | id = int(str_id) 236 | gpu_list.append(id) 237 | 238 | # dataloader 239 | print("==> Create dataloader") 240 | dataloaders_dict = { 241 | "train": build_data_loader(dataset_name, train_data_dir, train_batch_size, dataset_type, is_train=True, 242 | fractal_dir=fractal_dir,edge_dir=edge_dir,skeleton_dir=skeleton_dir, decoder_type=decoder_type), 243 | "eval": build_data_loader(dataset_name, eval_data_dir, val_batch_size, dataset_type, is_train=False, 244 | fractal_dir=fractal_dir, decoder_type=decoder_type)} 245 | 246 | # define parameters of two networks 247 | print("==> Create network") 248 | if 'fractal' in opts["dataset_type"]: 249 | if 'road' in opts["dataset_type"]: 250 | num_channels = 5 251 | else: 252 | num_channels = 3 253 | elif 'copy' in opts["dataset_type"]: 254 | if 'road' in opts["dataset_type"]: 255 | num_channels = 6 256 | else: 257 | num_channels = 3 258 | elif 'road' in opts["dataset_type"]: 259 | num_channels = 3 260 | else: 261 | num_channels = 1 262 | 263 | num_classes = 1 264 | 265 | model = Two_decoder_Net(num_channels, num_classes) 266 | 267 | init_weights(model) 268 | 269 | # loss layer 270 | criterion_iou = create_criterion(criterion=loss_criterion) 271 | criterion_bce = create_criterion(criterion="bce") 272 | 273 | best_acc = 0.0 274 | start_epoch = 0 275 | 276 | # load pretrained model 277 | if pretrained is not None and os.path.isfile(pretrained): 278 | print("==> Train from model '{}'".format(pretrained)) 279 | checkpoint_gan = torch.load(pretrained) 280 | model.load_state_dict(checkpoint_gan['model_state_dict']) 281 | print("==> Loaded checkpoint '{}')".format(pretrained)) 282 | for param in model.parameters(): 283 | param.requires_grad = False 284 | 285 | # resume training 286 | elif resume is not None and os.path.isfile(resume): 287 | print("==> Resume from checkpoint '{}'".format(resume)) 288 | checkpoint = torch.load(resume) 289 | start_epoch = checkpoint['epoch'] + 1 290 | best_acc = checkpoint['best_acc'] 291 | model_dict = model.state_dict() 292 | pretrained_dict = {k: v for k, v in checkpoint['model_state_dict'].items() if 293 | k in model_dict and v.size() == model_dict[k].size()} 294 | model_dict.update(pretrained_dict) 295 | model.load_state_dict(pretrained_dict) 296 | print("==> Loaded checkpoint '{}' (epoch {})".format(resume, checkpoint['epoch'] + 1)) 297 | 298 | # train from scratch 299 | else: 300 | print("==> Train from initial or random state.") 301 | 302 | # define mutiple-gpu mode 303 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 304 | model.cuda() 305 | model = nn.DataParallel(model) 306 | 307 | # print learnable parameters 308 | print("==> List learnable parameters") 309 | for name, param in model.named_parameters(): 310 | if param.requires_grad == True: 311 | print("\t{}, size {}".format(name, param.size())) 312 | params_to_update = [{'params': model.parameters()}] 313 | 314 | # define optimizer 315 | print("==> Create optimizer") 316 | optimizer = create_optimizer(params_to_update, opti_mode, lr=lr, momentum=0.9, wd=wd) 317 | if resume is not None and os.path.isfile(resume): optimizer.load_state_dict(checkpoint['optimizer']) 318 | 319 | # start training 320 | since = time.time() 321 | 322 | # Each epoch has a training and validation phase 323 | print("==> Start training") 324 | total_steps = 0 325 | threshold = opts["threshold"] 326 | epochs = [] 327 | ious = [] 328 | best_iou = 0.0 329 | for epoch in range(start_epoch, num_epochs): 330 | 331 | print('-' * 50) 332 | print("==> Epoch {}/{}".format(epoch + 1, num_epochs)) 333 | 334 | total_steps = train_one_epoch(epoch, model_type, total_steps, 335 | dataloaders_dict['train'], 336 | model, device, 337 | criterion_iou, criterion_bce, optimizer, lr, 338 | display_iter, log_file, warmup_step, warmup_method, decoder_type) 339 | 340 | epoch_acc, epoch_iou, epoch_f1 = eval_one_epoch(epoch, model_type, threshold, dataloaders_dict['eval'], 341 | model, device, log_file) 342 | epochs.append(epoch) 343 | ious.append(epoch_iou) 344 | 345 | if best_iou < epoch_iou and epoch >= 5: 346 | best_iou = epoch_iou 347 | best_acc = epoch_acc 348 | torch.save({'epoch': epoch, 349 | 'model_state_dict': model.module.state_dict(), 350 | 'optimizer': optimizer.state_dict(), 351 | 'best_acc': best_acc}, 352 | os.path.join(ckt_dir, "best.pth")) 353 | 354 | if (epoch + 1) % save_epoch == 0 and (epoch + 1) >= 30: 355 | torch.save({'epoch': epoch, 356 | 'model_state_dict': model.module.state_dict(), 357 | 'optimizer': optimizer.state_dict(), 358 | 'best_iou': epoch_iou}, 359 | os.path.join(ckt_dir, "checkpoints_" + str(epoch + 1) + ".pth")) 360 | 361 | time_elapsed = time.time() - since 362 | time_message = 'Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60) 363 | print(time_message) 364 | plt.figure(figsize=(10, 10)) 365 | plt.plot(epochs, ious) 366 | plt.ylim(0, 0.9) 367 | # set the label of x and y 368 | plt.xlabel("epoch") 369 | plt.ylabel("iou") 370 | plt.title("Train model= " + str(model_type) + "; lr=" + str(lr)) 371 | plt.legend() 372 | plt.savefig(os.path.join(train_dir, 'lr_' + str(lr) + '_train_iou.png')) 373 | with open(log_file, "a+") as fid: 374 | fid.write('%s\n' % time_message) 375 | fid.write('==> Best val Acc: {:4f}; Iou: {:4f}'.format(best_acc, best_iou)) 376 | print('==> Best val Acc: {:4f}; Iou: {:4f}'.format(best_acc, best_iou)) 377 | 378 | 379 | if __name__ == '__main__': 380 | dataset_names = ['ER', 'MITO', 'ROSE', 'STARE', 'ROAD', 'NUCLEUS'] 381 | dataset_list = ['er', 'er_fractal', 'er_fractal_two_decoder', 'nucleus_fractal_two_decoder', 382 | 'nucleus_fractal_two_decoder_weighted'] 383 | model_choice = ['Two_decoder_Net'] 384 | date = '20240312' 385 | 386 | opts = dict() 387 | opts['dataset_type'] = 'er_fractal_two_decoder' 388 | opts["dataset_name"] = 'ER' 389 | opts["num_epochs"] = 50 390 | opts["fractal_dir"] = 'Fractal_info_5' 391 | opts["edge_dir"] = 'mask_edge' 392 | opts["skeleton_dir"] = 'mask_skeleton' 393 | opts["train_data_dir"] = "./dataset_txts/train_er.txt" 394 | opts["eval_data_dir"] = "./dataset_txts/test_er.txt" 395 | opts["train_batch_size"] = 32 396 | opts["eval_batch_size"] = 32 397 | opts["optimizer"] = "SGD" 398 | opts["model_type"] = "Two_decoder_Net" 399 | opts["decoder_type"] = 'skeleton' 400 | opts["loss_criterion"] = "iou" 401 | opts["lr"] = 0.03 402 | opts["threshold"] = 0.5 403 | opts["warmup_step"] = 1000 404 | opts["warmup_method"] = 'exp' 405 | opts["weight_decay"] = 0.0005 406 | opts["gpu_list"] = "0,1,2,3" 407 | log_dir = "./train_logs/" + str(opts["dataset_type"]) + '_' + opts["model_type"] + '_' + opts[ 408 | "decoder_type"] + "_iou_bce_loss_one_half_" + \ 409 | '_' + str(opts["train_batch_size"]) + '_' + str(opts["lr"]) + '_' + str(opts["num_epochs"]) + '_' + str( 410 | opts["threshold"]) + '_' + \ 411 | str(opts["warmup_step"]) + '_' + date + '_warmup_' + opts["fractal_dir"] 412 | opts["log_dir"] = log_dir 413 | opts["pretrained_model"] = None 414 | opts["resume"] = None 415 | opts["display_iter"] = 10 416 | opts["save_every_epoch"] = 5 417 | 418 | train_eval_model(opts) 419 | -------------------------------------------------------------------------------- /train_unet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 5 | 6 | import sys 7 | import shutil 8 | import time 9 | import torch.nn as nn 10 | 11 | root_dir = os.path.abspath(os.path.dirname(__file__)) 12 | sys.path.append(root_dir) 13 | sys.path.append(os.path.join(root_dir, "datasets")) 14 | sys.path.append(os.path.join(root_dir, "models")) 15 | sys.path.append(os.path.join(root_dir, "optim")) 16 | from models.unet import UNet 17 | from datasets.dataset import build_data_loader 18 | from models.utils import init_weights 19 | from models.optimize import create_criterion, create_optimizer, update_learning_rate, warmup_learning_rate 20 | from datasets.metric import * 21 | 22 | print("PyTorch Version: ", torch.__version__) 23 | 24 | 25 | def FrobeniusNorm(input): # [b,c,h,w] 26 | b, c, h, w = input.size() 27 | triu = torch.eye(h).cuda() 28 | triu = triu.unsqueeze(0).unsqueeze(0) 29 | triu = triu.repeat(b, c, 1, 1) 30 | 31 | x = torch.matmul(input, input.transpose(-2, -1)) 32 | tr = torch.mul(x, triu) 33 | y = torch.sum(tr) 34 | return y 35 | 36 | 37 | def print_table(data): 38 | col_width = [max(len(item) for item in col) for col in data] 39 | for row_idx in range(len(data[0])): 40 | for col_idx, col in enumerate(data): 41 | item = col[row_idx] 42 | align = '<' if not col_idx == 0 else '>' 43 | print(('{:' + align + str(col_width[col_idx]) + '}').format(item), end=" ") 44 | print() 45 | 46 | 47 | def gmm_loss(label, prd, mu_f, mu_b, std_f, std_b, f_k): 48 | b_k = 1 - f_k 49 | 50 | f_likelihood = - f_k * ( 51 | torch.log(np.sqrt(2 * 3.14) * std_f) + torch.pow((prd - mu_f), 2) / (2 * torch.pow(std_f, 2)) + 1e-10) 52 | b_likelihood = - b_k * ( 53 | torch.log(np.sqrt(2 * 3.14) * std_b) + torch.pow((prd - mu_b), 2) / (2 * torch.pow(std_b, 2)) + 1e-10) 54 | likelihood = f_likelihood + b_likelihood 55 | loss = torch.mean(torch.pow(label - torch.exp(likelihood), 2)) 56 | return loss 57 | 58 | 59 | def train_one_epoch(epoch, total_steps, dataloader, model, 60 | device, criterion, optimizer, lr, 61 | display_iter, log_file, warmup_step, warmup_method): 62 | model.train() 63 | 64 | smooth_loss = 0.0 65 | current_step = 0 66 | t0 = 0.0 67 | 68 | for inputs in dataloader: 69 | 70 | t1 = time.time() 71 | 72 | images = inputs['image'].to(device) 73 | # c_images = inputs['c_masks'].to(device) 74 | labels = inputs['mask'].to(device) 75 | # print(inputs["ID"]) 76 | 77 | # forward pass 78 | pred = model(images) 79 | 80 | # compute loss 81 | loss = criterion(pred, labels) 82 | 83 | # predictions 84 | t0 += (time.time() - t1) 85 | 86 | total_steps += 1 87 | current_step += 1 88 | smooth_loss += loss.item() 89 | 90 | # backpropagate when training 91 | optimizer.zero_grad() 92 | lr_update = warmup_learning_rate(optimizer, total_steps, warmup_step, lr, warmup_method) 93 | # lr_update = update_learning_rate(optimizer, epoch, lr, step=lr_decay) 94 | loss.backward() 95 | # loss.backward(retain_graph = True) 96 | optimizer.step() 97 | 98 | # log loss 99 | if total_steps % display_iter == 0: 100 | smooth_loss = smooth_loss / current_step 101 | message = "Epoch: %d Step: %d LR: %.6f Loss: %.4f Runtime: %.2fs/%diters." % ( 102 | epoch + 1, total_steps, lr_update, smooth_loss, t0, display_iter) 103 | print("==> %s" % (message)) 104 | with open(log_file, "a+") as fid: 105 | fid.write('%s\n' % message) 106 | 107 | t0 = 0.0 108 | current_step = 0 109 | smooth_loss = 0.0 110 | 111 | return total_steps 112 | 113 | 114 | def eval_one_epoch(epoch, threshold, dataloader, model, device, log_file): 115 | with torch.no_grad(): 116 | model.eval() 117 | 118 | total_iou = 0.0 119 | total_f1 = 0.0 120 | # total_distance = 0.0 121 | total_acc = 0.0 122 | total_img = 0 123 | 124 | for inputs in dataloader: 125 | images = inputs['image'].to(device) 126 | labels = inputs['mask'] 127 | 128 | total_img += len(images) 129 | outputs = model(images) 130 | 131 | preds = outputs > threshold 132 | preds = preds.cpu() 133 | 134 | # metric 135 | val_acc = acc(preds, labels) 136 | 137 | total_acc += val_acc 138 | 139 | val_iou = IoU(preds, labels) 140 | total_iou += val_iou 141 | 142 | val_f1 = F1_score(preds, labels) 143 | total_f1 += val_f1 144 | 145 | # iou 146 | epoch_iou = total_iou / total_img 147 | epoch_f1 = total_f1 / total_img 148 | epoch_acc = total_acc / total_img 149 | 150 | message = "total Threshold: {:.3f} =====> Evaluation IOU: {:.4f}; F1_score: {:.4f}; Acc: {:.4f}".format( 151 | threshold, epoch_iou, epoch_f1, epoch_acc) 152 | print("==> %s" % (message)) 153 | with open(log_file, "a+") as fid: 154 | fid.write('%s\n' % message) 155 | 156 | return epoch_acc, epoch_iou, epoch_f1 157 | 158 | 159 | def train_eval_model(opts): 160 | # parse model configuration 161 | num_epochs = opts["num_epochs"] 162 | train_batch_size = opts["train_batch_size"] 163 | val_batch_size = opts["eval_batch_size"] 164 | dataset_type = opts["dataset_type"] 165 | dataset_name = opts['dataset_name'] 166 | model_type = opts["model_type"] 167 | fractal_dir = opts["fractal_dir"] 168 | 169 | opti_mode = opts["optimizer"] 170 | loss_criterion = opts["loss_criterion"] 171 | lr = opts["lr"] 172 | warmup_step = opts["warmup_step"] 173 | warmup_method = opts["warmup_method"] 174 | wd = opts["weight_decay"] 175 | 176 | gpus = opts["gpu_list"].split(',') 177 | os.environ['CUDA_VISIBLE_DEVICE'] = opts["gpu_list"] 178 | train_dir = opts["log_dir"] 179 | 180 | train_data_dir = opts["train_data_dir"] 181 | eval_data_dir = opts["eval_data_dir"] 182 | 183 | pretrained = opts["pretrained_model"] 184 | resume = opts["resume"] 185 | display_iter = opts["display_iter"] 186 | save_epoch = opts["save_every_epoch"] 187 | 188 | # backup train configs 189 | log_file = os.path.join(train_dir, "log_file.txt") 190 | os.makedirs(train_dir, exist_ok=True) 191 | model_dir = os.path.join(train_dir, "code_backup") 192 | os.makedirs(model_dir, exist_ok=True) 193 | if resume is None and os.path.exists(log_file): os.remove(log_file) 194 | shutil.copy("./models/unet.py", os.path.join(model_dir, "unet.py")) 195 | shutil.copy("./train_unet.py", os.path.join(model_dir, "train_unet.py")) 196 | shutil.copy("./datasets/dataset.py", os.path.join(model_dir, "dataset.py")) 197 | 198 | ckt_dir = os.path.join(train_dir, "checkpoints") 199 | os.makedirs(ckt_dir, exist_ok=True) 200 | 201 | # format printing configs 202 | print("*" * 50) 203 | table_key = [] 204 | table_value = [] 205 | n = 0 206 | for key, value in opts.items(): 207 | table_key.append(key) 208 | table_value.append(str(value)) 209 | n += 1 210 | print_table([table_key, ["="] * n, table_value]) 211 | 212 | # format gpu list 213 | gpu_list = [] 214 | for str_id in gpus: 215 | id = int(str_id) 216 | gpu_list.append(id) 217 | 218 | # dataloader 219 | print("==> Create dataloader") 220 | dataloaders_dict = { 221 | "train": build_data_loader(dataset_name, train_data_dir, train_batch_size, dataset_type, is_train=True, 222 | fractal_dir=fractal_dir), 223 | "eval": build_data_loader(dataset_name, eval_data_dir, val_batch_size, dataset_type, is_train=False, 224 | fractal_dir=fractal_dir)} 225 | 226 | # define parameters of two networks 227 | print("==> Create network") 228 | if 'fractal' in opts["dataset_type"]: 229 | if 'road' in opts["dataset_type"]: 230 | num_channels = 5 231 | else: 232 | num_channels = 3 233 | elif 'copy' in opts["dataset_type"]: 234 | if 'road' in opts["dataset_type"]: 235 | num_channels = 6 236 | else: 237 | num_channels = 3 238 | elif 'road' in opts["dataset_type"]: 239 | num_channels = 3 240 | else: 241 | num_channels = 1 242 | 243 | num_classes = 1 244 | model = UNet(num_channels, num_classes) 245 | 246 | init_weights(model) 247 | 248 | # loss layer 249 | criterion = create_criterion(criterion=loss_criterion) 250 | 251 | best_acc = 0.0 252 | start_epoch = 0 253 | 254 | # load pretrained model 255 | if pretrained is not None and os.path.isfile(pretrained): 256 | print("==> Train from model '{}'".format(pretrained)) 257 | checkpoint_gan = torch.load(pretrained) 258 | model.load_state_dict(checkpoint_gan['model_state_dict']) 259 | print("==> Loaded checkpoint '{}')".format(pretrained)) 260 | for param in model.parameters(): 261 | param.requires_grad = False 262 | 263 | # resume training 264 | elif resume is not None and os.path.isfile(resume): 265 | print("==> Resume from checkpoint '{}'".format(resume)) 266 | checkpoint = torch.load(resume) 267 | start_epoch = checkpoint['epoch'] + 1 268 | best_acc = checkpoint['best_acc'] 269 | model_dict = model.state_dict() 270 | pretrained_dict = {k: v for k, v in checkpoint['model_state_dict'].items() if 271 | k in model_dict and v.size() == model_dict[k].size()} 272 | model_dict.update(pretrained_dict) 273 | model.load_state_dict(pretrained_dict) 274 | print("==> Loaded checkpoint '{}' (epoch {})".format(resume, checkpoint['epoch'] + 1)) 275 | 276 | # train from scratch 277 | else: 278 | print("==> Train from initial or random state.") 279 | 280 | # define mutiple-gpu mode 281 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 282 | model.cuda() 283 | model = nn.DataParallel(model) 284 | 285 | # print learnable parameters 286 | print("==> List learnable parameters") 287 | for name, param in model.named_parameters(): 288 | if param.requires_grad == True: 289 | print("\t{}, size {}".format(name, param.size())) 290 | params_to_update = [{'params': model.parameters()}] 291 | 292 | # define optimizer 293 | print("==> Create optimizer") 294 | optimizer = create_optimizer(params_to_update, opti_mode, lr=lr, momentum=0.9, wd=wd) 295 | if resume is not None and os.path.isfile(resume): optimizer.load_state_dict(checkpoint['optimizer']) 296 | 297 | # start training 298 | since = time.time() 299 | 300 | # Each epoch has a training and validation phase 301 | print("==> Start training") 302 | total_steps = 0 303 | threshold = opts["threshold"] 304 | epochs = [] 305 | ious = [] 306 | best_iou = 0.0 307 | for epoch in range(start_epoch, num_epochs): 308 | 309 | print('-' * 50) 310 | print("==> Epoch {}/{}".format(epoch + 1, num_epochs)) 311 | 312 | total_steps = train_one_epoch(epoch, total_steps, 313 | dataloaders_dict['train'], 314 | model, device, 315 | criterion, optimizer, lr, 316 | display_iter, log_file, warmup_step, warmup_method) 317 | 318 | epoch_acc, epoch_iou, epoch_f1 = eval_one_epoch(epoch, threshold, dataloaders_dict['eval'], 319 | model, device, log_file) 320 | epochs.append(epoch) 321 | ious.append(epoch_iou) 322 | 323 | if best_acc < epoch_acc and epoch >= 5: 324 | best_iou = epoch_iou 325 | best_acc = epoch_acc 326 | torch.save({'epoch': epoch, 327 | 'model_state_dict': model.module.state_dict(), 328 | 'optimizer': optimizer.state_dict(), 329 | 'best_acc': best_acc}, 330 | os.path.join(ckt_dir, "best.pth")) 331 | 332 | if (epoch + 1) % save_epoch == 0 and (epoch + 1) >= 30: 333 | torch.save({'epoch': epoch, 334 | 'model_state_dict': model.module.state_dict(), 335 | 'optimizer': optimizer.state_dict(), 336 | 'best_iou': epoch_iou}, 337 | os.path.join(ckt_dir, "checkpoints_" + str(epoch + 1) + ".pth")) 338 | 339 | time_elapsed = time.time() - since 340 | time_message = 'Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60) 341 | print(time_message) 342 | plt.figure(figsize=(10, 10)) 343 | plt.plot(epochs, ious) 344 | plt.ylim(0, 0.9) 345 | # set the label of x and y 346 | plt.xlabel("epoch") 347 | plt.ylabel("iou") 348 | plt.title("Train model= " + str(model_type) + "; lr=" + str(lr)) 349 | plt.legend() 350 | plt.savefig(os.path.join(train_dir, 'lr_' + str(lr) + '_train_iou.png')) 351 | with open(log_file, "a+") as fid: 352 | fid.write('%s\n' % time_message) 353 | fid.write('==> Best val Acc: {:4f}; Iou: {:4f}'.format(best_acc, best_iou)) 354 | print('==> Best val Acc: {:4f}; Iou: {:4f}'.format(best_acc, best_iou)) 355 | 356 | 357 | if __name__ == '__main__': 358 | dataset_names = ['ER', 'MITO', 'ROSE', 'STARE', 'ROAD', 'NUCLEUS'] 359 | date = '20240312' 360 | dataset_type_list = ['er', 'er_fractal'] 361 | 362 | opts = dict() 363 | opts['dataset_type'] = 'er_fractal' 364 | opts["dataset_name"] = 'ER' 365 | opts["model_type"] = 'UNet' 366 | opts["fractal_dir"] = 'Fractal_info_5' 367 | opts["num_epochs"] = 50 368 | opts["train_data_dir"] = "./dataset_txts/train_er.txt" 369 | opts["eval_data_dir"] = "./dataset_txts/test_er.txt" 370 | opts["train_batch_size"] = 32 371 | opts["eval_batch_size"] = 32 372 | opts["optimizer"] = "SGD" 373 | opts["loss_criterion"] = "iou" 374 | opts["threshold"] = 0.3 375 | opts["lr"] = 0.05 376 | opts["warmup_step"] = 1000 377 | opts["warmup_method"] = 'exp' 378 | opts["weight_decay"] = 0.0005 379 | opts["gpu_list"] = "0,1,2,3" 380 | 381 | log_dir = "./train_logs/" + str(opts["dataset_type"]) + "_" + opts["model_type"] + "_" + \ 382 | opts["loss_criterion"] + "_" + str(opts["train_batch_size"]) + '_' + str(opts["lr"]) + \ 383 | '_' + str(opts["num_epochs"]) + '_' + str(opts["threshold"]) + '_' + \ 384 | str(opts["warmup_step"]) + '_' + date + '_warmup_' + opts["fractal_dir"] 385 | 386 | opts["log_dir"] = log_dir 387 | opts["pretrained_model"] = None 388 | opts["resume"] = None 389 | opts["display_iter"] = 10 390 | opts["save_every_epoch"] = 5 391 | 392 | train_eval_model(opts) 393 | --------------------------------------------------------------------------------