├── @voxresnet info ├── DataReader ├── Generate_Dataset_name.py ├── __pycache__ │ ├── reader.cpython-36.pyc │ ├── sampler.cpython-36.pyc │ └── transform.cpython-36.pyc ├── beautiful_swan_2-t2.jpg ├── reader.py ├── sampler.py ├── split │ ├── .Data_set.swp │ ├── Data_set │ ├── images │ └── labels ├── transform.py └── wget-log ├── Model ├── UnetGenerator_3d.py ├── Unet_Zoo.py ├── __pycache__ │ ├── UnetGenerator_3d.cpython-36.pyc │ └── model.cpython-36.pyc └── model.py ├── README.md ├── check_function.py ├── common.py ├── config.py ├── losses.py ├── main.py ├── model.py └── utils.py /@voxresnet info: -------------------------------------------------------------------------------- 1 | @voxresnet info 2 | preprocess steps: 3 | #1. subtract gaussian smooth images 4 | 5 | #2. adaptive hostogram equalization 6 | 7 | #3. normalization each slice with zeros mean and unit std 8 | 9 | train the network 10 | #cross validation method is used for invastigating effients of multi-modality. 11 | trian: random crop the train data to 80*80*80 because of limitation of GPU memory 12 | 13 | test: the probability map of whole volume was generated in overlap-tiling strategy for stitching the sub-volume results 14 | 15 | evaluate parameters 16 | 17 | #dice coefficients 18 | 19 | @denseseg 20 | preprocess steps: 21 | -------------------------------------------------------------------------------- /DataReader/Generate_Dataset_name.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append("..") 4 | from utils import Mkdir 5 | 6 | def last_9chars(x): 7 | return(x[-9:-4]) 8 | 9 | def ImageLits2txt(image_list, txt_file_name, folder_name): 10 | txt_file_name = './split/' + txt_file_name 11 | txt_file = open(txt_file_name, 'w') 12 | for img_name in image_list: 13 | txt_file.write(os.path.join(folder_name, img_name)) 14 | txt_file.write('\n') 15 | txt_file.close() 16 | 17 | def Make_Dataset_txt(Dataset_dir): 18 | image_dir = Dataset_dir + '/images' 19 | label_dir = Dataset_dir + '/labels' 20 | #save image txt file 21 | image_list = os.listdir(image_dir) 22 | image_list = sorted(image_list, key = last_9chars) 23 | label_list = os.listdir(label_dir) 24 | label_list = sorted(label_list, key = last_9chars) 25 | #image_list2txt 26 | ImageLits2txt(image_list = image_list, txt_file_name = 'images', 27 | folder_name = 'images') 28 | ImageLits2txt(image_list = label_list, txt_file_name = 'labels', 29 | folder_name = 'labels') 30 | print('succuss!') 31 | 32 | def check_txt_file(split, Data_dir): 33 | with open(split) as f: 34 | content = f.readlines() 35 | for line in content: 36 | print(line) 37 | 38 | if __name__ == '__main__': 39 | # Dataset_dir = '/home/liuh/Desktop/datasets/data' 40 | # txt_file_dir = os.path.join('./', 'split', 'Data_set') 41 | # Dataset_dir, image_txt_name, label_txt_name 42 | # Make_Dataset_txt(Dataset_dir = Dataset_dir) 43 | 44 | # check_txt_file(split) -------------------------------------------------------------------------------- /DataReader/__pycache__/reader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiuHao-THU/3D_Segmentation_Pytorch/2736b7bfb4f5cdfe69bb5102a7a7bb076e628033/DataReader/__pycache__/reader.cpython-36.pyc -------------------------------------------------------------------------------- /DataReader/__pycache__/sampler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiuHao-THU/3D_Segmentation_Pytorch/2736b7bfb4f5cdfe69bb5102a7a7bb076e628033/DataReader/__pycache__/sampler.cpython-36.pyc -------------------------------------------------------------------------------- /DataReader/__pycache__/transform.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiuHao-THU/3D_Segmentation_Pytorch/2736b7bfb4f5cdfe69bb5102a7a7bb076e628033/DataReader/__pycache__/transform.cpython-36.pyc -------------------------------------------------------------------------------- /DataReader/beautiful_swan_2-t2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiuHao-THU/3D_Segmentation_Pytorch/2736b7bfb4f5cdfe69bb5102a7a7bb076e628033/DataReader/beautiful_swan_2-t2.jpg -------------------------------------------------------------------------------- /DataReader/reader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from time import time as timer 3 | from skimage import io 4 | import numpy as np 5 | def read_names_from_split(split_dir): 6 | """read split txt file from source""" 7 | file_dir_list = [] 8 | with open(split_dir) as f: 9 | content = f.readlines() 10 | for line in content: 11 | #save the data as list 12 | file_dir_list.append(line[0:-1]) 13 | return file_dir_list 14 | 15 | class ScienceDataset(): 16 | 17 | def __init__(self, split, Data_dir, transform=None, mode='train'): 18 | super(ScienceDataset, self).__init__() 19 | start = timer() 20 | self.transform = transform 21 | self.mode = mode 22 | self.Data_dir = Data_dir 23 | split_dir = os.path.join(Data_dir, 'split', split) 24 | # get image_list 25 | print(split_dir) 26 | print('this si split_dir') 27 | ids = read_names_from_split(split_dir) 28 | # save 29 | self.ids = ids 30 | print(self.ids) 31 | # print 32 | print('\ttime = %0.2f min' % ((timer() - start) / 60)) 33 | print('\tnum_ids_images = %d' % (len(self.ids))) 34 | print('') 35 | 36 | def __getitem__(self, index): 37 | # read image 38 | id_image = self.ids[index] 39 | image_folder, image_name = id_image.split('/') 40 | image_dir = os.path.join(self.Data_dir, 41 | 'images', 42 | 'TH_' + image_name) 43 | if self.mode in ['train']: 44 | # read label 45 | label_dir = os.path.join(self.Data_dir, 46 | 'labels', 47 | image_name) 48 | # load images and labels with augmentation 49 | image = io.imread(image_dir).astype(np.float32) 50 | label = io.imread(label_dir).astype(np.int32) 51 | if self.transform is not None: 52 | return self.transform(image, label, index) 53 | else: 54 | input = image 55 | return input, label, index 56 | 57 | if self.mode in ['test']: 58 | #load images and labels no need transform 59 | if self.transform is not None: 60 | image = io.imread(image_dir).astype(np.float32) 61 | return self.transform(image, index) 62 | else: 63 | return image, index 64 | 65 | def __len__(self): 66 | return len(self.ids) 67 | 68 | 69 | def chech_read_names_from_split(split_dir): 70 | txt_file = \ 71 | read_names_from_split(split_dir = split_dir) 72 | print(txt_file) 73 | 74 | 75 | # if __name__ == '__main__': 76 | # split_dir = '/home/liuh/Documents/3D_pytorch/build/DataReader/split/labels' 77 | # chech_read_names_from_split(split_dir) -------------------------------------------------------------------------------- /DataReader/sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch._six import int_classes as _int_classes 3 | 4 | 5 | class Sampler(object): 6 | """Base class for all Samplers. 7 | 8 | Every Sampler subclass has to provide an __iter__ method, providing a way 9 | to iterate over indices of dataset elements, and a __len__ method that 10 | returns the length of the returned iterators. 11 | """ 12 | 13 | def __init__(self, data_source): 14 | pass 15 | 16 | def __iter__(self): 17 | raise NotImplementedError 18 | 19 | def __len__(self): 20 | raise NotImplementedError 21 | 22 | 23 | 24 | class SequentialSampler(Sampler): 25 | """Samples elements sequentially, always in the same order. 26 | 27 | Arguments: 28 | data_source (Dataset): dataset to sample from 29 | """ 30 | 31 | def __init__(self, data_source): 32 | self.data_source = data_source 33 | 34 | def __iter__(self): 35 | return iter(range(len(self.data_source))) 36 | 37 | def __len__(self): 38 | return len(self.data_source) 39 | 40 | 41 | 42 | class RandomSampler(Sampler): 43 | """Samples elements randomly, without replacement. 44 | 45 | Arguments: 46 | data_source (Dataset): dataset to sample from 47 | """ 48 | 49 | def __init__(self, data_source): 50 | self.data_source = data_source 51 | 52 | def __iter__(self): 53 | return iter(torch.randperm(len(self.data_source)).tolist()) 54 | 55 | def __len__(self): 56 | return len(self.data_source) 57 | 58 | 59 | 60 | class SubsetRandomSampler(Sampler): 61 | """Samples elements randomly from a given list of indices, without replacement. 62 | 63 | Arguments: 64 | indices (list): a list of indices 65 | """ 66 | 67 | def __init__(self, indices): 68 | self.indices = indices 69 | 70 | def __iter__(self): 71 | return (self.indices[i] for i in torch.randperm(len(self.indices))) 72 | 73 | def __len__(self): 74 | return len(self.indices) 75 | 76 | 77 | 78 | class WeightedRandomSampler(Sampler): 79 | """Samples elements from [0,..,len(weights)-1] with given probabilities (weights). 80 | 81 | Arguments: 82 | weights (list) : a list of weights, not necessary summing up to one 83 | num_samples (int): number of samples to draw 84 | replacement (bool): if ``True``, samples are drawn with replacement. 85 | If not, they are drawn without replacement, which means that when a 86 | sample index is drawn for a row, it cannot be drawn again for that row. 87 | """ 88 | 89 | def __init__(self, weights, num_samples, replacement=True): 90 | if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or \ 91 | num_samples <= 0: 92 | raise ValueError("num_samples should be a positive integeral " 93 | "value, but got num_samples={}".format(num_samples)) 94 | if not isinstance(replacement, bool): 95 | raise ValueError("replacement should be a boolean value, but got " 96 | "replacement={}".format(replacement)) 97 | self.weights = torch.tensor(weights, dtype=torch.double) 98 | self.num_samples = num_samples 99 | self.replacement = replacement 100 | 101 | def __iter__(self): 102 | return iter(torch.multinomial(self.weights, self.num_samples, self.replacement)) 103 | 104 | def __len__(self): 105 | return self.num_samples 106 | 107 | 108 | 109 | class BatchSampler(object): 110 | """Wraps another sampler to yield a mini-batch of indices. 111 | 112 | Args: 113 | sampler (Sampler): Base sampler. 114 | batch_size (int): Size of mini-batch. 115 | drop_last (bool): If ``True``, the sampler will drop the last batch if 116 | its size would be less than ``batch_size`` 117 | 118 | Example: 119 | >>> list(BatchSampler(range(10), batch_size=3, drop_last=False)) 120 | [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] 121 | >>> list(BatchSampler(range(10), batch_size=3, drop_last=True)) 122 | [[0, 1, 2], [3, 4, 5], [6, 7, 8]] 123 | """ 124 | 125 | def __init__(self, sampler, batch_size, drop_last): 126 | if not isinstance(sampler, Sampler): 127 | raise ValueError("sampler should be an instance of " 128 | "torch.utils.data.Sampler, but got sampler={}" 129 | .format(sampler)) 130 | if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \ 131 | batch_size <= 0: 132 | raise ValueError("batch_size should be a positive integeral value, " 133 | "but got batch_size={}".format(batch_size)) 134 | if not isinstance(drop_last, bool): 135 | raise ValueError("drop_last should be a boolean value, but got " 136 | "drop_last={}".format(drop_last)) 137 | self.sampler = sampler 138 | self.batch_size = batch_size 139 | self.drop_last = drop_last 140 | 141 | def __iter__(self): 142 | batch = [] 143 | for idx in self.sampler: 144 | batch.append(int(idx)) 145 | if len(batch) == self.batch_size: 146 | yield batch 147 | batch = [] 148 | if len(batch) > 0 and not self.drop_last: 149 | yield batch 150 | 151 | def __len__(self): 152 | if self.drop_last: 153 | return len(self.sampler) // self.batch_size 154 | else: 155 | return (len(self.sampler) + self.batch_size - 1) // self.batch_size 156 | -------------------------------------------------------------------------------- /DataReader/split/.Data_set.swp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiuHao-THU/3D_Segmentation_Pytorch/2736b7bfb4f5cdfe69bb5102a7a7bb076e628033/DataReader/split/.Data_set.swp -------------------------------------------------------------------------------- /DataReader/split/Data_set: -------------------------------------------------------------------------------- 1 | label/50191 2 | label/TH_50191 3 | label/TH_50192 4 | label/50192 5 | label/50193 6 | label/TH_50193 7 | label/50194 8 | label/TH_50194 9 | label/50195 10 | label/TH_50195 11 | label/TH_50196 12 | label/50196 13 | label/50197 14 | label/TH_50197 15 | label/TH_50198 16 | label/50198 17 | label/TH_50199 18 | label/50199 19 | label/50201 20 | label/TH_50201 21 | label/50203 22 | label/TH_50203 23 | label/TH_50204 24 | label/50204 25 | label/TH_50205 26 | label/50205 27 | label/50206 28 | label/TH_50206 29 | label/TH_50207 30 | label/50207 31 | label/TH_50208 32 | label/50208 33 | label/TH_50209 34 | label/50209 35 | label/50212 36 | label/TH_50212 37 | label/TH_50213 38 | label/50213 39 | label/TH_50214 40 | label/50214 41 | label/50216 42 | label/TH_50216 43 | label/TH_50218 44 | label/50218 45 | label/50219 46 | label/TH_50219 47 | label/TH_50220 48 | label/50220 49 | label/50221 50 | label/TH_50221 51 | label/50222 52 | label/TH_50222 53 | label/TH_50223 54 | label/50223 55 | label/50224 56 | label/TH_50224 57 | label/TH_50227 58 | label/50227 59 | label/50229 60 | label/TH_50229 61 | label/50230 62 | label/TH_50230 63 | label/50231 64 | label/TH_50231 65 | label/50232 66 | label/TH_50232 67 | label/TH_50233 68 | label/50233 69 | label/TH_50234 70 | label/50234 71 | label/50235 72 | label/TH_50235 73 | label/TH_50236 74 | label/50236 75 | label/TH_50237 76 | label/50237 77 | label/50239 78 | label/TH_50239 79 | label/TH_50240 80 | label/50240 81 | label/TH_50241 82 | label/50241 83 | label/TH_50242 84 | label/50242 85 | label/TH_50243 86 | label/50243 87 | label/TH_50244 88 | label/50244 89 | label/TH_50245 90 | label/50245 91 | label/50246 92 | label/TH_50246 93 | label/50247 94 | label/TH_50247 95 | label/50248 96 | label/TH_50248 97 | label/50251 98 | label/TH_50251 99 | label/50252 100 | label/TH_50252 101 | label/50255 102 | label/TH_50255 103 | label/TH_50263 104 | label/50263 105 | label/TH_50265 106 | label/50265 107 | label/50266 108 | label/TH_50266 109 | label/50267 110 | label/TH_50267 111 | label/TH_50268 112 | label/50268 113 | label/TH_50271 114 | label/50271 115 | label/50272 116 | label/TH_50272 117 | label/50274 118 | label/TH_50274 119 | label/TH_50275 120 | label/50275 121 | label/TH_50276 122 | label/50276 123 | label/50278 124 | label/TH_50278 125 | label/50279 126 | label/TH_50279 127 | label/TH_50281 128 | label/50281 129 | label/TH_50285 130 | label/50285 131 | label/TH_50286 132 | label/50286 133 | label/TH_50287 134 | label/50287 135 | label/TH_50288 136 | label/50288 137 | label/TH_50289 138 | label/50289 139 | label/50290 140 | label/TH_50290 141 | label/50291 142 | label/TH_50291 143 | label/50292 144 | label/TH_50292 145 | label/TH_50293 146 | label/50293 147 | label/50294 148 | label/TH_50294 149 | label/TH_50295 150 | label/50295 151 | label/50296 152 | label/TH_50296 153 | label/TH_50297 154 | label/50297 155 | label/TH_50298 156 | label/50298 157 | label/50303 158 | label/TH_50303 159 | label/50304 160 | label/TH_50304 161 | label/TH_50305 162 | label/50305 163 | label/50306 164 | label/TH_50306 165 | label/TH_50307 166 | label/50307 167 | label/50310 168 | label/TH_50310 169 | label/50311 170 | label/TH_50311 171 | label/TH_50312 172 | label/50312 173 | label/TH_50313 174 | label/50313 175 | label/50314 176 | label/TH_50314 177 | label/50317 178 | label/TH_50317 179 | label/TH_50318 180 | label/50318 181 | label/TH_50320 182 | label/50320 183 | label/50321 184 | label/TH_50321 185 | label/TH_50322 186 | label/50322 187 | label/TH_50323 188 | label/50323 189 | label/TH_50324 190 | label/50324 191 | label/50325 192 | label/TH_50325 193 | label/TH_50326 194 | label/50326 195 | label/50327 196 | label/TH_50327 197 | label/50400 198 | label/TH_50400 199 | -------------------------------------------------------------------------------- /DataReader/split/images: -------------------------------------------------------------------------------- 1 | images/TH_50191.tif 2 | images/TH_50192.tif 3 | images/TH_50193.tif 4 | images/TH_50194.tif 5 | images/TH_50195.tif 6 | images/TH_50196.tif 7 | images/TH_50197.tif 8 | images/TH_50198.tif 9 | images/TH_50199.tif 10 | images/TH_50201.tif 11 | images/TH_50203.tif 12 | images/TH_50204.tif 13 | images/TH_50205.tif 14 | images/TH_50206.tif 15 | images/TH_50207.tif 16 | images/TH_50208.tif 17 | images/TH_50209.tif 18 | images/TH_50212.tif 19 | images/TH_50213.tif 20 | images/TH_50214.tif 21 | images/TH_50216.tif 22 | images/TH_50218.tif 23 | images/TH_50219.tif 24 | images/TH_50220.tif 25 | images/TH_50221.tif 26 | images/TH_50222.tif 27 | images/TH_50223.tif 28 | images/TH_50224.tif 29 | images/TH_50227.tif 30 | images/TH_50229.tif 31 | images/TH_50230.tif 32 | images/TH_50231.tif 33 | images/TH_50232.tif 34 | images/TH_50233.tif 35 | images/TH_50234.tif 36 | images/TH_50235.tif 37 | images/TH_50236.tif 38 | images/TH_50237.tif 39 | images/TH_50239.tif 40 | images/TH_50240.tif 41 | images/TH_50241.tif 42 | images/TH_50242.tif 43 | images/TH_50243.tif 44 | images/TH_50244.tif 45 | images/TH_50245.tif 46 | images/TH_50246.tif 47 | images/TH_50247.tif 48 | images/TH_50248.tif 49 | images/TH_50251.tif 50 | images/TH_50252.tif 51 | images/TH_50255.tif 52 | images/TH_50263.tif 53 | images/TH_50265.tif 54 | images/TH_50266.tif 55 | images/TH_50267.tif 56 | images/TH_50268.tif 57 | images/TH_50271.tif 58 | images/TH_50272.tif 59 | images/TH_50274.tif 60 | images/TH_50275.tif 61 | images/TH_50276.tif 62 | images/TH_50278.tif 63 | images/TH_50279.tif 64 | images/TH_50281.tif 65 | images/TH_50285.tif 66 | images/TH_50286.tif 67 | images/TH_50287.tif 68 | images/TH_50288.tif 69 | images/TH_50289.tif 70 | images/TH_50290.tif 71 | images/TH_50291.tif 72 | images/TH_50292.tif 73 | images/TH_50293.tif 74 | images/TH_50294.tif 75 | images/TH_50295.tif 76 | images/TH_50296.tif 77 | images/TH_50297.tif 78 | images/TH_50298.tif 79 | images/TH_50303.tif 80 | images/TH_50304.tif 81 | images/TH_50305.tif 82 | images/TH_50306.tif 83 | images/TH_50307.tif 84 | images/TH_50310.tif 85 | images/TH_50311.tif 86 | images/TH_50312.tif 87 | images/TH_50313.tif 88 | images/TH_50314.tif 89 | images/TH_50317.tif 90 | images/TH_50318.tif 91 | images/TH_50320.tif 92 | images/TH_50321.tif 93 | images/TH_50322.tif 94 | images/TH_50323.tif 95 | images/TH_50324.tif 96 | images/TH_50325.tif 97 | images/TH_50326.tif 98 | images/TH_50327.tif 99 | images/TH_50400.tif 100 | -------------------------------------------------------------------------------- /DataReader/split/labels: -------------------------------------------------------------------------------- 1 | labels/50191.tif 2 | labels/50192.tif 3 | labels/50193.tif 4 | labels/50194.tif 5 | labels/50195.tif 6 | labels/50196.tif 7 | labels/50197.tif 8 | labels/50198.tif 9 | labels/50199.tif 10 | labels/50201.tif 11 | labels/50203.tif 12 | labels/50204.tif 13 | labels/50205.tif 14 | labels/50206.tif 15 | labels/50207.tif 16 | labels/50208.tif 17 | labels/50209.tif 18 | labels/50212.tif 19 | labels/50213.tif 20 | labels/50214.tif 21 | labels/50216.tif 22 | labels/50218.tif 23 | labels/50219.tif 24 | labels/50220.tif 25 | labels/50221.tif 26 | labels/50222.tif 27 | labels/50223.tif 28 | labels/50224.tif 29 | labels/50227.tif 30 | labels/50229.tif 31 | labels/50230.tif 32 | labels/50231.tif 33 | labels/50232.tif 34 | labels/50233.tif 35 | labels/50234.tif 36 | labels/50235.tif 37 | labels/50236.tif 38 | labels/50237.tif 39 | labels/50239.tif 40 | labels/50240.tif 41 | labels/50241.tif 42 | labels/50242.tif 43 | labels/50243.tif 44 | labels/50244.tif 45 | labels/50245.tif 46 | labels/50246.tif 47 | labels/50247.tif 48 | labels/50248.tif 49 | labels/50251.tif 50 | labels/50252.tif 51 | labels/50255.tif 52 | labels/50263.tif 53 | labels/50265.tif 54 | labels/50266.tif 55 | labels/50267.tif 56 | labels/50268.tif 57 | labels/50271.tif 58 | labels/50272.tif 59 | labels/50274.tif 60 | labels/50275.tif 61 | labels/50276.tif 62 | labels/50278.tif 63 | labels/50279.tif 64 | labels/50281.tif 65 | labels/50285.tif 66 | labels/50286.tif 67 | labels/50287.tif 68 | labels/50288.tif 69 | labels/50289.tif 70 | labels/50290.tif 71 | labels/50291.tif 72 | labels/50292.tif 73 | labels/50293.tif 74 | labels/50294.tif 75 | labels/50295.tif 76 | labels/50296.tif 77 | labels/50297.tif 78 | labels/50298.tif 79 | labels/50303.tif 80 | labels/50304.tif 81 | labels/50305.tif 82 | labels/50306.tif 83 | labels/50307.tif 84 | labels/50310.tif 85 | labels/50311.tif 86 | labels/50312.tif 87 | labels/50313.tif 88 | labels/50314.tif 89 | labels/50317.tif 90 | labels/50318.tif 91 | labels/50320.tif 92 | labels/50321.tif 93 | labels/50322.tif 94 | labels/50323.tif 95 | labels/50324.tif 96 | labels/50325.tif 97 | labels/50326.tif 98 | labels/50327.tif 99 | labels/50400.tif 100 | -------------------------------------------------------------------------------- /DataReader/transform.py: -------------------------------------------------------------------------------- 1 | #basic operation flip transpose random crop 2 | import random 3 | import numpy as np 4 | from scipy.ndimage.interpolation import rotate 5 | 6 | def random_boolean(): 7 | random_bool = np.random.choice([True, False]) 8 | return random_bool 9 | 10 | def random_flip_dimensions(n_dimensions): 11 | axis = list() 12 | for dim in range(n_dimensions): 13 | if random_boolean(): 14 | axis.append(dim) 15 | return axis 16 | 17 | def Flip_image(image): 18 | n_dim = len(image.shape) 19 | flip_axis = random_flip_dimensions(n_dim) 20 | try: 21 | new_data = np.copy(image) 22 | for axis_index in flip_axis: 23 | new_data = np.flip(new_data, axis=axis_index) 24 | 25 | except TypeError: 26 | new_data = np.flip(image, axis=flip_axis) 27 | 28 | return new_data 29 | 30 | def Random_crop(image, label, crop_size): 31 | """generate random croped cube according to 32 | crop_size = [crop_size_x,...crop_size_z]""" 33 | w,h,z = image.shape 34 | 35 | random_x_l = w - crop_size[0] 36 | random_y_l = h - crop_size[1] 37 | random_z_l = z - crop_size[2] 38 | #generate random nums 39 | random_x = random.randint(0, random_x_l - 1) 40 | random_y = random.randint(0, random_y_l - 1) 41 | random_z = random.randint(0, random_z_l - 1) 42 | croped_image = image[random_x: random_x + crop_size[0], 43 | random_y: random_y + crop_size[1], 44 | random_z: random_z + crop_size[2]] 45 | croped_label = label[random_x: random_x + crop_size[0], 46 | random_y: random_y + crop_size[1], 47 | random_z: random_z + crop_size[2]] 48 | return croped_image, croped_label -------------------------------------------------------------------------------- /DataReader/wget-log: -------------------------------------------------------------------------------- 1 | --2018-04-13 00:53:44-- http://hd.wallpaperswide.com/thumbs/beautiful_swan_2-t2.jpg 2 | Resolving hd.wallpaperswide.com (hd.wallpaperswide.com)... 88.198.175.3 3 | Connecting to hd.wallpaperswide.com (hd.wallpaperswide.com)|88.198.175.3|:80... connected. 4 | HTTP request sent, awaiting response... Read error (Connection reset by peer) in headers. 5 | Retrying. 6 | 7 | -------------------------------------------------------------------------------- /Model/UnetGenerator_3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def conv_block_3d(in_dim,out_dim,act_fn): 5 | model = nn.Sequential( 6 | nn.Conv3d(in_dim,out_dim, kernel_size=3, stride=1, padding=1), 7 | nn.BatchNorm3d(out_dim), 8 | act_fn, 9 | ) 10 | return model 11 | 12 | def conv_trans_block_3d(in_dim,out_dim,act_fn): 13 | model = nn.Sequential( 14 | nn.ConvTranspose3d(in_dim,out_dim, kernel_size=3, stride=2, padding=1,output_padding=1), 15 | nn.BatchNorm3d(out_dim), 16 | act_fn, 17 | ) 18 | return model 19 | 20 | def maxpool_3d(): 21 | pool = nn.MaxPool3d(kernel_size=2, stride=2, padding=0) 22 | return pool 23 | 24 | def conv_block_2_3d(in_dim,out_dim,act_fn): 25 | model = nn.Sequential( 26 | conv_block_3d(in_dim,out_dim,act_fn), 27 | nn.Conv3d(out_dim,out_dim, kernel_size=3, stride=1, padding=1), 28 | nn.BatchNorm3d(out_dim), 29 | ) 30 | return model 31 | 32 | def conv_block_3_3d(in_dim,out_dim,act_fn): 33 | model = nn.Sequential( 34 | conv_block_3d(in_dim,out_dim,act_fn), 35 | conv_block_3d(out_dim,out_dim,act_fn), 36 | nn.Conv3d(out_dim,out_dim, kernel_size=3, stride=1, padding=1), 37 | nn.BatchNorm3d(out_dim), 38 | ) 39 | return model 40 | 41 | class UnetGenerator_3d(nn.Module): 42 | 43 | def __init__(self,in_dim,out_dim,num_filter): 44 | super(UnetGenerator_3d,self).__init__() 45 | self.in_dim = in_dim 46 | self.out_dim = out_dim 47 | self.num_filter = num_filter 48 | act_fn = nn.LeakyReLU(0.2, inplace=True) 49 | 50 | print("\n------Initiating U-Net------\n") 51 | 52 | self.down_1 = conv_block_2_3d(self.in_dim,self.num_filter,act_fn) 53 | self.pool_1 = maxpool_3d() 54 | self.down_2 = conv_block_2_3d(self.num_filter,self.num_filter*2,act_fn) 55 | self.pool_2 = maxpool_3d() 56 | self.down_3 = conv_block_2_3d(self.num_filter*2,self.num_filter*4,act_fn) 57 | self.pool_3 = maxpool_3d() 58 | 59 | self.bridge = conv_block_2_3d(self.num_filter*4,self.num_filter*8,act_fn) 60 | 61 | self.trans_1 = conv_trans_block_3d(self.num_filter*8,self.num_filter*8,act_fn) 62 | self.up_1 = conv_block_2_3d(self.num_filter*12,self.num_filter*4,act_fn) 63 | self.trans_2 = conv_trans_block_3d(self.num_filter*4,self.num_filter*4,act_fn) 64 | self.up_2 = conv_block_2_3d(self.num_filter*6,self.num_filter*2,act_fn) 65 | self.trans_3 = conv_trans_block_3d(self.num_filter*2,self.num_filter*2,act_fn) 66 | self.up_3 = conv_block_2_3d(self.num_filter*3,self.num_filter*1,act_fn) 67 | 68 | self.out = conv_block_3d(self.num_filter,out_dim,act_fn) 69 | 70 | 71 | def forward(self,x): 72 | down_1 = self.down_1(x) 73 | pool_1 = self.pool_1(down_1) 74 | down_2 = self.down_2(pool_1) 75 | pool_2 = self.pool_2(down_2) 76 | down_3 = self.down_3(pool_2) 77 | pool_3 = self.pool_3(down_3) 78 | 79 | bridge = self.bridge(pool_3) 80 | 81 | trans_1 = self.trans_1(bridge) 82 | concat_1 = torch.cat([trans_1,down_3],dim=1) 83 | up_1 = self.up_1(concat_1) 84 | trans_2 = self.trans_2(up_1) 85 | concat_2 = torch.cat([trans_2,down_2],dim=1) 86 | up_2 = self.up_2(concat_2) 87 | trans_3 = self.trans_3(up_2) 88 | concat_3 = torch.cat([trans_3,down_1],dim=1) 89 | up_3 = self.up_3(concat_3) 90 | 91 | out = self.out(up_3) 92 | 93 | return out -------------------------------------------------------------------------------- /Model/Unet_Zoo.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | class UNet(nn.Module): 5 | def __init__(self, num_channels=1, num_classes=2): 6 | super(UNet, self).__init__() 7 | num_feat = [64, 128, 256, 512, 1024] 8 | 9 | self.down1 = nn.Sequential(Conv3x3(num_channels, num_feat[0])) 10 | 11 | self.down2 = nn.Sequential(nn.MaxPool2d(kernel_size=2), 12 | Conv3x3(num_feat[0], num_feat[1])) 13 | 14 | self.down3 = nn.Sequential(nn.MaxPool2d(kernel_size=2), 15 | Conv3x3(num_feat[1], num_feat[2])) 16 | 17 | self.down4 = nn.Sequential(nn.MaxPool2d(kernel_size=2), 18 | Conv3x3(num_feat[2], num_feat[3])) 19 | 20 | self.bottom = nn.Sequential(nn.MaxPool2d(kernel_size=2), 21 | Conv3x3(num_feat[3], num_feat[4])) 22 | 23 | self.up1 = UpConcat(num_feat[4], num_feat[3]) 24 | self.upconv1 = Conv3x3(num_feat[4], num_feat[3]) 25 | 26 | self.up2 = UpConcat(num_feat[3], num_feat[2]) 27 | self.upconv2 = Conv3x3(num_feat[3], num_feat[2]) 28 | 29 | self.up3 = UpConcat(num_feat[2], num_feat[1]) 30 | self.upconv3 = Conv3x3(num_feat[2], num_feat[1]) 31 | 32 | self.up4 = UpConcat(num_feat[1], num_feat[0]) 33 | self.upconv4 = Conv3x3(num_feat[1], num_feat[0]) 34 | 35 | self.final = nn.Sequential(nn.Conv2d(num_feat[0], 36 | num_classes, 37 | kernel_size=1), 38 | nn.Softmax2d()) 39 | 40 | def forward(self, inputs, return_features=False): 41 | print(inputs) 42 | print('this is inputs') 43 | # print(inputs.data.size()) 44 | down1_feat = self.down1(inputs) 45 | # print(down1_feat.size()) 46 | down2_feat = self.down2(down1_feat) 47 | # print(down2_feat.size()) 48 | down3_feat = self.down3(down2_feat) 49 | # print(down3_feat.size()) 50 | down4_feat = self.down4(down3_feat) 51 | # print(down4_feat.size()) 52 | bottom_feat = self.bottom(down4_feat) 53 | 54 | # print(bottom_feat.size()) 55 | up1_feat = self.up1(bottom_feat, down4_feat) 56 | # print(up1_feat.size()) 57 | up1_feat = self.upconv1(up1_feat) 58 | # print(up1_feat.size()) 59 | up2_feat = self.up2(up1_feat, down3_feat) 60 | # print(up2_feat.size()) 61 | up2_feat = self.upconv2(up2_feat) 62 | # print(up2_feat.size()) 63 | up3_feat = self.up3(up2_feat, down2_feat) 64 | # print(up3_feat.size()) 65 | up3_feat = self.upconv3(up3_feat) 66 | # print(up3_feat.size()) 67 | up4_feat = self.up4(up3_feat, down1_feat) 68 | # print(up4_feat.size()) 69 | up4_feat = self.upconv4(up4_feat) 70 | # print(up4_feat.size()) 71 | 72 | if return_features: 73 | outputs = up4_feat 74 | else: 75 | outputs = self.final(up4_feat) 76 | 77 | return outputs 78 | 79 | 80 | class UNetSmall(nn.Module): 81 | def __init__(self, num_channels=1, num_classes=2): 82 | super(UNetSmall, self).__init__() 83 | num_feat = [32, 64, 128, 256] 84 | 85 | self.down1 = nn.Sequential(Conv3x3Small(num_channels, num_feat[0])) 86 | 87 | self.down2 = nn.Sequential(nn.MaxPool2d(kernel_size=2), 88 | nn.BatchNorm2d(num_feat[0]), 89 | Conv3x3Small(num_feat[0], num_feat[1])) 90 | 91 | self.down3 = nn.Sequential(nn.MaxPool2d(kernel_size=2), 92 | nn.BatchNorm2d(num_feat[1]), 93 | Conv3x3Small(num_feat[1], num_feat[2])) 94 | 95 | self.bottom = nn.Sequential(nn.MaxPool2d(kernel_size=2), 96 | nn.BatchNorm2d(num_feat[2]), 97 | Conv3x3Small(num_feat[2], num_feat[3]), 98 | nn.BatchNorm2d(num_feat[3])) 99 | 100 | self.up1 = UpSample(num_feat[3], num_feat[2]) 101 | self.upconv1 = nn.Sequential(Conv3x3Small(num_feat[3] + num_feat[2], num_feat[2]), 102 | nn.BatchNorm2d(num_feat[2])) 103 | 104 | self.up2 = UpSample(num_feat[2], num_feat[1]) 105 | self.upconv2 = nn.Sequential(Conv3x3Small(num_feat[2] + num_feat[1], num_feat[1]), 106 | nn.BatchNorm2d(num_feat[1])) 107 | 108 | self.up3 = UpSample(num_feat[1], num_feat[0]) 109 | self.upconv3 = nn.Sequential(Conv3x3Small(num_feat[1] + num_feat[0], num_feat[0]), 110 | nn.BatchNorm2d(num_feat[0])) 111 | 112 | self.final = nn.Sequential(nn.Conv2d(num_feat[0], 113 | 1, 114 | kernel_size=1), 115 | nn.Sigmoid()) 116 | 117 | def forward(self, inputs, return_features=False): 118 | # print(inputs.data.size()) 119 | down1_feat = self.down1(inputs) 120 | # print(down1_feat.size()) 121 | down2_feat = self.down2(down1_feat) 122 | # print(down2_feat.size()) 123 | down3_feat = self.down3(down2_feat) 124 | # print(down3_feat.size()) 125 | bottom_feat = self.bottom(down3_feat) 126 | 127 | # print(bottom_feat.size()) 128 | up1_feat = self.up1(bottom_feat, down3_feat) 129 | # print(up1_feat.size()) 130 | up1_feat = self.upconv1(up1_feat) 131 | # print(up1_feat.size()) 132 | up2_feat = self.up2(up1_feat, down2_feat) 133 | # print(up2_feat.size()) 134 | up2_feat = self.upconv2(up2_feat) 135 | # print(up2_feat.size()) 136 | up3_feat = self.up3(up2_feat, down1_feat) 137 | # print(up3_feat.size()) 138 | up3_feat = self.upconv3(up3_feat) 139 | # print(up3_feat.size()) 140 | 141 | if return_features: 142 | outputs = up3_feat 143 | else: 144 | outputs = self.final(up3_feat) 145 | 146 | return outputs 147 | 148 | 149 | class Conv3x3(nn.Module): 150 | def __init__(self, in_feat, out_feat): 151 | super(Conv3x3, self).__init__() 152 | 153 | self.conv1 = nn.Sequential(nn.Conv2d(in_feat, out_feat, 154 | kernel_size=3, 155 | stride=1, 156 | padding=1), 157 | nn.BatchNorm2d(out_feat), 158 | nn.ReLU()) 159 | 160 | self.conv2 = nn.Sequential(nn.Conv2d(out_feat, out_feat, 161 | kernel_size=3, 162 | stride=1, 163 | padding=1), 164 | nn.BatchNorm2d(out_feat), 165 | nn.ReLU()) 166 | 167 | def forward(self, inputs): 168 | outputs = self.conv1(inputs) 169 | outputs = self.conv2(outputs) 170 | return outputs 171 | 172 | 173 | class Conv3x3Drop(nn.Module): 174 | def __init__(self, in_feat, out_feat): 175 | super(Conv3x3Drop, self).__init__() 176 | 177 | self.conv1 = nn.Sequential(nn.Conv2d(in_feat, out_feat, 178 | kernel_size=3, 179 | stride=1, 180 | padding=1), 181 | nn.Dropout(p=0.2), 182 | nn.ReLU()) 183 | 184 | self.conv2 = nn.Sequential(nn.Conv2d(out_feat, out_feat, 185 | kernel_size=3, 186 | stride=1, 187 | padding=1), 188 | nn.BatchNorm2d(out_feat), 189 | nn.ReLU()) 190 | 191 | def forward(self, inputs): 192 | outputs = self.conv1(inputs) 193 | outputs = self.conv2(outputs) 194 | return outputs 195 | 196 | 197 | class Conv3x3Small(nn.Module): 198 | def __init__(self, in_feat, out_feat): 199 | super(Conv3x3Small, self).__init__() 200 | 201 | self.conv1 = nn.Sequential(nn.Conv2d(in_feat, out_feat, 202 | kernel_size=3, 203 | stride=1, 204 | padding=1), 205 | nn.ELU(), 206 | nn.Dropout(p=0.2)) 207 | 208 | self.conv2 = nn.Sequential(nn.Conv2d(out_feat, out_feat, 209 | kernel_size=3, 210 | stride=1, 211 | padding=1), 212 | nn.ELU()) 213 | 214 | def forward(self, inputs): 215 | outputs = self.conv1(inputs) 216 | outputs = self.conv2(outputs) 217 | return outputs 218 | 219 | 220 | class UpConcat(nn.Module): 221 | def __init__(self, in_feat, out_feat): 222 | super(UpConcat, self).__init__() 223 | 224 | self.up = nn.UpsamplingBilinear2d(scale_factor=2) 225 | 226 | # self.deconv = nn.ConvTranspose2d(in_feat, out_feat, 227 | # kernel_size=3, 228 | # stride=1, 229 | # dilation=1) 230 | 231 | self.deconv = nn.ConvTranspose2d(in_feat, 232 | out_feat, 233 | kernel_size=2, 234 | stride=2) 235 | 236 | def forward(self, inputs, down_outputs): 237 | # TODO: Upsampling required after deconv? 238 | # outputs = self.up(inputs) 239 | outputs = self.deconv(inputs) 240 | out = torch.cat([down_outputs, outputs], 1) 241 | return out 242 | 243 | 244 | class UpSample(nn.Module): 245 | def __init__(self, in_feat, out_feat): 246 | super(UpSample, self).__init__() 247 | 248 | self.up = nn.Upsample(scale_factor=2, mode='nearest') 249 | 250 | self.deconv = nn.ConvTranspose2d(in_feat, 251 | out_feat, 252 | kernel_size=2, 253 | stride=2) 254 | 255 | def forward(self, inputs, down_outputs): 256 | # TODO: Upsampling required after deconv? 257 | outputs = self.up(inputs) 258 | # outputs = self.deconv(inputs) 259 | out = torch.cat([outputs, down_outputs], 1) 260 | return out -------------------------------------------------------------------------------- /Model/__pycache__/UnetGenerator_3d.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiuHao-THU/3D_Segmentation_Pytorch/2736b7bfb4f5cdfe69bb5102a7a7bb076e628033/Model/__pycache__/UnetGenerator_3d.cpython-36.pyc -------------------------------------------------------------------------------- /Model/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiuHao-THU/3D_Segmentation_Pytorch/2736b7bfb4f5cdfe69bb5102a7a7bb076e628033/Model/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /Model/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | class Net(nn.Module): 6 | def __init__(self, n_class = 2): 7 | super().__init__() 8 | # conv1 9 | self.conv1_1 = nn.Conv3d(1, 8, 3, padding=60) 10 | # torch.nn.Conv3d(in_channels, out_channels, kernel_size, 11 | # stride=1, padding=0, dilation=1, groups=1, bias=True) 12 | self.relu1_1 = nn.ReLU(inplace=True) 13 | self.conv1_2 = nn.Conv3d(8, 8, 3, padding=1) 14 | self.relu1_2 = nn.ReLU(inplace=True) 15 | self.pool1 = nn.MaxPool3d(2, stride=2, ceil_mode=True) # 1/2 16 | 17 | # conv2 18 | self.conv2_1 = nn.Conv3d(8, 16, 3, padding=15) 19 | self.relu2_1 = nn.ReLU(inplace=True) 20 | self.conv2_2 = nn.Conv3d(16, 16, 3, padding=1) 21 | self.relu2_2 = nn.ReLU(inplace=True) 22 | self.pool2 = nn.MaxPool3d(2, stride=2, ceil_mode=True) # 1/4 23 | 24 | # conv3 25 | self.conv3_1 = nn.Conv3d(16, 32, 3, padding=1) 26 | self.relu3_1 = nn.ReLU(inplace=True) 27 | self.conv3_2 = nn.Conv3d(32, 32, 3, padding=1) 28 | self.relu3_2 = nn.ReLU(inplace=True) 29 | self.conv3_3 = nn.Conv3d(32, 32, 3, padding=1) 30 | self.relu3_3 = nn.ReLU(inplace=True) 31 | self.pool3 = nn.MaxPool3d(2, stride=2, ceil_mode=True) # 1/8 32 | 33 | # conv4 34 | self.conv4_1 = nn.Conv3d(32, 64, 3, padding=1) 35 | self.relu4_1 = nn.ReLU(inplace=True) 36 | self.conv4_2 = nn.Conv3d(64, 64, 3, padding=1) 37 | self.relu4_2 = nn.ReLU(inplace=True) 38 | self.conv4_3 = nn.Conv3d(64, 64, 3, padding=1) 39 | self.relu4_3 = nn.ReLU(inplace=True) 40 | self.pool4 = nn.MaxPool3d(2, stride=2, ceil_mode=True) # 1/16 41 | 42 | # conv5 43 | self.conv5_1 = nn.Conv3d(64, 64, 3, padding=1) 44 | self.relu5_1 = nn.ReLU(inplace=True) 45 | self.conv5_2 = nn.Conv3d(64, 64, 3, padding=1) 46 | self.relu5_2 = nn.ReLU(inplace=True) 47 | self.conv5_3 = nn.Conv3d(64, 64, 3, padding=1) 48 | self.relu5_3 = nn.ReLU(inplace=True) 49 | self.pool5 = nn.MaxPool3d(2, stride=2, ceil_mode=True) # 1/32 50 | 51 | # fc6 52 | self.fc6 = nn.Conv3d(64, 512, 7) 53 | self.relu6 = nn.ReLU(inplace=True) 54 | self.drop6 = nn.Dropout3d() 55 | 56 | # fc7 57 | self.fc7 = nn.Conv3d(512, 512, 1) 58 | self.relu7 = nn.ReLU(inplace=True) 59 | self.drop7 = nn.Dropout3d() 60 | 61 | self.score_fr = nn.Conv3d(512, n_class, 1) 62 | self.score_pool3 = nn.Conv3d(32, n_class, 1) 63 | self.score_pool4 = nn.Conv3d(64, n_class, 1) 64 | 65 | self.upscore2 = nn.ConvTranspose3d( 66 | n_class, n_class, 4, stride=2, bias=False) 67 | self.upscore8 = nn.ConvTranspose3d( 68 | n_class, n_class, 16, stride=8, bias=False) 69 | self.upscore_pool4 = nn.ConvTranspose3d( 70 | n_class, n_class, 4, stride=2, bias=False) 71 | 72 | self._initialize_weights() 73 | 74 | 75 | def get_upsampling_weight(self, in_channels, out_channels, kernel_size): 76 | """Make a 2D bilinear kernel suitable for upsampling""" 77 | factor = (kernel_size + 1) // 2 78 | if kernel_size % 2 == 1: 79 | center = factor - 1 80 | else: 81 | center = factor - 0.5 82 | og = np.ogrid[:kernel_size, :kernel_size, :kernel_size] 83 | filt = (1 - abs(og[0] - center) / factor) * \ 84 | (1 - abs(og[1] - center) / factor) 85 | weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size, kernel_size), 86 | dtype=np.float64) 87 | weight[range(in_channels), range(out_channels), :, :, :] = filt 88 | return torch.from_numpy(weight).float() 89 | 90 | def _initialize_weights(self): 91 | for m in self.modules(): 92 | if isinstance(m, nn.Conv3d): 93 | m.weight.data.zero_() 94 | m.weight.data.normal_(0.0, 0.1) 95 | if m.bias is not None: 96 | m.bias.data.zero_() 97 | if isinstance(m, nn.ConvTranspose3d): 98 | assert m.kernel_size[0] == m.kernel_size[1] 99 | initial_weight = self.get_upsampling_weight( 100 | m.in_channels, m.out_channels, m.kernel_size[0]) 101 | m.weight.data.copy_(initial_weight) 102 | 103 | # def _initialize_weights(self): 104 | # for m in self.modules(): 105 | # if isinstance(m, nn.Conv3d): 106 | # m.weight.data.zero_() 107 | # m.weight.data.normal_(0.0, 0.1) 108 | # if m.bias is not None: 109 | # m.bias.data.zero_() 110 | # m.bias.data.normal_(0.0, 0.1) 111 | # if isinstance(m, nn.ConvTranspose3d): 112 | # m.weight.data.zero_() 113 | # m.weight.data.normal_(0.0, 0.1) 114 | # # assert m.kernel_size[0] == m.kernel_size[1] 115 | # # initial_weight = self.get_upsampling_weight( 116 | # # m.in_channels, m.out_channels, m.kernel_size[0]) 117 | # # m.weight.data.copy_(initial_weight) 118 | def copy_params_from_vgg16(self, vgg16): 119 | features = [ 120 | self.conv1_1, self.relu1_1, 121 | self.conv1_2, self.relu1_2, 122 | self.pool1, 123 | self.conv2_1, self.relu2_1, 124 | self.conv2_2, self.relu2_2, 125 | self.pool2, 126 | self.conv3_1, self.relu3_1, 127 | self.conv3_2, self.relu3_2, 128 | self.conv3_3, self.relu3_3, 129 | self.pool3, 130 | self.conv4_1, self.relu4_1, 131 | self.conv4_2, self.relu4_2, 132 | self.conv4_3, self.relu4_3, 133 | self.pool4, 134 | self.conv5_1, self.relu5_1, 135 | self.conv5_2, self.relu5_2, 136 | self.conv5_3, self.relu5_3, 137 | self.pool5, 138 | ] 139 | for l1, l2 in zip(vgg16.features, features): 140 | print("what is l1? ", l1) 141 | print("what is l2? ", l2) 142 | if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d): 143 | assert l1.weight.size() == l2.weight.size() 144 | assert l1.bias.size() == l2.bias.size() 145 | l2.weight.data.copy_(l1.weight.data) 146 | l2.bias.data.copy_(l1.bias.data) 147 | for i, name in zip([0, 3], ['fc6', 'fc7']): 148 | l1 = vgg16.classifier[i] 149 | l2 = getattr(self, name) 150 | l2.weight.data.copy_(l1.weight.data.view(l2.weight.size())) 151 | l2.bias.data.copy_(l1.bias.data.view(l2.bias.size())) 152 | 153 | def forward(self, x): 154 | h = x 155 | happyprint("init: ", x.data[0].shape) 156 | 157 | h = self.relu1_1(self.conv1_1(h)) 158 | happyprint("after conv1_1: ", h.data[0].shape) 159 | 160 | h = self.relu1_2(self.conv1_2(h)) 161 | h = self.pool1(h) 162 | 163 | happyprint("after pool1: ", h.data[0].shape) 164 | 165 | h = self.relu2_1(self.conv2_1(h)) 166 | h = self.relu2_2(self.conv2_2(h)) 167 | h = self.pool2(h) 168 | 169 | happyprint("after pool2: ", h.data[0].shape) 170 | 171 | h = self.relu3_1(self.conv3_1(h)) 172 | h = self.relu3_2(self.conv3_2(h)) 173 | h = self.relu3_3(self.conv3_3(h)) 174 | h = self.pool3(h) 175 | pool3 = h # 1/8 176 | 177 | happyprint("after pool3: ", h.data[0].shape) 178 | 179 | h = self.relu4_1(self.conv4_1(h)) 180 | h = self.relu4_2(self.conv4_2(h)) 181 | h = self.relu4_3(self.conv4_3(h)) 182 | h = self.pool4(h) 183 | pool4 = h # 1/16 184 | 185 | happyprint("after pool4: ", h.data[0].shape) 186 | 187 | h = self.relu5_1(self.conv5_1(h)) 188 | h = self.relu5_2(self.conv5_2(h)) 189 | h = self.relu5_3(self.conv5_3(h)) 190 | h = self.pool5(h) 191 | 192 | happyprint("after pool5: ", h.data[0].shape) 193 | 194 | h = self.relu6(self.fc6(h)) 195 | h = self.drop6(h) 196 | 197 | h = self.relu7(self.fc7(h)) 198 | h = self.drop7(h) 199 | 200 | h = self.score_fr(h) 201 | 202 | happyprint("after score_fr: ", h.data[0].shape) 203 | h = self.upscore2(h) 204 | 205 | happyprint("after upscore2: ", h.data[0].shape) 206 | upscore2 = h # 1/16 207 | 208 | h = self.score_pool4(pool4 * 0.01) # XXX: scaling to train at once 209 | happyprint("after score_pool4: ", h.data[0].shape) 210 | 211 | h = h[:, :, 5:5 + upscore2.size()[2], 5:5 + upscore2.size()[3], 5:5 + upscore2.size()[4]] 212 | 213 | score_pool4c = h # 1/16 214 | happyprint("after score_pool4c: ", h.data[0].shape) 215 | 216 | h = upscore2 + score_pool4c # 1/16 217 | h = self.upscore_pool4(h) 218 | upscore_pool4 = h # 1/8 219 | happyprint("after upscore_pool4: ", h.data[0].shape) 220 | 221 | h = self.score_pool3(pool3 * 0.0001) # XXX: scaling to train at once 222 | h = h[:, :, 223 | 9:9 + upscore_pool4.size()[2], 224 | 9:9 + upscore_pool4.size()[3], 225 | 9:9 + upscore_pool4.size()[4]] 226 | score_pool3c = h # 1/8 227 | happyprint("after score_pool3: ", h.data[0].shape) 228 | 229 | # print(upscore_pool4.data[0].shape) 230 | # print(score_pool3c.data[0].shape) 231 | 232 | # Adjusting stride in self.upscore2 and self.upscore_pool4 233 | # and self.conv1_1 can change the tensor shape (size). 234 | # I don't know why! 235 | 236 | h = upscore_pool4 + score_pool3c # 1/8 237 | 238 | h = self.upscore8(h) # dim: 88^3 239 | h = h[:, :, 31:31 + x.size()[2], 31:31 + x.size()[3], 31:31 + x.size()[4]].contiguous() 240 | happyprint("after upscore8: ", h.data[0].shape) 241 | return h 242 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 3D-Segmentation-Pytorch-version 2 | 3D Segmentation Pytorch version(unet basic model) 3 | 2nd version 4 | main.py for train 5 | test.py 6 | 7 | ops the newest version is in my lab computer.. 8 | -------------------------------------------------------------------------------- /check_function.py: -------------------------------------------------------------------------------- 1 | def check_train_augment(image, label, indices): 2 | """include random crop and random flip transpose""" 3 | # image = Flip_image(image) 4 | # label = Flip_image(label) 5 | image, label = Random_crop(image, label, config['crop_size']) 6 | return image, label, indices 7 | 8 | 9 | def check_valid_augment(image, label, indices): 10 | """include random crop and random flip transpose""" 11 | image, label = Random_crop(image, label, config['crop_size']) 12 | return image, label, indices 13 | 14 | 15 | def check_images_labels(images, labels): 16 | # catch images and labels 17 | # batch, channel, width, height, depth = images.cpu().numpy().shape 18 | plt.imshow(data[0, 0, :, :].cpu().numpy()) 19 | plt.title('data') 20 | plt.pause(0.3) 21 | plt.imshow(target[0, :, :].cpu().numpy()) 22 | plt.title('target') 23 | plt.pause(0.3) 24 | 25 | 26 | def check_dataloader(): 27 | """chech dataloader for dataset""" 28 | train_dataset = ScienceDataset( 29 | split='Train_DataSet', 30 | Data_dir=Data_Dir 31 | , 32 | transform=check_train_augment, 33 | mode='train') 34 | 35 | train_loader = DataLoader( 36 | train_dataset, 37 | sampler=None, 38 | shuffle=True, 39 | batch_size=config['train_batch_size'], 40 | drop_last=config['drop_last'], 41 | num_workers=config['num_workers'], 42 | pin_memory=config['pin_memory']) 43 | 44 | valid_dataset = ScienceDataset( 45 | split='Valid_DataSet', 46 | Data_dir=Data_Dir, 47 | transform=check_valid_augment, 48 | mode='train') 49 | 50 | valid_loader = DataLoader( 51 | valid_dataset, 52 | sampler=None, 53 | shuffle=True, 54 | batch_size=config['valid_batch_size'], 55 | drop_last=config['drop_last'], 56 | num_workers=config['num_workers'], 57 | pin_memory=config['pin_memory']) 58 | 59 | # check train_loader: 60 | for batch_image, batch_label, indices in train_loader: 61 | plt.imshow(batch_image[0,:,:,48]) 62 | plt.pause(0.1) 63 | plt.imshow(batch_label[0,:,:,48]) 64 | plt.pause(0.1) 65 | print(batch_image.shape, batch_label.shape) 66 | print(indices) 67 | # check valid loader 68 | # for batch_image, batch_label, indices in valid_loader: 69 | # print(batch_image.shape, batch_label.shape) 70 | 71 | print('sucuess!') 72 | -------------------------------------------------------------------------------- /common.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import sys 4 | import random 5 | import matplotlib 6 | from utils import * 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' # '3,2' #'3,2,1,0' 10 | # -------------------------------------------------------------------- 11 | PROJECT_PATH = '/home/wanggh/Music/3D_Pytorch' 12 | print('@%s: ' % PROJECT_PATH) 13 | 14 | Data_Dir = PROJECT_PATH + '/Data' 15 | out_dir = PROJECT_PATH + '/results' 16 | Split_dir = Data_Dir + '/split' 17 | Check_Pints_dir = out_dir + '/checkpoints' 18 | 19 | Mkdir(Check_Pints_dir) 20 | 21 | if 1: 22 | SEED = 35202 23 | random.seed(SEED) 24 | np.random.seed(SEED) 25 | torch.manual_seed(SEED) 26 | torch.cuda.manual_seed_all(SEED) 27 | print('\tset random seed') 28 | print('\t\tSEED=%d' % SEED) 29 | if 1: 30 | torch.backends.cudnn.benchmark = True 31 | torch.backends.cudnn.enabled = True 32 | print('\tset cuda environment') 33 | print( 34 | '\t\ttorch.__version__ =', 35 | torch.__version__) 36 | print( 37 | '\t\ttorch.version.cuda =', 38 | torch.version.cuda) 39 | print( 40 | '\t\ttorch.backends.cudnn.version() =', 41 | torch.backends.cudnn.version()) 42 | try: 43 | print( 44 | '\t\tos[\'CUDA_VISIBLE_DEVICES\'] =', 45 | os.environ['CUDA_VISIBLE_DEVICES']) 46 | NUM_CUDA_DEVICES = len(os.environ['CUDA_VISIBLE_DEVICES'].split(',')) 47 | except Exception: 48 | print('\t\tos[\'CUDA_VISIBLE_DEVICES\'] =', 'None') 49 | NUM_CUDA_DEVICES = 1 50 | 51 | print( 52 | '\t\ttorch.cuda.device_count() =', 53 | torch.cuda.device_count()) 54 | print( 55 | '\t\ttorch.cuda.current_device() =', 56 | torch.cuda.current_device()) -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | config = {} 4 | config['train_batch_size'] = 1 5 | config['valid_batch_size'] = 1 6 | config['test_batch_size'] = 1 7 | config['epochs'] = 200 8 | config['lr'] = 0.01 9 | config['momentum'] = 0.5 10 | config['no-cuda'] = False 11 | config['seed'] = 1 12 | config['log_interval'] = 1 13 | config['cuda'] = torch.cuda.is_available() #args.no_cuda 14 | config['label_split_dir'] = '/home/liuh/Documents/3D_pytorch/build/DataReader/split/labels' 15 | config['image_split_dir'] = '/home/liuh/Documents/3D_pytorch/build/DataReader/split/images' 16 | config['crop_size'] = [48, 96, 96] 17 | config['num_workers'] = 4 18 | config['pin_memory'] = True 19 | config['drop_last'] = True 20 | config['save_epoch_num'] = 2 -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | import torch.functional as f 5 | import numpy as np 6 | 7 | 8 | class DICELossMultiClass(nn.Module): 9 | 10 | def __init__(self): 11 | super(DICELossMultiClass, self).__init__() 12 | 13 | def forward(self, output, mask): 14 | 15 | probs = output[:, 1, :, :] 16 | mask = torch.squeeze(mask, 1) 17 | 18 | num = probs * mask 19 | num = torch.sum(num, 2) 20 | num = torch.sum(num, 1) 21 | 22 | # print( num ) 23 | 24 | den1 = probs * probs 25 | # print(den1.size()) 26 | den1 = torch.sum(den1, 2) 27 | den1 = torch.sum(den1, 1) 28 | 29 | # print(den1.size()) 30 | 31 | den2 = mask * mask 32 | # print(den2.size()) 33 | den2 = torch.sum(den2, 2) 34 | den2 = torch.sum(den2, 1) 35 | 36 | # print(den2.size()) 37 | eps = 0.0000001 38 | dice = 2 * ((num + eps) / (den1 + den2 + eps)) 39 | # dice_eso = dice[:, 1:] 40 | dice_eso = dice 41 | 42 | loss = 1 - torch.sum(dice_eso) / dice_eso.size(0) 43 | return loss 44 | 45 | 46 | class DICELoss(nn.Module): 47 | 48 | def __init__(self): 49 | super(DICELoss, self).__init__() 50 | 51 | def forward(self, output, mask): 52 | 53 | probs = torch.squeeze(output, 1) 54 | mask = torch.squeeze(mask, 1) 55 | 56 | intersection = probs * mask 57 | intersection = torch.sum(intersection, 2) 58 | intersection = torch.sum(intersection, 1) 59 | 60 | # print( num ) 61 | 62 | den1 = probs * probs 63 | # print(den1.size()) 64 | den1 = torch.sum(den1, 2) 65 | den1 = torch.sum(den1, 1) 66 | 67 | # print(den1.size()) 68 | 69 | den2 = mask * mask 70 | # print(den2.size()) 71 | den2 = torch.sum(den2, 2) 72 | den2 = torch.sum(den2, 1) 73 | 74 | # print(den2.size()) 75 | eps = 0.0000001 76 | dice = 2 * ((intersection + eps) / (den1 + den2 + eps)) 77 | # dice_eso = dice[:, 1:] 78 | dice_eso = dice 79 | 80 | loss = 1 - torch.sum(dice_eso) / dice_eso.size(0) 81 | return loss -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: antigen 3 | # @Date: 2018-04-12 22:39:01 4 | # @Last Modified by: antigen 5 | # @Last Modified time: 2018-04-15 23:53:05 6 | import torch 7 | from common import * 8 | import torch.nn as nn 9 | from config import config 10 | import torch.optim as optim 11 | import torch.utils as utils 12 | import torch.nn.init as init 13 | import torch.nn.functional as F 14 | import torch.utils.data as data 15 | from DataReader.sampler import * 16 | import torchvision.utils as v_utils 17 | import torchvision.datasets as dset 18 | from torch.utils.data import DataLoader 19 | import torchvision.transforms as transforms 20 | from torch.autograd import Variable 21 | from time import time as timer 22 | from losses import DICELossMultiClass 23 | # from Model.model import Net 24 | # from Model.Unet_Zoo import UNet 25 | from Model.UnetGenerator_3d import * 26 | from DataReader.reader import ScienceDataset 27 | from DataReader.transform import Flip_image, Random_crop 28 | # define log 29 | log = Logger() 30 | 31 | start = timer() 32 | #define loss function 33 | criterion = DICELossMultiClass() 34 | 35 | def loss_function(output,label): 36 | batch_size,channel,x,y,z = output.size() 37 | total_loss = 0 38 | for i in range(batch_size): 39 | for j in range(z): 40 | loss = 0 41 | output_z = output[i:i+1,:,:,:,j] 42 | label_z = label[i,:,:,:,j] 43 | 44 | softmax_output_z = nn.Softmax2d()(output_z) 45 | logsoftmax_output_z = torch.log(softmax_output_z) 46 | 47 | loss = nn.NLLLoss2d()(logsoftmax_output_z,label_z) 48 | total_loss += loss 49 | 50 | return total_loss 51 | 52 | def save_check_points(net, check_points_dir, epoch, optimizer): 53 | iter_num = config['train_batch_size'] * epoch 54 | if (epoch+1) % config['save_epoch_num'] == 0: #save last 55 | torch.save(net.state_dict(),Check_Pints_dir + '%s_model.pth'%(str(epoch).zfill(5))) 56 | torch.save({ 57 | 'optimizer': optimizer.state_dict(), 58 | 'iter' : iter_num, 59 | 'epoch' : epoch, 60 | }, Check_Pints_dir +'%s_optimizer.pth'%(str(epoch).zfill(5))) 61 | 62 | def adjust_lr(optimizer, epoch): 63 | lr = config['lr'] * (0.1 ** (epoch // 20)) 64 | for param_group in optimizer.param_groups: 65 | param_group['lr'] = lr 66 | return lr 67 | def train(epoch, train_loader, model, optimizer): 68 | iter_num = 0 69 | train_loss = 0 70 | for batch_idx, (data, target, indices) in enumerate(train_loader): 71 | if config['cuda']: 72 | data, target = data.cuda(), target.cuda() 73 | data, train_augment = Variable(data), Variable(target) 74 | optimizer.zero_grad() 75 | out = model(data) 76 | loss = loss_function(out, train_augment) 77 | # loss = F.nll_loss(log_p, target) 78 | # loss = criterion(output_score, target) 79 | loss.backward() 80 | optimizer.step() 81 | # save the checkpoints 82 | save_check_points(model, Check_Pints_dir, epoch, optimizer) 83 | train_loss = train_loss + loss.data[0] 84 | learning_rate = adjust_lr(optimizer, epoch) 85 | #learning rate decay 86 | 87 | 88 | 89 | train_loss = train_loss/len(train_loader.dataset) 90 | return train_loss, learning_rate 91 | 92 | # some train info name batch_index for train_loss = loss.data[0] 93 | 94 | 95 | def valid(valid_loader, model): 96 | model.eval() 97 | valid_loss = 0 98 | correct = 0 99 | for data, target, indices in valid_loader: 100 | if config['cuda']: 101 | data, target = data.cuda(), target.cuda() 102 | data, target = Variable(data), Variable(target) 103 | output = model(data) 104 | loss = loss_function(output, target) 105 | valid_loss = valid_loss + loss.data[0] 106 | 107 | valid_loss /= len(valid_loader.dataset) 108 | 109 | return valid_loss 110 | 111 | 112 | def train_augment(image, label, indices): 113 | """include random crop and random flip transpose""" 114 | # image = Flip_image(image) 115 | # label = Flip_image(label) 116 | 117 | image, label = Random_crop(image, label, config['crop_size']) 118 | image = np.expand_dims(image, axis = 0) 119 | label = np.expand_dims(label, axis = 0) 120 | input = torch.from_numpy(image).float().div(255) 121 | label = torch.from_numpy(label).long() 122 | return input, label, indices 123 | 124 | 125 | def valid_augment(image, label, indices): 126 | """include random crop and random flip transpose""" 127 | image, label = Random_crop(image, label, config['crop_size']) 128 | image = np.expand_dims(image, axis=0) 129 | label = np.expand_dims(label, axis=0) 130 | input = torch.from_numpy(image.copy()).float().div(255) 131 | label = torch.from_numpy(label).long() 132 | return input, label, indices 133 | 134 | 135 | def main(): 136 | 137 | initial_checkpoint = None 138 | pretrain_file = None 139 | log.open(out_dir+'/log.train.txt', mode='a') 140 | log.write('** some experiment setting **\n') 141 | log.write('\tSEED = %u\n' % SEED) 142 | log.write('\tPROJECT_PATH = %s\n' % PROJECT_PATH) 143 | log.write('\tout_dir = %s\n' % out_dir) 144 | log.write('\n') 145 | # load image data 146 | train_dataset = ScienceDataset( 147 | split='Train_DataSet', 148 | Data_dir=Data_Dir, 149 | transform=train_augment, mode='train') 150 | 151 | train_loader = DataLoader( 152 | train_dataset, 153 | sampler=None, 154 | shuffle=True, 155 | batch_size=config['train_batch_size'], 156 | drop_last=config['drop_last'], 157 | num_workers=config['num_workers'], 158 | pin_memory=config['pin_memory']) 159 | 160 | valid_dataset = ScienceDataset( 161 | split='Valid_DataSet', 162 | Data_dir=Data_Dir, 163 | transform=valid_augment, 164 | mode='train') 165 | 166 | valid_loader = DataLoader( 167 | valid_dataset, 168 | sampler=None, 169 | shuffle=True, 170 | batch_size=config['valid_batch_size'], 171 | drop_last=False, 172 | num_workers=config['num_workers'], 173 | pin_memory=config['pin_memory']) 174 | 175 | log.write('** dataset setting **\n') 176 | log.write( 177 | '\tWIDTH, HEIGHT = %d, %d, %d\n' % ( 178 | config['crop_size'][0], 179 | config['crop_size'][1], 180 | config['crop_size'][2])) 181 | log.write('\ttrain_dataset.split = %s\n' % (len(train_dataset.ids))) 182 | log.write('\tvalid_dataset.split = %s\n' % (len(valid_dataset.ids))) 183 | log.write('\tlen(train_dataset) = %d\n' % (len(train_dataset))) 184 | log.write('\tlen(valid_dataset) = %d\n' % (len(valid_dataset))) 185 | log.write('\tlen(train_loader) = %d\n' % (len(train_loader))) 186 | log.write('\tlen(valid_loader) = %d\n' % (len(valid_loader))) 187 | log.write('\tbatch_size = %d\n' % (config['train_batch_size'])) 188 | log.write('\n') 189 | 190 | start_epoch = 0 191 | start_iteration = 0 192 | # initial model 193 | model = UnetGenerator_3d(in_dim=1, out_dim=2, num_filter=4) 194 | if config['cuda']: 195 | model.cuda() 196 | optimizer = optim.SGD( 197 | model.parameters(), 198 | lr=config['lr'], 199 | momentum=config['momentum']) 200 | # optimizer = optim.Adam( 201 | # model.parameters(), 202 | # lr=config['lr'], 203 | # betas=(0.9, 0.999), 204 | # eps=1e-8, 205 | # weight_decay=0) 206 | # start training here! ############################################## 207 | log.write('** start training here! **\n') 208 | log.write(' optimizer=%s\n' % str(optimizer)) 209 | log.write(' momentum=%f\n' % optimizer.param_groups[0]['momentum']) 210 | log.write(' LR=%s\n\n' % str(config['lr'])) 211 | log.write(' images_per_epoch = %d\n\n' % len(train_dataset)) 212 | log.write(' rate iter epoch num | valid_loss | \ 213 | train_loss| time \n') 214 | log.write('---------------------\ 215 | ---------------------\ 216 | ------------------\n') 217 | 218 | if initial_checkpoint is not None: 219 | log.write('\tinitial_checkpoint = %s\n' % initial_checkpoint) 220 | net.load_state_dict(torch.load(initial_checkpoint, map_location=lambda storage, loc: storage)) 221 | if pretrain_file is not None: 222 | log.write('\tpretrain_file = %s\n' % pretrain_file) 223 | net.load_pretrain(pretrain_file, skip) 224 | 225 | log.write('** net setting **\n') 226 | log.write('\tinitial_checkpoint = %s\n' % initial_checkpoint) 227 | log.write('\tpretrain_file = %s\n' % pretrain_file) 228 | log.write('%s\n\n' % (type(model))) 229 | log.write('\n') 230 | 231 | for epoch in range(1, config['epochs']+1): 232 | model.train() 233 | train_loss,learning_rate = train(epoch, train_loader, model, optimizer) 234 | # train(epoch) 235 | valid_loss = valid(valid_loader, model) 236 | log.write('%d k | train_loss') 237 | log.write('%d k | %0.3f | %0.3f | %0.3f | %0.3f\n' % (\ 238 | config['train_batch_size'] * epoch , epoch, 239 | train_loss, valid_loss, learning_rate)) 240 | log.write('\n') 241 | 242 | if __name__ == "__main__": 243 | main() 244 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class Net(nn.Module): 6 | def __init__(self, n_class = 2): 7 | super().__init__() 8 | # conv1 9 | self.conv1_1 = nn.Conv3d(1, 8, 3, padding=60) 10 | # torch.nn.Conv3d(in_channels, out_channels, kernel_size, 11 | # stride=1, padding=0, dilation=1, groups=1, bias=True) 12 | self.relu1_1 = nn.ReLU(inplace=True) 13 | self.conv1_2 = nn.Conv3d(8, 8, 3, padding=1) 14 | self.relu1_2 = nn.ReLU(inplace=True) 15 | self.pool1 = nn.MaxPool3d(2, stride=2, ceil_mode=True) # 1/2 16 | 17 | # conv2 18 | self.conv2_1 = nn.Conv3d(8, 16, 3, padding=15) 19 | self.relu2_1 = nn.ReLU(inplace=True) 20 | self.conv2_2 = nn.Conv3d(16, 16, 3, padding=1) 21 | self.relu2_2 = nn.ReLU(inplace=True) 22 | self.pool2 = nn.MaxPool3d(2, stride=2, ceil_mode=True) # 1/4 23 | 24 | # conv3 25 | self.conv3_1 = nn.Conv3d(16, 32, 3, padding=1) 26 | self.relu3_1 = nn.ReLU(inplace=True) 27 | self.conv3_2 = nn.Conv3d(32, 32, 3, padding=1) 28 | self.relu3_2 = nn.ReLU(inplace=True) 29 | self.conv3_3 = nn.Conv3d(32, 32, 3, padding=1) 30 | self.relu3_3 = nn.ReLU(inplace=True) 31 | self.pool3 = nn.MaxPool3d(2, stride=2, ceil_mode=True) # 1/8 32 | 33 | # conv4 34 | self.conv4_1 = nn.Conv3d(32, 64, 3, padding=1) 35 | self.relu4_1 = nn.ReLU(inplace=True) 36 | self.conv4_2 = nn.Conv3d(64, 64, 3, padding=1) 37 | self.relu4_2 = nn.ReLU(inplace=True) 38 | self.conv4_3 = nn.Conv3d(64, 64, 3, padding=1) 39 | self.relu4_3 = nn.ReLU(inplace=True) 40 | self.pool4 = nn.MaxPool3d(2, stride=2, ceil_mode=True) # 1/16 41 | 42 | # conv5 43 | self.conv5_1 = nn.Conv3d(64, 64, 3, padding=1) 44 | self.relu5_1 = nn.ReLU(inplace=True) 45 | self.conv5_2 = nn.Conv3d(64, 64, 3, padding=1) 46 | self.relu5_2 = nn.ReLU(inplace=True) 47 | self.conv5_3 = nn.Conv3d(64, 64, 3, padding=1) 48 | self.relu5_3 = nn.ReLU(inplace=True) 49 | self.pool5 = nn.MaxPool3d(2, stride=2, ceil_mode=True) # 1/32 50 | 51 | # fc6 52 | self.fc6 = nn.Conv3d(64, 512, 7) 53 | self.relu6 = nn.ReLU(inplace=True) 54 | self.drop6 = nn.Dropout3d() 55 | 56 | # fc7 57 | self.fc7 = nn.Conv3d(512, 512, 1) 58 | self.relu7 = nn.ReLU(inplace=True) 59 | self.drop7 = nn.Dropout3d() 60 | 61 | self.score_fr = nn.Conv3d(512, n_class, 1) 62 | self.score_pool3 = nn.Conv3d(32, n_class, 1) 63 | self.score_pool4 = nn.Conv3d(64, n_class, 1) 64 | 65 | self.upscore2 = nn.ConvTranspose3d( 66 | n_class, n_class, 4, stride=2, bias=False) 67 | self.upscore8 = nn.ConvTranspose3d( 68 | n_class, n_class, 16, stride=8, bias=False) 69 | self.upscore_pool4 = nn.ConvTranspose3d( 70 | n_class, n_class, 4, stride=2, bias=False) 71 | 72 | self._initialize_weights() 73 | 74 | 75 | def get_upsampling_weight(self, in_channels, out_channels, kernel_size): 76 | """Make a 2D bilinear kernel suitable for upsampling""" 77 | factor = (kernel_size + 1) // 2 78 | if kernel_size % 2 == 1: 79 | center = factor - 1 80 | else: 81 | center = factor - 0.5 82 | og = np.ogrid[:kernel_size, :kernel_size, :kernel_size] 83 | filt = (1 - abs(og[0] - center) / factor) * \ 84 | (1 - abs(og[1] - center) / factor) 85 | weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size, kernel_size), 86 | dtype=np.float64) 87 | weight[range(in_channels), range(out_channels), :, :, :] = filt 88 | return torch.from_numpy(weight).float() 89 | 90 | def _initialize_weights(self): 91 | for m in self.modules(): 92 | if isinstance(m, nn.Conv3d): 93 | m.weight.data.zero_() 94 | m.weight.data.normal_(0.0, 0.1) 95 | if m.bias is not None: 96 | m.bias.data.zero_() 97 | if isinstance(m, nn.ConvTranspose3d): 98 | assert m.kernel_size[0] == m.kernel_size[1] 99 | initial_weight = self.get_upsampling_weight( 100 | m.in_channels, m.out_channels, m.kernel_size[0]) 101 | m.weight.data.copy_(initial_weight) 102 | 103 | # def _initialize_weights(self): 104 | # for m in self.modules(): 105 | # if isinstance(m, nn.Conv3d): 106 | # m.weight.data.zero_() 107 | # m.weight.data.normal_(0.0, 0.1) 108 | # if m.bias is not None: 109 | # m.bias.data.zero_() 110 | # m.bias.data.normal_(0.0, 0.1) 111 | # if isinstance(m, nn.ConvTranspose3d): 112 | # m.weight.data.zero_() 113 | # m.weight.data.normal_(0.0, 0.1) 114 | # # assert m.kernel_size[0] == m.kernel_size[1] 115 | # # initial_weight = self.get_upsampling_weight( 116 | # # m.in_channels, m.out_channels, m.kernel_size[0]) 117 | # # m.weight.data.copy_(initial_weight) 118 | def copy_params_from_vgg16(self, vgg16): 119 | features = [ 120 | self.conv1_1, self.relu1_1, 121 | self.conv1_2, self.relu1_2, 122 | self.pool1, 123 | self.conv2_1, self.relu2_1, 124 | self.conv2_2, self.relu2_2, 125 | self.pool2, 126 | self.conv3_1, self.relu3_1, 127 | self.conv3_2, self.relu3_2, 128 | self.conv3_3, self.relu3_3, 129 | self.pool3, 130 | self.conv4_1, self.relu4_1, 131 | self.conv4_2, self.relu4_2, 132 | self.conv4_3, self.relu4_3, 133 | self.pool4, 134 | self.conv5_1, self.relu5_1, 135 | self.conv5_2, self.relu5_2, 136 | self.conv5_3, self.relu5_3, 137 | self.pool5, 138 | ] 139 | for l1, l2 in zip(vgg16.features, features): 140 | # print("what is l1? ", l1) 141 | # print("what is l2? ", l2) 142 | if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d): 143 | assert l1.weight.size() == l2.weight.size() 144 | assert l1.bias.size() == l2.bias.size() 145 | l2.weight.data.copy_(l1.weight.data) 146 | l2.bias.data.copy_(l1.bias.data) 147 | for i, name in zip([0, 3], ['fc6', 'fc7']): 148 | l1 = vgg16.classifier[i] 149 | l2 = getattr(self, name) 150 | l2.weight.data.copy_(l1.weight.data.view(l2.weight.size())) 151 | l2.bias.data.copy_(l1.bias.data.view(l2.bias.size())) 152 | 153 | def forward(self, x): 154 | h = x 155 | # happyprint("init: ", x.data[0].shape) 156 | 157 | h = self.relu1_1(self.conv1_1(h)) 158 | # happyprint("after conv1_1: ", h.data[0].shape) 159 | 160 | h = self.relu1_2(self.conv1_2(h)) 161 | h = self.pool1(h) 162 | 163 | # happyprint("after pool1: ", h.data[0].shape) 164 | 165 | h = self.relu2_1(self.conv2_1(h)) 166 | h = self.relu2_2(self.conv2_2(h)) 167 | h = self.pool2(h) 168 | 169 | # happyprint("after pool2: ", h.data[0].shape) 170 | 171 | h = self.relu3_1(self.conv3_1(h)) 172 | h = self.relu3_2(self.conv3_2(h)) 173 | h = self.relu3_3(self.conv3_3(h)) 174 | h = self.pool3(h) 175 | pool3 = h # 1/8 176 | 177 | # happyprint("after pool3: ", h.data[0].shape) 178 | 179 | h = self.relu4_1(self.conv4_1(h)) 180 | h = self.relu4_2(self.conv4_2(h)) 181 | h = self.relu4_3(self.conv4_3(h)) 182 | h = self.pool4(h) 183 | pool4 = h # 1/16 184 | 185 | # happyprint("after pool4: ", h.data[0].shape) 186 | 187 | h = self.relu5_1(self.conv5_1(h)) 188 | h = self.relu5_2(self.conv5_2(h)) 189 | h = self.relu5_3(self.conv5_3(h)) 190 | h = self.pool5(h) 191 | 192 | # happyprint("after pool5: ", h.data[0].shape) 193 | 194 | h = self.relu6(self.fc6(h)) 195 | h = self.drop6(h) 196 | 197 | h = self.relu7(self.fc7(h)) 198 | h = self.drop7(h) 199 | 200 | h = self.score_fr(h) 201 | 202 | # happyprint("after score_fr: ", h.data[0].shape) 203 | h = self.upscore2(h) 204 | 205 | # happyprint("after upscore2: ", h.data[0].shape) 206 | upscore2 = h # 1/16 207 | 208 | h = self.score_pool4(pool4 * 0.01) # XXX: scaling to train at once 209 | # happyprint("after score_pool4: ", h.data[0].shape) 210 | 211 | h = h[:, :, 5:5 + upscore2.size()[2], 5:5 + upscore2.size()[3], 5:5 + upscore2.size()[4]] 212 | 213 | score_pool4c = h # 1/16 214 | # happyprint("after score_pool4c: ", h.data[0].shape) 215 | 216 | h = upscore2 + score_pool4c # 1/16 217 | h = self.upscore_pool4(h) 218 | upscore_pool4 = h # 1/8 219 | # happyprint("after upscore_pool4: ", h.data[0].shape) 220 | 221 | h = self.score_pool3(pool3 * 0.0001) # XXX: scaling to train at once 222 | h = h[:, :, 223 | 9:9 + upscore_pool4.size()[2], 224 | 9:9 + upscore_pool4.size()[3], 225 | 9:9 + upscore_pool4.size()[4]] 226 | score_pool3c = h # 1/8 227 | # happyprint("after score_pool3: ", h.data[0].shape) 228 | 229 | # print(upscore_pool4.data[0].shape) 230 | # print(score_pool3c.data[0].shape) 231 | 232 | # Adjusting stride in self.upscore2 and self.upscore_pool4 233 | # and self.conv1_1 can change the tensor shape (size). 234 | # I don't know why! 235 | 236 | h = upscore_pool4 + score_pool3c # 1/8 237 | 238 | h = self.upscore8(h) # dim: 88^3 239 | h = h[:, :, 31:31 + x.size()[2], 31:31 + x.size()[3], 31:31 + x.size()[4]].contiguous() 240 | # happyprint("after upscore8: ", h.data[0].shape) 241 | return h 242 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import shutil 4 | 5 | 6 | # def Mkdir(dir): 7 | # if os.path.isdir(dir): 8 | # shutil.rmtree(dir, ignore_errors=True) 9 | # else: 10 | # os.makedirs(dir) 11 | def Mkdir(directory): 12 | try: 13 | os.stat(directory) 14 | except: 15 | os.mkdir(directory) 16 | 17 | # http://stackoverflow.com/questions/34950201/pycharm-print-end-r-statement-not-working 18 | class Logger(object): 19 | def __init__(self): 20 | self.terminal = sys.stdout # stdout 21 | self.file = None 22 | 23 | def open(self, file, mode=None): 24 | if mode is None: 25 | mode = 'w' 26 | self.file = open(file, mode) 27 | 28 | def write(self, message, is_terminal=1, is_file=1): 29 | if '\r' in message: 30 | is_file = 0 31 | 32 | if is_terminal == 1: 33 | self.terminal.write(message) 34 | self.terminal.flush() 35 | # time.sleep(1) 36 | 37 | if is_file == 1: 38 | self.file.write(message) 39 | self.file.flush() 40 | 41 | def flush(self): 42 | # this flush method is needed for python 3 compatibility. 43 | # this handles the flush command by doing nothing. 44 | # you might want to specify some extra behavior here. 45 | pass --------------------------------------------------------------------------------