├── README.md ├── cppnet ├── __pycache__ │ ├── dataloader_custom.cpython-37.pyc │ ├── distance_loss_sampling_refine.cpython-37.pyc │ ├── load_save_model.cpython-37.pyc │ ├── metric_v2.cpython-37.pyc │ ├── stats_utils.cpython-37.pyc │ └── train_sampling_refine_withgt_separate_metric.cpython-37.pyc ├── dataloader_aug_pannuke.py ├── dataloader_custom.py ├── distance_loss_sampling_refine.py ├── distance_loss_sampling_refine_cls.py ├── load_save_model.py ├── main_cppnet_dsb.py ├── main_cppnet_pannuke.py ├── metric_v2.py ├── models │ ├── SamplingFeatures2.py │ ├── __pycache__ │ │ ├── SamplingFeatures2.cpython-37.pyc │ │ ├── cpp_net.cpython-37.pyc │ │ ├── feature_extractor.cpython-37.pyc │ │ └── unet_parts_gn.cpython-37.pyc │ ├── cpp_net.py │ ├── cppnet_res50.py │ ├── feature_extractor.py │ ├── resnet50_preact.py │ └── unet_parts_gn.py ├── predict_eval.py ├── predict_eval_pannuke.py ├── stats_utils.py ├── train.py └── train_sampling_refine_withgt_separate_metric.py ├── feature_extractor ├── __pycache__ │ ├── dataloader_aug.cpython-37.pyc │ ├── instance_loss.cpython-37.pyc │ ├── load_save_model.cpython-37.pyc │ └── train.cpython-37.pyc ├── dataloader_aug.py ├── dataloader_aug_pannuke_cls.py ├── instance_loss.py ├── load_save_model.py ├── main_shape.py ├── models │ ├── __pycache__ │ │ ├── unet_model.cpython-37.pyc │ │ └── unet_parts_gn.cpython-37.pyc │ ├── unet_model.py │ ├── unet_model_3layer.py │ └── unet_parts_gn.py └── train.py └── reorganize_datasets ├── reorganize_bbbc006.py ├── reorganize_dsb2018.py └── reorganize_pannuke.py /README.md: -------------------------------------------------------------------------------- 1 | # CPP-Net: Context-aware Polygon Proposal Network for Nucleus Segmentation 2 | 3 | ## Requirements 4 | ``` 5 | pytorch==1.11.0 6 | stardist==0.6.0 7 | csbdeep==0.6.3 8 | ``` 9 | 10 | 11 | ## Prepare the datasets 12 | ``` 13 | DATA_PATH/train/images/*.tif or *.png 14 | DATA_PATH/val/images/*.tif or *.png 15 | DATA_PATH/test/images/*.tif or *.png 16 | DATA_PATH/train/masks/*.tif 17 | DATA_PATH/val/masks/*.tif 18 | DATA_PATH/test/masks/*.tif 19 | ... 20 | ``` 21 | 22 | Change the path in the script in reorganize_datasets, and run the script. 23 | ``` 24 | python reorganize_datasets/reorganize_dsb2018.py 25 | python reorganize_datasets/reorganize_bbbc006.py 26 | python reorganize_datasets/reorganize_pannuke.py 27 | ``` 28 | The download link can also be found in these scripts. 29 | 30 | Change "type_list" in the function getDataLoaders (in cppnet/dataloader_custom.py and feature_extractor/dataloader_aug.py) according to the names of your dataset splits. 31 | 32 | ## Prepare the instance shape-aware feature extractor 33 | 34 | Modify the DATA_PATH in ./feature_extractor/main_shape.py. Here, the parameter --n_cls includes both foreground classes and the background. 35 | Run the script like 36 | ``` 37 | python feature_extractor/main_shape.py --gpuid 0 --dataset DSB2018 --n_cls 1 38 | python feature_extractor/main_shape.py --gpuid 0 --dataset PanNuke --n_cls 6 39 | ``` 40 | 41 | ## Train and Eval 42 | 43 | 44 | Modify the SAP_Weight_path in ./cppnet/main_cppnet_dsb.py after the training process of SAP model 45 | 46 | or set SAP_Weight_path=None to ignore the SAP Loss 47 | 48 | Modify the DATA_PATH in ./cppnet/main_cppnet_dsb.py 49 | 50 | ``` 51 | python cppnet/main_cppnet_dsb.py --gpuid 0 52 | ``` 53 | 54 | Modify the MODEL_WEIGHT_PATH in ./cppnet/main_cppnet_dsb.py after the training process of CPP-Net 55 | 56 | Modify the DATASET_PATH_IMAGE and DATASET_PATH_LABEL in ./cppnet/main_cppnet_dsb.py 57 | (e.g., DATASET_PATH_IMAGE=DATA_PATH/test/images and DATASET_PATH_LABEL=DATA_PATH/test/masks) 58 | 59 | Modify the path in cppnet/predict_eval.p, and run the script to evaluate model performances 60 | 61 | ``` 62 | python cppnet/predict_eval.py --gpuid 0 63 | ``` 64 | 65 | For each fold in PanNuke, use script cppnet/predict_eval_pannuke.py, and you can obtain a '.npy' file that includes predictions 66 | 67 | 68 | ### Pytorch StarDist 69 | There is a pytorch reimplementation of StarDist in https://github.com/ASHISRAVINDRAN/stardist_pytorch and part of the codes in our project are modified from this repository. 70 | -------------------------------------------------------------------------------- /cppnet/__pycache__/dataloader_custom.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csccsccsccsc/cpp-net/e99cbc790fddb1964355753be12ca8f10e51f247/cppnet/__pycache__/dataloader_custom.cpython-37.pyc -------------------------------------------------------------------------------- /cppnet/__pycache__/distance_loss_sampling_refine.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csccsccsccsc/cpp-net/e99cbc790fddb1964355753be12ca8f10e51f247/cppnet/__pycache__/distance_loss_sampling_refine.cpython-37.pyc -------------------------------------------------------------------------------- /cppnet/__pycache__/load_save_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csccsccsccsc/cpp-net/e99cbc790fddb1964355753be12ca8f10e51f247/cppnet/__pycache__/load_save_model.cpython-37.pyc -------------------------------------------------------------------------------- /cppnet/__pycache__/metric_v2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csccsccsccsc/cpp-net/e99cbc790fddb1964355753be12ca8f10e51f247/cppnet/__pycache__/metric_v2.cpython-37.pyc -------------------------------------------------------------------------------- /cppnet/__pycache__/stats_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csccsccsccsc/cpp-net/e99cbc790fddb1964355753be12ca8f10e51f247/cppnet/__pycache__/stats_utils.cpython-37.pyc -------------------------------------------------------------------------------- /cppnet/__pycache__/train_sampling_refine_withgt_separate_metric.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csccsccsccsc/cpp-net/e99cbc790fddb1964355753be12ca8f10e51f247/cppnet/__pycache__/train_sampling_refine_withgt_separate_metric.cpython-37.pyc -------------------------------------------------------------------------------- /cppnet/dataloader_aug_pannuke.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.data import Dataset, DataLoader 3 | from skimage import io 4 | from skimage.transform import resize 5 | import numpy as np 6 | from stardist import star_dist,edt_prob 7 | from csbdeep.utils import normalize 8 | import random 9 | import scipy.io as scio 10 | 11 | class PanNukeDataset(Dataset): 12 | def __init__(self, root_dir, n_rays, max_dist=None, if_training=False, resz=None): 13 | self.img_filefold = os.path.join(root_dir,'images') 14 | self.target_filefold = os.path.join(root_dir,'masks') 15 | 16 | self.img_list = [] 17 | self.target_list = [] 18 | with open(os.path.join(self.img_filefold, 'name_list.txt'), 'r') as f: 19 | for line in f.readlines(): 20 | line_terms = line.split(',') 21 | self.img_list.append(line_terms[1].strip()) 22 | for ic in range(0, 5): 23 | with open(os.path.join(self.target_filefold, 'name_list_c'+str(ic)+'.txt'), 'r') as f: 24 | ic_target_list = [] 25 | for line in f.readlines(): 26 | line_terms = line.split(',') 27 | ic_target_list.append(line_terms[1].strip()) 28 | self.target_list.append(ic_target_list) 29 | 30 | self.n_rays = n_rays 31 | self.max_dist = max_dist 32 | self.if_training=if_training 33 | self.resz = resz 34 | 35 | def __len__(self): 36 | return len(self.img_list) 37 | 38 | def __getitem__(self, idx): 39 | image = io.imread(self.img_list[idx]) 40 | image = image.astype(np.float32) 41 | for imod in range(3): 42 | tmp_image = image[:, :, imod] 43 | meanv = tmp_image.mean() 44 | stdv = tmp_image.std() 45 | image[:, :, imod] = (tmp_image-meanv)/stdv 46 | # target: [0, 4] 47 | target = [] 48 | for ic in range(5): 49 | ic_target = io.imread(self.target_list[ic][idx]) 50 | if ic > 0: 51 | last_target_max = target[ic-1].max() 52 | ic_target[ic_target>0] += last_target_max 53 | target.append(ic_target) 54 | target = np.stack(target, axis=2) 55 | if self.if_training: 56 | aug_type = random.randint(0, 5) # rot90: 0, 1, 2; flip: 3, 4; ori: 5 57 | if aug_type<=2: 58 | image = np.rot90(image, k=aug_type+1, axes=(0, 1)).copy() 59 | target = np.rot90(target, k=aug_type+1, axes=(0, 1)).copy() 60 | elif aug_type<=4: 61 | image = np.flip(image, axis=aug_type-3).copy() 62 | target = np.flip(target, axis=aug_type-3).copy() 63 | obj_probabilities = [] 64 | allcls_target = target.max(axis=2) 65 | distances = star_dist(allcls_target, self.n_rays) 66 | if self.max_dist: 67 | distances[distances>self.max_dist] = self.max_dist 68 | obj_probabilities = edt_prob(allcls_target) 69 | # seg + 1 !!! 70 | # background = 0; foreground = 1:5 71 | seg_target = ((np.argmax(target, axis=2)+1).astype(np.int64) * (target.max(axis=2)>0).astype(np.int64)).astype(np.int64) 72 | 73 | if self.resz is not None: 74 | image = resize(image, self.resz, order=1, preserve_range=True) 75 | obj_probabilities = resize(obj_probabilities, self.resz, order=1, preserve_range=True) 76 | distances = resize(distances, self.resz, order=1, preserve_range=True) 77 | seg_target = resize(seg_target, self.resz, order=0, preserve_range=True) 78 | image = np.transpose(image, (2, 0, 1)) 79 | distances = np.transpose(distances, (2,0,1)) 80 | obj_probabilities = np.expand_dims(obj_probabilities, axis=0) 81 | # seg_target = np.expand_dims(seg_target, axis=0) 82 | # scio.savemat('test.mat', {'image':image, 'prob':obj_probabilities, 'dist':distances, 'target':target, 'seg':seg_target}) 83 | # print(self.img_list[idx]) 84 | return image, obj_probabilities, distances, seg_target 85 | 86 | def getDataLoaders(n_rays, max_dist, root_dir, type_list=['fold_1', 'fold_2'], batch_size=8, resz=None): 87 | trainset = PanNukeDataset(root_dir=root_dir+'/'+type_list[0]+'/', n_rays=n_rays, max_dist=max_dist, if_training=True, resz=resz) 88 | testset = PanNukeDataset(root_dir=root_dir+'/'+type_list[1]+'/', n_rays=n_rays, max_dist=max_dist, if_training=False, resz=resz) 89 | trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4) 90 | testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4) 91 | return trainloader,testloader -------------------------------------------------------------------------------- /cppnet/dataloader_custom.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.data import Dataset, DataLoader 3 | from skimage import io 4 | from skimage.transform import resize 5 | import numpy as np 6 | from stardist import star_dist,edt_prob 7 | from csbdeep.utils import normalize 8 | import random 9 | 10 | class DSB2018Dataset(Dataset): 11 | def __init__(self, root_dir, n_rays, max_dist=None, if_training=False, resz=None, crop=None, image_flag='images', mask_flag='masks'): 12 | self.raw_files = os.listdir(os.path.join(root_dir, image_flag)) 13 | self.target_files = os.listdir( os.path.join(root_dir, mask_flag)) 14 | self.raw_files.sort() 15 | self.target_files.sort() 16 | self.root_dir = root_dir 17 | self.n_rays = n_rays 18 | self.max_dist = max_dist 19 | self.if_training=if_training 20 | self.resz = resz 21 | self.crop = crop 22 | 23 | self.image_flag = image_flag 24 | self.mask_flag = mask_flag 25 | 26 | def __len__(self): 27 | return len(self.raw_files) 28 | 29 | def __getitem__(self, idx): 30 | assert self.raw_files[idx] == self.target_files[idx] 31 | img_name = os.path.join(self.root_dir, self.image_flag, self.raw_files[idx]) 32 | image = io.imread(img_name) 33 | target_name = os.path.join(self.root_dir, self.mask_flag, self.target_files[idx]) 34 | target = io.imread(target_name) 35 | 36 | if self.crop is not None and (None not in self.crop): 37 | h, w = image.shape 38 | dh = h-self.crop[0] 39 | dw = w-self.crop[1] 40 | sh = random.randint(0, dh-1) 41 | sw = random.randint(0, dw-1) 42 | image = image[sh:(sh+self.crop[0]), sw:(sw+self.crop[1])] 43 | target = target[sh:(sh+self.crop[0]), sw:(sw+self.crop[1])] 44 | 45 | image = normalize(image, 1, 99.8, axis=(0,1)) 46 | 47 | if self.if_training: 48 | aug_type = random.randint(0, 5) # rot90: 0, 1, 2; flip: 3, 4; ori: 5 49 | if aug_type<=2: 50 | image = np.rot90(image, aug_type).copy() 51 | target = np.rot90(target, aug_type).copy() 52 | elif aug_type<=4: 53 | image = np.flip(image, aug_type-3).copy() 54 | target = np.flip(target, aug_type-3).copy() 55 | distances = star_dist(target, self.n_rays) 56 | if self.max_dist: 57 | distances[distances>self.max_dist] = self.max_dist 58 | obj_probabilities = edt_prob(target) 59 | 60 | if self.resz is not None: 61 | image = resize(image, self.resz, order=1, preserve_range=True) 62 | obj_probabilities = resize(obj_probabilities, self.resz, order=1, preserve_range=True) 63 | distances = resize(distances, self.resz, order=1, preserve_range=True) 64 | 65 | distances = np.transpose(distances, (2,0,1)) 66 | image = np.expand_dims(image,0) 67 | obj_probabilities = np.expand_dims(obj_probabilities,0) 68 | 69 | return image, obj_probabilities, distances 70 | 71 | def getDataLoaders(n_rays, max_dist, root_dir, type_list=['train', 'test'], image_flag='images', mask_flag='masks', batch_size=1, train_crop=None, test_crop=None, resz=None): 72 | trainset = DSB2018Dataset(root_dir=root_dir+'/'+type_list[0]+'/', n_rays=n_rays, max_dist=max_dist, if_training=True, crop=train_crop, resz=resz, image_flag=image_flag, mask_flag=mask_flag) 73 | testset = DSB2018Dataset(root_dir=root_dir+'/'+type_list[1]+'/', n_rays=n_rays, max_dist=max_dist, if_training=False, crop=test_crop, resz=resz, image_flag=image_flag, mask_flag=mask_flag) 74 | trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4) 75 | testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4) 76 | return trainloader,testloader 77 | 78 | 79 | class PanNukeDataset(Dataset): 80 | def __init__(self, root_dir, n_rays, max_dist=None, if_training=False, resz=None): 81 | self.img_filefold = os.path.join(root_dir,'images') 82 | self.target_filefold = os.path.join(root_dir,'masks') 83 | 84 | self.img_list = [] 85 | self.target_list = [] 86 | with open(os.path.join(self.img_filefold, 'name_list.txt'), 'r') as f: 87 | for line in f.readlines(): 88 | line_terms = line.split(',') 89 | self.img_list.append(line_terms[1].strip()) 90 | for ic in range(0, 5): 91 | with open(os.path.join(self.target_filefold, 'name_list_c'+str(ic)+'.txt'), 'r') as f: 92 | ic_target_list = [] 93 | for line in f.readlines(): 94 | line_terms = line.split(',') 95 | ic_target_list.append(line_terms[1].strip()) 96 | self.target_list.append(ic_target_list) 97 | 98 | self.n_rays = n_rays 99 | self.max_dist = max_dist 100 | self.if_training=if_training 101 | self.resz = resz 102 | 103 | def __len__(self): 104 | return len(self.img_list) 105 | 106 | def __getitem__(self, idx): 107 | image = io.imread(self.img_list[idx]) 108 | image = image.astype(np.float32) 109 | for imod in range(3): 110 | tmp_image = image[:, :, imod] 111 | meanv = tmp_image.mean() 112 | stdv = tmp_image.std() 113 | image[:, :, imod] = (tmp_image-meanv)/stdv 114 | # target: [0, 4] 115 | target = [] 116 | for ic in range(5): 117 | ic_target = io.imread(self.target_list[ic][idx]) 118 | if ic > 0: 119 | last_target_max = target[ic-1].max() 120 | ic_target[ic_target>0] += last_target_max 121 | target.append(ic_target) 122 | target = np.stack(target, axis=2) 123 | if self.if_training: 124 | aug_type = random.randint(0, 5) # rot90: 0, 1, 2; flip: 3, 4; ori: 5 125 | if aug_type<=2: 126 | image = np.rot90(image, k=aug_type+1, axes=(0, 1)).copy() 127 | target = np.rot90(target, k=aug_type+1, axes=(0, 1)).copy() 128 | elif aug_type<=4: 129 | image = np.flip(image, axis=aug_type-3).copy() 130 | target = np.flip(target, axis=aug_type-3).copy() 131 | obj_probabilities = [] 132 | allcls_target = target.max(axis=2) 133 | distances = star_dist(allcls_target, self.n_rays) 134 | if self.max_dist: 135 | distances[distances>self.max_dist] = self.max_dist 136 | obj_probabilities = edt_prob(allcls_target) 137 | # seg + 1 !!! 138 | # background = 0; foreground = 1:5 139 | seg_target = ((np.argmax(target, axis=2)+1).astype(np.int64) * (target.max(axis=2)>0).astype(np.int64)).astype(np.int64) 140 | 141 | if self.resz is not None: 142 | image = resize(image, self.resz, order=1, preserve_range=True) 143 | obj_probabilities = resize(obj_probabilities, self.resz, order=1, preserve_range=True) 144 | distances = resize(distances, self.resz, order=1, preserve_range=True) 145 | seg_target = resize(seg_target, self.resz, order=0, preserve_range=True) 146 | image = np.transpose(image, (2, 0, 1)) 147 | distances = np.transpose(distances, (2,0,1)) 148 | obj_probabilities = np.expand_dims(obj_probabilities, axis=0) 149 | # seg_target = np.expand_dims(seg_target, axis=0) 150 | # scio.savemat('test.mat', {'image':image, 'prob':obj_probabilities, 'dist':distances, 'target':target, 'seg':seg_target}) 151 | # print(self.img_list[idx]) 152 | return image, obj_probabilities, distances, seg_target 153 | 154 | def getPanNukeDataLoaders(n_rays, max_dist, root_dir, type_list=['fold_1', 'fold_2'], batch_size=8, resz=None): 155 | trainset = PanNukeDataset(root_dir=root_dir+'/'+type_list[0]+'/', n_rays=n_rays, max_dist=max_dist, if_training=True, resz=resz) 156 | testset = PanNukeDataset(root_dir=root_dir+'/'+type_list[1]+'/', n_rays=n_rays, max_dist=max_dist, if_training=False, resz=resz) 157 | trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4) 158 | testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4) 159 | return trainloader,testloader -------------------------------------------------------------------------------- /cppnet/distance_loss_sampling_refine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def dice_loss(pred, target, eps=1e-7): 6 | b = pred.shape[0] 7 | n_cls = pred.shape[1] 8 | loss = 0.0 9 | for ic in range(n_cls): 10 | ic_target = (target == ic).float().view(b, -1) 11 | ic_pred = pred[:, ic, :, :].view(b, -1) 12 | loss += (2*(ic_pred*ic_target).sum(dim=1)+eps) / (ic_pred.pow(2).sum(dim=1)+ic_target.pow(2).sum(dim=1)+eps) 13 | loss /= n_cls 14 | loss = 1.0 - loss.mean() 15 | return loss 16 | 17 | class L1Loss_List_withSAP_withSeg(torch.nn.Module): 18 | def __init__(self, feature_extractor=None, scale=[1,1,1,1]): 19 | super(L1Loss_List_withSAP_withSeg, self).__init__() 20 | 21 | self.scale = scale 22 | self.feature_extractor = feature_extractor 23 | if self.feature_extractor is not None: 24 | assert len(scale)==4 25 | else: 26 | assert len(scale)==3 27 | def forward(self, prediction, target_dists, **kwargs): 28 | 29 | prob = kwargs.get('labels', None) 30 | pred_dists = prediction[0] 31 | pred_probs = prediction[1] 32 | 33 | l1loss = 0.0 34 | bceloss = 0.0 35 | 36 | for i_dist in pred_dists: 37 | l1loss_map = F.l1_loss(i_dist, target_dists, reduction='none') 38 | l1loss += torch.mean(prob*l1loss_map) 39 | for i_prob in pred_probs: 40 | bceloss += F.binary_cross_entropy(i_prob, prob) 41 | 42 | loss = self.scale[0]*l1loss + self.scale[1]*bceloss 43 | 44 | if self.scale[2] > 0: 45 | segloss = 0.0 46 | pred_segs = prediction[2] 47 | seg = (prob>0).float() 48 | for i_seg in pred_segs: 49 | segloss += F.binary_cross_entropy(i_seg, seg) 50 | loss += self.scale[2]*segloss 51 | 52 | metric = loss.data.clone().cpu() 53 | 54 | if self.feature_extractor is not None: 55 | self.feature_extractor.zero_grad() 56 | sap_loss = 0.0 57 | f_target = self.feature_extractor(torch.cat((prob, target_dists), dim=1)) 58 | for i_dist in pred_dists: 59 | f_pred = self.feature_extractor(torch.cat((pred_probs[-1]*pred_segs[-1], i_dist*pred_segs[-1]), dim=1)) 60 | sap_loss += F.l1_loss(f_pred, f_target) 61 | loss += self.scale[3]*sap_loss 62 | else: 63 | sap_loss = 0.0 64 | 65 | # print('loss: {:.5f}, metric: {:.5f}, l1: {:.5f}, bce: {:.5f}, seg: {:.5f}, sap: {:.5f}'\ 66 | # .format(loss, metric, l1loss, bceloss, segloss, sap_loss)) 67 | 68 | return loss, metric 69 | -------------------------------------------------------------------------------- /cppnet/distance_loss_sampling_refine_cls.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def dice_loss(pred, target, eps=1e-7): 6 | b = pred.shape[0] 7 | n_cls = pred.shape[1] 8 | loss = 0.0 9 | for ic in range(n_cls): 10 | ic_target = (target == ic).float().view(b, -1) 11 | ic_pred = pred[:, ic, :, :].view(b, -1) 12 | loss += (2*(ic_pred*ic_target).sum(dim=1)+eps) / (ic_pred.pow(2).sum(dim=1)+ic_target.pow(2).sum(dim=1)+eps) 13 | loss /= n_cls 14 | loss = 1.0 - loss.mean() 15 | return loss 16 | 17 | class L1Loss_List_withSAP_withSeg(torch.nn.Module): 18 | def __init__(self, feature_extractor, scale=[1,1,1,1], cls_balance_mode=False): 19 | super(L1Loss_List_withSAP_withSeg, self).__init__() 20 | assert len(scale)==4 21 | self.scale = scale 22 | self.feature_extractor = feature_extractor 23 | if self.feature_extractor is not None: 24 | assert len(scale)==4 25 | else: 26 | assert len(scale)==3 27 | self.cls_balance_mode = cls_balance_mode 28 | assert(self.cls_balance_mode in [True, False]) 29 | 30 | def forward(self, prediction, target_dists, **kwargs): 31 | 32 | prob = kwargs.get('labels', None) 33 | pred_dists = prediction[0] 34 | pred_probs = prediction[1] 35 | 36 | l1loss = 0.0 37 | bceloss = 0.0 38 | segloss = 0.0 39 | for i_dist in pred_dists: 40 | l1loss_map = F.l1_loss(i_dist, target_dists, reduction='none') 41 | l1loss += torch.mean(prob*l1loss_map) 42 | for i_prob in pred_probs: 43 | bceloss += F.binary_cross_entropy(i_prob, prob) 44 | for i_seg in pred_segs: 45 | if self.cls_balance_mode: 46 | segloss_map = F.cross_entropy(i_seg, seg_target, reduction='none') 47 | cur_segloss = 0.0 48 | for ic in range(i_seg.shape[1]): 49 | icmask = (seg_target==ic).float() 50 | cur_segloss += ((icmask * segloss_map).sum()+1e-5) / (icmask.sum()+1e-5) 51 | segloss += cur_segloss / i_seg.shape[1] 52 | + dice_loss(F.softmax(i_seg, dim=1), seg_target) 53 | else: 54 | segloss += F.cross_entropy(i_seg, seg_target) \ 55 | + dice_loss(F.softmax(i_seg, dim=1), seg_target) 56 | 57 | loss = self.scale[0]*l1loss + self.scale[1]*bceloss 58 | metric = loss.data.clone().cpu() 59 | loss += self.scale[2]*segloss 60 | 61 | if self.feature_extractor is not None: 62 | self.feature_extractor.zero_grad() 63 | sap_loss = 0.0 64 | target_mask = (target_dists.max(dim=1, keepdim=True)[0]>0).float() 65 | f_target = self.feature_extractor(torch.cat((prob, target_dists), dim=1)) 66 | for i_dist in pred_dists: 67 | f_pred = self.feature_extractor(torch.cat((pred_probs[-1], i_dist*target_mask), dim=1)) 68 | sap_loss += F.l1_loss(f_pred, f_target) 69 | loss += self.scale[3]*sap_loss 70 | else: 71 | sap_loss = 0.0 72 | 73 | # print('loss: {:.5f}, metric: {:.5f}, l1: {:.5f}, bce: {:.5f}, seg: {:.5f}, sap: {:.5f}'\ 74 | # .format(loss, metric, l1loss, bceloss, segloss, sap_loss)) 75 | 76 | return loss, metric 77 | -------------------------------------------------------------------------------- /cppnet/load_save_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import shutil 4 | 5 | def load_model(model,Model_name,Train_mode,Dataset): 6 | Model_name=Model_name.upper() 7 | Train_mode=Train_mode.upper() 8 | Dataset =Dataset.upper() 9 | filepath= os.getcwd()+'/'+Dataset+'/'+Train_mode+'/'+Model_name+'/'+Model_name+'_'+Train_mode+'_'+Dataset+'.t7' 10 | ############################# 11 | print('File to be loaded:'+filepath) 12 | if os.path.isfile(filepath): 13 | try: 14 | model=model.module #For DATAPARALLEL 15 | except: 16 | pass 17 | print('Loading File: '+filepath) 18 | model.load_state_dict(torch.load(filepath)) 19 | return model 20 | else: 21 | print ('WARNING!!!: Weight of '+Model_name+' not loaded. No Existing file') 22 | return model 23 | 24 | 25 | def save_model(model,trainAcc_to_file,testAcc_to_file,trainloss_to_file,testloss_to_file,Parameters, 26 | Model_name,Train_mode,Dataset,model2=None,**kwargs): 27 | try: 28 | model=model.module 29 | except: 30 | pass 31 | 32 | path= kwargs['save_path'] 33 | if not os.path.exists(path): 34 | os.makedirs(path) 35 | 36 | stage='' 37 | if model2 is not None: 38 | weights_filename1=Model_name+'_'+Train_mode+'_'+Dataset+'_1.t7' 39 | weights_filename2=Model_name+'_'+Train_mode+'_'+Dataset+'_2.t7' 40 | torch.save(model.state_dict(),path+weights_filename1) 41 | torch.save(model2.state_dict(),path+weights_filename2) 42 | else: 43 | weights_filename=Model_name+'_'+Train_mode+'_'+Dataset+'.t7' 44 | torch.save(model.state_dict(),path+weights_filename) 45 | print(path+weights_filename+' saved') 46 | 47 | if testAcc_to_file is not None: 48 | testacc_filename='Testacc_'+stage+Model_name+'_'+Train_mode+'_'+Dataset+'.csv' 49 | if os.path.isfile(path+testacc_filename): 50 | thefile = open(path+testacc_filename, 'a') 51 | else: 52 | thefile = open(path+testacc_filename, 'w') 53 | for item in testAcc_to_file: 54 | thefile.write("%s," % item) 55 | thefile.close() 56 | 57 | if testloss_to_file is not None: 58 | testloss_filename='Testloss_'+stage+Model_name+'_'+Train_mode+'_'+Dataset+'.csv' 59 | if os.path.isfile(path+testloss_filename): 60 | thefile = open(path+testloss_filename, 'a') 61 | else: 62 | thefile = open(path+testloss_filename, 'w') 63 | for item in testloss_to_file: 64 | thefile.write("%s," % item) 65 | thefile.close() 66 | 67 | if trainloss_to_file is not None: 68 | trainloss_filename='Trainloss_'+stage+Model_name+'_'+Train_mode+'_'+Dataset+'.csv' 69 | if os.path.isfile(path+trainloss_filename): 70 | thefile = open(path+trainloss_filename, 'a') 71 | else: 72 | thefile = open(path+trainloss_filename, 'w') 73 | for item in trainloss_to_file: 74 | thefile.write("%s," % item) 75 | thefile.close() 76 | 77 | if trainAcc_to_file is not None: 78 | trainacc_filename='Trainacc_'+stage+Model_name+'_'+Train_mode+'_'+Dataset+'.csv' 79 | if os.path.isfile(path+trainacc_filename): 80 | thefile = open(path+trainacc_filename, 'a') 81 | else: 82 | thefile = open(path+trainacc_filename, 'w') 83 | for item in trainAcc_to_file: 84 | thefile.write("%s," % item) 85 | thefile.close() 86 | 87 | param_filename='Parameters_'+Model_name+'_'+Train_mode+'_'+Dataset+'.txt' 88 | if os.path.isfile(path+param_filename): 89 | thefile = open(path+param_filename, 'a') 90 | else: 91 | thefile = open(path+param_filename, 'w') 92 | thefile.write('%s \n' %stage) 93 | thefile.write("Patience_scheduler=%s, Weight_decay=%s \n" %(Parameters[2],Parameters[3])) 94 | if not Parameters[1][0][1:] == Parameters[1][0][:-1]: 95 | for i in range(len(Parameters[1][0])): 96 | thefile.write("Initial learning rate for param_groups %s is %s epochs \n" %(str(i),Parameters[1][0][i])) 97 | else: 98 | thefile.write("Initial learning rate is %s epochs \n" %Parameters[1][0][0]) 99 | thefile.write("\n\n" ) 100 | 101 | for epoch,lr in zip(Parameters[0],Parameters[1][1:]): 102 | thefile.write("In epoch %s, maximum of the learning rates decreased to %s \n" %(epoch, lr)) 103 | thefile.write("Trained for %s epochs \n\n" %Parameters[0][-1]) 104 | 105 | thefile.write("Train Statistics \n") 106 | if trainAcc_to_file is not None: 107 | thefile.write('Accuracy: %s \n' %trainAcc_to_file[-1]) 108 | thefile.write('Average Loss: %s \n\n'%trainloss_to_file[-1]) 109 | 110 | thefile.write("Test Statistics \n") 111 | if testAcc_to_file is not None: 112 | thefile.write('Accuracy: %s \n' %testAcc_to_file[-1]) 113 | for i in range(len(testAcc_to_file)): 114 | if testAcc_to_file[i]==testAcc_to_file[-1]: 115 | break 116 | if i+1==len(testAcc_to_file): 117 | i=-1 118 | thefile.write('Maximum test accuracy in epoch %s (if 0 it means that the initial state was the best)\n\n'%str(i+1)) 119 | 120 | thefile.write('Average Loss: %s \n\n'%testloss_to_file[-1]) 121 | thefile.write('Total time elapsed %s\n\n' %Parameters[4]) 122 | thefile.write('Note: %s\n\n' %kwargs['additional_notes']) 123 | thefile.write(20*'-'+'\n\n') 124 | thefile.close() 125 | print(os.getcwd()+'/CHECKPOINT/checkpoint_'+Model_name+'_'+Train_mode+'_'+Dataset) 126 | shutil.rmtree(os.getcwd()+'/CHECKPOINT/checkpoint_'+Model_name+'_'+Train_mode+'_'+Dataset) 127 | 128 | 129 | ''' 130 | def checkpoint_save(model,trainAcc_to_file,testAcc_to_file,trainloss_to_file,testloss_to_file,Parameters,Model_name,Train_mode,Dataset): 131 | 132 | path=os.getcwd()+'/CHECKPOINT/checkpoint_'+Model_name+'_'+Train_mode+'_'+Dataset 133 | if not os.path.exists(path): 134 | os.makedirs(path) 135 | 136 | torch.save(model.state_dict(),path+'/CHECKPOINT.t7') 137 | print(path+'/CHECKPOINT.t7'+' saved') 138 | 139 | thefile = open(path+'/Testacc_CHECKPOINT.csv', 'w') 140 | for item in testAcc_to_file: 141 | thefile.write("%s," % item) 142 | thefile.close() 143 | 144 | 145 | thefile = open(path+'/Testloss_CHECKPOINT.csv', 'w') 146 | for item in testloss_to_file: 147 | thefile.write("%s," % item) 148 | thefile.close() 149 | 150 | thefile = open(path+'/Trainloss_CHECKPOINT.csv', 'w') 151 | for item in trainloss_to_file: 152 | thefile.write("%s," % item) 153 | thefile.close() 154 | 155 | 156 | thefile = open(path+'/Trainacc_CHECKPOINT.csv', 'w') 157 | for item in trainAcc_to_file: 158 | thefile.write("%s," % item) 159 | thefile.close() 160 | 161 | thefile = open(path+'/Parameters_CHECKPOINT.txt', 'w') 162 | 163 | thefile.write("Patience_scheduler=%s, Weight_decay=%s \n" %(Parameters[2],Parameters[3])) 164 | if not Parameters[1][0][1:] == Parameters[1][0][:-1]: 165 | for i in range(len(Parameters[1][0])): 166 | thefile.write("Initial learning rate for param_grooups %s is %s epochs \n" %(str(i),Parameters[1][0][i])) 167 | else: 168 | thefile.write("Initial learning rate is %s epochs \n" %Parameters[1][0][0]) 169 | 170 | for epoch,lr in zip(Parameters[0],Parameters[1][1:]): 171 | thefile.write("In epoch %s, maximum learning rate decreased to %s \n" %(epoch, lr)) 172 | if not(Parameters[0]==[]): 173 | thefile.write("Trained for %s epochs \n" %Parameters[0][-1]) 174 | thefile.write("\n\n" ) 175 | if not(trainAcc_to_file==[]): 176 | thefile.write("Train Statistics \n") 177 | thefile.write('Accuracy: %s \n' %trainAcc_to_file[-1]) 178 | thefile.write('Average Loss: %s \n\n'%trainloss_to_file[-1]) 179 | 180 | thefile.write("Test Statistics \n") 181 | thefile.write('Accuracy: %s \n' %testAcc_to_file[-1]) 182 | thefile.write('Average Loss: %s \n\n'%testloss_to_file[-1]) 183 | thefile.write(20*'-'+'\n\n') 184 | thefile.close() 185 | 186 | 187 | ###################################################################################################### 188 | ''' 189 | def checkpoint_save_stage(model,trainloss_to_file,testloss_to_file,train_metric_to_file,test_metric_to_file,Parameters,Model_name,Train_mode,Dataset,model2=None): 190 | 191 | path=os.getcwd()+'/CHECKPOINT/checkpoint_'+Model_name+'_'+Train_mode+'_'+Dataset 192 | if not os.path.exists(path): 193 | os.makedirs(path) 194 | 195 | if model2 is not None: 196 | torch.save(model.state_dict(),path+'/CHECKPOINT1.t7') 197 | torch.save(model2.state_dict(),path+'/CHECKPOINT2.t7') 198 | else: 199 | torch.save(model.state_dict(),path+'/CHECKPOINT.t7') 200 | print(path+'/CHECKPOINT.t7'+' saved') 201 | 202 | thefile = open(path+'/Testloss_CHECKPOINT.csv', 'w') 203 | for item in testloss_to_file: 204 | thefile.write("%s," % item) 205 | thefile.close() 206 | 207 | thefile = open(path+'/Trainloss_CHECKPOINT.csv', 'w') 208 | for item in trainloss_to_file: 209 | thefile.write("%s," % item) 210 | thefile.close() 211 | 212 | thefile = open(path+'/Parameters_CHECKPOINT.txt', 'w') 213 | thefile.write("STAGE1 \n" ) 214 | thefile.write("Patience_scheduler=%s, Weight_decay=%s \n" %(Parameters[2],Parameters[3])) 215 | if not Parameters[1][0][1:] == Parameters[1][0][:-1]: 216 | for i in range(len(Parameters[1][0])): 217 | thefile.write("Initial learning rate for param_groups %s is %s epochs \n" %(str(i),Parameters[1][0][i])) 218 | else: 219 | thefile.write("Initial learning rate is %s epochs \n" %Parameters[1][0][0]) 220 | 221 | for epoch,lr in zip(Parameters[0],Parameters[1][1:]): 222 | thefile.write("In epoch %s, maximum learning rate decreased to %s \n" %(epoch, lr)) 223 | if not(Parameters[0]==[]): 224 | thefile.write("Trained for %s epochs \n" %Parameters[0][-1]) 225 | thefile.write("\n\n" ) 226 | if not(trainloss_to_file==[]): 227 | thefile.write("Train Statistics \n") 228 | thefile.write('Accuracy: %s \n' %train_metric_to_file[-1]) 229 | thefile.write('Average Loss: %s \n\n'%trainloss_to_file[-1]) 230 | 231 | thefile.write("Test Statistics \n") 232 | thefile.write('Accuracy: %s \n' %test_metric_to_file[-1]) 233 | thefile.write('Average Loss: %s \n\n'%testloss_to_file[-1]) 234 | thefile.write(20*'-'+'\n\n') 235 | thefile.close() 236 | -------------------------------------------------------------------------------- /cppnet/main_cppnet_dsb.py: -------------------------------------------------------------------------------- 1 | import os 2 | print('Working dir', os.getcwd()) 3 | from load_save_model import save_model 4 | from train_sampling_refine_withgt_separate_metric import Trainer 5 | import torch.optim 6 | from torch.optim.lr_scheduler import ReduceLROnPlateau 7 | from models.cpp_net import CPPNet 8 | from models.feature_extractor import Feature_Extractor 9 | from distance_loss_sampling_refine import L1Loss_List_withSAP_withSeg as L1BCELoss 10 | import dataloader_custom 11 | import random 12 | import numpy as np 13 | 14 | import argparse 15 | 16 | 17 | def run(data_path, nc_in=1, init_lr=1e-4, n_rays=32, SAP_weight_path=None, n_sampling=6, K=5): 18 | crop_sz = None 19 | print(n_sampling, K) 20 | erosion_factor_list = [float(i+1)/n_sampling for i in range(n_sampling)] 21 | print(erosion_factor_list) 22 | 23 | for irnd in range(K): 24 | 25 | Trainloader, Testloader = dataloader_custom.getDataLoaders(n_rays, max_dist=None, root_dir=data_path) 26 | 27 | model = CPPNet(nc_in, n_rays, erosion_factor_list=erosion_factor_list) 28 | 29 | if SAP_weight_path is not None: 30 | SAP_weight = torch.load(SAP_weight_path) 31 | SAP_model = Feature_Extractor(n_rays+1, 32) 32 | SAP_model_weight = SAP_model.state_dict() 33 | for k, v in SAP_weight.items(): 34 | if k in SAP_model_weight.keys(): 35 | SAP_model_weight.update({k:v}) 36 | print('Loaded: ', k, v.shape) 37 | SAP_model.load_state_dict(SAP_model_weight) 38 | SAP_model = SAP_model.cuda() 39 | SAP_model.eval() 40 | loss_scale = [1,1,1,1] 41 | else: 42 | SAP_model = None 43 | loss_scale = [1,1,1] 44 | loss = L1BCELoss(SAP_model, loss_scale) 45 | 46 | model_name='UNet2D_sampling_ensemble_n' + str(len(erosion_factor_list)) + '_r' + str(irnd) + '_weight_correct_conf_train3' + '_' + str(n_sampling) + '_withseg' + '_SAP_loss' + '_Others' 47 | print('model='+model_name) 48 | dataset='DSB2018_aug' 49 | print('dataset='+dataset) 50 | train_mode='StarDist' 51 | print('No.of rays',n_rays) 52 | 53 | kwargs={} 54 | additional_notes= '.' 55 | kwargs['additional_notes'] = additional_notes 56 | SAVE_PATH = os.getcwd()+'/'+dataset+'/'+train_mode+'_'+model_name+'/' 57 | kwargs['save_path'] = SAVE_PATH 58 | RESULTS_DIRECTORY = os.getcwd()+'/'+dataset+'/'+train_mode+'_'+model_name+'/plots/' 59 | 60 | trainer = Trainer(loss, None, None, validate_every=2) 61 | optimizer = torch.optim.Adam(model.parameters(), lr=init_lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=5e-5) 62 | scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, verbose=True, patience=10, eps=1e-8, threshold=1e-20) 63 | 64 | print ('Starting Training') 65 | # # Pre train 66 | trainer.pretrain(model, Trainloader, optimizer, 5) 67 | trainloss_to_file, testloss_to_file, trainMetric_to_file, testMetric_to_file, Parameters = trainer.Train( 68 | model,optimizer, 69 | Trainloader,Testloader,epochs=None,Train_mode=train_mode, 70 | Model_name=model_name, 71 | Dataset=dataset,scheduler=scheduler 72 | ) 73 | print('Saving Final Model') 74 | save_model(model, trainMetric_to_file, testMetric_to_file, trainloss_to_file, testloss_to_file, Parameters, model_name,train_mode,dataset,plot=False,**kwargs) 75 | 76 | 77 | DATA_PATH = '/data/cong/datasets/dsb2018/dsb2018_in_stardist/dsb2018/dataset_split_for_training' 78 | SAP_Weight_path = None# '/data/cong/workplace/stardist/shape_project/DSB2018_aug/StarDist2Others_32_UNet2D_32d_rnd0/UNet2D_32d_rnd0_StarDist2Others_32_DSB2018_aug.t7' 79 | 80 | if __name__ == '__main__': 81 | parser = argparse.ArgumentParser() 82 | parser.add_argument('--gpuid', type=int, default=0) 83 | parser.add_argument('--n_rays', type=int, default=32) 84 | parser.add_argument('--n_sampling', type=int, default=6) 85 | parser.add_argument('--nc_in', type=int, default=1) 86 | args = parser.parse_args() 87 | 88 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpuid) 89 | # torch.set_num_threads(8) 90 | 91 | run(DATA_PATH, nc_in=args.nc_in, SAP_weight_path=SAP_Weight_path, n_rays=args.n_rays, n_sampling=args.n_sampling) 92 | -------------------------------------------------------------------------------- /cppnet/main_cppnet_pannuke.py: -------------------------------------------------------------------------------- 1 | import os 2 | print('Working dir', os.getcwd()) 3 | from load_save_model import save_model 4 | from train_sampling_refine_withgt_separate_metric import Trainer 5 | import torch.optim 6 | from torch.optim.lr_scheduler import ReduceLROnPlateau 7 | 8 | # from models.cpp_net import CPPNet 9 | from models.cpp_net_res50 import CPPNet 10 | 11 | from models.feature_extractor import Feature_Extractor 12 | from distance_loss_sampling_refine_cls import L1Loss_List_withSAP_withSeg as L1BCELoss 13 | import dataloader_custom 14 | import random 15 | import numpy as np 16 | 17 | import argparse 18 | 19 | 20 | def run(data_path, nc_in=1, init_lr=1e-4, n_rays=32, n_cls=6, SAP_weight_path=None, n_sampling=6, train_type_idx=-1, cls_balance_mode=False): 21 | crop_sz = None 22 | erosion_factor_list = [float(i+1)/n_sampling for i in range(n_sampling)] 23 | print(erosion_factor_list) 24 | 25 | # https://warwick.ac.uk/fac/cross_fac/tia/data/pannuke 26 | train_type_list = [['fold_1', 'fold_2'], ['fold_2', 'fold_1'], ['fold_3', 'fold_2']] 27 | assert(train_type_idx <= len(train_type_list) and train_type_idx >= 0) 28 | Trainloader, Testloader = dataloader_custom.getPanNukeDataLoaders(n_rays, root_dir=data_path, type_list=train_type_list[train_type_idx], batch_size=6) 29 | 30 | model = CPPNet(nc_in, n_rays, erosion_factor_list=erosion_factor_list, n_seg_cls=n_cls) 31 | 32 | if SAP_weight_path is not None: 33 | SAP_weight = torch.load(SAP_weight_path) 34 | SAP_model = Feature_Extractor(n_rays+1+n_cls, 32) 35 | SAP_model_weight = SAP_model.state_dict() 36 | for k, v in SAP_weight.items(): 37 | if k in SAP_model_weight.keys(): 38 | SAP_model_weight.update({k:v}) 39 | print('Loaded: ', k, v.shape) 40 | SAP_model.load_state_dict(SAP_model_weight) 41 | SAP_model = SAP_model.cuda() 42 | SAP_model.eval() 43 | loss_scale = [1,1,1,1] 44 | else: 45 | SAP_model = None 46 | loss_scale = [1,1,1] 47 | loss = L1BCELoss(SAP_model, loss_scale, cls_balance_mode=cls_balance_mode) 48 | 49 | model_name='UNet2D_sampling_ensemble_n' + str(len(erosion_factor_list)) + '_r' + str(train_type_idx) + '_weight_correct_conf_train3' + '_' + str(n_sampling) + '_withseg' + '_SAP_loss' + '_Others' 50 | print('model='+model_name) 51 | dataset='PanNuke' 52 | print('dataset='+dataset) 53 | train_mode='StarDist' 54 | print('No.of rays',n_rays) 55 | 56 | kwargs={} 57 | additional_notes= '.' 58 | kwargs['additional_notes'] = additional_notes 59 | SAVE_PATH = os.getcwd()+'/'+dataset+'/'+train_mode+'_'+model_name+'/' 60 | kwargs['save_path'] = SAVE_PATH 61 | RESULTS_DIRECTORY = os.getcwd()+'/'+dataset+'/'+train_mode+'_'+model_name+'/plots/' 62 | 63 | trainer = Trainer(loss, None, None, validate_every=2) 64 | optimizer = torch.optim.Adam(model.parameters(), lr=init_lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=5e-5) 65 | scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, verbose=True, patience=10, eps=1e-8, threshold=1e-20) 66 | 67 | print ('Starting Training') 68 | # # Pre train 69 | trainer.pretrain(model, Trainloader, optimizer, 5) 70 | trainloss_to_file, testloss_to_file, trainMetric_to_file, testMetric_to_file, Parameters = trainer.Train( 71 | model,optimizer, 72 | Trainloader,Testloader,epochs=None,Train_mode=train_mode, 73 | Model_name=model_name, 74 | Dataset=dataset,scheduler=scheduler 75 | ) 76 | print('Saving Final Model') 77 | save_model(model, trainMetric_to_file, testMetric_to_file, trainloss_to_file, testloss_to_file, Parameters, model_name,train_mode,dataset,plot=False,**kwargs) 78 | 79 | 80 | DATA_PATH = '/data/cong/dataset/pannuke/reorganized_dataset' 81 | SAP_Weight_path = None 82 | 83 | if __name__ == '__main__': 84 | parser = argparse.ArgumentParser() 85 | parser.add_argument('--gpuid', type=int, default=0) 86 | parser.add_argument('--n_rays', type=int, default=32) 87 | parser.add_argument('--n_sampling', type=int, default=6) 88 | parser.add_argument('--nc_in', type=int, default=1) 89 | parser.add_argument('--nc_cls', type=int, default=6) 90 | args = parser.parse_args() 91 | 92 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpuid) 93 | # torch.set_num_threads(8) 94 | 95 | run(DATA_PATH, nc_in=args.nc_in, SAP_weight_path=SAP_Weight_path, n_rays=args.n_rays, n_sampling=args.n_sampling, nc_cls=args.nc_cls) 96 | -------------------------------------------------------------------------------- /cppnet/metric_v2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from numba import jit 4 | from tqdm import tqdm 5 | from scipy.optimize import linear_sum_assignment 6 | from collections import namedtuple 7 | from csbdeep.utils import _raise 8 | 9 | from stardist import dist_to_coord, non_maximum_suppression, polygons_to_label 10 | 11 | 12 | matching_criteria = dict() 13 | 14 | 15 | 16 | def label_are_sequential(y): 17 | """ returns true if y has only sequential labels from 1... """ 18 | labels = np.unique(y) 19 | return (set(labels)-{0}) == set(range(1,1+labels.max())) 20 | 21 | 22 | 23 | def is_array_of_integers(y): 24 | return isinstance(y,np.ndarray) and np.issubdtype(y.dtype, np.integer) 25 | 26 | def _check_label_array(y, name=None, check_sequential=False): 27 | err = ValueError("{label} must be an array of {integers}.".format( 28 | label = 'labels' if name is None else name, 29 | integers = ('sequential ' if check_sequential else '') + 'non-negative integers', 30 | )) 31 | is_array_of_integers(y) or _raise(err) 32 | if check_sequential: 33 | label_are_sequential(y) or _raise(err) 34 | else: 35 | y.min() >= 0 or _raise(err) 36 | return True 37 | 38 | 39 | 40 | def label_overlap(x, y, check=True): 41 | if check: 42 | _check_label_array(x,'x',True) 43 | _check_label_array(y,'y',True) 44 | x.shape == y.shape or _raise(ValueError("x and y must have the same shape")) 45 | return _label_overlap(x, y) 46 | 47 | @jit(nopython=True) 48 | def _label_overlap(x, y): 49 | x = x.ravel() 50 | y = y.ravel() 51 | overlap = np.zeros((1+x.max(),1+y.max()), dtype=np.uint) 52 | for i in range(len(x)): 53 | overlap[x[i],y[i]] += 1 54 | return overlap 55 | 56 | 57 | 58 | def intersection_over_union(overlap): 59 | _check_label_array(overlap,'overlap') 60 | if np.sum(overlap) == 0: 61 | return overlap 62 | n_pixels_pred = np.sum(overlap, axis=0, keepdims=True) 63 | n_pixels_true = np.sum(overlap, axis=1, keepdims=True) 64 | return overlap / (n_pixels_pred + n_pixels_true - overlap) 65 | 66 | matching_criteria['iou'] = intersection_over_union 67 | 68 | 69 | 70 | def intersection_over_true(overlap): 71 | _check_label_array(overlap,'overlap') 72 | if np.sum(overlap) == 0: 73 | return overlap 74 | n_pixels_true = np.sum(overlap, axis=1, keepdims=True) 75 | return overlap / n_pixels_true 76 | 77 | matching_criteria['iot'] = intersection_over_true 78 | 79 | 80 | 81 | def intersection_over_pred(overlap): 82 | _check_label_array(overlap,'overlap') 83 | if np.sum(overlap) == 0: 84 | return overlap 85 | n_pixels_pred = np.sum(overlap, axis=0, keepdims=True) 86 | return overlap / n_pixels_pred 87 | 88 | matching_criteria['iop'] = intersection_over_pred 89 | 90 | 91 | 92 | def precision(tp,fp,fn): 93 | return tp/(tp+fp) if tp > 0 else 0 94 | def recall(tp,fp,fn): 95 | return tp/(tp+fn) if tp > 0 else 0 96 | def accuracy(tp,fp,fn): 97 | # also known as "average precision" (?) 98 | # -> https://www.kaggle.com/c/data-science-bowl-2018#evaluation 99 | return tp/(tp+fp+fn) if tp > 0 else 0 100 | def f1(tp,fp,fn): 101 | # also known as "dice coefficient" 102 | return (2*tp)/(2*tp+fp+fn) if tp > 0 else 0 103 | 104 | 105 | def wrap_match(y_pred, y_true): 106 | dists = y_pred[0].detach().cpu().numpy().squeeze() 107 | probs = y_pred[1].detach().cpu().numpy().squeeze() 108 | y_true = y_true.detach().cpu().numpy().squeeze().astype(np.int16) 109 | dists = np.transpose(dists,(1,2,0)) 110 | coord = dist_to_coord(dists) 111 | points = non_maximum_suppression(coord,probs,prob_thresh=0.4) 112 | star_label = polygons_to_label(coord,probs,points) 113 | stat = matching(y_true, star_label, thresh=0.5, criterion='iou', report_matches=False) 114 | return stat.accuracy 115 | 116 | def matching(y_true, y_pred, thresh=0.5, criterion='iou', report_matches=False): 117 | """ 118 | if report_matches=True, return (matched_pairs,matched_scores) are independent of 'thresh' 119 | """ 120 | _check_label_array(y_true,'y_true') 121 | _check_label_array(y_pred,'y_pred') 122 | y_true.shape == y_pred.shape or _raise(ValueError("y_true ({y_true.shape}) and y_pred ({y_pred.shape}) have different shapes".format(y_true=y_true, y_pred=y_pred))) 123 | criterion in matching_criteria or _raise(ValueError("Matching criterion '%s' not supported." % criterion)) 124 | if thresh is None: thresh = 0 125 | thresh = float(thresh) if np.isscalar(thresh) else map(float,thresh) 126 | 127 | y_true, _, map_rev_true = relabel_sequential(y_true) 128 | y_pred, _, map_rev_pred = relabel_sequential(y_pred) 129 | 130 | overlap = label_overlap(y_true, y_pred, check=False) 131 | scores = matching_criteria[criterion](overlap) 132 | assert 0 <= np.min(scores) <= np.max(scores) <= 1 133 | 134 | # ignoring background 135 | scores = scores[1:,1:] 136 | n_true, n_pred = scores.shape 137 | n_matched = min(n_true, n_pred) 138 | 139 | def _single(thr): 140 | not_trivial = n_matched > 0 and np.any(scores >= thr) 141 | if not_trivial: 142 | # compute optimal matching with scores as tie-breaker 143 | costs = -(scores >= thr).astype(float) - scores / (2*n_matched) 144 | true_ind, pred_ind = linear_sum_assignment(costs) 145 | assert n_matched == len(true_ind) == len(pred_ind) 146 | match_ok = scores[true_ind,pred_ind] >= thr 147 | tp = np.count_nonzero(match_ok) 148 | else: 149 | tp = 0 150 | fp = n_pred - tp 151 | fn = n_true - tp 152 | # assert tp+fp == n_pred 153 | # assert tp+fn == n_true 154 | stats_dict = dict ( 155 | criterion = criterion, 156 | thresh = thr, 157 | fp = fp, 158 | tp = tp, 159 | fn = fn, 160 | precision = precision(tp,fp,fn), 161 | recall = recall(tp,fp,fn), 162 | accuracy = accuracy(tp,fp,fn), 163 | f1 = f1(tp,fp,fn), 164 | n_true = n_true, 165 | n_pred = n_pred, 166 | mean_true_score = np.sum(scores[true_ind,pred_ind][match_ok]) / n_true if not_trivial else 0.0, 167 | ) 168 | if bool(report_matches): 169 | if not_trivial: 170 | stats_dict.update ( 171 | # int() to be json serializable 172 | matched_pairs = tuple((int(map_rev_true[i]),int(map_rev_pred[j])) for i,j in zip(1+true_ind,1+pred_ind)), 173 | matched_scores = tuple(scores[true_ind,pred_ind]), 174 | matched_tps = tuple(map(int,np.flatnonzero(match_ok))), 175 | ) 176 | else: 177 | stats_dict.update ( 178 | matched_pairs = (), 179 | matched_scores = (), 180 | matched_tps = (), 181 | ) 182 | return namedtuple('Matching',stats_dict.keys())(*stats_dict.values()) 183 | 184 | return _single(thresh) if np.isscalar(thresh) else tuple(map(_single,thresh)) 185 | 186 | 187 | 188 | def matching_dataset(y_true, y_pred, thresh=0.5, criterion='iou', by_image=False, show_progress=True, parallel=False): 189 | len(y_true) == len(y_pred) or _raise(ValueError("y_true and y_pred must have the same length.")) 190 | return matching_dataset_lazy ( 191 | tuple(zip(y_true,y_pred)), thresh=thresh, criterion=criterion, by_image=by_image, show_progress=show_progress, parallel=parallel, 192 | ) 193 | 194 | 195 | 196 | def matching_dataset_lazy(y_gen, thresh=0.5, criterion='iou', by_image=False, show_progress=True, parallel=False): 197 | 198 | expected_keys = set(('fp', 'tp', 'fn', 'precision', 'recall', 'accuracy', 'f1', 'criterion', 'thresh', 'n_true', 'n_pred', 'mean_true_score')) 199 | 200 | single_thresh = False 201 | if np.isscalar(thresh): 202 | single_thresh = True 203 | thresh = (thresh,) 204 | 205 | tqdm_kwargs = {} 206 | tqdm_kwargs['disable'] = not bool(show_progress) 207 | if int(show_progress) > 1: 208 | tqdm_kwargs['total'] = int(show_progress) 209 | 210 | # compute matching stats for every pair of label images 211 | if parallel: 212 | from concurrent.futures import ThreadPoolExecutor 213 | fn = lambda pair: matching(*pair, thresh=thresh, criterion=criterion, report_matches=False) 214 | with ThreadPoolExecutor() as pool: 215 | stats_all = tuple(pool.map(fn, tqdm(y_gen,**tqdm_kwargs))) 216 | else: 217 | stats_all = tuple ( 218 | matching(y_t, y_p, thresh=thresh, criterion=criterion, report_matches=False) 219 | for y_t,y_p in tqdm(y_gen,**tqdm_kwargs) 220 | ) 221 | 222 | # accumulate results over all images for each threshold separately 223 | n_images, n_threshs = len(stats_all), len(thresh) 224 | accumulate = [{} for _ in range(n_threshs)] 225 | for stats in stats_all: 226 | for i,s in enumerate(stats): 227 | acc = accumulate[i] 228 | for k,v in s._asdict().items(): 229 | if k == 'mean_true_score' and not bool(by_image): 230 | # convert mean_true_score to "sum_true_score" 231 | acc[k] = acc.setdefault(k,0) + v * s.n_true 232 | else: 233 | try: 234 | acc[k] = acc.setdefault(k,0) + v 235 | except TypeError: 236 | pass 237 | 238 | # normalize/compute 'precision', 'recall', 'accuracy', 'f1' 239 | for thr,acc in zip(thresh,accumulate): 240 | set(acc.keys()) == expected_keys or _raise(ValueError("unexpected keys")) 241 | acc['criterion'] = criterion 242 | acc['thresh'] = thr 243 | acc['by_image'] = bool(by_image) 244 | if bool(by_image): 245 | for k in ('precision', 'recall', 'accuracy', 'f1', 'mean_true_score'): 246 | acc[k] /= n_images 247 | else: 248 | tp, fp, fn = acc['tp'], acc['fp'], acc['fn'] 249 | acc.update( 250 | precision = precision(tp,fp,fn), 251 | recall = recall(tp,fp,fn), 252 | accuracy = accuracy(tp,fp,fn), 253 | f1 = f1(tp,fp,fn), 254 | mean_true_score = acc['mean_true_score'] / acc['n_true'] if acc['n_true'] > 0 else 0.0, 255 | ) 256 | 257 | accumulate = tuple(namedtuple('DatasetMatching',acc.keys())(*acc.values()) for acc in accumulate) 258 | return accumulate[0] if single_thresh else accumulate 259 | 260 | 261 | 262 | # copied from scikit-image master for now (remove when part of a release) 263 | def relabel_sequential(label_field, offset=1): 264 | """Relabel arbitrary labels to {`offset`, ... `offset` + number_of_labels}. 265 | This function also returns the forward map (mapping the original labels to 266 | the reduced labels) and the inverse map (mapping the reduced labels back 267 | to the original ones). 268 | Parameters 269 | ---------- 270 | label_field : numpy array of int, arbitrary shape 271 | An array of labels, which must be non-negative integers. 272 | offset : int, optional 273 | The return labels will start at `offset`, which should be 274 | strictly positive. 275 | Returns 276 | ------- 277 | relabeled : numpy array of int, same shape as `label_field` 278 | The input label field with labels mapped to 279 | {offset, ..., number_of_labels + offset - 1}. 280 | The data type will be the same as `label_field`, except when 281 | offset + number_of_labels causes overflow of the current data type. 282 | forward_map : numpy array of int, shape ``(label_field.max() + 1,)`` 283 | The map from the original label space to the returned label 284 | space. Can be used to re-apply the same mapping. See examples 285 | for usage. The data type will be the same as `relabeled`. 286 | inverse_map : 1D numpy array of int, of length offset + number of labels 287 | The map from the new label space to the original space. This 288 | can be used to reconstruct the original label field from the 289 | relabeled one. The data type will be the same as `relabeled`. 290 | Notes 291 | ----- 292 | The label 0 is assumed to denote the background and is never remapped. 293 | The forward map can be extremely big for some inputs, since its 294 | length is given by the maximum of the label field. However, in most 295 | situations, ``label_field.max()`` is much smaller than 296 | ``label_field.size``, and in these cases the forward map is 297 | guaranteed to be smaller than either the input or output images. 298 | Examples 299 | -------- 300 | >>> from skimage.segmentation import relabel_sequential 301 | >>> label_field = np.array([1, 1, 5, 5, 8, 99, 42]) 302 | >>> relab, fw, inv = relabel_sequential(label_field) 303 | >>> relab 304 | array([1, 1, 2, 2, 3, 5, 4]) 305 | >>> fw 306 | array([0, 1, 0, 0, 0, 2, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 307 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 308 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 309 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 310 | 0, 0, 0, 0, 0, 0, 0, 5]) 311 | >>> inv 312 | array([ 0, 1, 5, 8, 42, 99]) 313 | >>> (fw[label_field] == relab).all() 314 | True 315 | >>> (inv[relab] == label_field).all() 316 | True 317 | >>> relab, fw, inv = relabel_sequential(label_field, offset=5) 318 | >>> relab 319 | array([5, 5, 6, 6, 7, 9, 8]) 320 | """ 321 | offset = int(offset) 322 | if offset <= 0: 323 | raise ValueError("Offset must be strictly positive.") 324 | if np.min(label_field) < 0: 325 | raise ValueError("Cannot relabel array that contains negative values.") 326 | m = label_field.max() 327 | if not np.issubdtype(label_field.dtype, np.integer): 328 | new_type = np.min_scalar_type(int(m)) 329 | label_field = label_field.astype(new_type) 330 | m = m.astype(new_type) # Ensures m is an integer 331 | labels = np.unique(label_field) 332 | labels0 = labels[labels != 0] 333 | required_type = np.min_scalar_type(offset + len(labels0)) 334 | if np.dtype(required_type).itemsize > np.dtype(label_field.dtype).itemsize: 335 | label_field = label_field.astype(required_type) 336 | new_labels0 = np.arange(offset, offset + len(labels0)) 337 | if np.all(labels0 == new_labels0): 338 | return label_field, labels, labels 339 | forward_map = np.zeros(int(m + 1), dtype=label_field.dtype) 340 | forward_map[labels0] = new_labels0 341 | if not (labels == 0).any(): 342 | labels = np.concatenate(([0], labels)) 343 | inverse_map = np.zeros(offset - 1 + len(labels), dtype=label_field.dtype) 344 | inverse_map[(offset - 1):] = labels 345 | relabeled = forward_map[label_field] 346 | return relabeled, forward_map, inverse_map -------------------------------------------------------------------------------- /cppnet/models/SamplingFeatures2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import math 5 | #import time 6 | 7 | def feature_sampling(feature_map, coord_map, nd_sampling, sampling_mode='nearest'): 8 | b, c, h, w = feature_map.shape 9 | # coord_map: b, k, 2, h, w 10 | # 'k' for k rays in each image 11 | _, k, _, h, w = coord_map.shape 12 | x_ = torch.arange(w).view(1, -1).expand(h, -1) 13 | y_ = torch.arange(h).view(-1, 1).expand(-1, w) 14 | grid = torch.stack([x_, y_], dim=0).float() 15 | # grid: b, 1, 2, h, w 16 | grid = grid.unsqueeze(0).expand(b, -1, -1, -1, -1).cuda() 17 | # sampling_coord: b, k, 2, h, w 18 | sampling_coord = grid + coord_map 19 | sampling_coord[:, :, 0, :, :] = sampling_coord[:, :, 0, :, :]/(w-1) 20 | sampling_coord[:, :, 1, :, :] = sampling_coord[:, :, 1, :, :]/(h-1) 21 | sampling_coord = sampling_coord*2.0-1.0 22 | 23 | assert(k*nd_sampling==c) 24 | 25 | if nd_sampling > 0: 26 | sampling_coord = sampling_coord.permute(1, 0, 3, 4, 2).flatten(start_dim=0, end_dim=1) # kb, h, w, 2 27 | sampling_features = F.grid_sample(feature_map.view(b, k, nd_sampling, h, w).permute(1, 0, 2, 3, 4).flatten(start_dim=0, end_dim=1), sampling_coord, mode=sampling_mode) # kb, c', h, w 28 | sampling_features = sampling_features.view(k, b, nd_sampling, h, w).permute(1, 0, 2, 3, 4) # b, k, c', h, w 29 | else: 30 | sampling_coord = sampling_coord.permute(0, 1, 3, 4, 2).flatten(start_dim=1, end_dim=2) # b, kh, w, 2 31 | sampling_features = F.grid_sample(feature_map, sampling_coord, mode=sampling_mode) 32 | sampling_features = sampling_features.view(b, c, k, h, w).permute(0, 2, 1, 3, 4) # b, k, c'/c, h, w 33 | 34 | sampling_features = sampling_features.flatten(start_dim=1, end_dim=2) # b, k*c', h, w 35 | 36 | return sampling_features, sampling_coord 37 | 38 | class SamplingFeatures(nn.Module): 39 | def __init__(self, n_rays, sampling_mode='nearest'): 40 | super(SamplingFeatures, self).__init__() 41 | self.n_rays = n_rays 42 | self.angles = torch.arange(n_rays).float()/float(n_rays)*math.pi*2.0 # 0 - 2*pi 43 | self.sin_angles = torch.sin(self.angles).cuda().view(1, n_rays, 1, 1) 44 | self.cos_angles = torch.cos(self.angles).cuda().view(1, n_rays, 1, 1) 45 | self.sampling_mode = sampling_mode 46 | def forward(self, feature_map, dist, nd_sampling): 47 | # feature_map: b, c, h, w 48 | # dist: b, k, h, w 49 | # sampled_features: b, k*c, h, w 50 | offset_ih = self.sin_angles * dist 51 | offset_iw = self.cos_angles * dist 52 | offsets = torch.stack([offset_iw, offset_ih], dim=2) 53 | sampled_features, sampling_coord = feature_sampling(feature_map, offsets, nd_sampling, self.sampling_mode) 54 | return sampled_features, sampling_coord 55 | -------------------------------------------------------------------------------- /cppnet/models/__pycache__/SamplingFeatures2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csccsccsccsc/cpp-net/e99cbc790fddb1964355753be12ca8f10e51f247/cppnet/models/__pycache__/SamplingFeatures2.cpython-37.pyc -------------------------------------------------------------------------------- /cppnet/models/__pycache__/cpp_net.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csccsccsccsc/cpp-net/e99cbc790fddb1964355753be12ca8f10e51f247/cppnet/models/__pycache__/cpp_net.cpython-37.pyc -------------------------------------------------------------------------------- /cppnet/models/__pycache__/feature_extractor.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csccsccsccsc/cpp-net/e99cbc790fddb1964355753be12ca8f10e51f247/cppnet/models/__pycache__/feature_extractor.cpython-37.pyc -------------------------------------------------------------------------------- /cppnet/models/__pycache__/unet_parts_gn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csccsccsccsc/cpp-net/e99cbc790fddb1964355753be12ca8f10e51f247/cppnet/models/__pycache__/unet_parts_gn.cpython-37.pyc -------------------------------------------------------------------------------- /cppnet/models/cpp_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .unet_parts_gn import * 6 | from .SamplingFeatures2 import SamplingFeatures 7 | 8 | class CPPNet(nn.Module): 9 | 10 | def __init__(self, n_channels, n_rays, erosion_factor_list=[0.2, 0.4, 0.6, 0.8, 1.0], return_conf=False, with_seg=True, n_seg_cls=1): 11 | super(CPPNet, self).__init__() 12 | self.inc = inconv(n_channels, 32) 13 | self.down1 = down(32, 64) 14 | self.down2 = down(64, 128) 15 | self.down3 = down(128, 128) 16 | self.up1 = up(256, 64, bilinear=True) 17 | self.up2 = up(128, 32, bilinear=True) 18 | self.up3 = up(64, 32, bilinear=True) 19 | self.features = nn.Conv2d(32, 128, 3, padding=1) 20 | self.out_prob = outconv(128, 1) 21 | self.out_ray = outconv(128, n_rays) 22 | self.conv_0_confidence = outconv(128, n_rays) 23 | self.conv_1_confidence = outconv(1+len(erosion_factor_list), 1+len(erosion_factor_list)) 24 | nn.init.constant_(self.conv_1_confidence.conv.bias, 1.0) 25 | 26 | self.with_seg = with_seg 27 | self.n_seg_cls = n_seg_cls 28 | if self.with_seg: 29 | self.up1_seg = up(256, 64, bilinear=True) 30 | self.up2_seg = up(128, 32, bilinear=True) 31 | self.up3_seg = up(64, 32, bilinear=True) 32 | self.out_seg = outconv(32, n_seg_cls) 33 | if self.n_seg_cls == 1: 34 | self.final_activation_seg = nn.Sigmoid() 35 | else: 36 | self.final_activation_seg = nn.Softmax(dim=1) 37 | 38 | self.final_activation_ray = nn.ReLU() 39 | self.final_activation_prob = nn.Sigmoid() 40 | 41 | # Refinement 42 | self.sampling_feature = SamplingFeatures(n_rays) 43 | self.erosion_factor_list = erosion_factor_list 44 | self.n_rays = n_rays 45 | self.return_conf = return_conf 46 | 47 | def forward(self, img, gt_dist=None): 48 | x1 = self.inc(img) 49 | x2 = self.down1(x1) 50 | x3 = self.down2(x2) 51 | x4 = self.down3(x3) 52 | x = self.up1(x4, x3) 53 | x = self.up2(x, x2) 54 | x = self.up3(x, x1) 55 | x = self.features(x) 56 | out_ray = self.out_ray(x) 57 | out_confidence = self.conv_0_confidence(x) 58 | out_prob = self.out_prob(x) 59 | 60 | if gt_dist is not None: 61 | out_ray_for_sampling = gt_dist 62 | else: 63 | out_ray_for_sampling = out_ray 64 | ray_refined = [ out_ray_for_sampling ] 65 | 66 | confidence_refined = [ out_confidence ] 67 | for erosion_factor in self.erosion_factor_list: 68 | base_dist = (out_ray_for_sampling-1.0)*erosion_factor 69 | ray_sampled, _ = self.sampling_feature(out_ray_for_sampling, base_dist, 1) 70 | conf_sampled, _ = self.sampling_feature(out_confidence, base_dist, 1) 71 | ray_refined.append(ray_sampled + base_dist) 72 | confidence_refined.append(conf_sampled) 73 | ray_refined = torch.stack(ray_refined, dim=1) 74 | b, k, c, h, w = ray_refined.shape 75 | 76 | confidence_refined = torch.stack(confidence_refined, dim=1) 77 | #confidence_refined = torch.cat((confidence_refined, ray_refined), dim=1) 78 | confidence_refined = confidence_refined.permute([0,2,1,3,4]).contiguous().view(b*c, k, h, w) 79 | confidence_refined = self.conv_1_confidence(confidence_refined) 80 | confidence_refined = confidence_refined.view(b, c, k, h, w).permute([0,2,1,3,4]) 81 | confidence_refined = F.softmax(confidence_refined, dim=1) 82 | if self.return_conf: 83 | out_conf = [out_confidence, confidence_refined] 84 | else: 85 | out_conf = None 86 | ray_refined = (ray_refined*confidence_refined).sum(dim=1) 87 | 88 | out_ray = self.final_activation_ray(out_ray) 89 | ray_refined = self.final_activation_ray(ray_refined) 90 | out_prob = self.final_activation_prob(out_prob) 91 | 92 | if self.with_seg: 93 | x_seg = self.up1_seg(x4, x3) 94 | x_seg = self.up2_seg(x_seg, x2) 95 | x_seg = self.up3_seg(x_seg, x1) 96 | out_seg = self.out_seg(x_seg) 97 | if self.n_seg_cls == 1: 98 | out_seg = self.final_activation_seg(out_seg) 99 | elif not self.training: 100 | out_seg = self.final_activation_seg(out_seg) 101 | 102 | else: 103 | out_seg = None 104 | 105 | 106 | return [out_ray, ray_refined], [out_prob], [out_seg, ], [out_conf, ] 107 | 108 | 109 | def init_weight(self): 110 | for m in self.modules(): 111 | if isinstance(m, nn.Conv2d): 112 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 113 | if hasattr(m, 'bias'): 114 | nn.init.constant_(m.bias, 0) 115 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.GroupNorm): 116 | nn.init.constant_(m.weight, 1.0) 117 | nn.init.constant_(m.bias, 0.0) 118 | nn.init.constant_(self.conv_1_confidence.conv.bias, 1.0) 119 | -------------------------------------------------------------------------------- /cppnet/models/cppnet_res50.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .unet_parts_gn import * 6 | from .SamplingFeatures2 import SamplingFeatures 7 | 8 | class CPPNet(nn.Module): 9 | 10 | def __init__(self, n_channels, n_rays, erosion_factor_list=[0.2, 0.4, 0.6, 0.8, 1.0], return_conf=False, with_seg=True, n_seg_cls=6): 11 | super(CPPNet, self).__init__() 12 | 13 | self.backbone = resnet50(True) 14 | self.up1 = up(2048+1024, 1024, bilinear=True) 15 | self.up2 = up(1024+512, 512, bilinear=True) 16 | self.up3 = up(512+256, 256, bilinear=True) 17 | 18 | self.features = nn.Conv2d(256, 256, 3, padding=1) 19 | self.out_prob = outconv(256, 1) 20 | self.out_ray = outconv(256, n_rays) 21 | self.conv_0_confidence = outconv(256, n_rays) 22 | self.conv_1_confidence = outconv(1+len(erosion_factor_list), 1+len(erosion_factor_list)) 23 | nn.init.constant_(self.conv_1_confidence.conv.bias, 1.0) 24 | 25 | self.with_seg = with_seg 26 | self.n_seg_cls = n_seg_cls 27 | if self.with_seg: 28 | self.up1_seg = up(2048+1024, 1024, bilinear=True) 29 | self.up2_seg = up(1024+512, 512, bilinear=True) 30 | self.up3_seg = up(512+256, 256, bilinear=True) 31 | self.out_seg = outconv(256, n_seg_cls) 32 | if self.n_seg_cls == 1: 33 | self.final_activation_seg = nn.Sigmoid() 34 | else: 35 | self.final_activation_seg = nn.Softmax(dim=1) 36 | 37 | self.final_activation_ray = nn.ReLU() 38 | self.final_activation_prob = nn.Sigmoid() 39 | 40 | # Refinement 41 | self.sampling_feature = SamplingFeatures(n_rays) 42 | self.erosion_factor_list = erosion_factor_list 43 | self.n_rays = n_rays 44 | self.return_conf = return_conf 45 | 46 | def forward(self, img, gt_dist=None): 47 | x1, x2, x3, x4 = self.backbone(img.repeat(1,3,1,1)) 48 | x = self.up1(x4, x3) 49 | x = self.up2(x, x2) 50 | x = self.up3(x, x1) 51 | x = self.features(x) 52 | out_ray = self.out_ray(x) 53 | out_confidence = self.conv_0_confidence(x) 54 | out_prob = self.out_prob(x) 55 | 56 | if gt_dist is not None: 57 | out_ray_for_sampling = gt_dist 58 | else: 59 | out_ray_for_sampling = out_ray 60 | ray_refined = [ out_ray_for_sampling ] 61 | 62 | confidence_refined = [ out_confidence ] 63 | for erosion_factor in self.erosion_factor_list: 64 | base_dist = (out_ray_for_sampling-1.0)*erosion_factor 65 | ray_sampled, _ = self.sampling_feature(out_ray_for_sampling, base_dist, 1) 66 | conf_sampled, _ = self.sampling_feature(out_confidence, base_dist, 1) 67 | ray_refined.append(ray_sampled + base_dist) 68 | confidence_refined.append(conf_sampled) 69 | ray_refined = torch.stack(ray_refined, dim=1) 70 | b, k, c, h, w = ray_refined.shape 71 | 72 | confidence_refined = torch.stack(confidence_refined, dim=1) 73 | #confidence_refined = torch.cat((confidence_refined, ray_refined), dim=1) 74 | confidence_refined = confidence_refined.permute([0,2,1,3,4]).contiguous().view(b*c, k, h, w) 75 | confidence_refined = self.conv_1_confidence(confidence_refined) 76 | confidence_refined = confidence_refined.view(b, c, k, h, w).permute([0,2,1,3,4]) 77 | confidence_refined = F.softmax(confidence_refined, dim=1) 78 | if self.return_conf: 79 | out_conf = [out_confidence, confidence_refined] 80 | else: 81 | out_conf = None 82 | ray_refined = (ray_refined*confidence_refined).sum(dim=1) 83 | 84 | out_ray = self.final_activation_ray(out_ray) 85 | ray_refined = self.final_activation_ray(ray_refined) 86 | out_prob = self.final_activation_prob(out_prob) 87 | 88 | if self.with_seg: 89 | x_seg = self.up1_seg(x4, x3) 90 | x_seg = self.up2_seg(x_seg, x2) 91 | x_seg = self.up3_seg(x_seg, x1) 92 | out_seg = self.out_seg(x_seg) 93 | if self.n_seg_cls == 1: 94 | out_seg = self.final_activation_seg(out_seg) 95 | elif not self.training: 96 | out_seg = self.final_activation_seg(out_seg) 97 | else: 98 | out_seg = None 99 | 100 | return [out_ray, ray_refined], [out_prob], [out_seg, ], [out_conf, ] 101 | 102 | 103 | def init_weight(self): 104 | for m in self.modules(): 105 | if isinstance(m, nn.Conv2d): 106 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 107 | if hasattr(m, 'bias'): 108 | nn.init.constant_(m.bias, 0) 109 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.GroupNorm): 110 | nn.init.constant_(m.weight, 1.0) 111 | nn.init.constant_(m.bias, 0.0) 112 | nn.init.constant_(self.conv_1_confidence.conv.bias, 1.0) 113 | -------------------------------------------------------------------------------- /cppnet/models/feature_extractor.py: -------------------------------------------------------------------------------- 1 | # full assembly of the sub-parts to form the complete net 2 | 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | from .unet_parts_gn import * 6 | import torch.nn.init as init 7 | 8 | # The encoder only 9 | class Feature_Extractor(nn.Module): 10 | def __init__(self, n_channels, n_features=32): 11 | super(Feature_Extractor, self).__init__() 12 | self.inc = inconv(n_channels, n_features) 13 | self.down1 = down(n_features, n_features*2) 14 | self.down2 = down(n_features*2, n_features*4) 15 | self.down3 = down(n_features*4, n_features*8) 16 | self.down4 = down(n_features*8, n_features*16) 17 | 18 | # self.up1 = up_single(n_features*16, n_features*8, bilinear=True) 19 | # self.up2 = up_single(n_features*8, n_features*4, bilinear=True) 20 | # self.up3 = up_single(n_features*4, n_features*2, bilinear=True) 21 | # self.up4 = up_single(n_features*2, n_features*1, bilinear=True) 22 | 23 | # self.features_segbnd = nn.Conv2d(n_features, n_features, 3, padding=1) 24 | # self.features_bbox = nn.Conv2d(n_features, n_features, 3, padding=1) 25 | # self.out_segbnd = outconv(n_features, 2) 26 | # self.out_bbox = outconv(n_features, 4) 27 | # self.final_activation_prob = nn.Sigmoid() 28 | # self.final_activation_ray = nn.ReLU() 29 | 30 | def forward(self, x): 31 | x0 = self.inc(x) 32 | x1 = self.down1(x0) 33 | x2 = self.down2(x1) 34 | x3 = self.down3(x2) 35 | x4 = self.down4(x3) 36 | 37 | # x = self.up1(x4, x3) 38 | # x = self.up2(x, x2) 39 | # x = self.up3(x, x1) 40 | # x = self.up4(x, x0) 41 | 42 | # x_segbnd = self.final_activation_prob(self.out_segbnd(self.features_segbnd(x))) 43 | # x_bbox = self.final_activation_ray(self.out_bbox(self.features_bbox(x))) 44 | 45 | return x4 -------------------------------------------------------------------------------- /cppnet/models/resnet50_preact.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .utils import load_state_dict_from_url 4 | 5 | 6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 7 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 8 | 'wide_resnet50_2', 'wide_resnet101_2'] 9 | 10 | 11 | model_urls = { 12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 17 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 18 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 19 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 20 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 21 | } 22 | 23 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 24 | """3x3 convolution with padding""" 25 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 26 | padding=dilation, groups=groups, bias=False, dilation=dilation) 27 | 28 | 29 | def conv1x1(in_planes, out_planes, stride=1): 30 | """1x1 convolution""" 31 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 32 | 33 | 34 | class BasicBlock(nn.Module): 35 | expansion = 1 36 | 37 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 38 | base_width=64, dilation=1, norm_layer=None): 39 | super(BasicBlock, self).__init__() 40 | if norm_layer is None: 41 | norm_layer = nn.BatchNorm2d 42 | if groups != 1 or base_width != 64: 43 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 44 | if dilation > 1: 45 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 46 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 47 | self.conv1 = conv3x3(inplanes, planes, stride) 48 | self.bn1 = norm_layer(planes) 49 | self.relu = nn.ReLU(inplace=True) 50 | self.conv2 = conv3x3(planes, planes) 51 | self.bn2 = norm_layer(planes) 52 | self.downsample = downsample 53 | self.stride = stride 54 | 55 | def forward(self, x): 56 | identity = x 57 | 58 | out = self.conv1(x) 59 | out = self.bn1(out) 60 | out = self.relu(out) 61 | 62 | out = self.conv2(out) 63 | out = self.bn2(out) 64 | 65 | if self.downsample is not None: 66 | identity = self.downsample(x) 67 | 68 | out += identity 69 | out = self.relu(out) 70 | 71 | return out 72 | 73 | 74 | class Bottleneck(nn.Module): 75 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 76 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 77 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 78 | # This variant is also known as ResNet V1.5 and improves accuracy according to 79 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 80 | 81 | expansion = 4 82 | 83 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 84 | base_width=64, dilation=1, norm_layer=None): 85 | super(Bottleneck, self).__init__() 86 | if norm_layer is None: 87 | norm_layer = nn.BatchNorm2d 88 | width = int(planes * (base_width / 64.)) * groups 89 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 90 | self.conv1 = conv1x1(inplanes, width) 91 | self.bn1 = norm_layer(width) 92 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 93 | self.bn2 = norm_layer(width) 94 | self.conv3 = conv1x1(width, planes * self.expansion) 95 | self.bn3 = norm_layer(planes * self.expansion) 96 | self.relu = nn.ReLU(inplace=True) 97 | self.downsample = downsample 98 | self.stride = stride 99 | 100 | def forward(self, x): 101 | identity = x 102 | 103 | out = self.conv1(x) 104 | out = self.bn1(out) 105 | out = self.relu(out) 106 | 107 | out = self.conv2(out) 108 | out = self.bn2(out) 109 | out = self.relu(out) 110 | 111 | out = self.conv3(out) 112 | out = self.bn3(out) 113 | 114 | if self.downsample is not None: 115 | identity = self.downsample(x) 116 | 117 | out += identity 118 | out = self.relu(out) 119 | 120 | return out 121 | 122 | 123 | class ResNet(nn.Module): 124 | 125 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 126 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 127 | norm_layer=None): 128 | super(ResNet, self).__init__() 129 | if norm_layer is None: 130 | norm_layer = nn.BatchNorm2d 131 | self._norm_layer = norm_layer 132 | 133 | self.inplanes = 64 134 | self.dilation = 1 135 | if replace_stride_with_dilation is None: 136 | # each element in the tuple indicates if we should replace 137 | # the 2x2 stride with a dilated convolution instead 138 | replace_stride_with_dilation = [False, False, False] 139 | if len(replace_stride_with_dilation) != 3: 140 | raise ValueError("replace_stride_with_dilation should be None " 141 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 142 | self.groups = groups 143 | self.base_width = width_per_group 144 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 145 | bias=False) 146 | self.bn1 = norm_layer(self.inplanes) 147 | self.relu = nn.ReLU(inplace=True) 148 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 149 | 150 | self.layer1 = self._make_layer(block, 64, layers[0]) 151 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 152 | dilate=replace_stride_with_dilation[0]) 153 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 154 | dilate=replace_stride_with_dilation[1]) 155 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 156 | dilate=replace_stride_with_dilation[2]) 157 | # self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 158 | # self.fc = nn.Linear(512 * block.expansion, num_classes) 159 | 160 | for m in self.modules(): 161 | if isinstance(m, nn.Conv2d): 162 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 163 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 164 | nn.init.constant_(m.weight, 1) 165 | nn.init.constant_(m.bias, 0) 166 | 167 | # Zero-initialize the last BN in each residual branch, 168 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 169 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 170 | if zero_init_residual: 171 | for m in self.modules(): 172 | if isinstance(m, Bottleneck): 173 | nn.init.constant_(m.bn3.weight, 0) 174 | elif isinstance(m, BasicBlock): 175 | nn.init.constant_(m.bn2.weight, 0) 176 | 177 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 178 | norm_layer = self._norm_layer 179 | downsample = None 180 | previous_dilation = self.dilation 181 | if dilate: 182 | self.dilation *= stride 183 | stride = 1 184 | if stride != 1 or self.inplanes != planes * block.expansion: 185 | downsample = nn.Sequential( 186 | conv1x1(self.inplanes, planes * block.expansion, stride), 187 | norm_layer(planes * block.expansion), 188 | ) 189 | 190 | layers = [] 191 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 192 | self.base_width, previous_dilation, norm_layer)) 193 | self.inplanes = planes * block.expansion 194 | for _ in range(1, blocks): 195 | layers.append(block(self.inplanes, planes, groups=self.groups, 196 | base_width=self.base_width, dilation=self.dilation, 197 | norm_layer=norm_layer)) 198 | 199 | return nn.Sequential(*layers) 200 | 201 | def _forward_impl(self, x): 202 | # See note [TorchScript super()] 203 | x = self.conv1(x) 204 | x = self.bn1(x) 205 | x = self.relu(x) 206 | # x = self.maxpool(x) 207 | 208 | x1 = self.layer1(x) 209 | x2 = self.layer2(x1) 210 | x3 = self.layer3(x2) 211 | x4 = self.layer4(x3) 212 | 213 | # x = self.avgpool(x) 214 | # x = torch.flatten(x, 1) 215 | # x = self.fc(x) 216 | 217 | return [x1, x2, x3, x4] 218 | 219 | def forward(self, x): 220 | return self._forward_impl(x) 221 | 222 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 223 | model = ResNet(block, layers, **kwargs) 224 | if pretrained: 225 | state_dict = load_state_dict_from_url(model_urls[arch], 226 | progress=progress) 227 | state_dict_for_load = {} 228 | target_state_dict_keys = model.state_dict().keys() 229 | for key, value in state_dict.items(): 230 | if key in target_state_dict_keys: 231 | state_dict_for_load.update({key:value}) 232 | # print('Key ' + key + ' loaded') 233 | model.load_state_dict(state_dict_for_load) 234 | return model 235 | 236 | 237 | def resnet18(pretrained=False, progress=True, **kwargs): 238 | r"""ResNet-18 model from 239 | `"Deep Residual Learning for Image Recognition" `_ 240 | Args: 241 | pretrained (bool): If True, returns a model pre-trained on ImageNet 242 | progress (bool): If True, displays a progress bar of the download to stderr 243 | """ 244 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 245 | **kwargs) 246 | 247 | 248 | def resnet34(pretrained=False, progress=True, **kwargs): 249 | r"""ResNet-34 model from 250 | `"Deep Residual Learning for Image Recognition" `_ 251 | Args: 252 | pretrained (bool): If True, returns a model pre-trained on ImageNet 253 | progress (bool): If True, displays a progress bar of the download to stderr 254 | """ 255 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 256 | **kwargs) 257 | 258 | 259 | def resnet50(pretrained=False, progress=True, **kwargs): 260 | r"""ResNet-50 model from 261 | `"Deep Residual Learning for Image Recognition" `_ 262 | Args: 263 | pretrained (bool): If True, returns a model pre-trained on ImageNet 264 | progress (bool): If True, displays a progress bar of the download to stderr 265 | """ 266 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 267 | **kwargs) 268 | 269 | 270 | def resnet101(pretrained=False, progress=True, **kwargs): 271 | r"""ResNet-101 model from 272 | `"Deep Residual Learning for Image Recognition" `_ 273 | Args: 274 | pretrained (bool): If True, returns a model pre-trained on ImageNet 275 | progress (bool): If True, displays a progress bar of the download to stderr 276 | """ 277 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 278 | **kwargs) 279 | 280 | 281 | def resnet152(pretrained=False, progress=True, **kwargs): 282 | r"""ResNet-152 model from 283 | `"Deep Residual Learning for Image Recognition" `_ 284 | Args: 285 | pretrained (bool): If True, returns a model pre-trained on ImageNet 286 | progress (bool): If True, displays a progress bar of the download to stderr 287 | """ 288 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 289 | **kwargs) 290 | 291 | 292 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 293 | r"""ResNeXt-50 32x4d model from 294 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 295 | Args: 296 | pretrained (bool): If True, returns a model pre-trained on ImageNet 297 | progress (bool): If True, displays a progress bar of the download to stderr 298 | """ 299 | kwargs['groups'] = 32 300 | kwargs['width_per_group'] = 4 301 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 302 | pretrained, progress, **kwargs) 303 | 304 | 305 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 306 | r"""ResNeXt-101 32x8d model from 307 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 308 | Args: 309 | pretrained (bool): If True, returns a model pre-trained on ImageNet 310 | progress (bool): If True, displays a progress bar of the download to stderr 311 | """ 312 | kwargs['groups'] = 32 313 | kwargs['width_per_group'] = 8 314 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 315 | pretrained, progress, **kwargs) 316 | 317 | 318 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 319 | r"""Wide ResNet-50-2 model from 320 | `"Wide Residual Networks" `_ 321 | The model is the same as ResNet except for the bottleneck number of channels 322 | which is twice larger in every block. The number of channels in outer 1x1 323 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 324 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 325 | Args: 326 | pretrained (bool): If True, returns a model pre-trained on ImageNet 327 | progress (bool): If True, displays a progress bar of the download to stderr 328 | """ 329 | kwargs['width_per_group'] = 64 * 2 330 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 331 | pretrained, progress, **kwargs) 332 | 333 | 334 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 335 | r"""Wide ResNet-101-2 model from 336 | `"Wide Residual Networks" `_ 337 | The model is the same as ResNet except for the bottleneck number of channels 338 | which is twice larger in every block. The number of channels in outer 1x1 339 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 340 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 341 | Args: 342 | pretrained (bool): If True, returns a model pre-trained on ImageNet 343 | progress (bool): If True, displays a progress bar of the download to stderr 344 | """ 345 | kwargs['width_per_group'] = 64 * 2 346 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 347 | pretrained, progress, **kwargs) -------------------------------------------------------------------------------- /cppnet/models/unet_parts_gn.py: -------------------------------------------------------------------------------- 1 | # sub-parts of the U-Net model 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | class double_conv(nn.Module): 8 | '''(conv => BN => ReLU) * 2''' 9 | def __init__(self, in_ch, out_ch): 10 | super(double_conv, self).__init__() 11 | num_groups = out_ch // 8 12 | self.conv = nn.Sequential( 13 | nn.Conv2d(in_ch, out_ch, 3, padding=1), 14 | nn.GroupNorm(num_channels=out_ch,num_groups=num_groups), 15 | nn.ELU(inplace=True), 16 | nn.Conv2d(out_ch, out_ch, 3, padding=1), 17 | nn.GroupNorm(num_channels=out_ch,num_groups=num_groups), 18 | nn.ELU(inplace=True) 19 | ) 20 | 21 | def forward(self, x): 22 | x = self.conv(x) 23 | return x 24 | 25 | 26 | class inconv(nn.Module): 27 | def __init__(self, in_ch, out_ch): 28 | super(inconv, self).__init__() 29 | self.conv = double_conv(in_ch, out_ch) 30 | 31 | def forward(self, x): 32 | x = self.conv(x) 33 | return x 34 | 35 | 36 | class down(nn.Module): 37 | def __init__(self, in_ch, out_ch): 38 | super(down, self).__init__() 39 | self.mpconv = nn.Sequential( 40 | nn.MaxPool2d(2), 41 | double_conv(in_ch, out_ch) 42 | ) 43 | 44 | def forward(self, x): 45 | x = self.mpconv(x) 46 | return x 47 | 48 | 49 | class up2(nn.Module): 50 | def __init__(self, in_ch, out_ch, bilinear=True): 51 | super(up2, self).__init__() 52 | 53 | if bilinear: 54 | self.conv_trans = nn.Conv2d(in_ch, in_ch//2, 1) 55 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 56 | else: 57 | self.up = nn.ConvTranspose2d(in_ch, in_ch//2, 2, stride=2) 58 | 59 | self.conv = double_conv(in_ch, out_ch) 60 | 61 | def forward(self, x1, x2): 62 | x1 = self.up(self.conv_trans(x1)) 63 | # input is CHW 64 | diffY = x2.size()[2] - x1.size()[2] 65 | diffX = x2.size()[3] - x1.size()[3] 66 | 67 | x1 = F.pad(x1, (diffX // 2, diffX - diffX//2, 68 | diffY // 2, diffY - diffY//2)) 69 | # for padding issues, see 70 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 71 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 72 | x = torch.cat([x2, x1], dim=1) 73 | x = self.conv(x) 74 | return x 75 | 76 | class up(nn.Module): 77 | def __init__(self, in_ch, out_ch, bilinear=True): 78 | super(up, self).__init__() 79 | 80 | 81 | if bilinear: 82 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 83 | else: 84 | self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2) 85 | 86 | self.conv = double_conv(in_ch, out_ch) 87 | 88 | def forward(self, x1, x2): 89 | x1 = self.up(x1) 90 | 91 | # input is CHW 92 | diffY = x2.size()[2] - x1.size()[2] 93 | diffX = x2.size()[3] - x1.size()[3] 94 | 95 | x1 = F.pad(x1, (diffX // 2, diffX - diffX//2, 96 | diffY // 2, diffY - diffY//2)) 97 | 98 | # for padding issues, see 99 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 100 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 101 | 102 | x = torch.cat([x2, x1], dim=1) 103 | x = self.conv(x) 104 | return x 105 | 106 | 107 | class outconv(nn.Module): 108 | def __init__(self, in_ch, out_ch): 109 | super(outconv, self).__init__() 110 | self.conv = nn.Conv2d(in_ch, out_ch, 1) 111 | def forward(self, x): 112 | x = self.conv(x) 113 | return x 114 | -------------------------------------------------------------------------------- /cppnet/predict_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import numpy as np 5 | from glob import glob 6 | from skimage.io import imread 7 | from csbdeep.utils import normalize 8 | from stardist import dist_to_coord, non_maximum_suppression, polygons_to_label 9 | from stardist import random_label_cmap, ray_angles 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from models.cpp_net import CPPNet 15 | 16 | import math 17 | from tqdm import tqdm 18 | import warnings 19 | warnings.filterwarnings("ignore") 20 | 21 | from stats_utils import get_fast_aji, get_fast_pq, get_fast_dice_2, get_dice_1, remap_label 22 | from metric_v2 import matching, matching_dataset 23 | 24 | try: 25 | import numpy_gpu as npgpu 26 | except: 27 | print('The package "numpy_gpu" is used for comparison only. You can try to use the package "numpy_gpu", but it is not necessary.') 28 | 29 | 30 | ap_ious = [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9] 31 | 32 | 33 | def predict_each_image( 34 | model_dist, img, 35 | axis_norm=[0,1], center_prob_thres=0.4, seg_prob_thres=0.5, 36 | n_rays=32, FPP=True, sin_angles=None, cos_angles=None, dist_cmp='cuda' 37 | ): 38 | 39 | division = 1 40 | 41 | img = img.copy() 42 | img = normalize(img, 1, 99.8, axis = axis_norm) 43 | 44 | h, w = img.shape 45 | if h%division!=0 or w%division!=0: 46 | dh = (h//division+1) * division - h 47 | dw = (w//division+1) * division - w 48 | img = np.pad(img, ((0, dh), (0, dw)), 'constant') 49 | 50 | assert(dist_cmp in ['cuda', 'cpu', 'np', 'npcuda']) 51 | 52 | input = torch.tensor(img) 53 | input = input.unsqueeze(0).unsqueeze(0) 54 | preds = model_dist(input.cuda()) 55 | dist_cuda = preds[0][-1][:, :, :h, :w] 56 | dist = dist_cuda.data.cpu() 57 | prob = preds[1][-1].data.cpu()[:, :, :h, :w] 58 | seg = preds[2][-1].data.cpu()[:, :, :h, :w] 59 | 60 | dist_numpy = dist.numpy().squeeze() 61 | prob_numpy = prob.numpy().squeeze() 62 | seg = seg.numpy().squeeze() 63 | prob_numpy = prob_numpy*seg# (seg>=seg_prob_thres).astype(np.float32) 64 | 65 | dist_numpy = np.transpose(dist_numpy,(1,2,0)) 66 | coord = dist_to_coord(dist_numpy) 67 | points = non_maximum_suppression(coord, prob_numpy, prob_thresh=center_prob_thres) 68 | star_label = polygons_to_label(coord, prob_numpy, points) 69 | 70 | # st0 = time.time() 71 | # You can try different approaches to finish the process of distance calculation. In our experiments dist_cmp='cuda' seems faster 72 | if FPP and sin_angles is None: 73 | if dist_cmp == 'cuda': 74 | angles = torch.arange(n_rays).float()/float(n_rays)*math.pi*2.0 # 0 - 2*pi 75 | sin_angles = torch.sin(angles).view(1, n_rays, 1, 1) 76 | cos_angles = torch.cos(angles).view(1, n_rays, 1, 1) 77 | sin_angles = sin_angles.cuda() 78 | cos_angles = cos_angles.cuda() 79 | 80 | offset_ih = sin_angles * dist_cuda 81 | offset_iw = cos_angles * dist_cuda 82 | # 1, r, h, w, 2 83 | offsets = torch.stack([offset_iw, offset_ih], dim=-1) 84 | # h, w, 2 85 | mean_coord = np.round(offsets.mean(dim=1).data.cpu().squeeze(dim=0).numpy()).astype(np.int16) 86 | elif dist_cmp == 'cpu': 87 | angles = torch.arange(n_rays).float()/float(n_rays)*math.pi*2.0 # 0 - 2*pi 88 | sin_angles = torch.sin(angles).view(1, n_rays, 1, 1) 89 | cos_angles = torch.cos(angles).view(1, n_rays, 1, 1) 90 | 91 | offset_ih = sin_angles * dist 92 | offset_iw = cos_angles * dist 93 | # 1, r, h, w, 2 94 | offsets = torch.stack([offset_iw, offset_ih], dim=-1) 95 | # h, w, 2 96 | mean_coord = np.round(offsets.mean(dim=1).data.cpu().squeeze(dim=0).numpy()).astype(np.int16) 97 | elif dist_cmp == 'np': 98 | angles = torch.arange(n_rays).float()/float(n_rays)*math.pi*2.0 # 0 - 2*pi 99 | sin_angles = torch.sin(angles).view(1, n_rays, 1, 1).data.numpy() 100 | cos_angles = torch.cos(angles).view(1, n_rays, 1, 1).data.numpy() 101 | 102 | offset_ih = sin_angles * dist.numpy() 103 | offset_iw = cos_angles * dist.numpy() 104 | # 1, r, h, w, 2 105 | offsets = np.stack([offset_iw, offset_ih], axis=-1) 106 | # h, w, 2 107 | mean_coord = np.round(offsets.mean(axis=1).squeeze(axis=0)).astype(np.int16) 108 | elif dist_cmp == 'npcuda': 109 | angles = torch.arange(n_rays).float()/float(n_rays)*math.pi*2.0 # 0 - 2*pi 110 | sin_angles = torch.sin(angles).view(1, n_rays, 1, 1).data.numpy() 111 | cos_angles = torch.cos(angles).view(1, n_rays, 1, 1).data.numpy() 112 | 113 | offset_ih = npgpu.dot(sin_angles, dist.numpy()) 114 | offset_iw = npgpu.dot(cos_angles, dist.numpy()) 115 | # 1, r, h, w, 2 116 | offsets = np.stack([offset_iw, offset_ih], axis=-1) 117 | # h, w, 2 118 | mean_coord = np.round(offsets.mean(axis=1).squeeze(axis=0)).astype(np.int16) 119 | 120 | pred = star_label 121 | 122 | # Offset-based Post Processing: 123 | if FPP: 124 | seg_remained = np.logical_and(seg>=seg_prob_thres, pred==0) 125 | while seg_remained.any(): 126 | if seg_remained.any(): 127 | rxs, rys = np.where(seg_remained) 128 | mean_coord_remained = mean_coord[seg_remained, :] 129 | pred_0 = pred.copy() 130 | rxs_a = np.clip((rxs + mean_coord_remained[:, 1]).astype(np.int16), 0, h-1) 131 | rys_a = np.clip((rys + mean_coord_remained[:, 0]).astype(np.int16), 0, w-1) 132 | pred[seg_remained] = pred[(rxs_a, rys_a)] 133 | if not((pred_0 != pred).any()): 134 | break 135 | else: 136 | break 137 | seg_remained = np.logical_and(seg>=seg_prob_thres, pred==0) 138 | 139 | 140 | return pred 141 | 142 | def run( 143 | DATASET_PATH_IMAGE, DATASET_PATH_LABEL, 144 | nc_in, n_rays, n_sampling, model_weight_path, 145 | center_prob_thres=0.4, seg_prob_thres=0.5 146 | ): 147 | 148 | X = sorted(glob(DATASET_PATH_IMAGE)) 149 | X = list(map(imread,X)) 150 | Y = sorted(glob(DATASET_PATH_LABEL)) 151 | Y = list(map(imread,Y)) 152 | 153 | with torch.no_grad(): 154 | 155 | erosion_factor_list = [float(i+1)/n_sampling for i in range(n_sampling)] 156 | model_dist = CPPNet(nc_in, n_rays, erosion_factor_list=erosion_factor_list).cuda() 157 | model_dist.load_state_dict(torch.load(model_weight_path)) 158 | model_dist.eval() 159 | 160 | ajis = [] 161 | pqs = [] 162 | dice2s = [] 163 | dice1s = [] 164 | aps_perimg = [[] for i_t in range(len(ap_ious))] 165 | preds = [] 166 | 167 | 168 | for idx, img_target in enumerate(zip(X,Y)): 169 | 170 | image, target = img_target 171 | h, w = image.shape 172 | 173 | star_label = predict_each_image( 174 | model_dist, image, (0, 1), 175 | center_prob_thres=center_prob_thres, seg_prob_thres=seg_prob_thres, n_rays=n_rays 176 | ) 177 | 178 | aji = get_fast_aji(target, star_label) 179 | pq = get_fast_pq(target, star_label)[0][2] 180 | dice2 = get_fast_dice_2(target, star_label) 181 | dice1 = get_dice_1(target, star_label) 182 | 183 | idx_aps = [] 184 | for i_t, t in enumerate(ap_ious): 185 | i_t_ap = matching(target, star_label, thresh=t).accuracy 186 | aps_perimg[i_t].append(i_t_ap) 187 | idx_aps.append(i_t_ap) 188 | 189 | ajis.append(aji) 190 | pqs.append(pq) 191 | dice1s.append(dice1) 192 | dice2s.append(dice2) 193 | preds.append(star_label) 194 | 195 | print('{:03d}, {:.5f}, {:.5f}, {:.5f}, {:.5f}, {:.5f}'.format(idx, aji, pq, np.mean(idx_aps), dice2, dice1)) 196 | 197 | stats = [matching_dataset(Y, preds, thresh=t, show_progress=False, by_image=True) for t in tqdm(ap_ious)] 198 | 199 | avg = 0.0 200 | aps_perimg = np.array(aps_perimg) 201 | for iou in ap_ious: 202 | print(iou, stats[ap_ious.index(iou)].accuracy) 203 | avg += stats[ap_ious.index(iou)].accuracy 204 | avg /= len(ap_ious) 205 | print('avg : {:.6f}'.format(avg)) 206 | print('aji: {:.6f}'.format(np.mean(ajis))) 207 | print('pq: {:.6f}'.format(np.mean(pqs))) 208 | print('dice2: {:.6f}'.format(np.mean(dice2s))) 209 | print('dice1: {:.6f}'.format(np.mean(dice1s))) 210 | 211 | # np.save('predictions.npy', preds) 212 | 213 | 214 | 215 | DATASET_PATH_IMAGE = 'DATA PATH/test/images/*.tif' 216 | DATASET_PATH_LABEL = 'DATA PATH/test/masks/*.tif' 217 | 218 | MODEL_WEIGHT_PATH = '' 219 | 220 | if __name__ == '__main__': 221 | parser = argparse.ArgumentParser() 222 | parser.add_argument('--gpuid', type=int, default=0) 223 | parser.add_argument('--nc_in', type=int, default=1) 224 | parser.add_argument('--n_rays', type=int, default=32) 225 | parser.add_argument('--n_sampling', type=int, default=6) 226 | args = parser.parse_args() 227 | 228 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpuid) 229 | 230 | run( 231 | DATASET_PATH_IMAGE, DATASET_PATH_LABEL, 232 | args.nc_in, args.n_rays, args.n_sampling, MODEL_WEIGHT_PATH, 233 | ) 234 | -------------------------------------------------------------------------------- /cppnet/predict_eval_pannuke.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, unicode_literals, absolute_import, division 2 | import numpy as np 3 | import os 4 | import matplotlib 5 | from tqdm import tqdm 6 | matplotlib.use('agg') 7 | import matplotlib.pyplot as plt 8 | from glob import glob 9 | #from tifffile import imread 10 | from skimage.io import imread 11 | from skimage.measure import label 12 | from csbdeep.utils import normalize 13 | from stardist import dist_to_coord, non_maximum_suppression, polygons_to_label 14 | from stardist import random_label_cmap, ray_angles 15 | import torch 16 | import cv2 17 | import torch.nn.functional as F 18 | from tqdm import tqdm 19 | import scipy.io as scio 20 | from scipy import stats 21 | import math 22 | import time 23 | # from models.cpp_net import CPPNet 24 | from models.cpp_net_res50 import CPPNet 25 | 26 | from stats_utils import remap_label 27 | 28 | import warnings 29 | warnings.filterwarnings("ignore") 30 | 31 | def predict_each_image(model_dist, img, resz=None, offset_refine=True, TARGET_LABELS=32, prob_thres=0.4): 32 | 33 | img = img.astype(np.float32) 34 | for imod in range(3): 35 | tmp_img = img[:, :, imod] 36 | meanv = tmp_img.mean() 37 | stdv = tmp_img.std() 38 | img[:, :, imod] = (tmp_img-meanv)/stdv 39 | 40 | input = torch.tensor(img) 41 | input = input.unsqueeze(0) 42 | if len(input.shape) < 4: 43 | input = input.unsqueeze(1) 44 | else: 45 | input = input.permute([0, 3, 1, 2]) 46 | _, _, h, w = input.shape 47 | 48 | # Model Prediction 49 | if resz is not None: 50 | resz_input = F.interpolate(input.cuda(), size=resz, mode='bilinear', align_corners=True) 51 | else: 52 | resz_input = input.cuda() 53 | preds = model_dist(resz_input) 54 | dist = preds[0] 55 | prob = preds[1] 56 | seg = preds[2] 57 | if isinstance(dist, (tuple, list)): 58 | dist_cuda = dist[-1].clone() 59 | dist = dist[-1].detach().cpu() 60 | else: 61 | dist_cuda = dist.clone() 62 | dist = dist.detach().cpu() 63 | if isinstance(prob, (tuple, list)): 64 | prob = prob[-1].detach().cpu() 65 | else: 66 | prob = prob.detach().cpu() 67 | if isinstance(seg, (tuple, list)): 68 | seg = seg[-1].detach().cpu() 69 | else: 70 | seg = seg.detach().cpu() 71 | seg = F.softmax(seg, dim=1) 72 | if resz is not None: 73 | dist_cuda = F.interpolate(dist_cuda, size=[h, w], mode='bilinear', align_corners=True) 74 | dist = F.interpolate(dist, size=[h, w], mode='bilinear', align_corners=True) 75 | prob = F.interpolate(prob, size=[h, w], mode='bilinear', align_corners=True) 76 | seg = F.interpolate(seg, size=[h, w], mode='bilinear', align_corners=True) 77 | dists = dist.numpy().squeeze() 78 | probs = prob.numpy().squeeze() 79 | segs = seg.numpy().squeeze() 80 | 81 | # Post Processing 82 | dists = np.transpose(dists,(1,2,0)) 83 | coord = dist_to_coord(dists) 84 | points = non_maximum_suppression(coord, probs, prob_thresh=prob_thres) 85 | binary_star_label = polygons_to_label(coord, probs, points) 86 | binary_star_label = remap_label(binary_star_label) 87 | 88 | # segs: background + n_cls foregrund 89 | # cls_star_labels: n_cls foreground 90 | N_CLASSES = segs.shape[0] 91 | seg_label = np.argmax(segs, axis=0) 92 | cls_star_labels = np.zeros((N_CLASSES-1, )+binary_star_label.shape, dtype=np.int16) 93 | cset = np.unique(binary_star_label[binary_star_label>0]) 94 | for ic in cset: 95 | icmap = binary_star_label==ic 96 | ic_seg_label = seg_label[icmap] 97 | ic_cls = stats.mode(ic_seg_label)[0][0] 98 | if ic_cls>0: 99 | cls_star_labels[ic_cls-1][icmap] = ic 100 | 101 | if offset_refine: 102 | h, w, n_rays = dists.shape 103 | angles = torch.arange(n_rays).float()/float(n_rays)*math.pi*2.0 # 0 - 2*pi 104 | sin_angles = torch.sin(angles).view(1, n_rays, 1, 1).to(dist_cuda.device) 105 | cos_angles = torch.cos(angles).view(1, n_rays, 1, 1).to(dist_cuda.device) 106 | offset_ih = sin_angles * dist_cuda 107 | offset_iw = cos_angles * dist_cuda 108 | # 1, r, h, w, 2 109 | offsets = torch.stack([offset_iw, offset_ih], dim=-1) 110 | # h, w, 2 111 | mean_coord = np.round(offsets.mean(dim=1).data.cpu().squeeze(dim=0).numpy()).astype(np.int16) 112 | 113 | for icls in range(N_CLASSES-1): 114 | icls_star_label = cls_star_labels[icls] 115 | binary_seg = seg_label==(icls+1) 116 | if icls_star_label.any(): 117 | if offset_refine: 118 | seg_remained = np.logical_and(binary_seg, icls_star_label==0) 119 | pred = icls_star_label.copy() 120 | i_iter = 0 121 | while seg_remained.any(): 122 | i_iter += 1 123 | rxs, rys = np.where(seg_remained) 124 | pred_0 = pred.copy() 125 | for rx, ry in zip(rxs, rys): 126 | dx_rx, dy_rx = np.clip(int(np.round(rx + mean_coord[rx, ry, 1])), 0, h-1), np.clip(int(np.round(ry + mean_coord[rx, ry, 0])), 0, w-1) 127 | pred[rx, ry] = pred[dx_rx, dy_rx] 128 | if not((pred_0 != pred).any()): 129 | break 130 | seg_remained = np.logical_and(binary_seg, pred==0) 131 | icls_star_label = pred 132 | if len(np.unique(icls_star_label[icls_star_label>0])) >= 1: 133 | icls_star_label = remap_label(icls_star_label) 134 | cls_star_labels[icls] = icls_star_label 135 | cls_star_labels = cls_star_labels.transpose([1, 2, 0]) 136 | 137 | return cls_star_labels 138 | 139 | # Classification: 140 | # GT: 0: Neoplastic cells, 1: Inflammatory, 2: Connective/Soft tissue cells, 3: Dead Cells, 4: Epithelial, 6: Background 141 | # Pred: 0: Background, 1: Neoplastic cells, 2: Inflammatory, 3: Connective/Soft tissue cells, 4: Dead Cells, 5: Epithelial 142 | 143 | 144 | nk = 6 145 | erosion_factor_list = [float(i+1)/nk for i in range(nk)] 146 | 147 | offset_refine = True 148 | 149 | if not offset_refine: 150 | results_filefold = './PanNuke_aug_results_x1_0806/' 151 | if not os.path.exists(results_filefold): 152 | os.makedirs(results_filefold) 153 | else: 154 | results_filefold = './PanNuke_aug_results_offset_refine_x1_0806/' 155 | if not os.path.exists(results_filefold): 156 | os.makedirs(results_filefold) 157 | 158 | def run( 159 | DATASET_PATH_IMAGE, model_weight_path, prediction_save_path, 160 | nc_in=3, n_rays=32, n_sampling=6, n_cls=6, 161 | center_prob_thres=0.4, resz=None, 162 | ): 163 | image_name_list = [] 164 | with open(os.path.join(DATASET_PATH_IMAGE, 'name_list.txt'), 'r') as f_img: 165 | for line in f_img.readlines(): 166 | line_term = (line.split(',')[-1]).strip() 167 | image_name_list.append(line_term) 168 | 169 | n_data = len(image_name_list) 170 | preds = np.zeros([n_data, 256, 256, N_CLASSES], dtype=np.int16) 171 | with torch.no_grad(): 172 | erosion_factor_list = [float(i+1)/n_sampling for i in range(n_sampling)] 173 | model_dist = CPPNet(nc_in, n_rays, erosion_factor_list=erosion_factor_list, n_seg_cls=n_cls) 174 | model_dist = model_dist.cuda() 175 | model_dist.load_state_dict(torch.load(model_weight_path)) 176 | model_dist.eval() 177 | for idx, image_name in enumerate(image_name_list): 178 | image = imread(image_name) 179 | pred = predict_each_image(model_dist, image, resz=resz, offset_refine=offset_refine, prob_thres=center_prob_thres) 180 | preds[idx] = pred 181 | 182 | np.save(prediction_save_path, preds) 183 | 184 | 185 | DATASET_PATH_IMAGE = 'DATA PATH/test/images/*.tif or .png' 186 | PREDICTION_SAVE_PATH = 'SAVE_PATH/*.npy' 187 | MODEL_WEIGHT_PATH = '' 188 | 189 | if __name__ == '__main__': 190 | parser = argparse.ArgumentParser() 191 | parser.add_argument('--gpuid', type=int, default=0) 192 | parser.add_argument('--nc_in', type=int, default=3) 193 | parser.add_argument('--n_rays', type=int, default=32) 194 | parser.add_argument('--n_sampling', type=int, default=6) 195 | parser.add_argument('--n_cls', type=int, default=6) 196 | args = parser.parse_args() 197 | 198 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpuid) 199 | 200 | run( 201 | DATASET_PATH_IMAGE, MODEL_WEIGHT_PATH, PREDICTION_SAVE_PATH, 202 | args.nc_in, args.n_rays, args.n_sampling, args.n_cls, 203 | ) 204 | -------------------------------------------------------------------------------- /cppnet/stats_utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import numpy as np 3 | from scipy.optimize import linear_sum_assignment 4 | 5 | import cv2 6 | import matplotlib.pyplot as plt 7 | 8 | 9 | # --------------------------Optimised for Speed 10 | def get_fast_aji(true, pred): 11 | """AJI version distributed by MoNuSeg, has no permutation problem but suffered from 12 | over-penalisation similar to DICE2. 13 | 14 | Fast computation requires instance IDs are in contiguous orderding i.e [1, 2, 3, 4] 15 | not [2, 3, 6, 10]. Please call `remap_label` before hand and `by_size` flag has no 16 | effect on the result. 17 | 18 | """ 19 | 20 | if true.any() and not(pred.any()): 21 | return 0.0 22 | if pred.any() and not(true.any()): 23 | return 0.0 24 | 25 | true = np.copy(true) # ? do we need this 26 | pred = np.copy(pred) 27 | true_id_list = list(np.unique(true)) 28 | pred_id_list = list(np.unique(pred)) 29 | 30 | true_masks = [ 31 | None, 32 | ] 33 | for t in true_id_list[1:]: 34 | t_mask = np.array(true == t, np.uint8) 35 | true_masks.append(t_mask) 36 | 37 | pred_masks = [ 38 | None, 39 | ] 40 | for p in pred_id_list[1:]: 41 | p_mask = np.array(pred == p, np.uint8) 42 | pred_masks.append(p_mask) 43 | 44 | # prefill with value 45 | pairwise_inter = np.zeros( 46 | [len(true_id_list) - 1, len(pred_id_list) - 1], dtype=np.float64 47 | ) 48 | pairwise_union = np.zeros( 49 | [len(true_id_list) - 1, len(pred_id_list) - 1], dtype=np.float64 50 | ) 51 | 52 | # caching pairwise 53 | for true_id in true_id_list[1:]: # 0-th is background 54 | t_mask = true_masks[true_id] 55 | pred_true_overlap = pred[t_mask > 0] 56 | pred_true_overlap_id = np.unique(pred_true_overlap) 57 | pred_true_overlap_id = list(pred_true_overlap_id) 58 | for pred_id in pred_true_overlap_id: 59 | if pred_id == 0: # ignore 60 | continue # overlaping background 61 | p_mask = pred_masks[pred_id] 62 | total = (t_mask + p_mask).sum() 63 | inter = (t_mask * p_mask).sum() 64 | pairwise_inter[true_id - 1, pred_id - 1] = inter 65 | pairwise_union[true_id - 1, pred_id - 1] = total - inter 66 | 67 | pairwise_iou = pairwise_inter / (pairwise_union + 1.0e-6) 68 | # pair of pred that give highest iou for each true, dont care 69 | # about reusing pred instance multiple times 70 | paired_pred = np.argmax(pairwise_iou, axis=1) 71 | pairwise_iou = np.max(pairwise_iou, axis=1) 72 | # exlude those dont have intersection 73 | paired_true = np.nonzero(pairwise_iou > 0.0)[0] 74 | paired_pred = paired_pred[paired_true] 75 | # print(paired_true.shape, paired_pred.shape) 76 | overall_inter = (pairwise_inter[paired_true, paired_pred]).sum() 77 | overall_union = (pairwise_union[paired_true, paired_pred]).sum() 78 | 79 | paired_true = list(paired_true + 1) # index to instance ID 80 | paired_pred = list(paired_pred + 1) 81 | # add all unpaired GT and Prediction into the union 82 | unpaired_true = np.array( 83 | [idx for idx in true_id_list[1:] if idx not in paired_true] 84 | ) 85 | unpaired_pred = np.array( 86 | [idx for idx in pred_id_list[1:] if idx not in paired_pred] 87 | ) 88 | for true_id in unpaired_true: 89 | overall_union += true_masks[true_id].sum() 90 | for pred_id in unpaired_pred: 91 | overall_union += pred_masks[pred_id].sum() 92 | 93 | aji_score = overall_inter / overall_union 94 | return aji_score 95 | 96 | 97 | ##### 98 | def get_fast_aji_plus(true, pred): 99 | """AJI+, an AJI version with maximal unique pairing to obtain overall intersecion. 100 | Every prediction instance is paired with at most 1 GT instance (1 to 1) mapping, unlike AJI 101 | where a prediction instance can be paired against many GT instances (1 to many). 102 | Remaining unpaired GT and Prediction instances will be added to the overall union. 103 | The 1 to 1 mapping prevents AJI's over-penalisation from happening. 104 | 105 | Fast computation requires instance IDs are in contiguous orderding i.e [1, 2, 3, 4] 106 | not [2, 3, 6, 10]. Please call `remap_label` before hand and `by_size` flag has no 107 | effect on the result. 108 | 109 | """ 110 | 111 | if true.any() and not(pred.any()): 112 | return 0.0 113 | if pred.any() and not(true.any()): 114 | return 0.0 115 | if not(pred.any()) and not(true.any()): 116 | return 1.0 117 | 118 | true = np.copy(true) # ? do we need this 119 | pred = np.copy(pred) 120 | true_id_list = list(np.unique(true)) 121 | pred_id_list = list(np.unique(pred)) 122 | 123 | true_masks = [ 124 | None, 125 | ] 126 | for t in true_id_list[1:]: 127 | t_mask = np.array(true == t, np.uint8) 128 | true_masks.append(t_mask) 129 | 130 | pred_masks = [ 131 | None, 132 | ] 133 | for p in pred_id_list[1:]: 134 | p_mask = np.array(pred == p, np.uint8) 135 | pred_masks.append(p_mask) 136 | 137 | # prefill with value 138 | pairwise_inter = np.zeros( 139 | [len(true_id_list) - 1, len(pred_id_list) - 1], dtype=np.float64 140 | ) 141 | pairwise_union = np.zeros( 142 | [len(true_id_list) - 1, len(pred_id_list) - 1], dtype=np.float64 143 | ) 144 | 145 | # caching pairwise 146 | for true_id in true_id_list[1:]: # 0-th is background 147 | t_mask = true_masks[true_id] 148 | pred_true_overlap = pred[t_mask > 0] 149 | pred_true_overlap_id = np.unique(pred_true_overlap) 150 | pred_true_overlap_id = list(pred_true_overlap_id) 151 | for pred_id in pred_true_overlap_id: 152 | if pred_id == 0: # ignore 153 | continue # overlaping background 154 | p_mask = pred_masks[pred_id] 155 | total = (t_mask + p_mask).sum() 156 | inter = (t_mask * p_mask).sum() 157 | pairwise_inter[true_id - 1, pred_id - 1] = inter 158 | pairwise_union[true_id - 1, pred_id - 1] = total - inter 159 | # 160 | pairwise_iou = pairwise_inter / (pairwise_union + 1.0e-6) 161 | #### Munkres pairing to find maximal unique pairing 162 | paired_true, paired_pred = linear_sum_assignment(-pairwise_iou) 163 | ### extract the paired cost and remove invalid pair 164 | paired_iou = pairwise_iou[paired_true, paired_pred] 165 | # now select all those paired with iou != 0.0 i.e have intersection 166 | paired_true = paired_true[paired_iou > 0.0] 167 | paired_pred = paired_pred[paired_iou > 0.0] 168 | paired_inter = pairwise_inter[paired_true, paired_pred] 169 | paired_union = pairwise_union[paired_true, paired_pred] 170 | paired_true = list(paired_true + 1) # index to instance ID 171 | paired_pred = list(paired_pred + 1) 172 | overall_inter = paired_inter.sum() 173 | overall_union = paired_union.sum() 174 | # add all unpaired GT and Prediction into the union 175 | unpaired_true = np.array( 176 | [idx for idx in true_id_list[1:] if idx not in paired_true] 177 | ) 178 | unpaired_pred = np.array( 179 | [idx for idx in pred_id_list[1:] if idx not in paired_pred] 180 | ) 181 | for true_id in unpaired_true: 182 | overall_union += true_masks[true_id].sum() 183 | for pred_id in unpaired_pred: 184 | overall_union += pred_masks[pred_id].sum() 185 | # 186 | aji_score = overall_inter / overall_union 187 | return aji_score 188 | 189 | 190 | ##### 191 | def get_fast_pq(true, pred, match_iou=0.5): 192 | """`match_iou` is the IoU threshold level to determine the pairing between 193 | GT instances `p` and prediction instances `g`. `p` and `g` is a pair 194 | if IoU > `match_iou`. However, pair of `p` and `g` must be unique 195 | (1 prediction instance to 1 GT instance mapping). 196 | 197 | If `match_iou` < 0.5, Munkres assignment (solving minimum weight matching 198 | in bipartite graphs) is caculated to find the maximal amount of unique pairing. 199 | 200 | If `match_iou` >= 0.5, all IoU(p,g) > 0.5 pairing is proven to be unique and 201 | the number of pairs is also maximal. 202 | 203 | Fast computation requires instance IDs are in contiguous orderding 204 | i.e [1, 2, 3, 4] not [2, 3, 6, 10]. Please call `remap_label` beforehand 205 | and `by_size` flag has no effect on the result. 206 | 207 | Returns: 208 | [dq, sq, pq]: measurement statistic 209 | 210 | [paired_true, paired_pred, unpaired_true, unpaired_pred]: 211 | pairing information to perform measurement 212 | 213 | """ 214 | assert match_iou >= 0.0, "Cant' be negative" 215 | 216 | if true.any() and not(pred.any()): 217 | return 0.0 218 | if pred.any() and not(true.any()): 219 | return 0.0 220 | if not(pred.any()) and not(true.any()): 221 | return 1.0 222 | 223 | 224 | true = np.copy(true) 225 | pred = np.copy(pred) 226 | true_id_list = list(np.unique(true)) 227 | pred_id_list = list(np.unique(pred)) 228 | 229 | true_masks = [ 230 | None, 231 | ] 232 | for t in true_id_list[1:]: 233 | t_mask = np.array(true == t, np.uint8) 234 | true_masks.append(t_mask) 235 | 236 | pred_masks = [ 237 | None, 238 | ] 239 | for p in pred_id_list[1:]: 240 | p_mask = np.array(pred == p, np.uint8) 241 | pred_masks.append(p_mask) 242 | 243 | # prefill with value 244 | pairwise_iou = np.zeros( 245 | [len(true_id_list) - 1, len(pred_id_list) - 1], dtype=np.float64 246 | ) 247 | 248 | # caching pairwise iou 249 | for true_id in true_id_list[1:]: # 0-th is background 250 | t_mask = true_masks[true_id] 251 | pred_true_overlap = pred[t_mask > 0] 252 | pred_true_overlap_id = np.unique(pred_true_overlap) 253 | pred_true_overlap_id = list(pred_true_overlap_id) 254 | for pred_id in pred_true_overlap_id: 255 | if pred_id == 0: # ignore 256 | continue # overlaping background 257 | p_mask = pred_masks[pred_id] 258 | total = (t_mask + p_mask).sum() 259 | inter = (t_mask * p_mask).sum() 260 | iou = inter / (total - inter) 261 | pairwise_iou[true_id - 1, pred_id - 1] = iou 262 | # 263 | if match_iou >= 0.5: 264 | paired_iou = pairwise_iou[pairwise_iou > match_iou] 265 | pairwise_iou[pairwise_iou <= match_iou] = 0.0 266 | paired_true, paired_pred = np.nonzero(pairwise_iou) 267 | paired_iou = pairwise_iou[paired_true, paired_pred] 268 | paired_true += 1 # index is instance id - 1 269 | paired_pred += 1 # hence return back to original 270 | else: # * Exhaustive maximal unique pairing 271 | #### Munkres pairing with scipy library 272 | # the algorithm return (row indices, matched column indices) 273 | # if there is multiple same cost in a row, index of first occurence 274 | # is return, thus the unique pairing is ensure 275 | # inverse pair to get high IoU as minimum 276 | paired_true, paired_pred = linear_sum_assignment(-pairwise_iou) 277 | ### extract the paired cost and remove invalid pair 278 | paired_iou = pairwise_iou[paired_true, paired_pred] 279 | 280 | # now select those above threshold level 281 | # paired with iou = 0.0 i.e no intersection => FP or FN 282 | paired_true = list(paired_true[paired_iou > match_iou] + 1) 283 | paired_pred = list(paired_pred[paired_iou > match_iou] + 1) 284 | paired_iou = paired_iou[paired_iou > match_iou] 285 | 286 | # get the actual FP and FN 287 | unpaired_true = [idx for idx in true_id_list[1:] if idx not in paired_true] 288 | unpaired_pred = [idx for idx in pred_id_list[1:] if idx not in paired_pred] 289 | # print(paired_iou.shape, paired_true.shape, len(unpaired_true), len(unpaired_pred)) 290 | 291 | # 292 | tp = len(paired_true) 293 | fp = len(unpaired_pred) 294 | fn = len(unpaired_true) 295 | # get the F1-score i.e DQ 296 | dq = tp / (tp + 0.5 * fp + 0.5 * fn) 297 | # get the SQ, no paired has 0 iou so not impact 298 | sq = paired_iou.sum() / (tp + 1.0e-6) 299 | 300 | return [dq, sq, dq * sq], [paired_true, paired_pred, unpaired_true, unpaired_pred] 301 | 302 | 303 | ##### 304 | def get_fast_dice_2(true, pred): 305 | """Ensemble dice.""" 306 | 307 | if true.any() and not(pred.any()): 308 | return 0.0 309 | if pred.any() and not(true.any()): 310 | return 0.0 311 | if not(pred.any()) and not(true.any()): 312 | return 1.0 313 | 314 | true = np.copy(true) 315 | pred = np.copy(pred) 316 | true_id = list(np.unique(true)) 317 | pred_id = list(np.unique(pred)) 318 | 319 | overall_total = 0 320 | overall_inter = 0 321 | 322 | true_masks = [np.zeros(true.shape)] 323 | for t in true_id[1:]: 324 | t_mask = np.array(true == t, np.uint8) 325 | true_masks.append(t_mask) 326 | 327 | pred_masks = [np.zeros(true.shape)] 328 | for p in pred_id[1:]: 329 | p_mask = np.array(pred == p, np.uint8) 330 | pred_masks.append(p_mask) 331 | 332 | for true_idx in range(1, len(true_id)): 333 | t_mask = true_masks[true_idx] 334 | pred_true_overlap = pred[t_mask > 0] 335 | pred_true_overlap_id = np.unique(pred_true_overlap) 336 | pred_true_overlap_id = list(pred_true_overlap_id) 337 | try: # blinly remove background 338 | pred_true_overlap_id.remove(0) 339 | except ValueError: 340 | pass # just mean no background 341 | for pred_idx in pred_true_overlap_id: 342 | p_mask = pred_masks[pred_idx] 343 | total = (t_mask + p_mask).sum() 344 | inter = (t_mask * p_mask).sum() 345 | overall_total += total 346 | overall_inter += inter 347 | 348 | return 2 * overall_inter / overall_total 349 | 350 | 351 | #####--------------------------As pseudocode 352 | def get_dice_1(true, pred): 353 | """Traditional dice.""" 354 | # cast to binary 1st 355 | true = np.copy(true) 356 | pred = np.copy(pred) 357 | true[true > 0] = 1 358 | pred[pred > 0] = 1 359 | inter = true * pred 360 | denom = true + pred 361 | return 2.0 * np.sum(inter) / max(np.sum(denom), 1.0) 362 | 363 | 364 | #### 365 | def get_dice_2(true, pred): 366 | """Ensemble Dice as used in Computational Precision Medicine Challenge.""" 367 | true = np.copy(true) 368 | pred = np.copy(pred) 369 | true_id = list(np.unique(true)) 370 | pred_id = list(np.unique(pred)) 371 | # remove background aka id 0 372 | true_id.remove(0) 373 | pred_id.remove(0) 374 | 375 | total_markup = 0 376 | total_intersect = 0 377 | for t in true_id: 378 | t_mask = np.array(true == t, np.uint8) 379 | for p in pred_id: 380 | p_mask = np.array(pred == p, np.uint8) 381 | intersect = p_mask * t_mask 382 | if intersect.sum() > 0: 383 | total_intersect += intersect.sum() 384 | total_markup += t_mask.sum() + p_mask.sum() 385 | return 2 * total_intersect / total_markup 386 | 387 | 388 | ##### 389 | def remap_label(pred, by_size=False): 390 | """Rename all instance id so that the id is contiguous i.e [0, 1, 2, 3] 391 | not [0, 2, 4, 6]. The ordering of instances (which one comes first) 392 | is preserved unless by_size=True, then the instances will be reordered 393 | so that bigger nucler has smaller ID. 394 | 395 | Args: 396 | pred : the 2d array contain instances where each instances is marked 397 | by non-zero integer 398 | by_size : renaming with larger nuclei has smaller id (on-top) 399 | 400 | """ 401 | pred_id = list(np.unique(pred)) 402 | pred_id.remove(0) 403 | if len(pred_id) == 0: 404 | return pred # no label 405 | if by_size: 406 | pred_size = [] 407 | for inst_id in pred_id: 408 | size = (pred == inst_id).sum() 409 | pred_size.append(size) 410 | # sort the id by size in descending order 411 | pair_list = zip(pred_id, pred_size) 412 | pair_list = sorted(pair_list, key=lambda x: x[1], reverse=True) 413 | pred_id, pred_size = zip(*pair_list) 414 | 415 | new_pred = np.zeros(pred.shape, np.int32) 416 | for idx, inst_id in enumerate(pred_id): 417 | new_pred[pred == inst_id] = idx + 1 418 | return new_pred 419 | 420 | 421 | ##### 422 | def pair_coordinates(setA, setB, radius): 423 | """Use the Munkres or Kuhn-Munkres algorithm to find the most optimal 424 | unique pairing (largest possible match) when pairing points in set B 425 | against points in set A, using distance as cost function. 426 | 427 | Args: 428 | setA, setB: np.array (float32) of size Nx2 contains the of XY coordinate 429 | of N different points 430 | radius: valid area around a point in setA to consider 431 | a given coordinate in setB a candidate for match 432 | Return: 433 | pairing: pairing is an array of indices 434 | where point at index pairing[0] in set A paired with point 435 | in set B at index pairing[1] 436 | unparedA, unpairedB: remaining poitn in set A and set B unpaired 437 | 438 | """ 439 | # * Euclidean distance as the cost matrix 440 | setA_tile = np.expand_dims(setA, axis=1) 441 | setB_tile = np.expand_dims(setB, axis=0) 442 | setA_tile = np.repeat(setA_tile, setB.shape[0], axis=1) 443 | setB_tile = np.repeat(setB_tile, setA.shape[0], axis=0) 444 | pair_distance = (setA_tile - setB_tile) ** 2 445 | # set A is row, and set B is paired against set A 446 | pair_distance = np.sqrt(np.sum(pair_distance, axis=-1)) 447 | 448 | # * Munkres pairing with scipy library 449 | # the algorithm return (row indices, matched column indices) 450 | # if there is multiple same cost in a row, index of first occurence 451 | # is return, thus the unique pairing is ensured 452 | indicesA, paired_indicesB = linear_sum_assignment(pair_distance) 453 | 454 | # extract the paired cost and remove instances 455 | # outside of designated radius 456 | pair_cost = pair_distance[indicesA, paired_indicesB] 457 | 458 | pairedA = indicesA[pair_cost <= radius] 459 | pairedB = paired_indicesB[pair_cost <= radius] 460 | 461 | unpairedA = [idx for idx in range(setA.shape[0]) if idx not in list(pairedA)] 462 | unpairedB = [idx for idx in range(setB.shape[0]) if idx not in list(pairedB)] 463 | 464 | pairing = np.array(list(zip(pairedA, pairedB))) 465 | unpairedA = np.array(unpairedA, dtype=np.int64) 466 | unpairedB = np.array(unpairedB, dtype=np.int64) 467 | 468 | return pairing, unpairedA, unpairedB 469 | -------------------------------------------------------------------------------- /cppnet/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import sys 4 | from load_save_model import checkpoint_save_stage 5 | import os 6 | from torch.optim.lr_scheduler import ReduceLROnPlateau 7 | 8 | import torch.nn as nn 9 | 10 | class Trainer(): 11 | def __init__(self, loss = None, metric = None,log_dir=None, validate_every=1,verborrea = True): 12 | self.loss_ce = loss 13 | self.metric = metric 14 | self.verborrea = verborrea 15 | self.USE_CUDA = torch.cuda.is_available() 16 | 17 | self.validate_every = validate_every 18 | if log_dir is not None: 19 | from tensorboard_local import TensorBoard 20 | self.tb_logger = TensorBoard(log_dir, 20) # log every 20th image 21 | else: 22 | self.tb_logger = None 23 | 24 | def pretrain(self, model, DataSet, optimizer, epoch): #eval is not correct in the method 25 | _loss=0 26 | _correct=0 27 | model.train() 28 | kwargs ={} 29 | kwargs['display'] = True 30 | for iepoch in range(epoch): 31 | for batch_idx, data in enumerate(DataSet): 32 | if len(data) == 3: 33 | inputs, target,distances = data 34 | if self.USE_CUDA: 35 | inputs, distances, target = inputs.cuda(),distances.cuda(),target.cuda() 36 | model.cuda() 37 | kwargs['labels'] = target 38 | else: 39 | inputs,distances = data 40 | if self.USE_CUDA: 41 | inputs, distances = inputs.cuda(),distances.cuda() 42 | model.cuda() 43 | optimizer.zero_grad() 44 | ## While training, using ground truth distances for sampling 45 | prediction = model(inputs, distances) 46 | 47 | distances= distances.squeeze(1) 48 | total_loss, total_metric = self.loss_ce(prediction, distances,**kwargs) 49 | total_loss.backward() 50 | optimizer.step() 51 | _loss += total_metric.item() 52 | if self.metric is not None: 53 | prediction_final = [prediction[2], prediction[3]] 54 | _correct += self.metric(prediction_final,distances) 55 | 56 | _loss_average=_loss/len(DataSet.dataset)/epoch 57 | if self.metric is not None: 58 | _acc=_correct/float(batch_idx+1)# Average over all batches 59 | if self.verborrea: 60 | print('Accuracy: ',_acc.item()) 61 | if self.verborrea: 62 | print('Loss: ',_loss) 63 | print('Average Loss: ',_loss_average) 64 | 65 | 66 | print('Current Erosion List: ') 67 | e_list = model.erosion_factor_list 68 | if isinstance(e_list, nn.Parameter): 69 | print(e_list.data.cpu().numpy()) 70 | else: 71 | print(e_list) 72 | 73 | def Train(self,model, optimizer, TrainSet, TestSet, Train_mode, Model_name, Dataset, epochs=None, scheduler=None): 74 | if self.loss_ce is None: 75 | print("Loss function not set,exiting...") 76 | sys.exit() 77 | 78 | if scheduler is None and epochs is None: 79 | print('WARNING!!!!Creating default min scheduler') 80 | scheduler = ReduceLROnPlateau(optimizer, "min", verbose=True, patience=10, eps=1e-8) 81 | path_checkpoint = os.getcwd()+'/CHECKPOINT/checkpoint_'+Model_name+'_'+Train_mode+'_'+Dataset+'/CHECKPOINT.t7' 82 | print('Checkpoint path',path_checkpoint) 83 | scheduler_mode = scheduler.mode 84 | 85 | max_lr,list_lr = self.update_list_lr(optimizer) 86 | trainloss_to_fil=[] 87 | testloss_to_fil=[] 88 | trainMetric_to_fil=[] 89 | testMetric_to_fil=[] 90 | 91 | if isinstance(scheduler,ReduceLROnPlateau): 92 | patience_num=scheduler.patience 93 | else: 94 | print('Scheduler not supported. But training will continue if epochs are specified.') 95 | if epochs==None: 96 | print('WARNING!!!! Number of epochs not specified') 97 | sys.exit() 98 | patience_num='nothing' 99 | 100 | parameters=[[],[],patience_num,optimizer.param_groups[0]['weight_decay']]#first list for epochs, second for learning rate,3rd patience, 4th weight_decay,5 for time 101 | parameters[1].append(list_lr) 102 | 103 | epoch=0 104 | if epochs==0: 105 | keep_training=False 106 | else: 107 | keep_training=True 108 | print ('INITIAL TEST STATISTICS') 109 | loss_test, metric = self.evaluate(model, TestSet) 110 | checkpoint_save_stage(model,trainloss_to_fil,testloss_to_fil,trainMetric_to_fil,testMetric_to_fil,parameters,Model_name,Train_mode,Dataset) 111 | check_load=0 112 | 113 | if isinstance(scheduler,ReduceLROnPlateau): 114 | if scheduler_mode == 'min': 115 | scheduler.step(loss_test) 116 | else: 117 | scheduler.step(metric) 118 | else: 119 | best_test = loss_test 120 | scheduler.step() 121 | since_init=time.time() 122 | while keep_training: 123 | epoch=epoch+1 124 | if epochs !=None: 125 | if self.verborrea: 126 | print('Epoch {}/{}, lr={}. patience={}, weight decay={}'.format(epoch, epochs,max_lr,scheduler.patience,optimizer.param_groups[0]['weight_decay'])) 127 | else: 128 | if self.verborrea: 129 | print('Epoch {}, lr={}, patience={}, weight decay={}'.format(epoch,max_lr,scheduler.patience,optimizer.param_groups[0]['weight_decay'])) 130 | 131 | if self.verborrea: 132 | print('-' * 20) 133 | 134 | if self.verborrea: 135 | print ('TRAIN STATISTICS') 136 | train_loss,train_metric= self.train_scratch(model,TrainSet,optimizer,epoch) #Training happens here! 137 | 138 | if epoch % self.validate_every == 0 : 139 | if self.verborrea: 140 | print ('TEST STATISTICS') 141 | print('Validating at epoch',epoch) 142 | test_loss,test_metric= self.evaluate(model,TestSet,epoch) 143 | 144 | trainloss_to_fil.append(train_loss) 145 | testloss_to_fil.append(test_loss) 146 | trainMetric_to_fil.append(train_metric) 147 | testMetric_to_fil.append(test_metric) 148 | 149 | if isinstance(scheduler,ReduceLROnPlateau): 150 | prev_num_bad_epochs=scheduler.num_bad_epochs 151 | if self.verborrea: 152 | print('-' * 10) 153 | if scheduler_mode =='min': 154 | save=(test_loss< scheduler.best) 155 | scheduler.step(test_loss) 156 | else: 157 | save=(test_metric>scheduler.best) 158 | scheduler.step(test_metric) 159 | print('Best', scheduler.best) 160 | 161 | if save: 162 | checkpoint_save_stage(model,trainloss_to_fil,testloss_to_fil,trainMetric_to_fil,testMetric_to_fil,parameters,Model_name,Train_mode,Dataset) 163 | check_load=0 164 | if scheduler.num_bad_epochs==0 and prev_num_bad_epochs==scheduler.patience and not save: 165 | max_lr,list_lr=self.update_list_lr(optimizer) 166 | parameters[0].append(epoch) 167 | parameters[1].append(max_lr) 168 | model.load_state_dict(torch.load(path_checkpoint)) 169 | check_load=check_load+1 170 | if self.verborrea: print ('Checkpoint loaded') 171 | 172 | if max_lr<10*scheduler.eps or check_load==6: 173 | keep_training=False 174 | else: 175 | prev_max_lr=max_lr 176 | 177 | scheduler.step() 178 | max_lr,list_lr = self.update_list_lr(optimizer) 179 | if test_loss<=best_test: 180 | checkpoint_save_stage(model,trainloss_to_fil,testloss_to_fil,trainMetric_to_fil,testMetric_to_fil,parameters,Model_name,Train_mode,Dataset) 181 | if max_lrscheduler.best) 158 | scheduler.step(test_metric) 159 | print('Best', scheduler.best) 160 | 161 | if save: 162 | checkpoint_save_stage(model,trainloss_to_fil,testloss_to_fil,trainMetric_to_fil,testMetric_to_fil,parameters,Model_name,Train_mode,Dataset) 163 | check_load=0 164 | if scheduler.num_bad_epochs==0 and prev_num_bad_epochs==scheduler.patience and not save: 165 | max_lr,list_lr=self.update_list_lr(optimizer) 166 | parameters[0].append(epoch) 167 | parameters[1].append(max_lr) 168 | model.load_state_dict(torch.load(path_checkpoint)) 169 | check_load=check_load+1 170 | if self.verborrea: print ('Checkpoint loaded') 171 | 172 | if max_lr<10*scheduler.eps or check_load==6: 173 | keep_training=False 174 | else: 175 | prev_max_lr=max_lr 176 | 177 | scheduler.step() 178 | max_lr,list_lr = self.update_list_lr(optimizer) 179 | if test_loss<=best_test: 180 | checkpoint_save_stage(model,trainloss_to_fil,testloss_to_fil,trainMetric_to_fil,testMetric_to_fil,parameters,Model_name,Train_mode,Dataset) 181 | if max_lrself.max_dist] = self.max_dist 40 | obj_probabilities = edt_prob(target) 41 | 42 | if self.resz is not None: 43 | obj_probabilities = resize(obj_probabilities, self.resz, order=1, preserve_range=True) 44 | distances = resize(distances, self.resz, order=1, preserve_range=True) 45 | 46 | distances = np.transpose(distances, (2,0,1)) 47 | obj_probabilities = np.expand_dims(obj_probabilities,0) 48 | 49 | seg = (target>0).astype(np.float32) 50 | 51 | h, w = target.shape 52 | 53 | cset = np.unique(target[target>0]) 54 | bndmap = np.zeros(target.shape, dtype=np.float32) 55 | cxmap = np.zeros(target.shape, dtype=np.float32) 56 | cymap = np.zeros(target.shape, dtype=np.float32) 57 | chmap = np.zeros(target.shape, dtype=np.float32) 58 | cwmap = np.zeros(target.shape, dtype=np.float32) 59 | 60 | for ic in cset: 61 | icmap = target==ic 62 | bndmap += np.logical_xor(ndimage.morphology.binary_dilation(icmap, iterations=2), icmap).astype(np.float32) 63 | cx, cy = np.nonzero(icmap) 64 | cxmap[icmap] = cx.mean() / h 65 | cymap[icmap] = cy.mean() / w 66 | chmap[icmap] = cx.max()-cx.min() 67 | cwmap[icmap] = cy.max()-cy.min() 68 | bndmap[bndmap>1] = 1.0 69 | 70 | # if random.random()>=0.5: 71 | # sigma = random.random()*2 72 | # distances = ndimage.gaussian_filter(distances, sigma=sigma, mode='reflect') 73 | # obj_probabilities = ndimage.gaussian_filter(obj_probabilities, sigma=sigma, mode='reflect') 74 | input_stardist = np.concatenate((obj_probabilities, distances), axis=0) 75 | 76 | segbnd = np.stack((seg, bndmap), axis=0) 77 | bbox = np.stack((cxmap, cymap, chmap, cwmap), axis=0) 78 | 79 | return input_stardist, segbnd, bbox 80 | 81 | def getDataLoaders(n_rays, root_dir, type_list=['train', 'test'], batch_size=1, resz=None, max_dist=65,): 82 | trainset = DSB2018Dataset(root_dir=root_dir+'/'+type_list[0]+'/', n_rays=n_rays, max_dist=max_dist, if_training=True, resz=resz) 83 | testset = DSB2018Dataset(root_dir=root_dir+'/'+type_list[1]+'/', n_rays=n_rays, max_dist=max_dist, if_training=False, resz=resz) 84 | trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4) 85 | testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4) 86 | return trainloader,testloader -------------------------------------------------------------------------------- /feature_extractor/dataloader_aug_pannuke_cls.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.data import Dataset, DataLoader 3 | from skimage import io 4 | from skimage.transform import resize 5 | import numpy as np 6 | from stardist import star_dist,edt_prob 7 | from csbdeep.utils import normalize 8 | import random 9 | from scipy import ndimage 10 | import scipy.io as scio 11 | 12 | class PanNukeDataset(Dataset): 13 | def __init__(self, root_dir, n_rays, max_dist=None, if_training=False, resz=None): 14 | self.img_filefold = os.path.join(root_dir,'images') 15 | self.target_filefold = os.path.join(root_dir,'masks') 16 | 17 | self.img_list = [] 18 | self.target_list = [] 19 | with open(os.path.join(self.img_filefold, 'name_list.txt'), 'r') as f: 20 | for line in f.readlines(): 21 | line_terms = line.split(',') 22 | self.img_list.append(line_terms[1].strip()) 23 | for ic in range(0, 5): 24 | with open(os.path.join(self.target_filefold, 'name_list_c'+str(ic)+'.txt'), 'r') as f: 25 | ic_target_list = [] 26 | for line in f.readlines(): 27 | line_terms = line.split(',') 28 | ic_target_list.append(line_terms[1].strip()) 29 | self.target_list.append(ic_target_list) 30 | 31 | self.n_rays = n_rays 32 | self.max_dist = max_dist 33 | self.if_training=if_training 34 | self.resz = resz 35 | 36 | def __len__(self): 37 | return len(self.img_list) 38 | 39 | def __getitem__(self, idx): 40 | target = [] 41 | for ic in range(5): 42 | ic_target = io.imread(self.target_list[ic][idx]) 43 | if ic > 0: 44 | last_target_max = target[ic-1].max() 45 | ic_target[ic_target>0] += last_target_max 46 | target.append(ic_target) 47 | target = np.stack(target, axis=2) 48 | seg_target = np.concatenate([(target.max(axis=2, keepdims=True)==0).astype(np.float32), (target > 0).astype(np.float32)], axis=2) 49 | target = target.max(axis=2) 50 | 51 | if self.if_training: 52 | aug_type = random.randint(0, 5) # rot90: 0, 1, 2; flip: 3, 4; ori: 5 53 | if aug_type<=2: 54 | target = np.rot90(target, aug_type).copy() 55 | elif aug_type<=4: 56 | target = np.flip(target, aug_type-3).copy() 57 | distances = star_dist(target, self.n_rays) 58 | if self.max_dist: 59 | distances[distances>self.max_dist] = self.max_dist 60 | obj_probabilities = edt_prob(target) 61 | 62 | if self.resz is not None: 63 | obj_probabilities = resize(obj_probabilities, self.resz, order=1, preserve_range=True) 64 | distances = resize(distances, self.resz, order=1, preserve_range=True) 65 | 66 | distances = np.transpose(distances, (2,0,1)) 67 | obj_probabilities = np.expand_dims(obj_probabilities,0) 68 | seg_target = np.transpose(seg_target, (2,0,1)) 69 | 70 | 71 | seg = (target>0).astype(np.float32) 72 | 73 | h, w = target.shape 74 | 75 | cset = np.unique(target[target>0]) 76 | bndmap = np.zeros(target.shape, dtype=np.float32) 77 | cxmap = np.zeros(target.shape, dtype=np.float32) 78 | cymap = np.zeros(target.shape, dtype=np.float32) 79 | chmap = np.zeros(target.shape, dtype=np.float32) 80 | cwmap = np.zeros(target.shape, dtype=np.float32) 81 | 82 | for ic in cset: 83 | icmap = target==ic 84 | bndmap += np.logical_xor(ndimage.morphology.binary_dilation(icmap, iterations=2), icmap).astype(np.float32) 85 | cx, cy = np.nonzero(icmap) 86 | cxmap[icmap] = cx.mean() / h 87 | cymap[icmap] = cy.mean() / w 88 | chmap[icmap] = cx.max()-cx.min() 89 | cwmap[icmap] = cy.max()-cy.min() 90 | bndmap[bndmap>1] = 1.0 91 | 92 | # if random.random()>=0.5: 93 | # sigma = random.random()*2 94 | # distances = ndimage.gaussian_filter(distances, sigma=sigma, mode='reflect') 95 | # obj_probabilities = ndimage.gaussian_filter(obj_probabilities, sigma=sigma, mode='reflect') 96 | input_stardist = np.concatenate((obj_probabilities, distances, seg_target), axis=0) 97 | 98 | segbnd = np.stack((seg, bndmap), axis=0) 99 | segbnd = np.concatenate((segbnd, seg_target), axis=0) 100 | bbox = np.stack((cxmap, cymap, chmap, cwmap), axis=0) 101 | 102 | return input_stardist, segbnd, bbox 103 | 104 | def getDataLoaders(n_rays, max_dist, root_dir, type_list=['train', 'test'], batch_size=1, resz=None): 105 | trainset = PanNukeDataset(root_dir=root_dir+'/'+type_list[0]+'/', n_rays=n_rays, max_dist=max_dist, if_training=True, resz=resz) 106 | testset = PanNukeDataset(root_dir=root_dir+'/'+type_list[1]+'/', n_rays=n_rays, max_dist=max_dist, if_training=False, resz=resz) 107 | trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4) 108 | testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4) 109 | return trainloader,testloader -------------------------------------------------------------------------------- /feature_extractor/instance_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | class InstanceLoss(torch.nn.Module): 5 | def __init__(self,scale=[1, 1], n_cls=0): 6 | super(InstanceLoss, self).__init__() 7 | 8 | self.scale = scale 9 | self.n_cls = n_cls 10 | if self.n_cls <= 1: 11 | assert len(scale) == 2 12 | else: 13 | assert len(scale) == 3 14 | 15 | def forward(self, prediction, gt_segbnd, **kwargs): 16 | 17 | segbnd = prediction[0] 18 | bbox = prediction[1] 19 | gt_bbox = kwargs['bbox'] 20 | gt_seg = gt_segbnd[:, 0] 21 | if self.scale[1]>0: 22 | bboxloss = F.l1_loss(bbox, gt_bbox, size_average=False, reduce=False)*gt_seg.unsqueeze(dim=1) 23 | bboxloss = torch.mean(bboxloss) 24 | else: 25 | bboxloss = 0.0 26 | if self.scale[0]>0: 27 | segbndloss = F.binary_cross_entropy(segbnd, gt_segbnd, weight=None, size_average=True, reduce=True) 28 | else: 29 | segbndloss = 0.0 30 | loss = self.scale[0]*segbndloss + self.scale[1]*bboxloss 31 | 32 | if self.n_cls > 1: 33 | gt_cls = gt_segbnd[:, 2:] 34 | pred_cls_log = prediction[2].log() 35 | if self.scale[2]>0: 36 | clsloss = F.kl_div(pred_cls_log, gt_cls, size_average=True, reduce=True) 37 | else: 38 | clsloss = 0.0 39 | loss += self.scale[2]*clsloss 40 | 41 | print('loss: {:.5f}, segbndloss: {:.5f}, bboxloss: {:.5f}, '\ 42 | .format(loss, segbndloss, bboxloss, )) 43 | 44 | return loss 45 | 46 | -------------------------------------------------------------------------------- /feature_extractor/load_save_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import shutil 4 | 5 | def load_model(model,Model_name,Train_mode,Dataset): 6 | Model_name=Model_name.upper() 7 | Train_mode=Train_mode.upper() 8 | Dataset =Dataset.upper() 9 | filepath= os.getcwd()+'/'+Dataset+'/'+Train_mode+'/'+Model_name+'/'+Model_name+'_'+Train_mode+'_'+Dataset+'.t7' 10 | ############################# 11 | print('File to be loaded:'+filepath) 12 | if os.path.isfile(filepath): 13 | try: 14 | model=model.module #For DATAPARALLEL 15 | except: 16 | pass 17 | print('Loading File: '+filepath) 18 | model.load_state_dict(torch.load(filepath)) 19 | return model 20 | else: 21 | print ('WARNING!!!: Weight of '+Model_name+' not loaded. No Existing file') 22 | return model 23 | 24 | 25 | def save_model(model,trainAcc_to_file,testAcc_to_file,trainloss_to_file,testloss_to_file,Parameters, 26 | Model_name,Train_mode,Dataset,model2=None,**kwargs): 27 | try: 28 | model=model.module 29 | except: 30 | pass 31 | 32 | path= kwargs['save_path'] 33 | if not os.path.exists(path): 34 | os.makedirs(path) 35 | 36 | stage='' 37 | if model2 is not None: 38 | weights_filename1=Model_name+'_'+Train_mode+'_'+Dataset+'_1.t7' 39 | weights_filename2=Model_name+'_'+Train_mode+'_'+Dataset+'_2.t7' 40 | torch.save(model.state_dict(),path+weights_filename1) 41 | torch.save(model2.state_dict(),path+weights_filename2) 42 | else: 43 | weights_filename=Model_name+'_'+Train_mode+'_'+Dataset+'.t7' 44 | torch.save(model.state_dict(),path+weights_filename) 45 | print(path+weights_filename+' saved') 46 | 47 | if testAcc_to_file is not None: 48 | testacc_filename='Testacc_'+stage+Model_name+'_'+Train_mode+'_'+Dataset+'.csv' 49 | if os.path.isfile(path+testacc_filename): 50 | thefile = open(path+testacc_filename, 'a') 51 | else: 52 | thefile = open(path+testacc_filename, 'w') 53 | for item in testAcc_to_file: 54 | thefile.write("%s," % item) 55 | thefile.close() 56 | 57 | if testloss_to_file is not None: 58 | testloss_filename='Testloss_'+stage+Model_name+'_'+Train_mode+'_'+Dataset+'.csv' 59 | if os.path.isfile(path+testloss_filename): 60 | thefile = open(path+testloss_filename, 'a') 61 | else: 62 | thefile = open(path+testloss_filename, 'w') 63 | for item in testloss_to_file: 64 | thefile.write("%s," % item) 65 | thefile.close() 66 | 67 | if trainloss_to_file is not None: 68 | trainloss_filename='Trainloss_'+stage+Model_name+'_'+Train_mode+'_'+Dataset+'.csv' 69 | if os.path.isfile(path+trainloss_filename): 70 | thefile = open(path+trainloss_filename, 'a') 71 | else: 72 | thefile = open(path+trainloss_filename, 'w') 73 | for item in trainloss_to_file: 74 | thefile.write("%s," % item) 75 | thefile.close() 76 | 77 | if trainAcc_to_file is not None: 78 | trainacc_filename='Trainacc_'+stage+Model_name+'_'+Train_mode+'_'+Dataset+'.csv' 79 | if os.path.isfile(path+trainacc_filename): 80 | thefile = open(path+trainacc_filename, 'a') 81 | else: 82 | thefile = open(path+trainacc_filename, 'w') 83 | for item in trainAcc_to_file: 84 | thefile.write("%s," % item) 85 | thefile.close() 86 | 87 | param_filename='Parameters_'+Model_name+'_'+Train_mode+'_'+Dataset+'.txt' 88 | if os.path.isfile(path+param_filename): 89 | thefile = open(path+param_filename, 'a') 90 | else: 91 | thefile = open(path+param_filename, 'w') 92 | thefile.write('%s \n' %stage) 93 | thefile.write("Patience_scheduler=%s, Weight_decay=%s \n" %(Parameters[2],Parameters[3])) 94 | if not Parameters[1][0][1:] == Parameters[1][0][:-1]: 95 | for i in range(len(Parameters[1][0])): 96 | thefile.write("Initial learning rate for param_groups %s is %s epochs \n" %(str(i),Parameters[1][0][i])) 97 | else: 98 | thefile.write("Initial learning rate is %s epochs \n" %Parameters[1][0][0]) 99 | thefile.write("\n\n" ) 100 | 101 | for epoch,lr in zip(Parameters[0],Parameters[1][1:]): 102 | thefile.write("In epoch %s, maximum of the learning rates decreased to %s \n" %(epoch, lr)) 103 | thefile.write("Trained for %s epochs \n\n" %Parameters[0][-1]) 104 | 105 | thefile.write("Train Statistics \n") 106 | if trainAcc_to_file is not None: 107 | thefile.write('Accuracy: %s \n' %trainAcc_to_file[-1]) 108 | thefile.write('Average Loss: %s \n\n'%trainloss_to_file[-1]) 109 | 110 | thefile.write("Test Statistics \n") 111 | if testAcc_to_file is not None: 112 | thefile.write('Accuracy: %s \n' %testAcc_to_file[-1]) 113 | for i in range(len(testAcc_to_file)): 114 | if testAcc_to_file[i]==testAcc_to_file[-1]: 115 | break 116 | if i+1==len(testAcc_to_file): 117 | i=-1 118 | thefile.write('Maximum test accuracy in epoch %s (if 0 it means that the initial state was the best)\n\n'%str(i+1)) 119 | 120 | thefile.write('Average Loss: %s \n\n'%testloss_to_file[-1]) 121 | thefile.write('Total time elapsed %s\n\n' %Parameters[4]) 122 | thefile.write('Note: %s\n\n' %kwargs['additional_notes']) 123 | thefile.write(20*'-'+'\n\n') 124 | thefile.close() 125 | print(os.getcwd()+'/CHECKPOINT/checkpoint_'+Model_name+'_'+Train_mode+'_'+Dataset) 126 | shutil.rmtree(os.getcwd()+'/CHECKPOINT/checkpoint_'+Model_name+'_'+Train_mode+'_'+Dataset) 127 | 128 | 129 | ''' 130 | def checkpoint_save(model,trainAcc_to_file,testAcc_to_file,trainloss_to_file,testloss_to_file,Parameters,Model_name,Train_mode,Dataset): 131 | 132 | path=os.getcwd()+'/CHECKPOINT/checkpoint_'+Model_name+'_'+Train_mode+'_'+Dataset 133 | if not os.path.exists(path): 134 | os.makedirs(path) 135 | 136 | torch.save(model.state_dict(),path+'/CHECKPOINT.t7') 137 | print(path+'/CHECKPOINT.t7'+' saved') 138 | 139 | thefile = open(path+'/Testacc_CHECKPOINT.csv', 'w') 140 | for item in testAcc_to_file: 141 | thefile.write("%s," % item) 142 | thefile.close() 143 | 144 | 145 | thefile = open(path+'/Testloss_CHECKPOINT.csv', 'w') 146 | for item in testloss_to_file: 147 | thefile.write("%s," % item) 148 | thefile.close() 149 | 150 | thefile = open(path+'/Trainloss_CHECKPOINT.csv', 'w') 151 | for item in trainloss_to_file: 152 | thefile.write("%s," % item) 153 | thefile.close() 154 | 155 | 156 | thefile = open(path+'/Trainacc_CHECKPOINT.csv', 'w') 157 | for item in trainAcc_to_file: 158 | thefile.write("%s," % item) 159 | thefile.close() 160 | 161 | thefile = open(path+'/Parameters_CHECKPOINT.txt', 'w') 162 | 163 | thefile.write("Patience_scheduler=%s, Weight_decay=%s \n" %(Parameters[2],Parameters[3])) 164 | if not Parameters[1][0][1:] == Parameters[1][0][:-1]: 165 | for i in range(len(Parameters[1][0])): 166 | thefile.write("Initial learning rate for param_grooups %s is %s epochs \n" %(str(i),Parameters[1][0][i])) 167 | else: 168 | thefile.write("Initial learning rate is %s epochs \n" %Parameters[1][0][0]) 169 | 170 | for epoch,lr in zip(Parameters[0],Parameters[1][1:]): 171 | thefile.write("In epoch %s, maximum learning rate decreased to %s \n" %(epoch, lr)) 172 | if not(Parameters[0]==[]): 173 | thefile.write("Trained for %s epochs \n" %Parameters[0][-1]) 174 | thefile.write("\n\n" ) 175 | if not(trainAcc_to_file==[]): 176 | thefile.write("Train Statistics \n") 177 | thefile.write('Accuracy: %s \n' %trainAcc_to_file[-1]) 178 | thefile.write('Average Loss: %s \n\n'%trainloss_to_file[-1]) 179 | 180 | thefile.write("Test Statistics \n") 181 | thefile.write('Accuracy: %s \n' %testAcc_to_file[-1]) 182 | thefile.write('Average Loss: %s \n\n'%testloss_to_file[-1]) 183 | thefile.write(20*'-'+'\n\n') 184 | thefile.close() 185 | 186 | 187 | ###################################################################################################### 188 | ''' 189 | def checkpoint_save_stage(model,trainloss_to_file,testloss_to_file,train_metric_to_file,test_metric_to_file,Parameters,Model_name,Train_mode,Dataset,model2=None): 190 | 191 | path=os.getcwd()+'/CHECKPOINT/checkpoint_'+Model_name+'_'+Train_mode+'_'+Dataset 192 | if not os.path.exists(path): 193 | os.makedirs(path) 194 | 195 | if model2 is not None: 196 | torch.save(model.state_dict(),path+'/CHECKPOINT1.t7') 197 | torch.save(model2.state_dict(),path+'/CHECKPOINT2.t7') 198 | else: 199 | torch.save(model.state_dict(),path+'/CHECKPOINT.t7') 200 | print(path+'/CHECKPOINT.t7'+' saved') 201 | 202 | thefile = open(path+'/Testloss_CHECKPOINT.csv', 'w') 203 | for item in testloss_to_file: 204 | thefile.write("%s," % item) 205 | thefile.close() 206 | 207 | thefile = open(path+'/Trainloss_CHECKPOINT.csv', 'w') 208 | for item in trainloss_to_file: 209 | thefile.write("%s," % item) 210 | thefile.close() 211 | 212 | thefile = open(path+'/Parameters_CHECKPOINT.txt', 'w') 213 | thefile.write("STAGE1 \n" ) 214 | thefile.write("Patience_scheduler=%s, Weight_decay=%s \n" %(Parameters[2],Parameters[3])) 215 | if not Parameters[1][0][1:] == Parameters[1][0][:-1]: 216 | for i in range(len(Parameters[1][0])): 217 | thefile.write("Initial learning rate for param_groups %s is %s epochs \n" %(str(i),Parameters[1][0][i])) 218 | else: 219 | thefile.write("Initial learning rate is %s epochs \n" %Parameters[1][0][0]) 220 | 221 | for epoch,lr in zip(Parameters[0],Parameters[1][1:]): 222 | thefile.write("In epoch %s, maximum learning rate decreased to %s \n" %(epoch, lr)) 223 | if not(Parameters[0]==[]): 224 | thefile.write("Trained for %s epochs \n" %Parameters[0][-1]) 225 | thefile.write("\n\n" ) 226 | if not(trainloss_to_file==[]): 227 | thefile.write("Train Statistics \n") 228 | thefile.write('Accuracy: %s \n' %train_metric_to_file[-1]) 229 | thefile.write('Average Loss: %s \n\n'%trainloss_to_file[-1]) 230 | 231 | thefile.write("Test Statistics \n") 232 | thefile.write('Accuracy: %s \n' %test_metric_to_file[-1]) 233 | thefile.write('Average Loss: %s \n\n'%testloss_to_file[-1]) 234 | thefile.write(20*'-'+'\n\n') 235 | thefile.close() 236 | -------------------------------------------------------------------------------- /feature_extractor/main_shape.py: -------------------------------------------------------------------------------- 1 | import os 2 | print('Working dir',os.getcwd()) 3 | from load_save_model import save_model 4 | from train import Trainer 5 | import torch.optim 6 | from torch.optim.lr_scheduler import ReduceLROnPlateau 7 | from models.unet_model import UNet 8 | from instance_loss import InstanceLoss 9 | import dataloader_aug 10 | import dataloader_aug_pannuke_cls 11 | import random 12 | import numpy as np 13 | 14 | import argparse 15 | 16 | def run(data_path, n_rays, nc_in, nd_features, loss_scale, loss_type, init_lr=1e-4, n_cls=0, dataset_name='DSB2018', train_type_idx=-1): 17 | 18 | if dataset_name in ['DSB2018',]: # 19 | Trainloader, Testloader = dataloader_aug.getDataLoaders(n_rays, root_dir=data_path) 20 | elif dataset_name in ['BBBC006',]: # 21 | Trainloader, Testloader = dataloader_aug.getDataLoaders(n_rays, root_dir=data_path, type_list=['train', 'val']) 22 | elif dataset_name in ['PANNUKE']: # 23 | train_type_list = [['fold_1', 'fold_2'], ['fold_2', 'fold_1'], ['fold_3', 'fold_2']] 24 | assert(train_type_idx <= len(train_type_list) and train_type_idx >= 0) 25 | Trainloader, Testloader = dataloader_aug_pannuke_cls.getDataLoaders(n_rays, root_dir=data_path, type_list=train_type_list[train_type_idx], batch_size=32) 26 | 27 | model = UNet(nc_in, nd_features, loss_type=loss_type, n_cls=n_cls).cuda() 28 | 29 | model_name='UNet2D_'+str(nd_features)+'d' + '' 30 | print('model='+model_name) 31 | dataset=dataset_name 32 | print('dataset='+dataset) 33 | train_mode='StarDist2'+loss_type.capitalize()+'_' + str(n_rays) 34 | print('No.of rays', n_rays) 35 | 36 | kwargs={} 37 | additional_notes= '' 38 | kwargs['additional_notes'] = additional_notes 39 | SAVE_PATH = os.getcwd()+'/'+dataset+'/'+train_mode+'_'+model_name+'/' 40 | kwargs['save_path'] = SAVE_PATH 41 | RESULTS_DIRECTORY = os.getcwd()+'/'+dataset+'/'+train_mode+'_'+model_name+'/plots/' 42 | 43 | loss = InstanceLoss(scale=loss_scale, n_cls=n_cls) 44 | trainer = Trainer(loss, None, validate_every=2) 45 | optimizer = torch.optim.Adam(model.parameters(), lr=init_lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=5e-5) 46 | scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, verbose=True, patience=5, eps=1e-8, threshold=1e-20) 47 | 48 | print ('Starting Training') 49 | trainloss_to_file,testloss_to_file,trainMetric_to_file,testMetric_to_file,Parameters = trainer.Train(model,optimizer, 50 | Trainloader, Testloader, epochs=None, Train_mode=train_mode, 51 | Model_name=model_name, 52 | DataSet=dataset,scheduler=scheduler) 53 | print('Saving model...') 54 | save_model(model,trainMetric_to_file,testMetric_to_file,trainloss_to_file,testloss_to_file,Parameters,model_name,train_mode,dataset, plot=False,**kwargs) 55 | 56 | 57 | DATA_PATH = '/data/cong/datasets/dsb2018/dsb2018_in_stardist/dsb2018/dataset_split_for_training' 58 | # For DSB2018 and BBBC006, the number of input channels of the SAP feature extractor is 32+1+0 59 | # For PanNuke, the number of input channels of the SAP feature extractor is 32+1+6 60 | 61 | if __name__ == '__main__': 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument('--gpuid', type=int, default=0) 64 | parser.add_argument('--n_rays', type=int, default=32) 65 | parser.add_argument('--n_cls', type=int, default=0) 66 | parser.add_argument('--nd_features', type=int, default=32) 67 | parser.add_argument('--loss_type', type=str, default='others') 68 | parser.add_argument('--dataset', type=str, default='DSB2018') 69 | args = parser.parse_args() 70 | 71 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpuid) 72 | # torch.set_num_threads(8) 73 | 74 | loss_names = {'others':[1.0, 1.0], 'bbox':[0.0, 1.0], 'segbnd':[1.0, 0.0], } 75 | 76 | if args.n_cls <= 1: # binary seg. only 77 | args.n_cls = 0 78 | 79 | run(DATA_PATH, args.n_rays, args.n_rays+1+args.n_cls, args.nd_features, loss_names[args.loss_type], args.loss_type, n_cls=args.n_cls, dataset_name=args.dataset) -------------------------------------------------------------------------------- /feature_extractor/models/__pycache__/unet_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csccsccsccsc/cpp-net/e99cbc790fddb1964355753be12ca8f10e51f247/feature_extractor/models/__pycache__/unet_model.cpython-37.pyc -------------------------------------------------------------------------------- /feature_extractor/models/__pycache__/unet_parts_gn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csccsccsccsc/cpp-net/e99cbc790fddb1964355753be12ca8f10e51f247/feature_extractor/models/__pycache__/unet_parts_gn.cpython-37.pyc -------------------------------------------------------------------------------- /feature_extractor/models/unet_model.py: -------------------------------------------------------------------------------- 1 | # full assembly of the sub-parts to form the complete net 2 | 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | from .unet_parts_gn import * 6 | import torch.nn.init as init 7 | 8 | class UNet(nn.Module): 9 | def __init__(self, n_channels, n_features=32, loss_type='others', n_cls=0): 10 | super(UNet, self).__init__() 11 | self.inc = inconv(n_channels, n_features) 12 | self.down1 = down(n_features, n_features*2) 13 | self.down2 = down(n_features*2, n_features*4) 14 | self.down3 = down(n_features*4, n_features*8) 15 | self.down4 = down(n_features*8, n_features*16) 16 | 17 | self.up1 = up_single(n_features*16, n_features*8, bilinear=True) 18 | self.up2 = up_single(n_features*8, n_features*4, bilinear=True) 19 | self.up3 = up_single(n_features*4, n_features*2, bilinear=True) 20 | self.up4 = up_single(n_features*2, n_features*1, bilinear=True) 21 | 22 | self.loss_type = loss_type 23 | if self.loss_type=='others' or self.loss_type=='segbnd': 24 | self.features_segbnd = nn.Conv2d(n_features, n_features, 3, padding=1) 25 | self.out_segbnd = outconv(n_features, 2) 26 | if self.loss_type=='others' or self.loss_type=='bbox': 27 | self.features_bbox = nn.Conv2d(n_features, n_features, 3, padding=1) 28 | self.out_bbox = outconv(n_features, 4) 29 | 30 | self.final_activation_prob = nn.Sigmoid() 31 | self.final_activation_ray = nn.ReLU() 32 | 33 | self.n_cls = n_cls 34 | if n_cls > 1: 35 | self.features_cls = nn.Conv2d(n_features, n_features, 3, padding=1) 36 | self.out_cls = outconv(n_features, n_cls) 37 | self.final_activation_cls = nn.Softmax(dim=1) 38 | 39 | 40 | def forward(self, x): 41 | x0 = self.inc(x) 42 | x1 = self.down1(x0) 43 | x2 = self.down2(x1) 44 | x3 = self.down3(x2) 45 | x4 = self.down4(x3) 46 | 47 | x = self.up1(x4, x3) 48 | x = self.up2(x, x2) 49 | x = self.up3(x, x1) 50 | x = self.up4(x, x0) 51 | 52 | if self.loss_type == 'others' or self.loss_type=='segbnd': 53 | x_segbnd = self.final_activation_prob(self.out_segbnd(self.features_segbnd(x))) 54 | if self.loss_type == 'others' or self.loss_type=='bbox': 55 | x_bbox = self.final_activation_ray(self.out_bbox(self.features_bbox(x))) 56 | 57 | 58 | if self.loss_type == 'others': 59 | outputs = [x_segbnd, x_bbox] 60 | elif self.loss_type == 'segbnd': 61 | outputs = [x_segbnd, 0.0] 62 | elif self.loss_type == 'bbox': 63 | outputs = [0.0, x_bbox] 64 | 65 | if self.n_cls > 1: 66 | outputs.append(self.final_activation_cls(self.out_cls(self.features_cls(x)))) 67 | 68 | return outputs -------------------------------------------------------------------------------- /feature_extractor/models/unet_model_3layer.py: -------------------------------------------------------------------------------- 1 | # full assembly of the sub-parts to form the complete net 2 | 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | from .unet_parts_gn import * 6 | import torch.nn.init as init 7 | 8 | class UNetStar(nn.Module): 9 | def __init__(self, n_channels, n_features=32, loss_type='others'): 10 | super(UNetStar, self).__init__() 11 | self.inc = inconv(n_channels, n_features) 12 | self.down1 = down(n_features, n_features*2) 13 | self.down2 = down(n_features*2, n_features*4) 14 | self.down3 = down(n_features*4, n_features*8) 15 | 16 | self.up1 = up_single(n_features*8, n_features*4, bilinear=True) 17 | self.up2 = up_single(n_features*4, n_features*2, bilinear=True) 18 | self.up3 = up_single(n_features*2, n_features*1, bilinear=True) 19 | 20 | self.loss_type = loss_type 21 | if self.loss_type=='others' or self.loss_type=='segbnd': 22 | self.features_segbnd = nn.Conv2d(n_features, n_features, 3, padding=1) 23 | self.out_segbnd = outconv(n_features, 2) 24 | if self.loss_type=='others' or self.loss_type=='bbox': 25 | self.features_bbox = nn.Conv2d(n_features, n_features, 3, padding=1) 26 | self.out_bbox = outconv(n_features, 4) 27 | 28 | self.final_activation_prob = nn.Sigmoid() 29 | self.final_activation_ray = nn.ReLU() 30 | 31 | def forward(self, x): 32 | x0 = self.inc(x) 33 | x1 = self.down1(x0) 34 | x2 = self.down2(x1) 35 | x3 = self.down3(x2) 36 | 37 | x = self.up1(x3, x2) 38 | x = self.up2(x, x1) 39 | x = self.up3(x, x0) 40 | 41 | if self.loss_type == 'others' or self.loss_type=='segbnd': 42 | x_segbnd = self.final_activation_prob(self.out_segbnd(self.features_segbnd(x))) 43 | if self.loss_type == 'others' or self.loss_type=='bbox': 44 | x_bbox = self.final_activation_ray(self.out_bbox(self.features_bbox(x))) 45 | 46 | if self.loss_type == 'others': 47 | return x_segbnd, x_bbox 48 | elif self.loss_type == 'segbnd': 49 | return x_segbnd, 0.0 50 | elif self.loss_type == 'bbox': 51 | return 0.0, x_bbox -------------------------------------------------------------------------------- /feature_extractor/models/unet_parts_gn.py: -------------------------------------------------------------------------------- 1 | # sub-parts of the U-Net model 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | class double_conv(nn.Module): 8 | '''(conv => BN => ReLU) * 2''' 9 | def __init__(self, in_ch, out_ch): 10 | super(double_conv, self).__init__() 11 | num_groups = out_ch // 8 12 | self.conv = nn.Sequential( 13 | nn.Conv2d(in_ch, out_ch, 3, padding=1), 14 | nn.GroupNorm(num_channels=out_ch,num_groups=num_groups), 15 | nn.ELU(inplace=True), 16 | #nn.ReLU(inplace=True), 17 | nn.Conv2d(out_ch, out_ch, 3, padding=1), 18 | nn.GroupNorm(num_channels=out_ch,num_groups=num_groups), 19 | nn.ELU(inplace=True) 20 | #nn.ReLU(inplace=True) 21 | ) 22 | 23 | def forward(self, x): 24 | x = self.conv(x) 25 | return x 26 | 27 | 28 | class inconv(nn.Module): 29 | def __init__(self, in_ch, out_ch): 30 | super(inconv, self).__init__() 31 | self.conv = double_conv(in_ch, out_ch) 32 | 33 | def forward(self, x): 34 | x = self.conv(x) 35 | return x 36 | 37 | 38 | class down(nn.Module): 39 | def __init__(self, in_ch, out_ch): 40 | super(down, self).__init__() 41 | self.mpconv = nn.Sequential( 42 | nn.MaxPool2d(2), 43 | double_conv(in_ch, out_ch) 44 | ) 45 | 46 | def forward(self, x): 47 | x = self.mpconv(x) 48 | return x 49 | 50 | 51 | class up_single(nn.Module): 52 | def __init__(self, in_ch, out_ch, bilinear=True): 53 | super(up_single, self).__init__() 54 | 55 | if bilinear: 56 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 57 | else: 58 | self.up = nn.ConvTranspose2d(in_ch, in_ch, 2, stride=2) 59 | 60 | self.conv = double_conv(in_ch, out_ch) 61 | 62 | def forward(self, x1, x2): 63 | x1 = self.up(x1) 64 | # input is CHW 65 | diffY = x2.size()[2] - x1.size()[2] 66 | diffX = x2.size()[3] - x1.size()[3] 67 | x1 = F.pad(x1, (diffX // 2, diffX - diffX//2, 68 | diffY // 2, diffY - diffY//2)) 69 | # for padding issues, see 70 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 71 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 72 | x1 = self.conv(x1) 73 | return x1 74 | 75 | 76 | class up_dual(nn.Module): 77 | def __init__(self, in_ch_1, in_ch_2, out_ch, bilinear=True): 78 | super(up_dual, self).__init__() 79 | 80 | if bilinear: 81 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 82 | else: 83 | self.up = nn.ConvTranspose2d(in_ch, in_ch, 2, stride=2) 84 | 85 | self.conv = double_conv(in_ch_1+in_ch_2, out_ch) 86 | 87 | def forward(self, x1, x2): 88 | x1 = self.up(x1) 89 | # input is CHW 90 | diffY = x2.size()[2] - x1.size()[2] 91 | diffX = x2.size()[3] - x1.size()[3] 92 | x1 = F.pad(x1, (diffX // 2, diffX - diffX//2, 93 | diffY // 2, diffY - diffY//2)) 94 | # for padding issues, see 95 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 96 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 97 | x = torch.cat([x2, x1], dim=1) 98 | x = self.conv(x) 99 | return x 100 | 101 | 102 | class up(nn.Module): 103 | def __init__(self, in_ch, out_ch, bilinear=True): 104 | super(up, self).__init__() 105 | 106 | if bilinear: 107 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 108 | else: 109 | self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2) 110 | 111 | self.conv = double_conv(in_ch, out_ch) 112 | 113 | def forward(self, x1, x2): 114 | x1 = self.up(x1) 115 | 116 | # input is CHW 117 | diffY = x2.size()[2] - x1.size()[2] 118 | diffX = x2.size()[3] - x1.size()[3] 119 | 120 | x1 = F.pad(x1, (diffX // 2, diffX - diffX//2, 121 | diffY // 2, diffY - diffY//2)) 122 | 123 | # for padding issues, see 124 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 125 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 126 | 127 | x = torch.cat([x2, x1], dim=1) 128 | x = self.conv(x) 129 | return x 130 | 131 | 132 | class outconv(nn.Module): 133 | def __init__(self, in_ch, out_ch): 134 | super(outconv, self).__init__() 135 | self.conv = nn.Conv2d(in_ch, out_ch, 1) 136 | def forward(self, x): 137 | x = self.conv(x) 138 | return x 139 | -------------------------------------------------------------------------------- /feature_extractor/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import sys 4 | from load_save_model import checkpoint_save_stage 5 | import os 6 | from torch.optim.lr_scheduler import ReduceLROnPlateau 7 | 8 | class Trainer(): 9 | def __init__(self, loss=None, metric=None, validate_every=1, verborrea=True): 10 | self.loss_ce = loss 11 | self.metric = metric 12 | self.verborrea = verborrea 13 | self.USE_CUDA = torch.cuda.is_available() 14 | self.validate_every = validate_every 15 | 16 | def Train(self,model, optimizer, TrainSet, TestSet, Train_mode, Model_name, DataSet, epochs=None, scheduler=None): 17 | if self.loss_ce is None: 18 | print("Loss function not set,exiting...") 19 | sys.exit() 20 | 21 | if scheduler is None and epochs is None: 22 | print('WARNING!!!!Creating default min scheduler') 23 | scheduler = ReduceLROnPlateau(optimizer, "min", verbose=True,patience=10,eps=1e-8) 24 | path_checkpoint = os.getcwd()+'/CHECKPOINT/checkpoint_'+Model_name+'_'+Train_mode+'_'+DataSet+'/CHECKPOINT.t7' 25 | print('Checkpoint path',path_checkpoint) 26 | scheduler_mode = scheduler.mode 27 | 28 | max_lr,list_lr = self.update_list_lr(optimizer) 29 | trainloss_to_fil=[] 30 | testloss_to_fil=[] 31 | trainMetric_to_fil=[] 32 | testMetric_to_fil=[] 33 | 34 | if isinstance(scheduler,ReduceLROnPlateau): 35 | patience_num=scheduler.patience 36 | else: 37 | print('Scheduler not supported. But training will continue if epochs are specified.') 38 | if epochs==None: 39 | print('WARNING!!!! Number of epochs not specified') 40 | sys.exit() 41 | patience_num='nothing' 42 | 43 | parameters=[[],[],patience_num,optimizer.param_groups[0]['weight_decay']]#first list for epochs, second for learning rate,3rd patience, 4th weight_decay,5 for time 44 | parameters[1].append(list_lr) 45 | 46 | epoch=0 47 | if epochs==0: 48 | keep_training=False 49 | else: 50 | keep_training=True 51 | print ('INITIAL TEST STATISTICS') 52 | loss_test,metric = self.evaluate(model,TestSet) 53 | checkpoint_save_stage(model,trainloss_to_fil,testloss_to_fil,trainMetric_to_fil,testMetric_to_fil,parameters,Model_name,Train_mode,DataSet) 54 | check_load=0 55 | 56 | if isinstance(scheduler,ReduceLROnPlateau): 57 | if scheduler_mode == 'min': 58 | scheduler.step(loss_test) 59 | else: 60 | scheduler.step(metric) 61 | else: 62 | best_test = loss_test 63 | scheduler.step() 64 | since_init=time.time() 65 | while keep_training: 66 | epoch=epoch+1 67 | if epochs !=None: 68 | if self.verborrea: 69 | print('Epoch {}/{}, lr={}. patience={}, weight decay={}'.format(epoch, epochs,max_lr,scheduler.patience,optimizer.param_groups[0]['weight_decay'])) 70 | else: 71 | if self.verborrea: 72 | print('Epoch {}, lr={}, patience={}, weight decay={}'.format(epoch,max_lr,scheduler.patience,optimizer.param_groups[0]['weight_decay'])) 73 | 74 | if self.verborrea: 75 | print('-' * 20) 76 | 77 | if self.verborrea: 78 | print ('TRAIN STATISTICS') 79 | model.train() 80 | train_loss,train_metric= self.train_scratch(model,TrainSet,optimizer) #Training happens here! 81 | 82 | if epoch % self.validate_every == 0 : 83 | if self.verborrea: 84 | print ('TEST STATISTICS') 85 | print('Validating at epoch',epoch) 86 | model.eval() 87 | test_loss,test_metric= self.evaluate(model,TestSet) 88 | 89 | trainloss_to_fil.append(train_loss) 90 | testloss_to_fil.append(test_loss) 91 | trainMetric_to_fil.append(train_metric) 92 | testMetric_to_fil.append(test_metric) 93 | 94 | if isinstance(scheduler,ReduceLROnPlateau): 95 | prev_num_bad_epochs=scheduler.num_bad_epochs 96 | if self.verborrea: 97 | print('-' * 10) 98 | if scheduler_mode =='min': 99 | save=(test_loss< scheduler.best) 100 | scheduler.step(test_loss) 101 | else: 102 | save=(test_metric>scheduler.best) 103 | scheduler.step(test_metric) 104 | print('Best', scheduler.best) 105 | 106 | if save: 107 | checkpoint_save_stage(model,trainloss_to_fil,testloss_to_fil,trainMetric_to_fil,testMetric_to_fil,parameters,Model_name,Train_mode,DataSet) 108 | check_load=0 109 | if scheduler.num_bad_epochs==0 and prev_num_bad_epochs==scheduler.patience and not save: 110 | max_lr,list_lr=self.update_list_lr(optimizer) 111 | parameters[0].append(epoch) 112 | parameters[1].append(max_lr) 113 | model.load_state_dict(torch.load(path_checkpoint)) 114 | check_load=check_load+1 115 | if self.verborrea: print ('Checkpoint loaded') 116 | 117 | if max_lr<10*scheduler.eps or check_load==6: 118 | keep_training=False 119 | else: 120 | prev_max_lr=max_lr 121 | 122 | scheduler.step() 123 | max_lr,list_lr = self.update_list_lr(optimizer) 124 | if test_loss<=best_test: 125 | checkpoint_save_stage(model,trainloss_to_fil,testloss_to_fil,trainMetric_to_fil,testMetric_to_fil,parameters,Model_name,Train_mode,DataSet) 126 | if max_lr= 256: 52 | # print(image_id, len(mask_name_list), mask.max(), mask[mask>0].min()) 53 | 54 | # Filename record 55 | f.write(str(image_id)+','+name+','+split_flag+'\n') -------------------------------------------------------------------------------- /reorganize_datasets/reorganize_pannuke.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from skimage import io 4 | import warnings 5 | warnings.filterwarnings("ignore") 6 | 7 | ####### 8 | ### 1) Download PanNuke from https://warwick.ac.uk/fac/cross_fac/tia/data/pannuke 9 | ### 2) Save each image and mask in the origianl ".npy" file as ".tif" 10 | 11 | for ifold in [1,2,3]: 12 | 13 | image_dir = 'DATA ROOT PATH/pannuke/fold_'+str(ifold)+'/images/fold'+str(ifold)+'/images.npy' 14 | masks_dir = 'DATA ROOT PATH/pannuke/fold_'+str(ifold)+'/masks/fold'+str(ifold)+'/masks.npy' 15 | imgs = np.load(image_dir).astype(np.uint8) 16 | msks = np.load(masks_dir).astype(np.uint16) 17 | 18 | img_filefold_tosave = os.path.join(os.getcwd(), 'reorganized_dataset', 'fold_'+str(ifold), 'images') 19 | msk_filefold_tosave = os.path.join(os.getcwd(), 'reorganized_dataset', 'fold_'+str(ifold), 'masks') 20 | 21 | if not(os.path.exists(img_filefold_tosave)): 22 | os.makedirs(img_filefold_tosave) 23 | if not(os.path.exists(msk_filefold_tosave)): 24 | os.makedirs(msk_filefold_tosave) 25 | img_name_list_file = open(os.path.join(img_filefold_tosave, 'name_list.txt'), 'w') 26 | msk_name_list_file = [ open(os.path.join(msk_filefold_tosave, 'name_list_c'+str(imod)+'.txt'), 'w') for imod in range(6) ] 27 | 28 | nimg = imgs.shape[0] 29 | 30 | for iimg in range(nimg): 31 | io.imsave(os.path.join(img_filefold_tosave, 'img_'+str(iimg)+'.png'), imgs[iimg]) 32 | img_name_list_file.write(str(iimg) + ',' + os.path.join(img_filefold_tosave, 'img_'+str(iimg)+'.png') + '\n') 33 | print(str(iimg) + ',' + os.path.join(img_filefold_tosave, 'img_'+str(iimg)+'.png')) 34 | for imod in range(6): 35 | io.imsave(os.path.join(msk_filefold_tosave, 'msk_'+str(iimg)+'_c_'+str(imod)+'.tif'), msks[iimg, :, :, imod]) 36 | msk_name_list_file[imod].write(str(iimg) + ',' + os.path.join(msk_filefold_tosave, 'msk_'+str(iimg)+'_c_'+str(imod)+'.tif') + '\n') 37 | 38 | img_name_list_file.close() 39 | for imod in range(6): 40 | msk_name_list_file[imod].close() 41 | --------------------------------------------------------------------------------