├── README.md ├── code ├── __pycache__ │ ├── test_util.cpython-36.pyc │ └── test_util.cpython-37.pyc ├── dataloaders │ ├── __pycache__ │ │ ├── NIH_pancreas.cpython-36.pyc │ │ ├── la_heart.cpython-36.pyc │ │ ├── la_heart.cpython-37.pyc │ │ ├── utils.cpython-36.pyc │ │ └── utils.cpython-37.pyc │ ├── la_heart.py │ ├── la_heart_processing.py │ └── utils.py ├── networks │ ├── __pycache__ │ │ ├── discriminator.cpython-36.pyc │ │ ├── discriminator.cpython-37.pyc │ │ ├── unet.cpython-36.pyc │ │ ├── vnet.cpython-36.pyc │ │ ├── vnet.cpython-37.pyc │ │ └── vnet_sdf.cpython-36.pyc │ ├── discriminator.py │ └── vnet_sdf.py ├── test_LA.py ├── test_util.py ├── train_gan_sdfloss.py └── utils │ ├── __pycache__ │ ├── losses.cpython-36.pyc │ ├── losses.cpython-37.pyc │ ├── losses_2.cpython-36.pyc │ ├── metrics.cpython-36.pyc │ ├── metrics.cpython-37.pyc │ ├── ramps.cpython-36.pyc │ ├── ramps.cpython-37.pyc │ └── util.cpython-36.pyc │ ├── losses.py │ ├── losses_2.py │ ├── metrics.py │ ├── ramps.py │ └── util.py ├── data ├── test.list └── train.list └── model ├── model_16label └── best.pth └── model_8label └── best.pth /README.md: -------------------------------------------------------------------------------- 1 | # SASSnet 2 | Code for paper: Shape-aware Semi-supervised 3D Semantic Segmentation for Medical Images(MICCAI 2020) 3 | 4 | Our code is origin from [UA-MT](https://github.com/yulequan/UA-MT) 5 | 6 | You can find paper in [Arxiv](https://arxiv.org/abs/2007.10732). 7 | 8 | # Usage 9 | 10 | 1. Clone the repo: 11 | ``` 12 | git clone https://github.com/kleinzcy/SASSnet.git 13 | cd SASSnet 14 | ``` 15 | 2. Put the data in `data/2018LA_Seg_Training Set`. 16 | 17 | 3. Train the model 18 | ``` 19 | cd code 20 | # for 16 label 21 | python train_gan_sdfloss.py --gpu 0 --label 16 --consistency 0.01 --exp model_name 22 | # for 8 label 23 | python train_gan_sdfloss.py --gpu 0 --label 8 --consistency 0.015 --exp model_name 24 | ``` 25 | 26 | Params are the best setting in our experiment. 27 | 28 | 4. Test the model 29 | ``` 30 | python test_LA.py --model model_name --gpu 0 --iter 6000 31 | ``` 32 | Our best model are saved in model dir. 33 | 34 | # Citation 35 | 36 | If you find our work is useful for you, please cite us. 37 | -------------------------------------------------------------------------------- /code/__pycache__/test_util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinzcy/SASSnet/1761aa8af08f42e5a36a737c88beb4dc798af35c/code/__pycache__/test_util.cpython-36.pyc -------------------------------------------------------------------------------- /code/__pycache__/test_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinzcy/SASSnet/1761aa8af08f42e5a36a737c88beb4dc798af35c/code/__pycache__/test_util.cpython-37.pyc -------------------------------------------------------------------------------- /code/dataloaders/__pycache__/NIH_pancreas.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinzcy/SASSnet/1761aa8af08f42e5a36a737c88beb4dc798af35c/code/dataloaders/__pycache__/NIH_pancreas.cpython-36.pyc -------------------------------------------------------------------------------- /code/dataloaders/__pycache__/la_heart.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinzcy/SASSnet/1761aa8af08f42e5a36a737c88beb4dc798af35c/code/dataloaders/__pycache__/la_heart.cpython-36.pyc -------------------------------------------------------------------------------- /code/dataloaders/__pycache__/la_heart.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinzcy/SASSnet/1761aa8af08f42e5a36a737c88beb4dc798af35c/code/dataloaders/__pycache__/la_heart.cpython-37.pyc -------------------------------------------------------------------------------- /code/dataloaders/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinzcy/SASSnet/1761aa8af08f42e5a36a737c88beb4dc798af35c/code/dataloaders/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /code/dataloaders/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinzcy/SASSnet/1761aa8af08f42e5a36a737c88beb4dc798af35c/code/dataloaders/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /code/dataloaders/la_heart.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from glob import glob 5 | from torch.utils.data import Dataset 6 | import h5py 7 | import itertools 8 | from torch.utils.data.sampler import Sampler 9 | 10 | class LAHeart(Dataset): 11 | """ LA Dataset """ 12 | def __init__(self, base_dir=None, split='train', num=None, transform=None): 13 | self._base_dir = base_dir 14 | self.transform = transform 15 | self.sample_list = [] 16 | 17 | train_path = self._base_dir+'/train.list' 18 | test_path = self._base_dir+'/test.list' 19 | 20 | if split=='train': 21 | with open(train_path, 'r') as f: 22 | self.image_list = f.readlines() 23 | elif split == 'test': 24 | with open(test_path, 'r') as f: 25 | self.image_list = f.readlines() 26 | 27 | self.image_list = [item.replace('\n','') for item in self.image_list] 28 | if num is not None: 29 | self.image_list = self.image_list[:num] 30 | print("total {} samples".format(len(self.image_list))) 31 | 32 | def __len__(self): 33 | return len(self.image_list) 34 | 35 | def __getitem__(self, idx): 36 | image_name = self.image_list[idx] 37 | h5f = h5py.File(self._base_dir + "/2018LA_Seg_Training Set/" + image_name + "/mri_norm2.h5", 'r') 38 | # h5f = h5py.File(self._base_dir+"/"+image_name+"/mri_norm2.h5", 'r') 39 | image = h5f['image'][:] 40 | label = h5f['label'][:] 41 | sample = {'image': image, 'label': label} 42 | if self.transform: 43 | sample = self.transform(sample) 44 | 45 | return sample 46 | 47 | 48 | class CenterCrop(object): 49 | def __init__(self, output_size): 50 | self.output_size = output_size 51 | 52 | def __call__(self, sample): 53 | image, label = sample['image'], sample['label'] 54 | 55 | # pad the sample if necessary 56 | if label.shape[0] <= self.output_size[0] or label.shape[1] <= self.output_size[1] or label.shape[2] <= \ 57 | self.output_size[2]: 58 | pw = max((self.output_size[0] - label.shape[0]) // 2 + 3, 0) 59 | ph = max((self.output_size[1] - label.shape[1]) // 2 + 3, 0) 60 | pd = max((self.output_size[2] - label.shape[2]) // 2 + 3, 0) 61 | image = np.pad(image, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 62 | label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 63 | 64 | (w, h, d) = image.shape 65 | 66 | w1 = int(round((w - self.output_size[0]) / 2.)) 67 | h1 = int(round((h - self.output_size[1]) / 2.)) 68 | d1 = int(round((d - self.output_size[2]) / 2.)) 69 | 70 | label = label[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 71 | image = image[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 72 | 73 | return {'image': image, 'label': label} 74 | 75 | 76 | class RandomCrop(object): 77 | """ 78 | Crop randomly the image in a sample 79 | Args: 80 | output_size (int): Desired output size 81 | """ 82 | 83 | def __init__(self, output_size, with_sdf=False): 84 | self.output_size = output_size 85 | self.with_sdf = with_sdf 86 | 87 | def __call__(self, sample): 88 | image, label = sample['image'], sample['label'] 89 | if self.with_sdf: 90 | sdf = sample['sdf'] 91 | 92 | # pad the sample if necessary 93 | if label.shape[0] <= self.output_size[0] or label.shape[1] <= self.output_size[1] or label.shape[2] <= \ 94 | self.output_size[2]: 95 | pw = max((self.output_size[0] - label.shape[0]) // 2 + 3, 0) 96 | ph = max((self.output_size[1] - label.shape[1]) // 2 + 3, 0) 97 | pd = max((self.output_size[2] - label.shape[2]) // 2 + 3, 0) 98 | image = np.pad(image, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 99 | label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 100 | if self.with_sdf: 101 | sdf = np.pad(sdf, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 102 | 103 | (w, h, d) = image.shape 104 | # if np.random.uniform() > 0.33: 105 | # w1 = np.random.randint((w - self.output_size[0])//4, 3*(w - self.output_size[0])//4) 106 | # h1 = np.random.randint((h - self.output_size[1])//4, 3*(h - self.output_size[1])//4) 107 | # else: 108 | w1 = np.random.randint(0, w - self.output_size[0]) 109 | h1 = np.random.randint(0, h - self.output_size[1]) 110 | d1 = np.random.randint(0, d - self.output_size[2]) 111 | 112 | label = label[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 113 | image = image[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 114 | if self.with_sdf: 115 | sdf = sdf[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 116 | return {'image': image, 'label': label, 'sdf': sdf} 117 | else: 118 | return {'image': image, 'label': label} 119 | 120 | 121 | class RandomRotFlip(object): 122 | """ 123 | Crop randomly flip the dataset in a sample 124 | Args: 125 | output_size (int): Desired output size 126 | """ 127 | 128 | def __call__(self, sample): 129 | image, label = sample['image'], sample['label'] 130 | k = np.random.randint(0, 4) 131 | image = np.rot90(image, k) 132 | label = np.rot90(label, k) 133 | axis = np.random.randint(0, 2) 134 | image = np.flip(image, axis=axis).copy() 135 | label = np.flip(label, axis=axis).copy() 136 | 137 | return {'image': image, 'label': label} 138 | 139 | 140 | class RandomNoise(object): 141 | def __init__(self, mu=0, sigma=0.1): 142 | self.mu = mu 143 | self.sigma = sigma 144 | 145 | def __call__(self, sample): 146 | image, label = sample['image'], sample['label'] 147 | noise = np.clip(self.sigma * np.random.randn(image.shape[0], image.shape[1], image.shape[2]), -2*self.sigma, 2*self.sigma) 148 | noise = noise + self.mu 149 | image = image + noise 150 | return {'image': image, 'label': label} 151 | 152 | 153 | class CreateOnehotLabel(object): 154 | def __init__(self, num_classes): 155 | self.num_classes = num_classes 156 | 157 | def __call__(self, sample): 158 | image, label = sample['image'], sample['label'] 159 | onehot_label = np.zeros((self.num_classes, label.shape[0], label.shape[1], label.shape[2]), dtype=np.float32) 160 | for i in range(self.num_classes): 161 | onehot_label[i, :, :, :] = (label == i).astype(np.float32) 162 | return {'image': image, 'label': label,'onehot_label':onehot_label} 163 | 164 | 165 | class ToTensor(object): 166 | """Convert ndarrays in sample to Tensors.""" 167 | 168 | def __call__(self, sample): 169 | image = sample['image'] 170 | image = image.reshape(1, image.shape[0], image.shape[1], image.shape[2]).astype(np.float32) 171 | if 'onehot_label' in sample: 172 | return {'image': torch.from_numpy(image), 'label': torch.from_numpy(sample['label']).long(), 173 | 'onehot_label': torch.from_numpy(sample['onehot_label']).long()} 174 | else: 175 | return {'image': torch.from_numpy(image), 'label': torch.from_numpy(sample['label']).long()} 176 | 177 | 178 | class TwoStreamBatchSampler(Sampler): 179 | """Iterate two sets of indices 180 | 181 | An 'epoch' is one iteration through the primary indices. 182 | During the epoch, the secondary indices are iterated through 183 | as many times as needed. 184 | """ 185 | def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size): 186 | self.primary_indices = primary_indices 187 | self.secondary_indices = secondary_indices 188 | self.secondary_batch_size = secondary_batch_size 189 | self.primary_batch_size = batch_size - secondary_batch_size 190 | 191 | assert len(self.primary_indices) >= self.primary_batch_size > 0 192 | assert len(self.secondary_indices) >= self.secondary_batch_size > 0 193 | 194 | def __iter__(self): 195 | primary_iter = iterate_once(self.primary_indices) 196 | secondary_iter = iterate_eternally(self.secondary_indices) 197 | return ( 198 | primary_batch + secondary_batch 199 | for (primary_batch, secondary_batch) 200 | in zip(grouper(primary_iter, self.primary_batch_size), 201 | grouper(secondary_iter, self.secondary_batch_size)) 202 | ) 203 | 204 | def __len__(self): 205 | return len(self.primary_indices) // self.primary_batch_size 206 | 207 | def iterate_once(iterable): 208 | return np.random.permutation(iterable) 209 | 210 | 211 | def iterate_eternally(indices): 212 | def infinite_shuffles(): 213 | while True: 214 | yield np.random.permutation(indices) 215 | return itertools.chain.from_iterable(infinite_shuffles()) 216 | 217 | 218 | def grouper(iterable, n): 219 | "Collect data into fixed-length chunks or blocks" 220 | # grouper('ABCDEFG', 3) --> ABC DEF" 221 | args = [iter(iterable)] * n 222 | return zip(*args) -------------------------------------------------------------------------------- /code/dataloaders/la_heart_processing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from glob import glob 3 | from tqdm import tqdm 4 | import h5py 5 | import nrrd 6 | 7 | output_size =[112, 112, 80] 8 | 9 | def covert_h5(): 10 | listt = glob('../../LA_dataset/2018LA_Seg_Training Set/*/lgemri.nrrd') 11 | for item in tqdm(listt): 12 | image, img_header = nrrd.read(item) 13 | label, gt_header = nrrd.read(item.replace('lgemri.nrrd', 'laendo.nrrd')) 14 | label = (label == 255).astype(np.uint8) 15 | w, h, d = label.shape 16 | 17 | tempL = np.nonzero(label) 18 | minx, maxx = np.min(tempL[0]), np.max(tempL[0]) 19 | miny, maxy = np.min(tempL[1]), np.max(tempL[1]) 20 | minz, maxz = np.min(tempL[2]), np.max(tempL[2]) 21 | 22 | px = max(output_size[0] - (maxx - minx), 0) // 2 23 | py = max(output_size[1] - (maxy - miny), 0) // 2 24 | pz = max(output_size[2] - (maxz - minz), 0) // 2 25 | minx = max(minx - np.random.randint(10, 20) - px, 0) 26 | maxx = min(maxx + np.random.randint(10, 20) + px, w) 27 | miny = max(miny - np.random.randint(10, 20) - py, 0) 28 | maxy = min(maxy + np.random.randint(10, 20) + py, h) 29 | minz = max(minz - np.random.randint(5, 10) - pz, 0) 30 | maxz = min(maxz + np.random.randint(5, 10) + pz, d) 31 | 32 | image = (image - np.mean(image)) / np.std(image) 33 | image = image.astype(np.float32) 34 | image = image[minx:maxx, miny:maxy] 35 | label = label[minx:maxx, miny:maxy] 36 | print(label.shape) 37 | f = h5py.File(item.replace('lgemri.nrrd', 'mri_norm2.h5'), 'w') 38 | f.create_dataset('image', data=image, compression="gzip") 39 | f.create_dataset('label', data=label, compression="gzip") 40 | f.close() 41 | 42 | if __name__ == '__main__': 43 | covert_h5() -------------------------------------------------------------------------------- /code/dataloaders/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | # import matplotlib.pyplot as plt 6 | from skimage import measure 7 | import scipy.ndimage as nd 8 | 9 | 10 | def recursive_glob(rootdir='.', suffix=''): 11 | """Performs recursive glob with given suffix and rootdir 12 | :param rootdir is the root directory 13 | :param suffix is the suffix to be searched 14 | """ 15 | return [os.path.join(looproot, filename) 16 | for looproot, _, filenames in os.walk(rootdir) 17 | for filename in filenames if filename.endswith(suffix)] 18 | 19 | def get_cityscapes_labels(): 20 | return np.array([ 21 | # [ 0, 0, 0], 22 | [128, 64, 128], 23 | [244, 35, 232], 24 | [70, 70, 70], 25 | [102, 102, 156], 26 | [190, 153, 153], 27 | [153, 153, 153], 28 | [250, 170, 30], 29 | [220, 220, 0], 30 | [107, 142, 35], 31 | [152, 251, 152], 32 | [0, 130, 180], 33 | [220, 20, 60], 34 | [255, 0, 0], 35 | [0, 0, 142], 36 | [0, 0, 70], 37 | [0, 60, 100], 38 | [0, 80, 100], 39 | [0, 0, 230], 40 | [119, 11, 32]]) 41 | 42 | def get_pascal_labels(): 43 | """Load the mapping that associates pascal classes with label colors 44 | Returns: 45 | np.ndarray with dimensions (21, 3) 46 | """ 47 | return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], 48 | [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], 49 | [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], 50 | [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], 51 | [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], 52 | [0, 64, 128]]) 53 | 54 | 55 | def encode_segmap(mask): 56 | """Encode segmentation label images as pascal classes 57 | Args: 58 | mask (np.ndarray): raw segmentation label image of dimension 59 | (M, N, 3), in which the Pascal classes are encoded as colours. 60 | Returns: 61 | (np.ndarray): class map with dimensions (M,N), where the value at 62 | a given location is the integer denoting the class index. 63 | """ 64 | mask = mask.astype(int) 65 | label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16) 66 | for ii, label in enumerate(get_pascal_labels()): 67 | label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii 68 | label_mask = label_mask.astype(int) 69 | return label_mask 70 | 71 | 72 | def decode_seg_map_sequence(label_masks, dataset='pascal'): 73 | rgb_masks = [] 74 | for label_mask in label_masks: 75 | rgb_mask = decode_segmap(label_mask, dataset) 76 | rgb_masks.append(rgb_mask) 77 | rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2])) 78 | return rgb_masks 79 | 80 | def decode_segmap(label_mask, dataset, plot=False): 81 | """Decode segmentation class labels into a color image 82 | Args: 83 | label_mask (np.ndarray): an (M,N) array of integer values denoting 84 | the class label at each spatial location. 85 | plot (bool, optional): whether to show the resulting color image 86 | in a figure. 87 | Returns: 88 | (np.ndarray, optional): the resulting decoded color image. 89 | """ 90 | if dataset == 'pascal': 91 | n_classes = 21 92 | label_colours = get_pascal_labels() 93 | elif dataset == 'cityscapes': 94 | n_classes = 19 95 | label_colours = get_cityscapes_labels() 96 | else: 97 | raise NotImplementedError 98 | 99 | r = label_mask.copy() 100 | g = label_mask.copy() 101 | b = label_mask.copy() 102 | for ll in range(0, n_classes): 103 | r[label_mask == ll] = label_colours[ll, 0] 104 | g[label_mask == ll] = label_colours[ll, 1] 105 | b[label_mask == ll] = label_colours[ll, 2] 106 | rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) 107 | rgb[:, :, 0] = r / 255.0 108 | rgb[:, :, 1] = g / 255.0 109 | rgb[:, :, 2] = b / 255.0 110 | if plot: 111 | plt.imshow(rgb) 112 | plt.show() 113 | else: 114 | return rgb 115 | 116 | def generate_param_report(logfile, param): 117 | log_file = open(logfile, 'w') 118 | # for key, val in param.items(): 119 | # log_file.write(key + ':' + str(val) + '\n') 120 | log_file.write(str(param)) 121 | log_file.close() 122 | 123 | def cross_entropy2d(logit, target, ignore_index=255, weight=None, size_average=True, batch_average=True): 124 | n, c, h, w = logit.size() 125 | # logit = logit.permute(0, 2, 3, 1) 126 | target = target.squeeze(1) 127 | if weight is None: 128 | criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index, size_average=False) 129 | else: 130 | criterion = nn.CrossEntropyLoss(weight=torch.from_numpy(np.array(weight)).float().cuda(), ignore_index=ignore_index, size_average=False) 131 | loss = criterion(logit, target.long()) 132 | 133 | if size_average: 134 | loss /= (h * w) 135 | 136 | if batch_average: 137 | loss /= n 138 | 139 | return loss 140 | 141 | def lr_poly(base_lr, iter_, max_iter=100, power=0.9): 142 | return base_lr * ((1 - float(iter_) / max_iter) ** power) 143 | 144 | 145 | def get_iou(pred, gt, n_classes=21): 146 | total_iou = 0.0 147 | for i in range(len(pred)): 148 | pred_tmp = pred[i] 149 | gt_tmp = gt[i] 150 | 151 | intersect = [0] * n_classes 152 | union = [0] * n_classes 153 | for j in range(n_classes): 154 | match = (pred_tmp == j) + (gt_tmp == j) 155 | 156 | it = torch.sum(match == 2).item() 157 | un = torch.sum(match > 0).item() 158 | 159 | intersect[j] += it 160 | union[j] += un 161 | 162 | iou = [] 163 | for k in range(n_classes): 164 | if union[k] == 0: 165 | continue 166 | iou.append(intersect[k] / union[k]) 167 | 168 | img_iou = (sum(iou) / len(iou)) 169 | total_iou += img_iou 170 | 171 | return total_iou 172 | 173 | def get_dice(pred, gt): 174 | total_dice = 0.0 175 | pred = pred.long() 176 | gt = gt.long() 177 | for i in range(len(pred)): 178 | pred_tmp = pred[i] 179 | gt_tmp = gt[i] 180 | dice = 2.0*torch.sum(pred_tmp*gt_tmp).item()/(1.0+torch.sum(pred_tmp**2)+torch.sum(gt_tmp**2)).item() 181 | print(dice) 182 | total_dice += dice 183 | 184 | return total_dice 185 | 186 | def get_mc_dice(pred, gt, num=2): 187 | # num is the total number of classes, include the background 188 | total_dice = np.zeros(num-1) 189 | pred = pred.long() 190 | gt = gt.long() 191 | for i in range(len(pred)): 192 | for j in range(1, num): 193 | pred_tmp = (pred[i]==j) 194 | gt_tmp = (gt[i]==j) 195 | dice = 2.0*torch.sum(pred_tmp*gt_tmp).item()/(1.0+torch.sum(pred_tmp**2)+torch.sum(gt_tmp**2)).item() 196 | total_dice[j-1] +=dice 197 | return total_dice 198 | 199 | def post_processing(prediction): 200 | prediction = nd.binary_fill_holes(prediction) 201 | label_cc, num_cc = measure.label(prediction,return_num=True) 202 | total_cc = np.sum(prediction) 203 | measure.regionprops(label_cc) 204 | for cc in range(1,num_cc+1): 205 | single_cc = (label_cc==cc) 206 | single_vol = np.sum(single_cc) 207 | if single_vol/total_cc<0.2: 208 | prediction[single_cc]=0 209 | 210 | return prediction 211 | 212 | 213 | 214 | 215 | -------------------------------------------------------------------------------- /code/networks/__pycache__/discriminator.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinzcy/SASSnet/1761aa8af08f42e5a36a737c88beb4dc798af35c/code/networks/__pycache__/discriminator.cpython-36.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/discriminator.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinzcy/SASSnet/1761aa8af08f42e5a36a737c88beb4dc798af35c/code/networks/__pycache__/discriminator.cpython-37.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/unet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinzcy/SASSnet/1761aa8af08f42e5a36a737c88beb4dc798af35c/code/networks/__pycache__/unet.cpython-36.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/vnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinzcy/SASSnet/1761aa8af08f42e5a36a737c88beb4dc798af35c/code/networks/__pycache__/vnet.cpython-36.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/vnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinzcy/SASSnet/1761aa8af08f42e5a36a737c88beb4dc798af35c/code/networks/__pycache__/vnet.cpython-37.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/vnet_sdf.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinzcy/SASSnet/1761aa8af08f42e5a36a737c88beb4dc798af35c/code/networks/__pycache__/vnet_sdf.cpython-36.pyc -------------------------------------------------------------------------------- /code/networks/discriminator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2019/12/30 下午9:34 4 | # @Author : chuyu zhang 5 | # @File : discriminator.py 6 | # @Software: PyCharm 7 | 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch 11 | 12 | 13 | class FCDiscriminator(nn.Module): 14 | 15 | def __init__(self, num_classes, ndf=64, n_channel=1): 16 | super(FCDiscriminator, self).__init__() 17 | self.conv0 = nn.Conv2d(num_classes, ndf, kernel_size=4, stride=2, padding=1) 18 | self.conv1 = nn.Conv2d(n_channel, ndf, kernel_size=4, stride=2, padding=1) 19 | self.conv2 = nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1) 20 | self.conv3 = nn.Conv2d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1) 21 | self.conv4 = nn.Conv2d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1) 22 | self.classifier = nn.Linear(ndf*8, 2) 23 | self.avgpool = nn.AvgPool2d((7, 7)) 24 | 25 | self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 26 | self.dropout = nn.Dropout2d(0.5) 27 | # self.up_sample = nn.Upsample(scale_factor=32, mode='bilinear') 28 | # self.sigmoid = nn.Sigmoid() 29 | 30 | def forward(self, map, feature): 31 | map_feature = self.conv0(map) 32 | image_feature = self.conv1(feature) 33 | x = torch.add(map_feature, image_feature) 34 | 35 | x = self.conv2(x) 36 | x = self.leaky_relu(x) 37 | x = self.dropout(x) 38 | 39 | x = self.conv3(x) 40 | x = self.leaky_relu(x) 41 | x = self.dropout(x) 42 | 43 | x = self.conv4(x) 44 | x = self.leaky_relu(x) 45 | x = self.avgpool(x) 46 | x = x.view(x.size(0), -1) 47 | x = self.classifier(x) 48 | # x = self.up_sample(x) 49 | # x = self.sigmoid(x) 50 | 51 | return x 52 | 53 | 54 | class FC3DDiscriminator(nn.Module): 55 | 56 | def __init__(self, num_classes, ndf=64, n_channel=1): 57 | super(FC3DDiscriminator, self).__init__() 58 | # downsample 16 59 | self.conv0 = nn.Conv3d(num_classes, ndf, kernel_size=4, stride=2, padding=1) 60 | self.conv1 = nn.Conv3d(n_channel, ndf, kernel_size=4, stride=2, padding=1) 61 | 62 | self.conv2 = nn.Conv3d(ndf, ndf*2, kernel_size=4, stride=2, padding=1) 63 | self.conv3 = nn.Conv3d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1) 64 | self.conv4 = nn.Conv3d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1) 65 | self.avgpool = nn.AvgPool3d((7, 7, 5)) 66 | self.classifier = nn.Linear(ndf*8, 2) 67 | 68 | self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 69 | self.dropout = nn.Dropout3d(0.5) 70 | self.Softmax = nn.Softmax() 71 | 72 | def forward(self, map, image): 73 | batch_size = map.shape[0] 74 | map_feature = self.conv0(map) 75 | image_feature = self.conv1(image) 76 | x = torch.add(map_feature, image_feature) 77 | x = self.leaky_relu(x) 78 | x = self.dropout(x) 79 | 80 | x = self.conv2(x) 81 | x = self.leaky_relu(x) 82 | x = self.dropout(x) 83 | 84 | x = self.conv3(x) 85 | x = self.leaky_relu(x) 86 | x = self.dropout(x) 87 | 88 | x = self.conv4(x) 89 | x = self.leaky_relu(x) 90 | 91 | x = self.avgpool(x) 92 | 93 | x = x.view(batch_size, -1) 94 | x = self.classifier(x) 95 | x = x.reshape((batch_size, 2)) 96 | # x = self.Softmax(x) 97 | 98 | return x 99 | 100 | 101 | class FC3DDiscriminatorNIH(nn.Module): 102 | def __init__(self, num_classes, ndf=64, n_channel=1): 103 | super(FC3DDiscriminatorNIH, self).__init__() 104 | # downsample 16 105 | self.conv0 = nn.Conv3d(num_classes, ndf, kernel_size=4, stride=2, padding=1) 106 | self.conv1 = nn.Conv3d(n_channel, ndf, kernel_size=4, stride=2, padding=1) 107 | 108 | self.conv2 = nn.Conv3d(ndf, ndf*2, kernel_size=4, stride=2, padding=1) 109 | self.conv3 = nn.Conv3d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1) 110 | self.conv4 = nn.Conv3d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1) 111 | self.avgpool = nn.AvgPool3d((13, 10, 9)) 112 | self.classifier = nn.Linear(ndf*8, 2) 113 | 114 | self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 115 | self.dropout = nn.Dropout3d(0.5) 116 | self.Softmax = nn.Softmax() 117 | 118 | def forward(self, map, image): 119 | batch_size = map.shape[0] 120 | map_feature = self.conv0(map) 121 | image_feature = self.conv1(image) 122 | x = torch.add(map_feature, image_feature) 123 | x = self.leaky_relu(x) 124 | x = self.dropout(x) 125 | 126 | x = self.conv2(x) 127 | x = self.leaky_relu(x) 128 | x = self.dropout(x) 129 | 130 | x = self.conv3(x) 131 | x = self.leaky_relu(x) 132 | x = self.dropout(x) 133 | 134 | x = self.conv4(x) 135 | x = self.leaky_relu(x) 136 | 137 | x = self.avgpool(x) 138 | 139 | x = x.view(batch_size, -1) 140 | x = self.classifier(x) 141 | x = x.reshape((batch_size, 2)) 142 | # x = self.Softmax(x) 143 | 144 | return x 145 | 146 | 147 | class FCDiscriminatorDAP(nn.Module): 148 | def __init__(self, num_classes, ndf = 64): 149 | super(FCDiscriminatorDAP, self).__init__() 150 | 151 | self.conv1 = nn.Conv3d(num_classes, ndf, kernel_size=4, stride=2, padding=1) 152 | self.conv2 = nn.Conv3d(ndf, ndf*2, kernel_size=4, stride=2, padding=1) 153 | self.conv3 = nn.Conv3d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1) 154 | self.classifier = nn.Conv3d(ndf*4, 1, kernel_size=4, stride=2, padding=1) 155 | 156 | self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 157 | self.up_sample = nn.Upsample(scale_factor=16, mode='trilinear', align_corners=True) 158 | self.sigmoid = nn.Sigmoid() 159 | 160 | def forward(self, x): 161 | x = self.conv1(x) 162 | x = self.leaky_relu(x) 163 | x = self.conv2(x) 164 | x = self.leaky_relu(x) 165 | x = self.conv3(x) 166 | x = self.leaky_relu(x) 167 | x = self.classifier(x) 168 | x = self.up_sample(x) 169 | x = self.sigmoid(x) 170 | 171 | return x 172 | 173 | if __name__ == '__main__': 174 | # compute FLOPS & PARAMETERS 175 | from thop import profile 176 | from thop import clever_format 177 | model = FC3DDiscriminator(num_classes=1) 178 | input = torch.randn(4, 1, 112, 112, 80) 179 | flops, params = profile(model, inputs=(input,input)) 180 | macs, params = clever_format([flops, params], "%.3f") 181 | print(macs, params) 182 | 183 | model = FCDiscriminatorDAP(num_classes=2) 184 | input = torch.randn(4, 2, 112, 112, 80) 185 | flops, params = profile(model, inputs=(input,)) 186 | macs, params = clever_format([flops, params], "%.3f") 187 | print(macs, params) 188 | 189 | import ipdb; ipdb.set_trace() -------------------------------------------------------------------------------- /code/networks/vnet_sdf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | """ 6 | Differences with V-Net 7 | Adding nn.Tanh in the end of the conv. to make the outputs in [-1, 1]. 8 | """ 9 | 10 | class ConvBlock(nn.Module): 11 | def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'): 12 | super(ConvBlock, self).__init__() 13 | 14 | ops = [] 15 | for i in range(n_stages): 16 | if i==0: 17 | input_channel = n_filters_in 18 | else: 19 | input_channel = n_filters_out 20 | 21 | ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1)) 22 | if normalization == 'batchnorm': 23 | ops.append(nn.BatchNorm3d(n_filters_out)) 24 | elif normalization == 'groupnorm': 25 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 26 | elif normalization == 'instancenorm': 27 | ops.append(nn.InstanceNorm3d(n_filters_out)) 28 | elif normalization != 'none': 29 | assert False 30 | ops.append(nn.ReLU(inplace=True)) 31 | 32 | self.conv = nn.Sequential(*ops) 33 | 34 | def forward(self, x): 35 | x = self.conv(x) 36 | return x 37 | 38 | 39 | class ResidualConvBlock(nn.Module): 40 | def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'): 41 | super(ResidualConvBlock, self).__init__() 42 | 43 | ops = [] 44 | for i in range(n_stages): 45 | if i == 0: 46 | input_channel = n_filters_in 47 | else: 48 | input_channel = n_filters_out 49 | 50 | ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1)) 51 | if normalization == 'batchnorm': 52 | ops.append(nn.BatchNorm3d(n_filters_out)) 53 | elif normalization == 'groupnorm': 54 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 55 | elif normalization == 'instancenorm': 56 | ops.append(nn.InstanceNorm3d(n_filters_out)) 57 | elif normalization != 'none': 58 | assert False 59 | 60 | if i != n_stages-1: 61 | ops.append(nn.ReLU(inplace=True)) 62 | 63 | self.conv = nn.Sequential(*ops) 64 | self.relu = nn.ReLU(inplace=True) 65 | 66 | def forward(self, x): 67 | x = (self.conv(x) + x) 68 | x = self.relu(x) 69 | return x 70 | 71 | 72 | class DownsamplingConvBlock(nn.Module): 73 | def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): 74 | super(DownsamplingConvBlock, self).__init__() 75 | 76 | ops = [] 77 | if normalization != 'none': 78 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 79 | if normalization == 'batchnorm': 80 | ops.append(nn.BatchNorm3d(n_filters_out)) 81 | elif normalization == 'groupnorm': 82 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 83 | elif normalization == 'instancenorm': 84 | ops.append(nn.InstanceNorm3d(n_filters_out)) 85 | else: 86 | assert False 87 | else: 88 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 89 | 90 | ops.append(nn.ReLU(inplace=True)) 91 | 92 | self.conv = nn.Sequential(*ops) 93 | 94 | def forward(self, x): 95 | x = self.conv(x) 96 | return x 97 | 98 | 99 | class UpsamplingDeconvBlock(nn.Module): 100 | def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): 101 | super(UpsamplingDeconvBlock, self).__init__() 102 | 103 | ops = [] 104 | if normalization != 'none': 105 | ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 106 | if normalization == 'batchnorm': 107 | ops.append(nn.BatchNorm3d(n_filters_out)) 108 | elif normalization == 'groupnorm': 109 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 110 | elif normalization == 'instancenorm': 111 | ops.append(nn.InstanceNorm3d(n_filters_out)) 112 | else: 113 | assert False 114 | else: 115 | ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 116 | 117 | ops.append(nn.ReLU(inplace=True)) 118 | 119 | self.conv = nn.Sequential(*ops) 120 | 121 | def forward(self, x): 122 | x = self.conv(x) 123 | return x 124 | 125 | 126 | class Upsampling(nn.Module): 127 | def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): 128 | super(Upsampling, self).__init__() 129 | 130 | ops = [] 131 | ops.append(nn.Upsample(scale_factor=stride, mode='trilinear',align_corners=False)) 132 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, kernel_size=3, padding=1)) 133 | if normalization == 'batchnorm': 134 | ops.append(nn.BatchNorm3d(n_filters_out)) 135 | elif normalization == 'groupnorm': 136 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 137 | elif normalization == 'instancenorm': 138 | ops.append(nn.InstanceNorm3d(n_filters_out)) 139 | elif normalization != 'none': 140 | assert False 141 | ops.append(nn.ReLU(inplace=True)) 142 | 143 | self.conv = nn.Sequential(*ops) 144 | 145 | def forward(self, x): 146 | x = self.conv(x) 147 | return x 148 | 149 | 150 | class VNet(nn.Module): 151 | def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False, has_residual=False): 152 | super(VNet, self).__init__() 153 | self.has_dropout = has_dropout 154 | convBlock = ConvBlock if not has_residual else ResidualConvBlock 155 | 156 | self.block_one = convBlock(1, n_channels, n_filters, normalization=normalization) 157 | self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization) 158 | 159 | self.block_two = convBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) 160 | self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization) 161 | 162 | self.block_three = convBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) 163 | self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization) 164 | 165 | self.block_four = convBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) 166 | self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization) 167 | 168 | self.block_five = convBlock(3, n_filters * 16, n_filters * 16, normalization=normalization) 169 | self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization) 170 | 171 | self.block_six = convBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) 172 | self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization) 173 | 174 | self.block_seven = convBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) 175 | self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization) 176 | 177 | self.block_eight = convBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) 178 | self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization) 179 | 180 | self.block_nine = convBlock(1, n_filters, n_filters, normalization=normalization) 181 | self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0) 182 | self.out_conv2 = nn.Conv3d(n_filters, n_classes, 1, padding=0) 183 | self.tanh = nn.Tanh() 184 | 185 | self.dropout = nn.Dropout3d(p=0.5, inplace=False) 186 | # self.__init_weight() 187 | 188 | def encoder(self, input): 189 | x1 = self.block_one(input) 190 | x1_dw = self.block_one_dw(x1) 191 | 192 | x2 = self.block_two(x1_dw) 193 | x2_dw = self.block_two_dw(x2) 194 | 195 | x3 = self.block_three(x2_dw) 196 | x3_dw = self.block_three_dw(x3) 197 | 198 | x4 = self.block_four(x3_dw) 199 | x4_dw = self.block_four_dw(x4) 200 | 201 | x5 = self.block_five(x4_dw) 202 | # x5 = F.dropout3d(x5, p=0.5, training=True) 203 | if self.has_dropout: 204 | x5 = self.dropout(x5) 205 | 206 | res = [x1, x2, x3, x4, x5] 207 | 208 | return res 209 | 210 | def decoder(self, features): 211 | x1 = features[0] 212 | x2 = features[1] 213 | x3 = features[2] 214 | x4 = features[3] 215 | x5 = features[4] 216 | 217 | x5_up = self.block_five_up(x5) 218 | x5_up = x5_up + x4 219 | 220 | x6 = self.block_six(x5_up) 221 | x6_up = self.block_six_up(x6) 222 | x6_up = x6_up + x3 223 | 224 | x7 = self.block_seven(x6_up) 225 | x7_up = self.block_seven_up(x7) 226 | x7_up = x7_up + x2 227 | 228 | x8 = self.block_eight(x7_up) 229 | x8_up = self.block_eight_up(x8) 230 | x8_up = x8_up + x1 231 | x9 = self.block_nine(x8_up) 232 | # x9 = F.dropout3d(x9, p=0.5, training=True) 233 | if self.has_dropout: 234 | x9 = self.dropout(x9) 235 | out = self.out_conv(x9) 236 | out_tanh = self.tanh(out) 237 | out_seg = self.out_conv2(x9) 238 | return out_tanh, out_seg 239 | 240 | 241 | def forward(self, input, turnoff_drop=False): 242 | if turnoff_drop: 243 | has_dropout = self.has_dropout 244 | self.has_dropout = False 245 | features = self.encoder(input) 246 | out_tanh, out_seg = self.decoder(features) 247 | if turnoff_drop: 248 | self.has_dropout = has_dropout 249 | return out_tanh, out_seg 250 | 251 | # def __init_weight(self): 252 | # for m in self.modules(): 253 | # if isinstance(m, nn.Conv3d): 254 | # torch.nn.init.kaiming_normal_(m.weight) 255 | # elif isinstance(m, nn.BatchNorm3d): 256 | # m.weight.data.fill_(1) 257 | 258 | if __name__ == '__main__': 259 | # compute FLOPS & PARAMETERS 260 | from thop import profile 261 | from thop import clever_format 262 | model = VNet(n_channels=1, n_classes=2) 263 | input = torch.randn(4, 1, 112, 112, 80) 264 | flops, params = profile(model, inputs=(input,)) 265 | macs, params = clever_format([flops, params], "%.3f") 266 | print(macs, params) 267 | 268 | import ipdb; ipdb.set_trace() -------------------------------------------------------------------------------- /code/test_LA.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | from networks.vnet_sdf import VNet 5 | from test_util import test_all_case 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--root_path', type=str, default='../data/2018LA_Seg_Training Set/', help='Name of Experiment') 9 | parser.add_argument('--model', type=str, default='UAMT', help='model_name') 10 | parser.add_argument('--gpu', type=str, default='0', help='GPU to use') 11 | parser.add_argument('--iter', type=int, default=6000, help='model iteration') 12 | parser.add_argument('--detail', type=int, default=0, help='print metrics for every samples?') 13 | parser.add_argument('--nms', type=int, default=0, help='apply NMS post-procssing?') 14 | 15 | 16 | FLAGS = parser.parse_args() 17 | 18 | os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu 19 | snapshot_path = "../model/{}".format(FLAGS.model) 20 | 21 | num_classes = 2 22 | 23 | test_save_path = os.path.join(snapshot_path, "test/") 24 | if not os.path.exists(test_save_path): 25 | os.makedirs(test_save_path) 26 | print(test_save_path) 27 | with open(FLAGS.root_path + '/test.list', 'r') as f: 28 | image_list = f.readlines() 29 | image_list = [FLAGS.root_path + "/2018LA_Seg_Training Set/" + item.replace('\n', '') + "/mri_norm2.h5" for item in 30 | image_list] 31 | 32 | 33 | def test_calculate_metric(epoch_num): 34 | net = VNet(n_channels=1, n_classes=num_classes-1, normalization='batchnorm', has_dropout=False).cuda() 35 | save_mode_path = os.path.join(snapshot_path, 'iter_' + str(epoch_num) + '.pth') 36 | net.load_state_dict(torch.load(save_mode_path)) 37 | print("init weight from {}".format(save_mode_path)) 38 | net.eval() 39 | 40 | avg_metric = test_all_case(net, image_list, num_classes=num_classes, 41 | patch_size=(112, 112, 80), stride_xy=18, stride_z=4, 42 | save_result=True, test_save_path=test_save_path, 43 | metric_detail=FLAGS.detail, nms=FLAGS.nms) 44 | 45 | return avg_metric 46 | 47 | 48 | if __name__ == '__main__': 49 | metric = test_calculate_metric(FLAGS.iter) #6000 50 | print(metric) 51 | 52 | # python test_LA.py --model 0214_re01 --gpu 0 53 | -------------------------------------------------------------------------------- /code/test_util.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import math 3 | import nibabel as nib 4 | import numpy as np 5 | from medpy import metric 6 | import torch 7 | import torch.nn.functional as F 8 | from tqdm import tqdm 9 | from skimage.measure import label 10 | 11 | def getLargestCC(segmentation): 12 | labels = label(segmentation) 13 | assert( labels.max() != 0 ) # assume at least 1 CC 14 | largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1 15 | return largestCC 16 | 17 | 18 | def test_all_case(net, image_list, num_classes, patch_size=(112, 112, 80), stride_xy=18, stride_z=4, save_result=True, test_save_path=None, preproc_fn=None, metric_detail=0, nms=0): 19 | total_metric = 0.0 20 | loader = tqdm(image_list) if not metric_detail else image_list 21 | ith = 0 22 | for image_path in loader: 23 | # id = image_path.split('/')[-2] 24 | h5f = h5py.File(image_path, 'r') 25 | image = h5f['image'][:] 26 | label = h5f['label'][:] 27 | if preproc_fn is not None: 28 | image = preproc_fn(image) 29 | prediction, score_map = test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=num_classes) 30 | if nms: 31 | prediction = getLargestCC(prediction) 32 | 33 | if np.sum(prediction)==0: 34 | single_metric = (0,0,0,0) 35 | else: 36 | single_metric = calculate_metric_percase(prediction, label[:]) 37 | if metric_detail: 38 | print('%02d,\t%.5f, %.5f, %.5f, %.5f' % (ith, single_metric[0], single_metric[1], single_metric[2], single_metric[3])) 39 | 40 | 41 | total_metric += np.asarray(single_metric) 42 | 43 | if save_result: 44 | nib.save(nib.Nifti1Image(prediction.astype(np.float32), np.eye(4)), test_save_path + "%02d_pred.nii.gz" % ith) 45 | nib.save(nib.Nifti1Image(image[:].astype(np.float32), np.eye(4)), test_save_path + "%02d_img.nii.gz" % ith) 46 | nib.save(nib.Nifti1Image(label[:].astype(np.float32), np.eye(4)), test_save_path + "%02d_gt.nii.gz" % ith) 47 | ith += 1 48 | 49 | avg_metric = total_metric / len(image_list) 50 | print('average metric is {}'.format(avg_metric)) 51 | 52 | return avg_metric 53 | 54 | 55 | def test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=1): 56 | w, h, d = image.shape 57 | 58 | # if the size of image is less than patch_size, then padding it 59 | add_pad = False 60 | if w < patch_size[0]: 61 | w_pad = patch_size[0]-w 62 | add_pad = True 63 | else: 64 | w_pad = 0 65 | if h < patch_size[1]: 66 | h_pad = patch_size[1]-h 67 | add_pad = True 68 | else: 69 | h_pad = 0 70 | if d < patch_size[2]: 71 | d_pad = patch_size[2]-d 72 | add_pad = True 73 | else: 74 | d_pad = 0 75 | wl_pad, wr_pad = w_pad//2,w_pad-w_pad//2 76 | hl_pad, hr_pad = h_pad//2,h_pad-h_pad//2 77 | dl_pad, dr_pad = d_pad//2,d_pad-d_pad//2 78 | if add_pad: 79 | image = np.pad(image, [(wl_pad,wr_pad),(hl_pad,hr_pad), (dl_pad, dr_pad)], mode='constant', constant_values=0) 80 | ww,hh,dd = image.shape 81 | 82 | sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1 83 | sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1 84 | sz = math.ceil((dd - patch_size[2]) / stride_z) + 1 85 | # print("{}, {}, {}".format(sx, sy, sz)) 86 | score_map = np.zeros((num_classes, ) + image.shape).astype(np.float32) 87 | cnt = np.zeros(image.shape).astype(np.float32) 88 | 89 | for x in range(0, sx): 90 | xs = min(stride_xy*x, ww-patch_size[0]) 91 | for y in range(0, sy): 92 | ys = min(stride_xy * y,hh-patch_size[1]) 93 | for z in range(0, sz): 94 | zs = min(stride_z * z, dd-patch_size[2]) 95 | test_patch = image[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] 96 | test_patch = np.expand_dims(np.expand_dims(test_patch,axis=0),axis=0).astype(np.float32) 97 | test_patch = torch.from_numpy(test_patch).cuda() 98 | 99 | with torch.no_grad(): 100 | y1_tanh, y1 = net(test_patch) 101 | # ensemble 102 | y = torch.sigmoid(y1) 103 | 104 | y = y.cpu().data.numpy() 105 | y = y[0,:,:,:,:] 106 | score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ 107 | = score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + y 108 | cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ 109 | = cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + 1 110 | score_map = score_map/np.expand_dims(cnt,axis=0) 111 | label_map = (score_map[0]>0.5).astype(np.int) 112 | 113 | if add_pad: 114 | label_map = label_map[wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d] 115 | score_map = score_map[:,wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d] 116 | return label_map, score_map 117 | 118 | 119 | def cal_dice(prediction, label, num=2): 120 | total_dice = np.zeros(num-1) 121 | for i in range(1, num): 122 | prediction_tmp = (prediction==i) 123 | label_tmp = (label==i) 124 | prediction_tmp = prediction_tmp.astype(np.float) 125 | label_tmp = label_tmp.astype(np.float) 126 | 127 | dice = 2 * np.sum(prediction_tmp * label_tmp) / (np.sum(prediction_tmp) + np.sum(label_tmp)) 128 | total_dice[i - 1] += dice 129 | 130 | return total_dice 131 | 132 | 133 | def calculate_metric_percase(pred, gt): 134 | dice = metric.binary.dc(pred, gt) 135 | jc = metric.binary.jc(pred, gt) 136 | hd = metric.binary.hd95(pred, gt) 137 | asd = metric.binary.asd(pred, gt) 138 | 139 | return dice, jc, hd, asd -------------------------------------------------------------------------------- /code/train_gan_sdfloss.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from tqdm import tqdm 4 | from tensorboardX import SummaryWriter 5 | import shutil 6 | import argparse 7 | import logging 8 | import time 9 | import random 10 | import numpy as np 11 | 12 | import torch 13 | import torch.optim as optim 14 | from torchvision import transforms 15 | import torch.nn.functional as F 16 | import torch.backends.cudnn as cudnn 17 | import torch.nn as nn 18 | from torch.nn import BCEWithLogitsLoss, MSELoss 19 | from torch.utils.data import DataLoader 20 | from torchvision.utils import make_grid 21 | 22 | from networks.vnet_sdf import VNet 23 | from networks.discriminator import FC3DDiscriminator 24 | 25 | from dataloaders import utils 26 | from utils import ramps, losses, metrics 27 | from dataloaders.la_heart import LAHeart, RandomCrop, CenterCrop, RandomRotFlip, ToTensor, TwoStreamBatchSampler 28 | from utils.util import compute_sdf 29 | 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--root_path', type=str, default='../data/2018LA_Seg_Training Set/', help='Name of Experiment') 32 | parser.add_argument('--exp', type=str, default='UAMT_001', help='model_name') 33 | parser.add_argument('--max_iterations', type=int, default=6000, help='maximum epoch number to train') 34 | parser.add_argument('--batch_size', type=int, default=4, help='batch_size per gpu') 35 | parser.add_argument('--labeled_bs', type=int, default=2, help='labeled_batch_size per gpu') 36 | parser.add_argument('--base_lr', type=float, default=0.01, help='maximum epoch number to train') 37 | parser.add_argument('--D_lr', type=float, default=1e-4, help='maximum discriminator learning rate to train') 38 | parser.add_argument('--deterministic', type=int, default=1, help='whether use deterministic training') 39 | parser.add_argument('--labelnum', type=int, default=16, help='random seed') 40 | parser.add_argument('--seed', type=int, default=1337, help='random seed') 41 | parser.add_argument('--gpu', type=str, default='0', help='GPU to use') 42 | parser.add_argument('--beta', type=float, default=0.3, help='balance factor to control regional and sdm loss') 43 | parser.add_argument('--gamma', type=float, default=0.5, help='balance factor to control supervised and consistency loss') 44 | ### costs 45 | parser.add_argument('--ema_decay', type=float, default=0.99, help='ema_decay') 46 | parser.add_argument('--consistency_type', type=str, default="mse", help='consistency_type') 47 | parser.add_argument('--consistency', type=float, default=0.01, help='consistency') 48 | parser.add_argument('--consistency_rampup', type=float, default=40.0, help='consistency_rampup') 49 | args = parser.parse_args() 50 | 51 | train_data_path = args.root_path 52 | snapshot_path = "../model/" + args.exp + "/" 53 | 54 | 55 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 56 | batch_size = args.batch_size * len(args.gpu.split(',')) 57 | max_iterations = args.max_iterations 58 | base_lr = args.base_lr 59 | labeled_bs = args.labeled_bs 60 | 61 | if not args.deterministic: 62 | cudnn.benchmark = True # 63 | cudnn.deterministic = False # 64 | else: 65 | cudnn.benchmark = False # True # 66 | cudnn.deterministic = True # False # 67 | random.seed(args.seed) 68 | np.random.seed(args.seed) 69 | torch.manual_seed(args.seed) 70 | torch.cuda.manual_seed(args.seed) 71 | 72 | num_classes = 2 73 | patch_size = (112, 112, 80) 74 | 75 | def get_current_consistency_weight(epoch): 76 | # Consistency ramp-up from https://arxiv.org/abs/1610.02242 77 | return args.consistency * ramps.sigmoid_rampup(epoch, args.consistency_rampup) 78 | 79 | if __name__ == "__main__": 80 | ## make logger file 81 | if not os.path.exists(snapshot_path): 82 | os.makedirs(snapshot_path) 83 | if os.path.exists(snapshot_path + '/code'): 84 | shutil.rmtree(snapshot_path + '/code') 85 | shutil.copytree('.', snapshot_path + '/code', shutil.ignore_patterns(['.git','__pycache__'])) 86 | 87 | logging.basicConfig(filename=snapshot_path+"/log.txt", level=logging.INFO, 88 | format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 89 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 90 | logging.info(str(args)) 91 | 92 | def create_model(ema=False): 93 | # Network definition 94 | net = VNet(n_channels=1, n_classes=num_classes-1, normalization='batchnorm', has_dropout=True) 95 | model = net.cuda() 96 | if ema: 97 | for param in model.parameters(): 98 | param.detach_() 99 | return model 100 | 101 | model = create_model() 102 | 103 | D = FC3DDiscriminator(num_classes=num_classes - 1) 104 | D = D.cuda() 105 | 106 | db_train = LAHeart(base_dir=train_data_path, 107 | split='train', # train/val split 108 | transform = transforms.Compose([ 109 | RandomRotFlip(), 110 | RandomCrop(patch_size), 111 | ToTensor(), 112 | ])) 113 | 114 | labelnum = args.labelnum # default 16 115 | labeled_idxs = list(range(labelnum)) 116 | unlabeled_idxs = list(range(labelnum, 80)) 117 | batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, batch_size, batch_size-labeled_bs) 118 | def worker_init_fn(worker_id): 119 | random.seed(args.seed+worker_id) 120 | trainloader = DataLoader(db_train, batch_sampler=batch_sampler, num_workers=4, pin_memory=True,worker_init_fn=worker_init_fn) 121 | 122 | model.train() 123 | 124 | Dopt = optim.Adam(D.parameters(), lr=args.D_lr, betas=(0.9,0.99)) 125 | optimizer = optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001) 126 | ce_loss = BCEWithLogitsLoss() 127 | mse_loss = MSELoss() 128 | 129 | if args.consistency_type == 'mse': 130 | consistency_criterion = losses.softmax_mse_loss 131 | elif args.consistency_type == 'kl': 132 | consistency_criterion = losses.softmax_kl_loss 133 | else: 134 | assert False, args.consistency_type 135 | 136 | writer = SummaryWriter(snapshot_path+'/log') 137 | logging.info("{} itertations per epoch".format(len(trainloader))) 138 | 139 | iter_num = 0 140 | max_epoch = max_iterations//len(trainloader)+1 141 | lr_ = base_lr 142 | 143 | iterator = tqdm(range(max_epoch), ncols=70) 144 | for epoch_num in iterator: 145 | time1 = time.time() 146 | for i_batch, sampled_batch in enumerate(trainloader): 147 | time2 = time.time() 148 | # print('fetch data cost {}'.format(time2-time1)) 149 | volume_batch, label_batch = sampled_batch['image'], sampled_batch['label'] 150 | volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda() 151 | 152 | # Generate Discriminator target based on sampler 153 | Dtarget = torch.tensor([1, 1, 0, 0]).cuda() 154 | model.train() 155 | D.eval() 156 | 157 | outputs_tanh, outputs = model(volume_batch) 158 | outputs_soft = torch.sigmoid(outputs) 159 | 160 | ## calculate the loss 161 | with torch.no_grad(): 162 | gt_dis = compute_sdf(label_batch[:].cpu().numpy(), outputs[:labeled_bs, 0, ...].shape) 163 | gt_dis = torch.from_numpy(gt_dis).float().cuda() 164 | loss_sdf = mse_loss(outputs_tanh[:labeled_bs, 0, ...], gt_dis) 165 | loss_seg = ce_loss(outputs[:labeled_bs, 0, ...], label_batch[:labeled_bs].float()) 166 | loss_seg_dice = losses.dice_loss(outputs_soft[:labeled_bs, 0, :, :, :], label_batch[:labeled_bs] == 1) 167 | 168 | consistency_weight = get_current_consistency_weight(iter_num//150) 169 | 170 | supervised_loss = loss_seg_dice + args.beta * loss_sdf 171 | 172 | Doutputs = D(outputs_tanh[labeled_bs:], volume_batch[labeled_bs:]) 173 | # G want D to misclassify unlabel data to label data. 174 | loss_adv = F.cross_entropy(Doutputs, (Dtarget[:labeled_bs]).long()) 175 | 176 | loss = supervised_loss + consistency_weight*loss_adv 177 | 178 | optimizer.zero_grad() 179 | loss.backward() 180 | optimizer.step() 181 | 182 | dc = metrics.dice(torch.argmax(outputs_soft[:labeled_bs], dim=1), label_batch[:labeled_bs]) 183 | 184 | # Train D 185 | model.eval() 186 | D.train() 187 | with torch.no_grad(): 188 | outputs_tanh, outputs = model(volume_batch) 189 | 190 | Doutputs = D(outputs_tanh, volume_batch) 191 | # D want to classify unlabel data and label data rightly. 192 | D_loss = F.cross_entropy(Doutputs, Dtarget.long()) 193 | 194 | # Dtp and Dfn is unreliable because of the num of samples is small(4) 195 | Dacc = torch.mean((torch.argmax(Doutputs, dim=1).float()==Dtarget.float()).float()) 196 | Dtp = torch.mean((torch.argmax(Doutputs, dim=1).float()==Dtarget.float()).float()) 197 | Dfn = torch.mean((torch.argmax(Doutputs, dim=1).float()==Dtarget.float()).float()) 198 | Dopt.zero_grad() 199 | D_loss.backward() 200 | Dopt.step() 201 | 202 | iter_num = iter_num + 1 203 | writer.add_scalar('lr', lr_, iter_num) 204 | writer.add_scalar('loss/loss', loss, iter_num) 205 | writer.add_scalar('loss/loss_seg', loss_seg, iter_num) 206 | writer.add_scalar('loss/loss_dice', loss_seg_dice, iter_num) 207 | writer.add_scalar('loss/loss_hausdorff', loss_sdf, iter_num) 208 | writer.add_scalar('train/consistency_weight', consistency_weight, iter_num) 209 | writer.add_scalar('loss/loss_adv', consistency_weight*loss_adv, iter_num) 210 | writer.add_scalar('GAN/loss_adv', loss_adv, iter_num) 211 | writer.add_scalar('GAN/D_loss', D_loss, iter_num) 212 | writer.add_scalar('GAN/Dtp', Dtp, iter_num) 213 | writer.add_scalar('GAN/Dfn', Dfn, iter_num) 214 | 215 | logging.info( 216 | 'iteration %d : loss : %f, loss_weight: %f, loss_haus: %f, loss_seg: %f, loss_dice: %f' % 217 | (iter_num, loss.item(), consistency_weight, loss_sdf.item(), 218 | loss_seg.item(), loss_seg_dice.item())) 219 | 220 | ## change lr 221 | if iter_num % 2500 == 0: 222 | lr_ = base_lr * 0.1 ** (iter_num // 2500) 223 | for param_group in optimizer.param_groups: 224 | param_group['lr'] = lr_ 225 | if iter_num % 1000 == 0: 226 | save_mode_path = os.path.join(snapshot_path, 'iter_' + str(iter_num) + '.pth') 227 | torch.save(model.state_dict(), save_mode_path) 228 | logging.info("save model to {}".format(save_mode_path)) 229 | 230 | if iter_num >= max_iterations: 231 | break 232 | time1 = time.time() 233 | if iter_num >= max_iterations: 234 | iterator.close() 235 | break 236 | writer.close() 237 | -------------------------------------------------------------------------------- /code/utils/__pycache__/losses.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinzcy/SASSnet/1761aa8af08f42e5a36a737c88beb4dc798af35c/code/utils/__pycache__/losses.cpython-36.pyc -------------------------------------------------------------------------------- /code/utils/__pycache__/losses.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinzcy/SASSnet/1761aa8af08f42e5a36a737c88beb4dc798af35c/code/utils/__pycache__/losses.cpython-37.pyc -------------------------------------------------------------------------------- /code/utils/__pycache__/losses_2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinzcy/SASSnet/1761aa8af08f42e5a36a737c88beb4dc798af35c/code/utils/__pycache__/losses_2.cpython-36.pyc -------------------------------------------------------------------------------- /code/utils/__pycache__/metrics.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinzcy/SASSnet/1761aa8af08f42e5a36a737c88beb4dc798af35c/code/utils/__pycache__/metrics.cpython-36.pyc -------------------------------------------------------------------------------- /code/utils/__pycache__/metrics.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinzcy/SASSnet/1761aa8af08f42e5a36a737c88beb4dc798af35c/code/utils/__pycache__/metrics.cpython-37.pyc -------------------------------------------------------------------------------- /code/utils/__pycache__/ramps.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinzcy/SASSnet/1761aa8af08f42e5a36a737c88beb4dc798af35c/code/utils/__pycache__/ramps.cpython-36.pyc -------------------------------------------------------------------------------- /code/utils/__pycache__/ramps.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinzcy/SASSnet/1761aa8af08f42e5a36a737c88beb4dc798af35c/code/utils/__pycache__/ramps.cpython-37.pyc -------------------------------------------------------------------------------- /code/utils/__pycache__/util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinzcy/SASSnet/1761aa8af08f42e5a36a737c88beb4dc798af35c/code/utils/__pycache__/util.cpython-36.pyc -------------------------------------------------------------------------------- /code/utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | import numpy as np 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | 7 | def dice_loss(score, target): 8 | target = target.float() 9 | smooth = 1e-5 10 | intersect = torch.sum(score * target) 11 | y_sum = torch.sum(target * target) 12 | z_sum = torch.sum(score * score) 13 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 14 | loss = 1 - loss 15 | return loss 16 | 17 | def dice_loss1(score, target): 18 | target = target.float() 19 | smooth = 1e-5 20 | intersect = torch.sum(score * target) 21 | y_sum = torch.sum(target) 22 | z_sum = torch.sum(score) 23 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 24 | loss = 1 - loss 25 | return loss 26 | 27 | def entropy_loss(p,C=2): 28 | ## p N*C*W*H*D 29 | y1 = -1*torch.sum(p*torch.log(p+1e-6), dim=1)/torch.tensor(np.log(C)).cuda() 30 | ent = torch.mean(y1) 31 | 32 | return ent 33 | 34 | def softmax_dice_loss(input_logits, target_logits): 35 | """Takes softmax on both sides and returns MSE loss 36 | 37 | Note: 38 | - Returns the sum over all examples. Divide by the batch size afterwards 39 | if you want the mean. 40 | - Sends gradients to inputs but not the targets. 41 | """ 42 | assert input_logits.size() == target_logits.size() 43 | input_softmax = F.softmax(input_logits, dim=1) 44 | target_softmax = F.softmax(target_logits, dim=1) 45 | n = input_logits.shape[1] 46 | dice = 0 47 | for i in range(0, n): 48 | dice += dice_loss1(input_softmax[:, i], target_softmax[:, i]) 49 | mean_dice = dice / n 50 | 51 | return mean_dice 52 | 53 | 54 | def entropy_loss_map(p, C=2): 55 | ent = -1*torch.sum(p * torch.log(p + 1e-6), dim=1, keepdim=True)/torch.tensor(np.log(C)).cuda() 56 | return ent 57 | 58 | def softmax_mse_loss(input_logits, target_logits, sigmoid=False): 59 | """Takes softmax on both sides and returns MSE loss 60 | 61 | Note: 62 | - Returns the sum over all examples. Divide by the batch size afterwards 63 | if you want the mean. 64 | - Sends gradients to inputs but not the targets. 65 | """ 66 | assert input_logits.size() == target_logits.size() 67 | if sigmoid: 68 | input_softmax = torch.sigmoid(input_logits) 69 | target_softmax = torch.sigmoid(target_logits) 70 | else: 71 | input_softmax = F.softmax(input_logits, dim=1) 72 | target_softmax = F.softmax(target_logits, dim=1) 73 | 74 | mse_loss = (input_softmax-target_softmax)**2 75 | return mse_loss 76 | 77 | 78 | 79 | def softmax_kl_loss(input_logits, target_logits, sigmoid=False): 80 | """Takes softmax on both sides and returns KL divergence 81 | 82 | Note: 83 | - Returns the sum over all examples. Divide by the batch size afterwards 84 | if you want the mean. 85 | - Sends gradients to inputs but not the targets. 86 | """ 87 | assert input_logits.size() == target_logits.size() 88 | if sigmoid: 89 | input_log_softmax = torch.log(torch.sigmoid(input_logits)) 90 | target_softmax = torch.sigmoid(target_logits) 91 | else: 92 | input_log_softmax = F.log_softmax(input_logits, dim=1) 93 | target_softmax = F.softmax(target_logits, dim=1) 94 | 95 | 96 | # return F.kl_div(input_log_softmax, target_softmax) 97 | kl_div = F.kl_div(input_log_softmax, target_softmax, reduction='none') 98 | # mean_kl_div = torch.mean(0.2*kl_div[:,0,...]+0.8*kl_div[:,1,...]) 99 | return kl_div 100 | 101 | def symmetric_mse_loss(input1, input2): 102 | """Like F.mse_loss but sends gradients to both directions 103 | 104 | Note: 105 | - Returns the sum over all examples. Divide by the batch size afterwards 106 | if you want the mean. 107 | - Sends gradients to both input1 and input2. 108 | """ 109 | assert input1.size() == input2.size() 110 | return torch.mean((input1 - input2)**2) 111 | 112 | 113 | class FocalLoss(nn.Module): 114 | def __init__(self, gamma=2, alpha=None, size_average=True): 115 | super(FocalLoss, self).__init__() 116 | self.gamma = gamma 117 | self.alpha = alpha 118 | if isinstance(alpha,(float,int)): self.alpha = torch.Tensor([alpha,1-alpha]) 119 | if isinstance(alpha,list): self.alpha = torch.Tensor(alpha) 120 | self.size_average = size_average 121 | 122 | def forward(self, input, target): 123 | if input.dim()>2: 124 | input = input.view(input.size(0),input.size(1),-1) # N,C,H,W => N,C,H*W 125 | input = input.transpose(1,2) # N,C,H*W => N,H*W,C 126 | input = input.contiguous().view(-1,input.size(2)) # N,H*W,C => N*H*W,C 127 | target = target.view(-1,1) 128 | 129 | logpt = F.log_softmax(input, dim=1) 130 | logpt = logpt.gather(1,target) 131 | logpt = logpt.view(-1) 132 | pt = Variable(logpt.data.exp()) 133 | 134 | if self.alpha is not None: 135 | if self.alpha.type()!=input.data.type(): 136 | self.alpha = self.alpha.type_as(input.data) 137 | at = self.alpha.gather(0,target.data.view(-1)) 138 | logpt = logpt * Variable(at) 139 | 140 | loss = -1 * (1-pt)**self.gamma * logpt 141 | if self.size_average: return loss.mean() 142 | else: return loss.sum() 143 | -------------------------------------------------------------------------------- /code/utils/losses_2.py: -------------------------------------------------------------------------------- 1 | # import torch 2 | # from torch.nn import functional as F 3 | import numpy as np 4 | from scipy.ndimage import distance_transform_edt as distance 5 | from skimage import segmentation as skimage_seg 6 | 7 | def compute_dtm(img_gt, out_shape, normalize=False, fg=False): 8 | """ 9 | compute the distance transform map of foreground in binary mask 10 | input: segmentation, shape = (batch_size, x, y, z) 11 | output: the foreground Distance Map (SDM) 12 | dtm(x) = 0; x in segmentation boundary 13 | inf|x-y|; x in segmentation 14 | """ 15 | 16 | fg_dtm = np.zeros(out_shape) 17 | 18 | for b in range(out_shape[0]): # batch size 19 | posmask = img_gt[b].astype(np.bool) 20 | if not fg: 21 | if posmask.any(): 22 | negmask = 1 - posmask 23 | posdis = distance(posmask) 24 | negdis = distance(negmask) 25 | boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8) 26 | if normalize: 27 | fg_dtm[b] = (negdis-np.min(negdis))/(np.max(negdis)-np.min(negdis)) + (posdis-np.min(posdis))/(np.max(posdis)-np.min(posdis)) 28 | else: 29 | fg_dtm[b] = posdis + negdis 30 | fg_dtm[b][boundary==1] = 0 31 | else: 32 | if posmask.any(): 33 | posdis = distance(posmask) 34 | boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8) 35 | if normalize: 36 | fg_dtm[b] = (posdis-np.min(posdis))/(np.max(posdis)-np.min(posdis)) 37 | else: 38 | fg_dtm[b] = posdis 39 | fg_dtm[b][boundary==1] = 0 40 | 41 | return fg_dtm 42 | 43 | def hd_loss(seg_soft, gt, gt_dtm=None, one_side=True, seg_dtm=None): 44 | """ 45 | compute huasdorff distance loss for binary segmentation 46 | input: seg_soft: softmax results, shape=(b,x,y,z) 47 | gt: ground truth, shape=(b,x,y,z) 48 | seg_dtm: segmentation distance transform map; shape=(b,x,y,z) 49 | gt_dtm: ground truth distance transform map; shape=(b,x,y,z) 50 | output: boundary_loss; sclar 51 | """ 52 | 53 | delta_s = (seg_soft - gt.float()) ** 2 54 | g_dtm = gt_dtm ** 2 55 | dtm = g_dtm if one_side else g_dtm + seg_dtm ** 2 56 | multipled = torch.einsum('bxyz, bxyz->bxyz', delta_s, dtm) 57 | # hd_loss = multipled.sum()*1.0/(gt_dtm > 0).sum() 58 | hd_loss = multipled.mean() 59 | 60 | return hd_loss 61 | 62 | 63 | 64 | def save_sdf(gt_path=None): 65 | ''' 66 | generate SDM for gt segmentation 67 | ''' 68 | import nibabel as nib 69 | dir_path = 'C:/Seolen/PycharmProjects/semi_seg/semantic-semi-supervised-master/model/gan_sdfloss3D_0229_04/test' 70 | gt_path = dir_path + '/00_gt.nii.gz' 71 | gt_img = nib.load(gt_path) 72 | gt = gt_img.get_data().astype(np.uint8) 73 | posmask = gt.astype(np.bool) 74 | negmask = ~posmask 75 | posdis = distance(posmask) 76 | negdis = distance(negmask) 77 | boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8) 78 | # sdf = (negdis - np.min(negdis)) / (np.max(negdis) - np.min(negdis)) - (posdis - np.min(posdis)) / ( np.max(posdis) - np.min(posdis)) 79 | sdf = (posdis - np.min(posdis)) / ( np.max(posdis) - np.min(posdis)) 80 | sdf[boundary==1] = 0 81 | sdf = sdf.astype(np.float32) 82 | 83 | sdf = nib.Nifti1Image(sdf, gt_img.affine) 84 | save_path = dir_path + '/00_sdm_pos.nii.gz' 85 | nib.save(sdf, save_path) 86 | 87 | 88 | 89 | def compute_sdf(img_gt, out_shape): 90 | """ 91 | compute the signed distance map of binary mask 92 | input: segmentation, shape = (batch_size, x, y, z) 93 | output: the Signed Distance Map (SDM) 94 | sdf(x) = 0; x in segmentation boundary 95 | -inf|x-y|; x in segmentation 96 | +inf|x-y|; x out of segmentation 97 | normalize sdf to [-1,1] 98 | """ 99 | 100 | img_gt = img_gt.astype(np.uint8) 101 | normalized_sdf = np.zeros(out_shape) 102 | 103 | for b in range(out_shape[0]): # batch size 104 | posmask = img_gt[b].astype(np.bool) 105 | if posmask.any(): 106 | negmask = ~posmask 107 | posdis = distance(posmask) 108 | negdis = distance(negmask) 109 | boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8) 110 | sdf = (negdis-np.min(negdis))/(np.max(negdis)-np.min(negdis)) - (posdis-np.min(posdis))/(np.max(posdis)-np.min(posdis)) 111 | sdf[boundary==1] = 0 112 | normalized_sdf[b] = sdf 113 | # assert np.min(sdf) == -1.0, print(np.min(posdis), np.max(posdis), np.min(negdis), np.max(negdis)) 114 | # assert np.max(sdf) == 1.0, print(np.min(posdis), np.min(negdis), np.max(posdis), np.max(negdis)) 115 | 116 | return normalized_sdf 117 | 118 | def sdf_loss(net_output, gt_sdm): 119 | # print('net_output.shape, gt_sdm.shape', net_output.shape, gt_sdm.shape) 120 | # ([4, 1, 112, 112, 80]) 121 | 122 | smooth = 1e-5 123 | # compute eq (4) 124 | intersect = torch.sum(net_output * gt_sdm) 125 | pd_sum = torch.sum(net_output ** 2) 126 | gt_sum = torch.sum(gt_sdm ** 2) 127 | L_product = (intersect + smooth) / (intersect + pd_sum + gt_sum + smooth) 128 | # print('L_product.shape', L_product.shape) (4,2) 129 | L_SDF = 1/3 - L_product + torch.norm(net_output - gt_sdm, 1)/torch.numel(net_output) 130 | 131 | return L_SDF 132 | 133 | 134 | 135 | 136 | 137 | def boundary_loss(outputs_soft, gt_sdf): 138 | """ 139 | compute boundary loss for binary segmentation 140 | input: outputs_soft: sigmoid results, shape=(b,2,x,y,z) 141 | gt_sdf: sdf of ground truth (can be original or normalized sdf); shape=(b,2,x,y,z) 142 | output: boundary_loss; sclar 143 | """ 144 | pc = outputs_soft[:,1,...] 145 | dc = gt_sdf[:,1,...] 146 | multipled = torch.einsum('bxyz, bxyz->bxyz', pc, dc) 147 | bd_loss = multipled.mean() 148 | 149 | return bd_loss 150 | 151 | if __name__ == '__main__': 152 | save_sdf() -------------------------------------------------------------------------------- /code/utils/metrics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2019/12/14 下午4:41 4 | # @Author : chuyu zhang 5 | # @File : metrics.py 6 | # @Software: PyCharm 7 | 8 | 9 | import numpy as np 10 | from medpy import metric 11 | 12 | 13 | def cal_dice(prediction, label, num=2): 14 | total_dice = np.zeros(num-1) 15 | for i in range(1, num): 16 | prediction_tmp = (prediction == i) 17 | label_tmp = (label == i) 18 | prediction_tmp = prediction_tmp.astype(np.float) 19 | label_tmp = label_tmp.astype(np.float) 20 | 21 | dice = 2 * np.sum(prediction_tmp * label_tmp) / (np.sum(prediction_tmp) + np.sum(label_tmp)) 22 | total_dice[i - 1] += dice 23 | 24 | return total_dice 25 | 26 | 27 | def calculate_metric_percase(pred, gt): 28 | dc = metric.binary.dc(pred, gt) 29 | jc = metric.binary.jc(pred, gt) 30 | hd = metric.binary.hd95(pred, gt) 31 | asd = metric.binary.asd(pred, gt) 32 | 33 | return dc, jc, hd, asd 34 | 35 | 36 | def dice(input, target, ignore_index=None): 37 | smooth = 1. 38 | # using clone, so that it can do change to original target. 39 | iflat = input.clone().view(-1) 40 | tflat = target.clone().view(-1) 41 | if ignore_index is not None: 42 | mask = tflat == ignore_index 43 | tflat[mask] = 0 44 | iflat[mask] = 0 45 | intersection = (iflat * tflat).sum() 46 | 47 | return (2. * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth) -------------------------------------------------------------------------------- /code/utils/ramps.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, Curious AI Ltd. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Functions for ramping hyperparameters up or down 9 | 10 | Each function takes the current training step or epoch, and the 11 | ramp length in the same format, and returns a multiplier between 12 | 0 and 1. 13 | """ 14 | 15 | 16 | import numpy as np 17 | 18 | 19 | def sigmoid_rampup(current, rampup_length): 20 | """Exponential rampup from https://arxiv.org/abs/1610.02242""" 21 | if rampup_length == 0: 22 | return 1.0 23 | else: 24 | current = np.clip(current, 0.0, rampup_length) 25 | phase = 1.0 - current / rampup_length 26 | return float(np.exp(-5.0 * phase * phase)) 27 | 28 | 29 | def linear_rampup(current, rampup_length): 30 | """Linear rampup""" 31 | assert current >= 0 and rampup_length >= 0 32 | if current >= rampup_length: 33 | return 1.0 34 | else: 35 | return current / rampup_length 36 | 37 | 38 | def cosine_rampdown(current, rampdown_length): 39 | """Cosine rampdown from https://arxiv.org/abs/1608.03983""" 40 | assert 0 <= current <= rampdown_length 41 | return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1)) 42 | -------------------------------------------------------------------------------- /code/utils/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import os 8 | import pickle 9 | import numpy as np 10 | from scipy.ndimage import distance_transform_edt as distance 11 | from skimage import segmentation as skimage_seg 12 | import torch 13 | from torch.utils.data.sampler import Sampler 14 | 15 | import networks 16 | 17 | def load_model(path): 18 | """Loads model and return it without DataParallel table.""" 19 | if os.path.isfile(path): 20 | print("=> loading checkpoint '{}'".format(path)) 21 | checkpoint = torch.load(path) 22 | 23 | # size of the top layer 24 | N = checkpoint['state_dict']['top_layer.bias'].size() 25 | 26 | # build skeleton of the model 27 | sob = 'sobel.0.weight' in checkpoint['state_dict'].keys() 28 | model = models.__dict__[checkpoint['arch']](sobel=sob, out=int(N[0])) 29 | 30 | # deal with a dataparallel table 31 | def rename_key(key): 32 | if not 'module' in key: 33 | return key 34 | return ''.join(key.split('.module')) 35 | 36 | checkpoint['state_dict'] = {rename_key(key): val 37 | for key, val 38 | in checkpoint['state_dict'].items()} 39 | 40 | # load weights 41 | model.load_state_dict(checkpoint['state_dict']) 42 | print("Loaded") 43 | else: 44 | model = None 45 | print("=> no checkpoint found at '{}'".format(path)) 46 | return model 47 | 48 | 49 | class UnifLabelSampler(Sampler): 50 | """Samples elements uniformely accross pseudolabels. 51 | Args: 52 | N (int): size of returned iterator. 53 | images_lists: dict of key (target), value (list of data with this target) 54 | """ 55 | 56 | def __init__(self, N, images_lists): 57 | self.N = N 58 | self.images_lists = images_lists 59 | self.indexes = self.generate_indexes_epoch() 60 | 61 | def generate_indexes_epoch(self): 62 | size_per_pseudolabel = int(self.N / len(self.images_lists)) + 1 63 | res = np.zeros(size_per_pseudolabel * len(self.images_lists)) 64 | 65 | for i in range(len(self.images_lists)): 66 | indexes = np.random.choice( 67 | self.images_lists[i], 68 | size_per_pseudolabel, 69 | replace=(len(self.images_lists[i]) <= size_per_pseudolabel) 70 | ) 71 | res[i * size_per_pseudolabel: (i + 1) * size_per_pseudolabel] = indexes 72 | 73 | np.random.shuffle(res) 74 | return res[:self.N].astype('int') 75 | 76 | def __iter__(self): 77 | return iter(self.indexes) 78 | 79 | def __len__(self): 80 | return self.N 81 | 82 | 83 | class AverageMeter(object): 84 | """Computes and stores the average and current value""" 85 | def __init__(self): 86 | self.reset() 87 | 88 | def reset(self): 89 | self.val = 0 90 | self.avg = 0 91 | self.sum = 0 92 | self.count = 0 93 | 94 | def update(self, val, n=1): 95 | self.val = val 96 | self.sum += val * n 97 | self.count += n 98 | self.avg = self.sum / self.count 99 | 100 | 101 | def learning_rate_decay(optimizer, t, lr_0): 102 | for param_group in optimizer.param_groups: 103 | lr = lr_0 / np.sqrt(1 + lr_0 * param_group['weight_decay'] * t) 104 | param_group['lr'] = lr 105 | 106 | 107 | class Logger(): 108 | """ Class to update every epoch to keep trace of the results 109 | Methods: 110 | - log() log and save 111 | """ 112 | 113 | def __init__(self, path): 114 | self.path = path 115 | self.data = [] 116 | 117 | def log(self, train_point): 118 | self.data.append(train_point) 119 | with open(os.path.join(self.path), 'wb') as fp: 120 | pickle.dump(self.data, fp, -1) 121 | 122 | 123 | def compute_sdf(img_gt, out_shape): 124 | """ 125 | compute the signed distance map of binary mask 126 | input: segmentation, shape = (batch_size, x, y, z) 127 | output: the Signed Distance Map (SDM) 128 | sdf(x) = 0; x in segmentation boundary 129 | -inf|x-y|; x in segmentation 130 | +inf|x-y|; x out of segmentation 131 | normalize sdf to [-1,1] 132 | """ 133 | 134 | img_gt = img_gt.astype(np.uint8) 135 | normalized_sdf = np.zeros(out_shape) 136 | 137 | for b in range(out_shape[0]): # batch size 138 | posmask = img_gt[b].astype(np.bool) 139 | if posmask.any(): 140 | negmask = ~posmask 141 | posdis = distance(posmask) 142 | negdis = distance(negmask) 143 | boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8) 144 | sdf = (negdis-np.min(negdis))/(np.max(negdis)-np.min(negdis)) - (posdis-np.min(posdis))/(np.max(posdis)-np.min(posdis)) 145 | sdf[boundary==1] = 0 146 | normalized_sdf[b] = sdf 147 | # assert np.min(sdf) == -1.0, print(np.min(posdis), np.max(posdis), np.min(negdis), np.max(negdis)) 148 | # assert np.max(sdf) == 1.0, print(np.min(posdis), np.min(negdis), np.max(posdis), np.max(negdis)) 149 | 150 | return normalized_sdf -------------------------------------------------------------------------------- /data/test.list: -------------------------------------------------------------------------------- 1 | UPT6DX9IQY9JAZ7HJKA7 2 | UTBUJIWZMKP64E3N73YC 3 | ULHWPWKKLTE921LQLH1P 4 | V0MZOWJ6MU3RMRCV9EXR 5 | VDOF02M8ZHEAADFMS6NP 6 | VG4C826RAAKVMV9BQLVD 7 | VIXBEFTNVHZWKAKURJBN 8 | VQ2L3WM8KEVF6L44E6G9 9 | WBG9WYZ1B25WDT5WAT8T 10 | WMDG2EFA6L2SNDZXIRU0 11 | WNPKE0W404QE9AELX1LR 12 | WSJB9P4JCXUVHBOYFVWL 13 | WW8F5CO4S4K5IM5Z7EXX 14 | X18LU5AOBNNDMLTA0JZL 15 | XYDLYJ5CS19FDBVLJIPI 16 | Y7ZU0B2APPF54WG6PDMF 17 | YDKD1HVHSME6NVMA8I39 18 | Z9GMG63CJLL0VW893BB1 19 | ZIJLJAVQV3FJ6JSQOH1E 20 | ZQPMJ4XEC5A4BISD45P1 21 | -------------------------------------------------------------------------------- /data/train.list: -------------------------------------------------------------------------------- 1 | 06SR5RBREL16DQ6M8LWS 2 | 0RZDK210BSMWAA6467LU 3 | 1D7CUD1955YZPGK8XHJX 4 | 1GU15S0GJ6PFNARO469W 5 | 1MHBF3G6DCPWHSKG7XCP 6 | 23X6SY44VT9KFHR7S7OC 7 | 2XL5HSFSE93RMOJDRGR4 8 | 38CWS74285MFGZZXR09Z 9 | 3C2QTUNI0852XV7ZH4Q1 10 | 3DA0T2V6JJ2NLUAV6FWM 11 | 4498CA6DZWELOXCBRYRF 12 | 45C45I6IXAFGNRO067W9 13 | 4CHFJGF6ZUM7CMZTNFQF 14 | 4EPVTT1HPA8U60CDUKXE 15 | 57SGAJMLCTCH92QUA0EE 16 | 5BHTH9RHH3PQT913I59W 17 | 5FKQL4K14KCB72Y8YMC2 18 | 5HH0WPWIY06DLAFOBQ4M 19 | 5QFK2PMHNX7UALK52NNA 20 | 5UB5KFD2PK38Z4LS6W80 21 | 6799D6LEBH3NSRV1KH27 22 | 78NJ5YFQF72BGC8RO51C 23 | 7FUCNXB39F78WTOP5K71 24 | 8GYK8A9MBRC9TV0FVSRA 25 | 8M99G0JLAXG9GLPV0O8G 26 | 8RE90C8H5DKF4V6HO8UU 27 | 8ZG2TRZ81MAWHZPN9KKG 28 | 9DCM2IB45SK6YKQNYUQY 29 | 9DHWWP5Y66VDMPXISZ13 30 | 9DQYTIU00I4JC0OEOKQQ 31 | A11O45O3NAXWM7T2H8CH 32 | A4R1S23KR0KU2WSYHK2X 33 | A5RNNK0A891WUSC2V624 34 | AT5CRO5JUDBWD4RUPXSQ 35 | BNK95S2SJXEGSW7VAKYU 36 | BXJWOUYP2J3EN4U92517 37 | BYSRSI3H4YTWKMM3MADP 38 | BZUFJX66T0W6ZPVTL9DU 39 | CB5P5W7X310NIIVU7UZV 40 | CBIJFVZ5L9BS0LKWE8YL 41 | CCGAKN4EDT72KC8TTJ76 42 | CLXFYOBQDCVXQ9P7YC07 43 | CMPXO4J23G58J53Q98SZ 44 | CZPMV6KWZ4I7IJJP9FOK 45 | DLKXBV73A55ZTSZ0QQI2 46 | DQ5UYBGR5QP6L692QSG6 47 | DYXSCIWHLSUOZIDDSZ40 48 | E2ZMO66WGS74UKXTZPPQ 49 | EJ5V7SPR4961JWD6SS8V 50 | FGM5NIWN3URY4HF4WNUW 51 | GSC9KNY0VEZXFSGWNF25 52 | HVE7DR3CUA2IM3RC6OMA 53 | HZZ4O0BRKF8S0YX3NNF7 54 | I2VZ7N8H9QYNYT7ZZF1Y 55 | IDWWHGWJ5STOQXSDT6GU 56 | IIY6TYJMTJIZRIZLB9YW 57 | IJJY51YW3W4YJJ7DTVTK 58 | IQYKPTWXVV9H0IHB8YXC 59 | JEC6HJ7SQJXBKVREX03F 60 | JGFOLWJF7YCYD8DPHQNH 61 | K32FD6LRSUSSXGS1YUOX 62 | KM5RYAMP4P4ZP6XWP3Q2 63 | KSNYHUBHHUJTYJ14UQZR 64 | LH4FVU3TQDEC87YGN6FL 65 | LJSDNMND9SHKM7Q4IRHJ 66 | MFTDVMBWFNQ3F5KHBRDR 67 | MJHV7F65TB2A76CQLOC3 68 | MVKIPGBKTNSENNP1S4HB 69 | O5TSIKRD4AIB8K84WIR9 70 | OIRDLE32TXZX942FVZMM 71 | P1OTI3IWJUIB5NRLULLH 72 | PVNXUK681N9BY14K4Z86 73 | Q0MEX9ZIKAGJORSPLQ3Y 74 | Q7J0WYM695R9MA285ZW0 75 | QZC1W0FNR19KJFLOCFLH 76 | R8ER97O9UUN77C02VE2J 77 | RSZY41MT2FGDKHWWL5L2 78 | SN4LF8SGBSRQUPTDSX78 79 | TDDI6L3Y0L9VVFP9MNFS 80 | UZUZZT2W9IUSHL6ASOX3 81 | -------------------------------------------------------------------------------- /model/model_16label/best.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinzcy/SASSnet/1761aa8af08f42e5a36a737c88beb4dc798af35c/model/model_16label/best.pth -------------------------------------------------------------------------------- /model/model_8label/best.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinzcy/SASSnet/1761aa8af08f42e5a36a737c88beb4dc798af35c/model/model_8label/best.pth --------------------------------------------------------------------------------