├── checkpoint ├── airway_seg_cascade.py ├── dicom_read.py ├── utils.py ├── test.py ├── organize_data.py ├── zoom.py ├── airway_seg.py ├── lung_seg.py └── tools.py /checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "model.cptk" 2 | all_model_checkpoint_paths: "model.cptk" 3 | -------------------------------------------------------------------------------- /airway_seg_cascade.py: -------------------------------------------------------------------------------- 1 | from lung_seg import Lung_Seg 2 | from airway_seg import airway_seg 3 | import SimpleITK as ST 4 | import numpy as np 5 | 6 | def airway_segmentation(dicom_dir): 7 | lung_img = Lung_Seg(dicom_dir) 8 | airway_mask = airway_seg(lung_img) 9 | return airway_mask 10 | 11 | # if __name__ =="__main__": 12 | # dicom_dir = "./case06/original1" 13 | # airway_mask = airway_segmentation(dicom_dir) -------------------------------------------------------------------------------- /dicom_read.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import SimpleITK as ST 3 | import sys 4 | 5 | def read_dicoms(input_directory): 6 | if len(sys.argv)<1: 7 | print "Usage: DicomSeriesReader " 8 | sys.exit(1) 9 | 10 | print "Reading Dicom directory",input_directory 11 | reader=ST.ImageSeriesReader() 12 | 13 | dicom_names=reader.GetGDCMSeriesFileNames(input_directory) 14 | reader.SetFileNames(dicom_names) 15 | # print dicom_names 16 | 17 | image=reader.Execute() 18 | return image -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import numpy as np 4 | import SimpleITK as ST 5 | import dicom_read 6 | import cPickle as pickle 7 | import scipy.io as sio 8 | 9 | def get_range(mask,type=0): 10 | begin_switch = 0 11 | begin = 0 12 | if type==0: 13 | end_switch = 0 14 | end = np.shape(mask)[2] - 1 15 | for i in range(np.shape(mask)[2]): 16 | if np.max(mask[:, :, i]) == 1 and begin_switch == 0: 17 | begin_switch = 1 18 | begin = i 19 | if np.max(mask[:, :, i]) == 0 and end_switch == 0 and begin_switch == 1: 20 | end = i 21 | end_switch = 1 22 | if end_switch: 23 | break 24 | return begin,end 25 | if type==1: 26 | end_switch = 0 27 | end = np.shape(mask)[1] - 1 28 | for i in range(np.shape(mask)[1]): 29 | if np.max(mask[:, i, :]) == 1 and begin_switch == 0: 30 | begin_switch = 1 31 | begin = i 32 | if np.max(mask[:, i, :]) == 0 and end_switch == 0 and begin_switch == 1: 33 | end = i 34 | end_switch = 1 35 | if end_switch: 36 | break 37 | return begin, end 38 | if type==2: 39 | end_switch = 0 40 | end = np.shape(mask)[0] - 1 41 | for i in range(np.shape(mask)[0]): 42 | if np.max(mask[i, :, :]) == 1 and begin_switch == 0: 43 | begin_switch = 1 44 | begin = i 45 | if np.max(mask[i, :, :]) == 0 and end_switch == 0 and begin_switch == 1: 46 | end = i 47 | end_switch = 1 48 | if end_switch: 49 | break 50 | return begin, end 51 | 52 | def get_range_slices(mask,type): 53 | mask_shape=np.shape(mask) 54 | ret_begin=0 55 | ret_end=0 56 | max_length = 0 57 | if type==1: 58 | ret_begin=0 59 | ret_end=mask_shape[1]-1 60 | for i in range(mask_shape[2]): 61 | temp_slice=mask[:,:,i] 62 | begin_switch = 0 63 | end_switch = 0 64 | begin = 0 65 | for j in range(np.shape(temp_slice)[1]): 66 | if np.max(temp_slice[:, j]) == 1 and begin_switch == 0: 67 | begin_switch = 1 68 | begin = j 69 | if np.max(temp_slice[:,j]) == 0 and end_switch == 0 and begin_switch == 1: 70 | end_switch = 1 71 | end = j 72 | if end_switch: 73 | break 74 | if (end-begin)>max_length: 75 | # print begin,end 76 | ret_begin = begin 77 | ret_end = end 78 | if type==2: 79 | ret_begin=0 80 | ret_end=mask_shape[0]-1 81 | for i in range(mask_shape[2]): 82 | temp_slice=mask[:,:,i] 83 | begin_switch = 0 84 | end_switch = 0 85 | begin = 0 86 | for j in range(np.shape(temp_slice)[1]): 87 | if np.max(temp_slice[j, :]) == 1 and begin_switch == 0: 88 | begin_switch = 1 89 | begin = j 90 | if np.max(temp_slice[j, :]) == 0 and end_switch == 0 and begin_switch == 1: 91 | end_switch = 1 92 | end = j 93 | if end_switch: 94 | break 95 | if (end-begin)>max_length: 96 | # print begin,end 97 | ret_begin = begin 98 | ret_end = end 99 | return ret_begin,ret_end 100 | 101 | def get_array(dicom_dir): 102 | img = dicom_read.read_dicoms(dicom_dir) 103 | ret_array = ST.GetArrayFromImage(img) 104 | ret_array = np.transpose(ret_array,[2,1,0]) 105 | return ret_array 106 | 107 | def get_original_arrays(root_path,type): 108 | number = 0 109 | origin_datas = dict() 110 | for patient_dir in os.listdir(root_path): 111 | # read folder names 112 | origin_datas[number] = dict() 113 | origin_datas[number]['name'] = patient_dir 114 | dicom_dirs = root_path+'/'+patient_dir 115 | for sub_dir in os.listdir(dicom_dirs): 116 | dicom_dir = dicom_dirs+'/'+sub_dir 117 | if os.path.isdir(dicom_dir) and ('origin' in dicom_dir or type in dicom_dir): 118 | origin_datas[number][sub_dir]=get_array(dicom_dirs+'/'+sub_dir) 119 | number+=1 120 | origin_datas['mask_type']=type 121 | return origin_datas 122 | 123 | def organize_data_pairs(origin_datas): 124 | ret_pairs=dict() 125 | number=0 126 | mask_type = origin_datas['mask_type'] 127 | for data_num in origin_datas.keys(): 128 | if not 'mask_type' == data_num: 129 | # Read each set of data 130 | print 'processing data: ',origin_datas[data_num]['name'] 131 | # Each data has only one mask array and will be used to check if shapes are identical 132 | mask_array = origin_datas[data_num][mask_type] 133 | mask_shape = np.shape(mask_array) 134 | for data_name in origin_datas[data_num]: 135 | # check if this original data array has the same shape with the mask array 136 | if 'original' in data_name: 137 | original_array = origin_datas[data_num][data_name] 138 | original_shape = np.shape(original_array) 139 | if mask_shape[0]==original_shape[0] and mask_shape[1]==original_shape[1] and mask_shape[2]==original_shape[2]: 140 | ret_pairs[number]=dict() 141 | ret_pairs[number]['original']=original_array 142 | ret_pairs[number]['mask']=mask_array 143 | ret_pairs[number]['name']=origin_datas[data_num]['name'] 144 | number+=1 145 | return ret_pairs 146 | 147 | def get_range_each(data_pair): 148 | # print data_pair.keys() 149 | mask_array = data_pair['mask'] 150 | ret=[] 151 | ending = 0 152 | z_length=90 153 | mask_shape = np.shape(mask_array) 154 | for j in range(mask_shape[2]): 155 | if np.sum(mask_array[:, :, mask_shape[2] - j - 1]) > 0: 156 | ending = mask_shape[2] - j - 1 157 | # print np.sum(mask_array[:,:,mask_shape[2]-j-1]) 158 | break 159 | for i in range(3): 160 | if i==0: 161 | begin, end = get_range(mask_array[:, :, ending - z_length:ending], i) 162 | ret.append(begin+ending-z_length) 163 | ret.append(end+ending-z_length) 164 | else: 165 | begin, end = get_range_slices(mask_array[:, :, ending - z_length:ending],i) 166 | ret.append(begin) 167 | ret.append(end) 168 | return ret 169 | 170 | def get_proper_range(data_pairs): 171 | maxs=[0,0,0] 172 | for number in data_pairs.keys(): 173 | ranger=get_range_each(data_pairs[number]) 174 | size=[] 175 | for i in range(3): 176 | size.append(ranger[i*2+1]-ranger[i*2]) 177 | for i in range(3): 178 | if size[i]>maxs[i]: 179 | maxs[i]=size[i] 180 | print ranger,'---',size,"---",data_pairs[number]['name'] 181 | print 'propel size of sliding window: ',maxs 182 | return maxs 183 | 184 | def out_put_data_pairs(data_pairs): 185 | root_dir = '/opt/analyse_airway/' 186 | if not os.path.exists('./output'): 187 | os.makedirs('./output') 188 | data_meta = dict() 189 | for number in data_pairs: 190 | sio.savemat('./output/data'+str(number)+'.mat',{'original':data_pairs[number]['original'], 191 | 'mask':data_pairs[number]['mask']}) 192 | data_meta[number]=root_dir+'output/data'+str(number)+'.mat' 193 | pickle_writer = open('./data_meta.pkl','wb') 194 | pickle.dump(data_meta,pickle_writer) 195 | pickle_writer.close() -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import tensorflow as tf 4 | import scipy.io 5 | import tools 6 | import numpy as np 7 | from dicom_read import read_dicoms 8 | import SimpleITK as ST 9 | import time 10 | import zoom 11 | from zoom import Array_Zoom_in,Array_Reduce 12 | 13 | def get_valid_area(input_array): 14 | array_shape=np.shape(input_array) 15 | central_point=[(array_shape[0]-1)/2,(array_shape[1]-1)/2] 16 | xmin=0 17 | xmax=array_shape[0]-1 18 | ymin=0 19 | ymax=array_shape[1]-1 20 | tags=[0,0,0,0] 21 | for i in range(array_shape[1]/2): 22 | if np.max(input_array[central_point[0]-i,:,:])>0 and tags[0]==0: 23 | xmin=central_point[0]-i 24 | else: 25 | tags[0]=1 26 | if np.max(input_array[central_point[0]+i,:,:])>0 and tags[1]==0: 27 | xmax=central_point[0]+i 28 | else: 29 | tags[1]=1 30 | for j in range(array_shape[1]/2): 31 | if np.max(input_array[:,central_point[1]-j,:])>0 and tags[2]==0: 32 | ymin=central_point[1]-j 33 | else: 34 | tags[2]=1 35 | if np.max(input_array[:,central_point[1]+j,:])>0 and tags[3]==0: 36 | ymax=central_point[1]+j 37 | else: 38 | tags[3]=1 39 | return [xmin,xmax,ymin,ymax] 40 | 41 | def resize_image(test_input,resized_length): 42 | array_shape = np.shape(test_input) 43 | # print time.strftime('%Y-%m-%d %H:%M:%S'),' ',np.shape(test_input) 44 | # ranger=get_valid_area(test_input) 45 | # print time.strftime('%Y-%m-%d %H:%M:%S'),' ',ranger 46 | ranger=[0,array_shape[0]-1,0,array_shape[1]-1] 47 | sliced_img=test_input[ranger[0]:ranger[1]+1,ranger[2]:ranger[3]+1,:] 48 | size=np.array([ranger[1]-ranger[0]+1,ranger[3]-ranger[2]+1]) 49 | maxsize=np.max(size) 50 | minsize=np.min(size) 51 | 52 | sliced_size=np.shape(sliced_img) 53 | processed_img=np.zeros((maxsize,maxsize,array_shape[2]),np.float32) 54 | padding_size=maxsize-minsize 55 | try: 56 | if size[0]>size[1]: 57 | processed_img[:,padding_size/2:maxsize-padding_size/2,:]=sliced_img[:,:,:] 58 | else: 59 | processed_img[padding_size / 2:maxsize - padding_size / 2, :, :] = sliced_img[:, :, :] 60 | except Exception,e: 61 | if size[0]>size[1]: 62 | processed_img[:,padding_size/2+1:maxsize-padding_size/2,:]=sliced_img[:,:,:] 63 | else: 64 | processed_img[padding_size / 2+1:maxsize - padding_size / 2, :, :] = sliced_img[:, :, :] 65 | 66 | resized_rate=float(resized_length)/float(maxsize) 67 | if resized_rate<1: 68 | resized_img=Array_Reduce(processed_img,resized_rate,resized_rate) 69 | else: 70 | resized_img=Array_Zoom_in(processed_img,resized_rate,resized_rate) 71 | return resized_img,ranger 72 | 73 | def get_threshed_img(dicom_dir): 74 | img=read_dicoms(dicom_dir) 75 | space = img.GetSpacing() 76 | image_array = ST.GetArrayFromImage(img) 77 | # image_array = np.transpose(image_array,(2,1,0)) 78 | print np.shape(image_array) 79 | 80 | array_shape = np.shape(image_array) 81 | central = [(array_shape[2] - 1) / 2, (array_shape[1] - 1) / 2, (array_shape[0] - 1) / 2] 82 | print central 83 | pointslist=[] 84 | for i in range(3): 85 | for j in range(3): 86 | for k in range(3): 87 | if i!=0 or j!=0 or k!=0: 88 | pointslist.append([central[0]+i,central[1]+j,central[2]+k]) 89 | pointslist.append([central[0]+i,central[1]+j,central[2]-k]) 90 | pointslist.append([central[0]+i,central[1]-j,central[2]+k]) 91 | pointslist.append([central[0]+i,central[1]-j,central[2]-k]) 92 | pointslist.append([central[0]-i,central[1]+j,central[2]+k]) 93 | pointslist.append([central[0]-i,central[1]+j,central[2]-k]) 94 | pointslist.append([central[0]-i,central[1]-j,central[2]+k]) 95 | pointslist.append([central[0]-i,central[1]-j,central[2]-k]) 96 | threshed_mask = ST.NeighborhoodConnected(img, pointslist, -40, 97 | np.float64(np.max(image_array)), [1, 1, 1], 1.0) 98 | threshed_mask_array = ST.GetArrayFromImage(threshed_mask) 99 | 100 | threshed_array = image_array * threshed_mask_array 101 | # threshed_img = ST.GetImageFromArray(threshed_array) 102 | 103 | threshed_array = np.transpose(threshed_array, (2, 1, 0)) 104 | # threshed_array = np.float32(threshed_array) 105 | # threshed_img = ST.GetImageFromArray(threshed_array) 106 | # blured_img = ST.CurvatureAnisotropicDiffusion(threshed_img,0.0625,3,1,3) 107 | # blured_array = ST.GetArrayFromImage(blured_img) 108 | return threshed_array,space 109 | 110 | def get_organized_data(dicom_dir,resized_size): 111 | half_size = resized_size[2]/2 112 | time1=time.time() 113 | origin_array,space = get_threshed_img(dicom_dir) 114 | time2 = time.time() 115 | print "time for thresholding: ",time2-time1," s" 116 | if np.shape(origin_array)[0]==resized_size[0]: 117 | resized_array=origin_array 118 | else: 119 | resized_array,ranger = resize_image(origin_array,resized_size[0]) 120 | # shape = np.shape(origin_array) 121 | # test_inputs = [] 122 | # for i in range(half_size,shape[2]-half_size,half_size): 123 | # test_inputs.append(resized_array[:,:,i-half_size:i+half_size]) 124 | # print i 125 | # print len(test_inputs) 126 | time3 = time.time() 127 | print "time for resizing: ", time3-time2," s" 128 | return space,resized_array 129 | 130 | # def get_results(output_shape): 131 | # dicom_dir = "./3Dircadb1.2/PATIENT_DICOM" 132 | # input_shape = [384,384,4] 133 | # batch_size = 8 134 | # GPU0 = '0' 135 | # train_models_dir = './train_models/' 136 | # Net = DenseVoxNet.Network() 137 | # # X = tf.placeholder(shape=[batch_size, input_shape[0], input_shape[1], input_shape[2]], dtype=tf.float32) 138 | # X = tf.placeholder(shape=[batch_size, input_shape[0], input_shape[1], input_shape[2]], dtype=tf.float32) 139 | # # Y = tf.placeholder(shape=[batch_size, output_shape[0], output_shape[1], output_shape[2]], dtype=tf.float32) 140 | # Y = tf.placeholder(shape=[batch_size, output_shape[0], output_shape[1], output_shape[2]], dtype=tf.float32) 141 | # training = tf.placeholder(tf.bool) 142 | # Y_pred, Y_pred_modi,Y_pred_nosig = Net.ae_u(X,training,batch_size) 143 | # input_datas = get_organized_data(dicom_dir,input_shape) 144 | # time1 = time.time() 145 | # results = [] 146 | # saver = tf.train.Saver(max_to_keep=1) 147 | # config = tf.ConfigProto(allow_soft_placement=True) 148 | # config.gpu_options.visible_device_list = GPU0 149 | # with tf.Session(config=config) as sess: 150 | # print "restoring saved model" 151 | # saver.restore(sess, train_models_dir + 'model.cptk') 152 | # if os.path.exists(train_models_dir): 153 | # saver.restore(sess, train_models_dir + 'model.cptk') 154 | # for i in range(0,len(input_datas),4): 155 | # if i+batch_size < len(input_datas)-1: 156 | # input_data = np.zeros([batch_size,input_shape[0], input_shape[1], input_shape[2]]) 157 | # for j in range(i,i+8): 158 | # input_data[j-i,:,:,:]=input_datas[j][:,:,:] 159 | # partial_result = sess.run([Y_pred_modi],feed_dict={X:input_data,training:True}) 160 | # results.append(partial_result) 161 | # time2 = time.time() 162 | # print "time for calculating: ",time2-time1," s" 163 | # print len(results) 164 | # return results 165 | # 166 | # def test_main(): 167 | # output_shape = [256, 256, 4] 168 | # results = get_results(output_shape) 169 | # final_array = np.zeros([output_shape[0],output_shape[1],len(results)*4+4],np.float32) 170 | # for i in range(len(results)): 171 | # final_array[:,:,i*4:i*4+8]+=np.float32((results[i][:,:,:]-0.01)>0) 172 | # final_array = np.int8(final_array>0.5) 173 | # final_img = ST.GetImageFromArray(final_array) 174 | # ST.WriteImage(final_img,'./test_result.vtk') 175 | 176 | # test_main() 177 | -------------------------------------------------------------------------------- /organize_data.py: -------------------------------------------------------------------------------- 1 | import SimpleITK as ST 2 | import numpy as np 3 | import cPickle as pickle 4 | import scipy.io as sio 5 | import sys 6 | import zoom 7 | import time 8 | import random 9 | 10 | def get_range(mask,type=0): 11 | begin_switch = 0 12 | begin = 0 13 | end_switch = 0 14 | end = np.shape(mask)[2] - 1 15 | for i in range(np.shape(mask)[2]): 16 | if np.max(mask[:, :, i]) == 1 and begin_switch == 0: 17 | begin_switch = 1 18 | begin = i 19 | if np.max(mask[:, :, i]) == 0 and end_switch == 0 and begin_switch == 1: 20 | end = i 21 | end_switch = 1 22 | if end_switch: 23 | break 24 | return begin,end 25 | # 26 | # def get_organized_data_fixed_2D(meta_path, type, half_size): 27 | # dicom_datas = dict() 28 | # clipped_datas = dict() 29 | # pickle_readier = open(meta_path) 30 | # meta_data = pickle.load(pickle_readier) 31 | # for number, dataset in meta_data['matrixes'].items(): 32 | # try: 33 | # patient_data = sio.loadmat(dataset['PATIENT_DICOM']) 34 | # mask_data = sio.loadmat(dataset[type]) 35 | # original_array = patient_data['original_resized'] 36 | # mask = mask_data[type + '_mask_resized'] 37 | # dicom_datas[number] = list() 38 | # clipped_datas[number] = list() 39 | # # get the binary mask 40 | # mask = np.int8(mask > 0) 41 | # if np.max(mask) <= 0: 42 | # continue 43 | # # get the valid mask area 44 | # begin, end = get_range(mask,0) 45 | # origin = original_array[:, :, begin:end] 46 | # # clip = original_array[:,:,begin:end]*mask[:,:,begin:end] 47 | # clip = mask[:, :, begin:end] 48 | # # if number=='5': 49 | # # dicom_img = ST.GetImageFromArray(np.transpose(dicom_datas[number],(2,1,0)) ) 50 | # # clipped_img = ST.GetImageFromArray(np.transpose(clipped_data[number],(2,1,0)) ) 51 | # # ST.WriteImage(dicom_img,'./dicom_img.vtk') 52 | # # ST.WriteImage(clipped_img,'./clipped_img.vtk') 53 | # # exit(0) 54 | # print "valid area: ", begin, ":", end 55 | # for i in range(begin, end, half_size): 56 | # origin_slice = original_array[:, :, i - half_size:i + half_size] 57 | # clip_slice = mask[:, :, i - half_size:i + half_size] 58 | # if not 0 in np.shape(origin_slice) and not 0 in np.shape(clip_slice): 59 | # if np.shape(origin_slice)[-1] == half_size * 2 and np.shape(clip_slice)[-1] == half_size * 2: 60 | # dicom_datas[number].append(origin_slice) 61 | # clipped_datas[number].append(clip_slice) 62 | # except Exception, e: 63 | # print e 64 | # return dicom_datas, clipped_datas 65 | # 66 | # def resize_img(img_array,input_size): 67 | # shape = np.shape(img_array) 68 | # ret = img_array 69 | # if shape[0]input_size[0] or shape[1]>input_size[1]: 72 | # ret = zoom.Array_Reduce(img_array,float(input_size[0])/float(shape[0]),float(input_size[1])/float(shape[1])) 73 | # shape_resized=np.shape(ret) 74 | # if shape_resized[0]input_size[0] or shape_resized[1]>input_size[1]: 79 | # ret = ret[0:input_size[0],0:input_size[1],:] 80 | # return ret 81 | # 82 | # def get_organized_data_common(meta_path, type, half_size,input_size): 83 | # range_type=1 84 | # dicom_datas = dict() 85 | # clipped_datas = dict() 86 | # pickle_readier = open(meta_path) 87 | # meta_data = pickle.load(pickle_readier) 88 | # for number, dataset in meta_data['matrixes'].items(): 89 | # try: 90 | # patient_data = sio.loadmat(dataset['PATIENT_DICOM']) 91 | # mask_data = sio.loadmat(dataset[type]) 92 | # original_array = patient_data['original_resized'] 93 | # mask = mask_data[type + '_mask_resized'] 94 | # dicom_datas[number] = list() 95 | # clipped_datas[number] = list() 96 | # shape = np.shape(mask) 97 | # # get the binary mask 98 | # mask = np.int8(mask > 0) 99 | # if np.max(mask) <= 0: 100 | # continue 101 | # # get the valid mask area 102 | # begin, end = get_range(mask,range_type) 103 | # print "valid area: ", begin, ":", end 104 | # for i in range(begin, end, half_size/2): 105 | # origin_slice = original_array[:, :, i - half_size:i + half_size] 106 | # clip_slice = mask[:, :, i - half_size:i + half_size] 107 | # if not 0 in np.shape(origin_slice) and not 0 in np.shape(clip_slice) and np.sum(np.float32(clip_slice))/(128.0*128*half_size*2)>0.001: 108 | # if np.shape(origin_slice)[2] == half_size * 2 and np.shape(clip_slice)[2] == half_size * 2: 109 | # dicom_datas[number].append(origin_slice) 110 | # clipped_datas[number].append(clip_slice) 111 | # except Exception, e: 112 | # print e 113 | # return dicom_datas, clipped_datas 114 | 115 | def get_organized_data(meta_path, single_size,epoch): 116 | rand = random.Random() 117 | dicom_datas = dict() 118 | mask_datas = dict() 119 | pickle_reader = open(meta_path) 120 | meta_data = pickle.load(pickle_reader) 121 | # accept_zeros = rand.sample(meta_data.keys(),8) 122 | total_keys = meta_data.keys()[15:] 123 | begin = epoch%len(total_keys) 124 | end = (epoch+8)%len(total_keys) 125 | if begin(0.1*(1-epoch*1.0/1500)): 152 | clipped_dicom = original_array[i:i + single_size[0], j:j + single_size[1], k:k + single_size[2]] 153 | dicom_datas[number].append(clipped_dicom) 154 | mask_datas[number].append(clipped_mask) 155 | # clipped_dicom = original_array[i:i + single_size[0], j:j + single_size[1], k:k + single_size[2]] 156 | # dicom_datas[number].append(clipped_dicom) 157 | # mask_datas[number].append(clipped_mask) 158 | return dicom_datas,mask_datas 159 | # 160 | # def test(): 161 | # meta_path = '/opt/analyse_airway/data_meta.pkl' 162 | # single_size = [64,64,64] 163 | # dicom_datas,mask_datas=get_organized_data(meta_path,single_size) 164 | # print dicom_datas.keys() 165 | # print mask_datas.keys() 166 | 167 | # test() -------------------------------------------------------------------------------- /zoom.py: -------------------------------------------------------------------------------- 1 | # import cv 2 | import math 3 | import numpy as np 4 | import scipy.io as sio 5 | import time 6 | import cPickle as pickle 7 | 8 | # def JZoom(image, m, n): 9 | # H = int(image.height * m - m) 10 | # W = int(image.width * n - n) 11 | # size = (W, H) 12 | # iZoom = cv.CreateImage(size, image.depth, image.nChannels) 13 | # sum = [0, 0, 0] 14 | # for i in range(H): 15 | # for j in range(W): 16 | # x1 = int(math.floor((i + 1) / m - 1)) 17 | # y1 = int(math.floor((j + 1) / n - 1)) 18 | # p = (i + 0.0) / m - x1 19 | # q = (j + 0.0) / n - y1 20 | # for k in range(3): 21 | # sum[k] = int( 22 | # image[x1, y1][k] * (1 - p) * (1 - q) + image[x1 + 1, y1][k] * p * (1 - q) + image[x1, y1 + 1][k] * ( 23 | # 1 - p) * q + image[x1 + 1, y1 + 1][k] * p * q) 24 | # iZoom[i, j] = (sum[0], sum[1], sum[2]) 25 | # return iZoom 26 | 27 | def Array_Zoom_in(image, m, n): 28 | shape=np.shape(image) 29 | H = int(shape[0] * m - m) 30 | W = int(shape[1] * n - n) 31 | iZoom = np.zeros((H,W,shape[2]),dtype=np.float32) 32 | for i in range(H): 33 | for j in range(W): 34 | x1 = int(math.floor((i + 1) / m - 1)) 35 | y1 = int(math.floor((j + 1) / n - 1)) 36 | p = (i + 0.0) / m - x1 37 | q = (j + 0.0) / n - y1 38 | for k in range(shape[2]): 39 | sum= int( 40 | image[x1, y1, k] * (1 - p) * (1 - q) + image[x1 + 1, y1, k] * p * (1 - q) + 41 | image[x1, y1 + 1, k] * (1 - p) * q + image[x1 + 1, y1 + 1,k] * p * q) 42 | iZoom[i, j, k] = sum 43 | return iZoom 44 | 45 | def Array_Reduce(image,m,n): 46 | shape=np.shape(image) 47 | H = int(shape[0] * m) 48 | W = int(shape[1] * n) 49 | iJReduce = np.zeros((H, W, shape[2]), dtype=np.float32) 50 | for c in range(shape[2]): 51 | for i in range(H): 52 | for j in range(W): 53 | x1 = int(i/m) 54 | x2 = int((i+1)/m) 55 | y1 = int(j/n) 56 | y2 = int((j+1)/n) 57 | sum = 0 58 | for k in range(x1,x2): 59 | for l in range(y1,y2): 60 | sum = sum+image[k , l, c] 61 | num = (x2-x1)*(y2-y1) 62 | iJReduce[i , j, c] = sum/num 63 | return iJReduce 64 | 65 | ''' 66 | returns an array [xmin,xmax,ymin,ymax] 67 | ''' 68 | def get_valid_area(input_array): 69 | array_shape=np.shape(input_array) 70 | central_point=[(array_shape[0]-1)/2,(array_shape[1]-1)/2] 71 | xmin=0 72 | xmax=array_shape[0]-1 73 | ymin=0 74 | ymax=array_shape[1]-1 75 | tags=[0,0,0,0] 76 | for i in range(array_shape[1]/2): 77 | if np.max(input_array[central_point[0]-i,:,:])>0 and tags[0]==0: 78 | xmin=central_point[0]-i 79 | else: 80 | tags[0]=1 81 | if np.max(input_array[central_point[0]+i,:,:])>0 and tags[1]==0: 82 | xmax=central_point[0]+i 83 | else: 84 | tags[1]=1 85 | for j in range(array_shape[1]/2): 86 | if np.max(input_array[:,central_point[1]-j,:])>0 and tags[2]==0: 87 | ymin=central_point[1]-j 88 | else: 89 | tags[2]=1 90 | if np.max(input_array[:,central_point[1]+j,:])>0 and tags[3]==0: 91 | ymax=central_point[1]+j 92 | else: 93 | tags[3]=1 94 | return [xmin,xmax,ymin,ymax] 95 | 96 | 97 | # image = cv.LoadImage('lena.jpg', 1) 98 | # iZoom1 = JZoom(image, 2, 3) 99 | # iZoom2 = JZoom(image, 2.5, 2.5) 100 | # cv.ShowImage('image', image) 101 | # cv.ShowImage('iZoom1', iZoom1) 102 | # cv.ShowImage('iZoom2', iZoom2) 103 | 104 | # img_array = np.array(iZoom2[:,:]) 105 | # print np.shape(img_array) 106 | # cv.ShowImage('iZoom2', cv.fromarray(img_array[:,:,1])) 107 | # cv.WaitKey(0) 108 | # central = [(array_shape[0] - 1) / 2, (array_shape[1] - 1) / 2,(array_shape[2] - 1) / 2] 109 | # test_input=np.transpose(test_input,(2,1,0)) 110 | # test_img=ST.GetImageFromArray(test_input) 111 | # # NeighborhoodConnected(Image image1, VectorUIntList seedList, double lower=0, double upper=1, VectorUInt32 radius, double replaceValue=1) 112 | # threshed_mask=ST.NeighborhoodConnected(test_img,[[central[0],central[1],central[2]]],-50,np.float64(np.max(test_input)),[1,1,1],1.0) 113 | # threshed_array = ST.GetArrayFromImage(threshed_mask) 114 | # 115 | # img_array = ST.GetArrayFromImage(test_img)*threshed_array 116 | # threshed_img = ST.GetImageFromArray(img_array) 117 | # print img_array.dtype 118 | # ST.WriteImage(threshed_img,'./threshed_img1.vtk') 119 | # img_array = np.transpose(img_array,(2,1,0)) 120 | # reduced=Array_Reduce(test_input,0.5,0.5) 121 | # print time.strftime('%Y-%m-%d %H:%M:%S'),np.shape(reduced) 122 | # zoomed = Array_Zoom_in(test_input,1.377,1.377) 123 | # print time.strftime('%Y-%m-%d %H:%M:%S'),np.shape(zoomed) 124 | # print np.max(test_input[ranger[0]-1,:,:]) 125 | # print np.max(test_input[ranger[1]+1,:,:]) 126 | # print np.max(test_input[:,ranger[2]-1,:]) 127 | # print np.max(test_input[:,ranger[3]+1,:]) 128 | # test_input = sio.loadmat( 129 | # '/home/fortis/pycharmProjects/analyse_liver_data/out_put/3Dircadb1.1/PATIENT_DICOM/PATIENT_DICOM.mat')[ 130 | # 'original'] 131 | # test_mask = sio.loadmat( 132 | # '/home/fortis/pycharmProjects/analyse_liver_data/out_put/3Dircadb1.1/MASKS_DICOM/liver/liver.mat')[ 133 | # 'liver_mask'] 134 | 135 | def resize_image(test_input,resized_length): 136 | array_shape = np.shape(test_input) 137 | # print time.strftime('%Y-%m-%d %H:%M:%S'),' ',np.shape(test_input) 138 | ranger=get_valid_area(test_input) 139 | # print time.strftime('%Y-%m-%d %H:%M:%S'),' ',ranger 140 | 141 | sliced_img=test_input[ranger[0]:ranger[1]+1,ranger[2]:ranger[3]+1,:] 142 | size=np.array([ranger[1]-ranger[0]+1,ranger[3]-ranger[2]+1]) 143 | maxsize=np.max(size) 144 | minsize=np.min(size) 145 | 146 | sliced_size=np.shape(sliced_img) 147 | processed_img=np.zeros((maxsize,maxsize,array_shape[2]),np.float32) 148 | padding_size=maxsize-minsize 149 | try: 150 | if size[0]>size[1]: 151 | processed_img[:,padding_size/2:maxsize-padding_size/2,:]=sliced_img[:,:,:] 152 | else: 153 | processed_img[padding_size / 2:maxsize - padding_size / 2, :, :] = sliced_img[:, :, :] 154 | except Exception,e: 155 | if size[0]>size[1]: 156 | processed_img[:,padding_size/2+1:maxsize-padding_size/2,:]=sliced_img[:,:,:] 157 | else: 158 | processed_img[padding_size / 2+1:maxsize - padding_size / 2, :, :] = sliced_img[:, :, :] 159 | 160 | resized_rate=float(resized_length)/float(maxsize) 161 | if resized_rate<1: 162 | resized_img=Array_Reduce(processed_img,resized_rate,resized_rate) 163 | else: 164 | resized_img=Array_Zoom_in(processed_img,resized_rate,resized_rate) 165 | return resized_img,ranger 166 | 167 | def resize_mask(test_mask,ranger,resized_length): 168 | array_shape = np.shape(test_mask) 169 | 170 | sliced_img=test_mask[ranger[0]:ranger[1]+1,ranger[2]:ranger[3]+1,:] 171 | size=np.array([ranger[1]-ranger[0]+1,ranger[3]-ranger[2]+1]) 172 | maxsize=np.max(size) 173 | minsize=np.min(size) 174 | 175 | sliced_size=np.shape(sliced_img) 176 | processed_img=np.zeros((maxsize,maxsize,array_shape[2]),np.float32) 177 | padding_size=maxsize-minsize 178 | try: 179 | if size[0]>size[1]: 180 | processed_img[:,padding_size/2:maxsize-padding_size/2,:]=sliced_img[:,:,:] 181 | else: 182 | processed_img[padding_size / 2:maxsize - padding_size / 2, :, :] = sliced_img[:, :, :] 183 | except Exception,e: 184 | if size[0]>size[1]: 185 | processed_img[:,padding_size/2+1:maxsize-padding_size/2,:]=sliced_img[:,:,:] 186 | else: 187 | processed_img[padding_size / 2+1:maxsize - padding_size / 2, :, :] = sliced_img[:, :, :] 188 | 189 | resized_rate=float(resized_length)/float(maxsize) 190 | if resized_rate<1: 191 | resized_img=Array_Reduce(processed_img,resized_rate,resized_rate) 192 | else: 193 | resized_img=Array_Zoom_in(processed_img,resized_rate,resized_rate) 194 | return resized_img 195 | 196 | # print np.shape(resized_img) 197 | def resize_data(resized_length): 198 | # resized_length=128 199 | reader1=open('/opt/analyse_liver_data/filelist.pkl','rb') 200 | reader2=open('/opt/analyse_liver_data/data_meta.pkl','rb') 201 | filelist=pickle.load(reader1) 202 | meta_data=pickle.load(reader2) 203 | successrul_list=[] 204 | failed_list=[] 205 | for number,dataset in meta_data['matrixes'].items(): 206 | try: 207 | original_data=sio.loadmat(dataset['PATIENT_DICOM']) 208 | original_img_temp=original_data['original'] 209 | temp_shape = np.shape(original_img_temp) 210 | liver_data=sio.loadmat(dataset['liver']) 211 | liver_mask_temp=liver_data['liver_mask'] 212 | if temp_shape[0]<=resized_length and temp_shape[1]<=resized_length: 213 | # print "probably resized : \n",dataset['PATIENT_DICOM'],'\npassed this time' 214 | resized_original = original_img_temp 215 | resized_liver_mask = resized_liver_mask 216 | sio.savemat(dataset['PATIENT_DICOM'], {'original': original_img_temp,'original_resized':resized_original}) 217 | sio.savemat(dataset['liver'], {'liver_mask': liver_mask_temp,'liver_mask_resized':resized_liver_mask}) 218 | print time.strftime('%Y-%m-%d %H:%M:%S'), ' ', np.shape(resized_original) 219 | print time.strftime('%Y-%m-%d %H:%M:%S'), ' ', np.shape(resized_liver_mask) 220 | print "======================================================================================" 221 | successrul_list.append(dataset['PATIENT_DICOM']) 222 | successrul_list.append(dataset['liver']) 223 | continue 224 | print dataset['PATIENT_DICOM'] 225 | print dataset['liver'] 226 | print time.strftime('%Y-%m-%d %H:%M:%S'), ' ', np.shape(original_img_temp) 227 | print time.strftime('%Y-%m-%d %H:%M:%S'), ' ', np.shape(liver_mask_temp) 228 | resized_original,ranger_temp=resize_image(original_img_temp,resized_length) 229 | print 'valid area: ',ranger_temp 230 | resized_liver_mask=resize_mask(liver_mask_temp,ranger_temp,resized_length) 231 | resized_liver_mask=np.int8(resized_liver_mask>0) 232 | sio.savemat(dataset['PATIENT_DICOM'], {'original': original_img_temp,'original_resized':resized_original}) 233 | sio.savemat(dataset['liver'], {'liver_mask': liver_mask_temp,'liver_mask_resized':resized_liver_mask}) 234 | print time.strftime('%Y-%m-%d %H:%M:%S'), ' ', np.shape(resized_original) 235 | print time.strftime('%Y-%m-%d %H:%M:%S'), ' ', np.shape(resized_liver_mask) 236 | print "======================================================================================" 237 | successrul_list.append(dataset['PATIENT_DICOM']) 238 | successrul_list.append(dataset['liver']) 239 | except Exception,e: 240 | print e 241 | failed_list.append(dataset['PATIENT_DICOM']) 242 | failed_list.append(dataset['liver']) 243 | 244 | reader1.close() 245 | reader2.close() 246 | 247 | file1=open('./successful.txt','wb') 248 | file2=open('./failed.txt','wb') 249 | 250 | for item in successrul_list: 251 | file1.write(item+"\n") 252 | for item in failed_list: 253 | file2.write(item+"\n") 254 | 255 | file1.close() 256 | file2.close() 257 | 258 | # print time.strftime('%Y-%m-%d %H:%M:%S'),' ',np.shape(test_input) 259 | # resized,rangerer=resize_image(test_input,resized_length) 260 | # print time.strftime('%Y-%m-%d %H:%M:%S'),' ',np.shape(resized) 261 | # resized_mask = resize_mask(test_mask,rangerer,resized_length) 262 | # print time.strftime('%Y-%m-%d %H:%M:%S'),' ',np.shape(resized_mask) 263 | 264 | # resize_data(512) 265 | -------------------------------------------------------------------------------- /airway_seg.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import tensorflow as tf 4 | import scipy.io 5 | import tools 6 | import numpy as np 7 | import time 8 | import test 9 | import SimpleITK as ST 10 | from dicom_read import read_dicoms 11 | import gc 12 | 13 | resolution = 64 14 | batch_size = 4 15 | lr_down = [0.001,0.0002,0.0001] 16 | ori_lr = 0.001 17 | power = 0.9 18 | GPU0 = '0' 19 | input_shape = [64,64,128] 20 | output_shape = [64,64,128] 21 | type_num = 0 22 | 23 | ############################################################### 24 | config={} 25 | config['train_names'] = ['chair'] 26 | for name in config['train_names']: 27 | config['X_train_'+name] = './Data/'+name+'/train_25d/voxel_grids_64/' 28 | config['Y_train_'+name] = './Data/'+name+'/train_3d/voxel_grids_64/' 29 | 30 | config['test_names']=['chair'] 31 | for name in config['test_names']: 32 | config['X_test_'+name] = './Data/'+name+'/test_25d/voxel_grids_64/' 33 | config['Y_test_'+name] = './Data/'+name+'/test_3d/voxel_grids_64/' 34 | 35 | config['resolution'] = resolution 36 | config['batch_size'] = batch_size 37 | config['meta_path'] = '/opt/analyse_airway/data_meta.pkl' 38 | config['data_size'] = input_shape 39 | 40 | ################################################################ 41 | 42 | class Network: 43 | def __init__(self): 44 | self.train_models_dir = './airway_model/' 45 | # self.train_sum_dir = './train_sum/' 46 | # self.test_results_dir = './test_results/' 47 | # self.test_sum_dir = './test_sum/' 48 | 49 | def ae_u(self,X,training,batch_size,threshold): 50 | original=16 51 | growth=10 52 | dense_layer_num=12 53 | # input layer 54 | X=tf.reshape(X,[batch_size,input_shape[0],input_shape[1],input_shape[2],1]) 55 | # image reduce layer 56 | conv_input=tools.Ops.conv3d(X,k=3,out_c=original,str=2,name='conv_input') 57 | with tf.device('/gpu:'+GPU0): 58 | ##### dense block 1 59 | c_e = [] 60 | s_e = [] 61 | layers_e=[] 62 | layers_e.append(conv_input) 63 | for i in range(dense_layer_num): 64 | c_e.append(original+growth*(i+1)) 65 | s_e.append(1) 66 | for j in range(dense_layer_num): 67 | layer = tools.Ops.batch_norm(layers_e[-1], 'bn_dense_1_' + str(j), training=training) 68 | layer = tools.Ops.xxlu(layer, name='relu') 69 | layer = tools.Ops.conv3d(layer,k=3,out_c=growth,str=s_e[j],name='dense_1_'+str(j)) 70 | next_input = tf.concat([layer,layers_e[-1]],axis=4) 71 | layers_e.append(next_input) 72 | 73 | # middle down sample 74 | mid_layer = tools.Ops.batch_norm(layers_e[-1], 'bn_mid', training=training) 75 | mid_layer = tools.Ops.xxlu(mid_layer,name='relu') 76 | mid_layer = tools.Ops.conv3d(mid_layer,k=1,out_c=original+growth*dense_layer_num,str=1,name='mid_conv') 77 | mid_layer_down = tools.Ops.maxpool3d(mid_layer,k=2,s=2,pad='SAME') 78 | 79 | ##### dense block 80 | with tf.device('/gpu:'+GPU0): 81 | c_d = [] 82 | s_d = [] 83 | layers_d = [] 84 | layers_d.append(mid_layer_down) 85 | for i in range(dense_layer_num): 86 | c_d.append(original+growth*(dense_layer_num+i+1)) 87 | s_d.append(1) 88 | for j in range(dense_layer_num): 89 | layer = tools.Ops.batch_norm(layers_d[-1],'bn_dense_2_'+str(j),training=training) 90 | layer = tools.Ops.xxlu(layer, name='relu') 91 | layer = tools.Ops.conv3d(layer,k=3,out_c=growth,str=s_d[j],name='dense_2_'+str(j)) 92 | next_input = tf.concat([layer,layers_d[-1]],axis=4) 93 | layers_d.append(next_input) 94 | 95 | ##### final up-sampling 96 | bn_1 = tools.Ops.batch_norm(layers_d[-1],'bn_after_dense',training=training) 97 | relu_1 = tools.Ops.xxlu(bn_1 ,name='relu') 98 | conv_27 = tools.Ops.conv3d(relu_1,k=1,out_c=original+growth*dense_layer_num*2,str=1,name='conv_up_sample_1') 99 | deconv_1 = tools.Ops.deconv3d(conv_27,k=2,out_c=128,str=2,name='deconv_up_sample_1') 100 | concat_up = tf.concat([deconv_1,mid_layer],axis=4) 101 | deconv_2 = tools.Ops.deconv3d(concat_up,k=2,out_c=64,str=2,name='deconv_up_sample_2') 102 | 103 | predict_map = tools.Ops.conv3d(deconv_2,k=1,out_c=1,str=1,name='predict_map') 104 | 105 | vox_no_sig = predict_map 106 | # vox_no_sig = tools.Ops.xxlu(vox_no_sig,name='relu') 107 | vox_sig = tf.sigmoid(predict_map) 108 | vox_sig_modified = tf.maximum(vox_sig-threshold,0.01) 109 | return vox_sig, vox_sig_modified,vox_no_sig 110 | 111 | def dis(self, X, Y,training): 112 | with tf.device('/gpu:'+GPU0): 113 | X = tf.reshape(X,[batch_size,input_shape[0],input_shape[1],input_shape[2],1]) 114 | Y = tf.reshape(Y,[batch_size,output_shape[0],output_shape[1],output_shape[2],1]) 115 | layer = tf.concat([X,Y],axis=4) 116 | c_d = [1,2,64,128,256,512] 117 | s_d = [0,2,2,2,2,2] 118 | layers_d =[] 119 | layers_d.append(layer) 120 | for i in range(1,6,1): 121 | layer = tools.Ops.conv3d(layers_d[-1],k=4,out_c=c_d[i],str=s_d[i],name='d_1'+str(i)) 122 | if i!=5: 123 | layer = tools.Ops.xxlu(layer, name='lrelu') 124 | # batch normal layer 125 | layer = tools.Ops.batch_norm(layer, 'bn_up' + str(i), training=training) 126 | layers_d.append(layer) 127 | y = tf.reshape(layers_d[-1],[batch_size,-1]) 128 | # for j in range(len(layers_d)-1): 129 | # y = tf.concat([y,tf.reshape(layers_d[j],[batch_size,-1])],axis=1) 130 | return tf.nn.sigmoid(y) 131 | 132 | def test(self,dicom_dir): 133 | # X = tf.placeholder(shape=[batch_size, input_shape[0], input_shape[1], input_shape[2]], dtype=tf.float32) 134 | g_airway = tf.Graph() 135 | with g_airway.as_default(): 136 | test_input_shape = input_shape 137 | test_batch_size = batch_size 138 | threshold = tf.placeholder(tf.float32) 139 | training = tf.placeholder(tf.bool) 140 | X = tf.placeholder(shape=[test_batch_size, test_input_shape[0], test_input_shape[1], test_input_shape[2]], 141 | dtype=tf.float32) 142 | with tf.variable_scope('ae',reuse=False): 143 | Y_pred, Y_pred_modi, Y_pred_nosig = self.ae_u(X, training, test_batch_size, threshold) 144 | 145 | # print tools.Ops.variable_count() 146 | sum_merged = tf.summary.merge_all() 147 | saver = tf.train.Saver(max_to_keep=1) 148 | config = tf.ConfigProto(allow_soft_placement=True) 149 | config.gpu_options.visible_device_list = GPU0 150 | with tf.Session(config=config) as sess: 151 | if os.path.exists(self.train_models_dir): 152 | saver.restore(sess, self.train_models_dir + 'model.cptk') 153 | # sum_writer_train = tf.summary.FileWriter(self.train_sum_dir, sess.graph) 154 | # sum_write_test = tf.summary.FileWriter(self.test_sum_dir) 155 | 156 | if os.path.isfile(self.train_models_dir + 'model.cptk.data-00000-of-00001'): 157 | print "restoring saved model" 158 | saver.restore(sess, self.train_models_dir + 'model.cptk') 159 | else: 160 | sess.run(tf.global_variables_initializer()) 161 | test_data = tools.Test_data(dicom_dir, input_shape,'vtk_data') 162 | test_data.organize_blocks() 163 | block_numbers = test_data.blocks.keys() 164 | for i in range(0, len(block_numbers), test_batch_size): 165 | batch_numbers = [] 166 | if i + test_batch_size < len(block_numbers): 167 | temp_input = np.zeros( 168 | [test_batch_size, input_shape[0], input_shape[1], input_shape[2]]) 169 | for j in range(test_batch_size): 170 | temp_num = block_numbers[i + j] 171 | temp_block = test_data.blocks[temp_num] 172 | batch_numbers.append(temp_num) 173 | block_array = temp_block.load_data() 174 | block_shape = np.shape(block_array) 175 | temp_input[j, 0:block_shape[0], 0:block_shape[1], 0:block_shape[2]] += block_array 176 | Y_temp_pred, Y_temp_modi, Y_temp_pred_nosig = sess.run([Y_pred, Y_pred_modi, Y_pred_nosig], 177 | feed_dict={X: temp_input, 178 | training: False, 179 | threshold: 0.8}) 180 | for j in range(test_batch_size): 181 | test_data.upload_result(batch_numbers[j], Y_temp_modi[j, :, :, :]) 182 | else: 183 | temp_batch_size = len(block_numbers) - i 184 | temp_input = np.zeros( 185 | [temp_batch_size, input_shape[0], input_shape[1], input_shape[2]]) 186 | for j in range(temp_batch_size): 187 | temp_num = block_numbers[i + j] 188 | temp_block = test_data.blocks[temp_num] 189 | batch_numbers.append(temp_num) 190 | block_array = temp_block.load_data() 191 | block_shape = np.shape(block_array) 192 | temp_input[j, 0:block_shape[0], 0:block_shape[1], 0:block_shape[2]] += block_array 193 | X_temp = tf.placeholder( 194 | shape=[temp_batch_size, input_shape[0], input_shape[1], input_shape[2]], 195 | dtype=tf.float32) 196 | with tf.variable_scope('ae', reuse=True): 197 | Y_pred_temp, Y_pred_modi_temp, Y_pred_nosig_temp = self.ae_u(X_temp, training, 198 | temp_batch_size, threshold) 199 | Y_temp_pred, Y_temp_modi, Y_temp_pred_nosig = sess.run( 200 | [Y_pred_temp, Y_pred_modi_temp, Y_pred_nosig_temp], 201 | feed_dict={X_temp: temp_input, 202 | training: False, 203 | threshold: 0.8}) 204 | for j in range(temp_batch_size): 205 | test_data.upload_result(batch_numbers[j], Y_temp_modi[j, :, :, :]) 206 | test_result_array = test_data.get_result() 207 | # print "result shape: ", np.shape(test_result_array) 208 | r_s = np.shape(test_result_array) # result shape 209 | e_t = 10 # edge thickness 210 | to_be_transformed = np.zeros(r_s, np.float32) 211 | to_be_transformed[e_t:r_s[0] - e_t, e_t:r_s[1] - e_t, 0:r_s[2] - e_t] += test_result_array[ 212 | e_t:r_s[0] - e_t, 213 | e_t:r_s[1] - e_t, 214 | 0:r_s[2] - e_t] 215 | # print np.max(to_be_transformed) 216 | # print np.min(to_be_transformed) 217 | final_img = ST.GetImageFromArray(np.transpose(to_be_transformed, [2, 1, 0])) 218 | final_img.SetSpacing(test_data.space) 219 | return final_img 220 | 221 | def airway_seg(lung_img): 222 | time1 = time.time() 223 | net = Network() 224 | airway_mask = net.test(lung_img) 225 | time2 = time.time() 226 | del net 227 | gc.collect() 228 | print "Writing airway mask" 229 | ST.WriteImage(airway_mask,'./output/airway_mask.vtk') 230 | print "total time cost of airway segmentation: ",str(time2-time1),'s' 231 | return airway_mask -------------------------------------------------------------------------------- /lung_seg.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import tensorflow as tf 4 | import scipy.io 5 | import tools 6 | import numpy as np 7 | import time 8 | import test 9 | import SimpleITK as ST 10 | from dicom_read import read_dicoms 11 | import gc 12 | 13 | resolution = 64 14 | batch_size = 4 15 | lr_down = [0.001,0.0002,0.0001] 16 | ori_lr = 0.001 17 | power = 0.9 18 | GPU0 = '0' 19 | input_shape = [512,512,4] 20 | output_shape = [512,512,4] 21 | type_num = 0 22 | 23 | ############################################################### 24 | config={} 25 | config['train_names'] = ['chair'] 26 | for name in config['train_names']: 27 | config['X_train_'+name] = './Data/'+name+'/train_25d/voxel_grids_64/' 28 | config['Y_train_'+name] = './Data/'+name+'/train_3d/voxel_grids_64/' 29 | 30 | config['test_names']=['chair'] 31 | for name in config['test_names']: 32 | config['X_test_'+name] = './Data/'+name+'/test_25d/voxel_grids_64/' 33 | config['Y_test_'+name] = './Data/'+name+'/test_3d/voxel_grids_64/' 34 | 35 | config['resolution'] = resolution 36 | config['batch_size'] = batch_size 37 | config['meta_path'] = '/opt/analyse_lung/data_meta.pkl' 38 | config['data_size'] = input_shape 39 | 40 | ################################################################ 41 | 42 | class Network: 43 | def __init__(self): 44 | self.train_models_dir = './lung_model/' 45 | 46 | def ae_u(self,X,training,batch_size,threshold): 47 | original=16 48 | growth=12 49 | dense_layer_num=6 50 | # input layer 51 | X=tf.reshape(X,[batch_size,input_shape[0],input_shape[1],input_shape[2],1]) 52 | # image reduce layer 53 | # conv_input_1=tools.Ops.conv3d(X,k=3,out_c=2,str=2,name='conv_input_down') 54 | # conv_input_normed=tools.Ops.batch_norm(conv_input_1, 'bn_dense_0_0', training=training) 55 | # network start 56 | conv_input_1=tools.Ops.conv3d(X,k=3,out_c=original,str=1,name='conv_input_1') 57 | conv_input=tools.Ops.conv3d(conv_input_1,k=3,out_c=original,str=2,name='conv_input') 58 | with tf.device('/gpu:'+GPU0): 59 | ##### dense block 1 60 | c_e = [] 61 | s_e = [] 62 | layers_e=[] 63 | layers_e.append(conv_input) 64 | for i in range(dense_layer_num): 65 | c_e.append(original+growth*(i+1)) 66 | s_e.append(1) 67 | for j in range(dense_layer_num): 68 | layer = tools.Ops.batch_norm(layers_e[-1], 'bn_dense_1_' + str(j), training=training) 69 | layer = tools.Ops.xxlu(layer, name='relu') 70 | layer = tools.Ops.conv3d(layer,k=3,out_c=growth,str=s_e[j],name='dense_1_'+str(j)) 71 | next_input = tf.concat([layer,layers_e[-1]],axis=4) 72 | layers_e.append(next_input) 73 | 74 | # middle down sample 75 | mid_layer = tools.Ops.batch_norm(layers_e[-1], 'bn_mid', training=training) 76 | mid_layer = tools.Ops.xxlu(mid_layer,name='relu') 77 | mid_layer = tools.Ops.conv3d(mid_layer,k=1,out_c=original+growth*dense_layer_num,str=1,name='mid_conv') 78 | mid_layer_down = tools.Ops.maxpool3d(mid_layer,k=2,s=2,pad='SAME') 79 | 80 | ##### dense block 81 | with tf.device('/gpu:'+GPU0): 82 | # lfc = tools.Ops.xxlu(tools.Ops.fc(lfc, out_d=d1 * d2 * d3 * cc, name='fc2'),name='relu') 83 | # lfc = tf.reshape(lfc, [bat, d1, d2, d3, cc]) 84 | 85 | c_d = [] 86 | s_d = [] 87 | layers_d = [] 88 | layers_d.append(mid_layer_down) 89 | for i in range(dense_layer_num): 90 | c_d.append(original+growth*(dense_layer_num+i+1)) 91 | s_d.append(1) 92 | for j in range(dense_layer_num): 93 | layer = tools.Ops.batch_norm(layers_d[-1],'bn_dense_2_'+str(j),training=training) 94 | layer = tools.Ops.xxlu(layer, name='relu') 95 | layer = tools.Ops.conv3d(layer,k=3,out_c=growth,str=s_d[j],name='dense_2_'+str(j)) 96 | next_input = tf.concat([layer,layers_d[-1]],axis=4) 97 | layers_d.append(next_input) 98 | 99 | ##### final up-sampling 100 | bn_1 = tools.Ops.batch_norm(layers_d[-1],'bn_after_dense',training=training) 101 | relu_1 = tools.Ops.xxlu(bn_1 ,name='relu') 102 | conv_27 = tools.Ops.conv3d(relu_1,k=1,out_c=original+growth*dense_layer_num*2,str=1,name='conv_up_sample_1') 103 | deconv_1 = tools.Ops.deconv3d(conv_27,k=2,out_c=128,str=2,name='deconv_up_sample_1') 104 | concat_up = tf.concat([deconv_1,mid_layer],axis=4) 105 | deconv_2 = tools.Ops.deconv3d(concat_up,k=2,out_c=64,str=2,name='deconv_up_sample_2') 106 | concat_up_1 = tf.concat([deconv_2, conv_input_1], axis=4) 107 | predict_map = tools.Ops.conv3d(concat_up_1,k=1,out_c=1,str=1,name='predict_map') 108 | 109 | # zoom in layer 110 | # predict_map_normed = tools.Ops.batch_norm(predict_map,'bn_after_dense_1',training=training) 111 | # predict_map_zoomed = tools.Ops.deconv3d(predict_map_normed,k=2,out_c=1,str=2,name='deconv_zoom_3') 112 | 113 | vox_no_sig = predict_map 114 | # vox_no_sig = tools.Ops.xxlu(vox_no_sig,name='relu') 115 | vox_sig = tf.sigmoid(predict_map) 116 | vox_sig_modified = tf.maximum(vox_sig-threshold,0.01) 117 | return vox_sig, vox_sig_modified,vox_no_sig 118 | 119 | def dis(self, X, Y,training): 120 | with tf.device('/gpu:'+GPU0): 121 | X = tf.reshape(X,[batch_size,input_shape[0],input_shape[1],input_shape[2],1]) 122 | Y = tf.reshape(Y,[batch_size,output_shape[0],output_shape[1],output_shape[2],1]) 123 | layer = tf.concat([X,Y],axis=4) 124 | c_d = [1,2,64,128,2565] 125 | s_d = [0,2,2,2,2] 126 | layers_d =[] 127 | layers_d.append(layer) 128 | for i in range(1,5,1): 129 | layer = tools.Ops.conv3d(layers_d[-1],k=4,out_c=c_d[i],str=s_d[i],name='d_1'+str(i)) 130 | if i!=5: 131 | layer = tools.Ops.xxlu(layer, name='lrelu') 132 | # batch normal layer 133 | layer = tools.Ops.batch_norm(layer, 'bn_up' + str(i), training=training) 134 | layers_d.append(layer) 135 | y = tf.reshape(layers_d[-1],[batch_size,-1]) 136 | # for j in range(len(layers_d)-1): 137 | # y = tf.concat([y,tf.reshape(layers_d[j],[batch_size,-1])],axis=1) 138 | return tf.nn.sigmoid(y) 139 | 140 | def test(self,dicom_dir): 141 | g_lung = tf.Graph() 142 | with g_lung.as_default(): 143 | # X = tf.placeholder(shape=[batch_size, input_shape[0], input_shape[1], input_shape[2]], dtype=tf.float32) 144 | test_input_shape = input_shape 145 | test_batch_size = batch_size 146 | threshold = tf.placeholder(tf.float32) 147 | training = tf.placeholder(tf.bool) 148 | X = tf.placeholder(shape=[test_batch_size, test_input_shape[0], test_input_shape[1], test_input_shape[2]], 149 | dtype=tf.float32) 150 | with tf.variable_scope('lung_net'): 151 | Y_pred, Y_pred_modi, Y_pred_nosig = self.ae_u(X, training, test_batch_size, threshold) 152 | 153 | # print tools.Ops.variable_count() 154 | # sum_merged = tf.summary.merge_all() 155 | saver = tf.train.Saver(max_to_keep=1) 156 | config = tf.ConfigProto(allow_soft_placement=True) 157 | config.gpu_options.visible_device_list = GPU0 158 | # with tf.Session(config=config) as sess: 159 | sess = tf.Session(config=config) 160 | if os.path.exists(self.train_models_dir): 161 | saver.restore(sess, self.train_models_dir + 'model.cptk') 162 | 163 | if os.path.isfile(self.train_models_dir + 'model.cptk.data-00000-of-00001'): 164 | print "restoring saved model" 165 | saver.restore(sess, self.train_models_dir + 'model.cptk') 166 | else: 167 | sess.run(tf.global_variables_initializer()) 168 | test_data = tools.Test_data(dicom_dir, input_shape, 'dicom_data') 169 | test_data.organize_blocks() 170 | block_numbers = test_data.blocks.keys() 171 | for i in range(0, len(block_numbers), test_batch_size): 172 | batch_numbers = [] 173 | if i + test_batch_size < len(block_numbers): 174 | temp_input = np.zeros( 175 | [test_batch_size, input_shape[0], input_shape[1], input_shape[2]]) 176 | for j in range(test_batch_size): 177 | temp_num = block_numbers[i + j] 178 | temp_block = test_data.blocks[temp_num] 179 | batch_numbers.append(temp_num) 180 | block_array = temp_block.load_data() 181 | block_shape = np.shape(block_array) 182 | temp_input[j, 0:block_shape[0], 0:block_shape[1], 0:block_shape[2]] += block_array 183 | Y_temp_pred, Y_temp_modi, Y_temp_pred_nosig = sess.run([Y_pred, Y_pred_modi, Y_pred_nosig], 184 | feed_dict={X: temp_input, 185 | training: False, 186 | threshold: 0.8}) 187 | for j in range(test_batch_size): 188 | test_data.upload_result(batch_numbers[j], Y_temp_modi[j, :, :, :]) 189 | else: 190 | temp_batch_size = len(block_numbers) - i 191 | temp_input = np.zeros( 192 | [temp_batch_size, input_shape[0], input_shape[1], input_shape[2]]) 193 | for j in range(temp_batch_size): 194 | temp_num = block_numbers[i + j] 195 | temp_block = test_data.blocks[temp_num] 196 | batch_numbers.append(temp_num) 197 | block_array = temp_block.load_data() 198 | block_shape = np.shape(block_array) 199 | temp_input[j, 0:block_shape[0], 0:block_shape[1], 0:block_shape[2]] += block_array 200 | X_temp = tf.placeholder( 201 | shape=[temp_batch_size, input_shape[0], input_shape[1], input_shape[2]], 202 | dtype=tf.float32) 203 | with tf.variable_scope('lung_net', reuse=True): 204 | Y_pred_temp, Y_pred_modi_temp, Y_pred_nosig_temp = self.ae_u(X_temp, training, 205 | temp_batch_size, threshold) 206 | Y_temp_pred, Y_temp_modi, Y_temp_pred_nosig = sess.run( 207 | [Y_pred_temp, Y_pred_modi_temp, Y_pred_nosig_temp], 208 | feed_dict={X_temp: temp_input, 209 | training: False, 210 | threshold: 0.8}) 211 | for j in range(temp_batch_size): 212 | test_data.upload_result(batch_numbers[j], Y_temp_modi[j, :, :, :]) 213 | test_result_array = test_data.get_result() 214 | # print "result shape: ", np.shape(test_result_array) 215 | r_s = np.shape(test_result_array) # result shape 216 | e_t = 10 # edge thickness 217 | to_be_transformed = np.zeros(r_s, np.float32) 218 | to_be_transformed[e_t:r_s[0] - e_t, e_t:r_s[1] - e_t, 0:r_s[2] - e_t] += test_result_array[ 219 | e_t:r_s[0] - e_t, 220 | e_t:r_s[1] - e_t, 221 | 0:r_s[2] - e_t] 222 | # print np.max(to_be_transformed) 223 | # print np.min(to_be_transformed) 224 | final_img = ST.GetImageFromArray(np.transpose(to_be_transformed, [2, 1, 0])) 225 | final_img.SetSpacing(test_data.space) 226 | # print "writing final testing result" 227 | # ST.WriteImage(final_img, './lung_mask.vtk') 228 | return final_img 229 | 230 | def post_process(img,dicom_dir): 231 | # print img.GetSize() 232 | original_img = read_dicoms(dicom_dir) 233 | img_array = np.transpose(ST.GetArrayFromImage(img),[2,1,0]) 234 | img_shape = np.shape(img_array) 235 | # Get outer mask to ensure outer noise get excluded 236 | original_array = ST.GetArrayFromImage(original_img) 237 | min_val = np.min(original_array) 238 | outer_seeds = [] 239 | inner_step = 2 240 | outer_seeds.append([inner_step, inner_step, img_shape[2] - inner_step]) 241 | outer_seeds.append([inner_step, img_shape[1] - inner_step, inner_step]) 242 | outer_seeds.append([img_shape[0] - inner_step, inner_step, inner_step]) 243 | outer_seeds.append([inner_step, img_shape[1] - inner_step, img_shape[2] - inner_step]) 244 | outer_seeds.append([img_shape[0] - inner_step, inner_step, img_shape[2] - inner_step]) 245 | outer_seeds.append([img_shape[0] - inner_step, img_shape[1] - inner_step, inner_step]) 246 | outer_seeds.append([img_shape[0] - inner_step, img_shape[1] - inner_step, img_shape[2] - inner_step]) 247 | outer_space = ST.NeighborhoodConnected(original_img, outer_seeds, min_val * 1.0, -200, [1, 1, 0], 1.0) 248 | # ST.WriteImage(outer_space , './outer_space.vtk') 249 | outer_array = ST.GetArrayFromImage(outer_space) 250 | outer_array = np.transpose(outer_array, [2, 1, 0]) 251 | # Take out outer noise 252 | inner_array = np.float32((img_array - outer_array) > 0) 253 | inner_img = ST.GetImageFromArray(np.transpose(inner_array,[2,1,0])) 254 | # ST.WriteImage(inner_img,'./inner_mask.vtk') 255 | 256 | median_filter = ST.MedianImageFilter() 257 | median_filter.SetRadius(1) 258 | midian_img = median_filter.Execute(inner_img) 259 | midian_array = ST.GetArrayFromImage(midian_img) 260 | midian_array = np.transpose(midian_array,[2,1,0]) 261 | array_shape = np.shape(midian_array) 262 | 263 | seed = [0,0,0] 264 | max = 0 265 | for i in range(array_shape[0]): 266 | temp_max = np.sum(midian_array[i,:,:]) 267 | if max < temp_max: 268 | max = temp_max 269 | seed[0]=i 270 | max = 0 271 | for i in range(array_shape[1]): 272 | temp_max = np.sum(midian_array[:,i,:]) 273 | if max < temp_max: 274 | max = temp_max 275 | seed[1]=i 276 | max = 0 277 | for i in range(array_shape[2]): 278 | temp_max = np.sum(midian_array[:,:,i]) 279 | if max < temp_max: 280 | max = temp_max 281 | seed[2]=i 282 | # print seed 283 | growed_img = ST.NeighborhoodConnected(img, [seed], 0.9,1, [1, 1, 1], 1.0) 284 | 285 | return img,growed_img 286 | 287 | def Lung_Seg(dicom_dir): 288 | time1 = time.time() 289 | original_img = read_dicoms(dicom_dir) 290 | net = Network() 291 | final_img = net.test(dicom_dir) 292 | del net 293 | gc.collect() 294 | img_spacing = final_img.GetSpacing() 295 | time2 = time.time() 296 | print "time cost for lung sement: ",str(time2-time1),'s' 297 | time3 = time.time() 298 | final_img, growed_mask = post_process(final_img,dicom_dir) 299 | growed_mask.SetSpacing(img_spacing) 300 | print "Writing lung mask" 301 | # ST.WriteImage(growed_mask, './output/lung_mask.vtk') 302 | time4 = time.time() 303 | print "time cost for lung post_process: ",str(time4-time3),'s' 304 | final_array = ST.GetArrayFromImage(growed_mask) 305 | img_array = ST.GetArrayFromImage(original_img) 306 | lung_array = final_array*img_array 307 | # lung_array = lung_array + np.min(lung_array)*2*np.int8(lung_array==0) 308 | lung_img = ST.GetImageFromArray(lung_array) 309 | lung_img.SetSpacing(img_spacing) 310 | print "Writing lung image" 311 | # ST.WriteImage(lung_img,'./output/lung_img.vtk') 312 | return lung_img 313 | 314 | # if __name__ =="__main__": 315 | # dicom_dir = "./WANG_REN/original1" 316 | # lung_img = Lung_Seg(dicom_dir) 317 | -------------------------------------------------------------------------------- /tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import re 4 | from random import shuffle 5 | import tensorflow as tf 6 | import matplotlib.pyplot as plt 7 | import matplotlib.gridspec as gridspec 8 | from mpl_toolkits import mplot3d 9 | import random 10 | import organize_data 11 | from dicom_read import read_dicoms 12 | import SimpleITK as ST 13 | 14 | class Data_block: 15 | # single input data block 16 | def __init__(self,ranger,data_array): 17 | self.ranger=ranger 18 | self.data_array=data_array 19 | 20 | def get_range(self): 21 | return self.ranger 22 | 23 | def load_data(self): 24 | return self.data_array 25 | 26 | class Test_data(): 27 | # load data and translate to original array 28 | def __init__(self,data,block_shape,type): 29 | if type == 'dicom_data': 30 | self.img = read_dicoms(data) 31 | elif type == 'vtk_data': 32 | self.img = data 33 | self.space = self.img.GetSpacing() 34 | self.image_array = ST.GetArrayFromImage(self.img) 35 | self.image_array = np.transpose(self.image_array,[2,1,0]) 36 | self.image_shape = np.shape(self.image_array) 37 | self.block_shape=block_shape 38 | self.blocks=dict() 39 | self.results=dict() 40 | 41 | # do the simple threshold function 42 | def threshold(self,low,high): 43 | mask_array=np.float32(np.float32(self.image_array<=high)*np.float32(self.image_array>=low)) 44 | return np.float32(np.float32(self.image_array<=high)*np.float32(self.image_array>=low)) 45 | 46 | def organize_blocks(self): 47 | block_num=0 48 | original_shape=np.shape(self.image_array) 49 | threshed_array = self.image_array*np.float32(self.image_array<=0) 50 | print 'data shape: ', original_shape 51 | for i in range(0,original_shape[0],self.block_shape[0]/2): 52 | for j in range(0,original_shape[1],self.block_shape[1]/2): 53 | for k in range(0,original_shape[2],self.block_shape[2]/2): 54 | if i0) 65 | this_result = Data_block(ranger,partial_result) 66 | self.results[block_num]=this_result 67 | 68 | def get_result(self): 69 | ret=np.zeros(self.image_shape,np.float32) 70 | for number in self.results.keys(): 71 | try: 72 | ranger=self.results[number].get_range() 73 | xmin=ranger[0] 74 | xmax=ranger[1] 75 | ymin=ranger[2] 76 | ymax=ranger[3] 77 | zmin=ranger[4] 78 | zmax=ranger[5] 79 | temp_result = self.results[number].load_data()[:,:,:,0] 80 | # temp_shape = np.shape(temp_result) 81 | ret[xmin:xmax,ymin:ymax,zmin:zmax]+=temp_result[:xmax-xmin,:ymax-ymin,:zmax-zmin] 82 | except Exception,e: 83 | print np.shape(self.results[number].load_data()[:,:,:,0]),self.results[number].get_range() 84 | return np.float32(ret>=2) 85 | 86 | class Data: 87 | def __init__(self,config,epoch): 88 | self.config = config 89 | self.train_batch_index = 0 90 | self.test_seq_index = 0 91 | self.epoch = epoch 92 | self.resolution = config['resolution'] 93 | self.batch_size = config['batch_size'] 94 | 95 | self.train_names = config['train_names'] 96 | self.test_names = config['test_names'] 97 | self.data_size = config['data_size'] 98 | # self.X_train_files, self.Y_train_files = self.load_X_Y_files_paths_all( self.train_names,label='train') 99 | # self.X_test_files, self.Y_test_files = self.load_X_Y_files_paths_all(self.test_names,label='test') 100 | # print "X_train_files:",len(self.X_train_files) 101 | # print "X_test_files:",len(self.X_test_files) 102 | 103 | self.train_numbers,self.test_numbers = self.load_X_Y_numbers_special(config['meta_path'],self.epoch) 104 | 105 | # self.total_train_batch_num = int(len(self.X_train_files) // self.batch_size) -1 106 | # self.total_test_seq_batch = int(len(self.X_test_files) // self.batch_size) -1 107 | print "train_numbers:",len(self.train_numbers),"---",self.train_numbers 108 | print "test_numbers:",len(self.test_numbers),"---",self.test_numbers 109 | self.total_train_batch_num,self.train_locs = self.load_X_Y_train_batch_num() 110 | self.total_test_seq_batch,self.test_locs = self.load_X_Y_test_batch_num() 111 | print "total_train_batch_num: ", self.total_train_batch_num 112 | print "total_test_seq_batch: ",self.total_test_seq_batch 113 | # self.check_data() 114 | self.shuffle_X_Y_pairs() 115 | # testing code 116 | # for i in range(0,3): 117 | # X_train_voxels,Y_train_voxels=self.load_X_Y_voxel_train_next_batch() 118 | # X_test_voxels,Y_test_voxels=self.load_X_Y_voxel_test_next_batch() 119 | # print 123 120 | 121 | 122 | @staticmethod 123 | def plotFromVoxels(voxels,original): 124 | if len(voxels.shape)>3: 125 | x_d = voxels.shape[0] 126 | y_d = voxels.shape[1] 127 | z_d = voxels.shape[2] 128 | v = voxels[:,:,:,0] 129 | v = np.reshape(v,(x_d,y_d,z_d)) 130 | else: 131 | v = voxels 132 | x, y, z = v.nonzero() 133 | fig = plt.figure() 134 | ax = fig.add_subplot(111, projection='3d') 135 | ax.scatter(x, y, z, zdir='z', c='red') 136 | print "generated :",str(len(x)) 137 | 138 | if len(original.shape)>3: 139 | x_d = original.shape[0] 140 | y_d = original.shape[1] 141 | z_d = original.shape[2] 142 | v_ori = original[:,:,:,0] 143 | v_ori = np.reshape(v_ori,(x_d,y_d,z_d)) 144 | else: 145 | v_ori = original 146 | x, y, z = v_ori.nonzero() 147 | fig = plt.figure() 148 | ax_ori = fig.add_subplot(111, projection='3d') 149 | ax_ori.scatter(x, y, z, zdir='z', c='red') 150 | print "orign :", str(len(x)) 151 | 152 | plt.show() 153 | 154 | def load_X_Y_files_paths_all(self, obj_names, label='train'): 155 | x_str='' 156 | y_str='' 157 | if label =='train': 158 | x_str='X_train_' 159 | y_str ='Y_train_' 160 | 161 | elif label == 'test': 162 | x_str = 'X_test_' 163 | y_str = 'Y_test_' 164 | 165 | else: 166 | print "label error!!" 167 | exit() 168 | 169 | X_data_files_all = [] 170 | Y_data_files_all = [] 171 | for name in obj_names: 172 | X_folder = self.config[x_str + name] 173 | Y_folder = self.config[y_str + name] 174 | X_data_files, Y_data_files = self.load_X_Y_files_paths(X_folder, Y_folder) 175 | 176 | for X_f, Y_f in zip(X_data_files, Y_data_files): 177 | if X_f[0:15] != Y_f[0:15]: 178 | print "index inconsistent!!\n" 179 | exit() 180 | X_data_files_all.append(X_folder + X_f) 181 | Y_data_files_all.append(Y_folder + Y_f) 182 | return X_data_files_all, Y_data_files_all 183 | 184 | def load_X_Y_files_paths(self,X_folder, Y_folder): 185 | X_data_files = [X_f for X_f in sorted(os.listdir(X_folder))] 186 | Y_data_files = [Y_f for Y_f in sorted(os.listdir(Y_folder))] 187 | 188 | return X_data_files, Y_data_files 189 | 190 | def voxel_grid_padding(self,a): 191 | x_d = a.shape[0] 192 | y_d = a.shape[1] 193 | z_d = a.shape[2] 194 | channel = a.shape[3] 195 | resolution = self.resolution 196 | size = [resolution, resolution, resolution,channel] 197 | b = np.zeros(size) 198 | 199 | bx_s = 0;bx_e = size[0];by_s = 0;by_e = size[1];bz_s = 0; bz_e = size[2] 200 | ax_s = 0;ax_e = x_d;ay_s = 0;ay_e = y_d;az_s = 0;az_e = z_d 201 | if x_d > size[0]: 202 | ax_s = int((x_d - size[0]) / 2) 203 | ax_e = int((x_d - size[0]) / 2) + size[0] 204 | else: 205 | bx_s = int((size[0] - x_d) / 2) 206 | bx_e = int((size[0] - x_d) / 2) + x_d 207 | 208 | if y_d > size[1]: 209 | ay_s = int((y_d - size[1]) / 2) 210 | ay_e = int((y_d - size[1]) / 2) + size[1] 211 | else: 212 | by_s = int((size[1] - y_d) / 2) 213 | by_e = int((size[1] - y_d) / 2) + y_d 214 | 215 | if z_d > size[2]: 216 | az_s = int((z_d - size[2]) / 2) 217 | az_e = int((z_d - size[2]) / 2) + size[2] 218 | else: 219 | bz_s = int((size[2] - z_d) / 2) 220 | bz_e = int((size[2] - z_d) / 2) + z_d 221 | b[bx_s:bx_e, by_s:by_e, bz_s:bz_e,:] = a[ax_s:ax_e, ay_s:ay_e, az_s:az_e, :] 222 | 223 | return b 224 | 225 | def load_single_voxel_grid(self,path): 226 | temp = re.split('_', path.split('.')[-2]) 227 | x_d = int(temp[len(temp) - 3]) 228 | y_d = int(temp[len(temp) - 2]) 229 | z_d = int(temp[len(temp) - 1]) 230 | 231 | a = np.loadtxt(path) 232 | if len(a)<=0: 233 | print " load_single_voxel_grid error: ", path 234 | exit() 235 | 236 | voxel_grid = np.zeros((x_d, y_d, z_d,1)) 237 | for i in a: 238 | voxel_grid[int(i[0]), int(i[1]), int(i[2]),0] = 1 # occupied 239 | 240 | #Data.plotFromVoxels(voxel_grid) 241 | voxel_grid = self.voxel_grid_padding(voxel_grid) 242 | return voxel_grid 243 | 244 | def load_X_Y_voxel_grids(self,X_data_files, Y_data_files): 245 | if len(X_data_files) !=self.batch_size or len(Y_data_files)!=self.batch_size: 246 | print "load_X_Y_voxel_grids error:", X_data_files, Y_data_files 247 | exit() 248 | 249 | X_voxel_grids = [] 250 | Y_voxel_grids = [] 251 | index = -1 252 | for X_f, Y_f in zip(X_data_files, Y_data_files): 253 | index += 1 254 | X_voxel_grid = self.load_single_voxel_grid(X_f) 255 | X_voxel_grids.append(X_voxel_grid) 256 | 257 | Y_voxel_grid = self.load_single_voxel_grid(Y_f) 258 | Y_voxel_grids.append(Y_voxel_grid) 259 | 260 | X_voxel_grids = np.asarray(X_voxel_grids) 261 | Y_voxel_grids = np.asarray(Y_voxel_grids) 262 | return X_voxel_grids, Y_voxel_grids 263 | 264 | def load_X_Y_numbers_special(self,meta_path,epoch): 265 | self.dicom_origin,self.mask = organize_data.get_organized_data(meta_path,self.data_size,epoch) 266 | numbers=[] 267 | train_numbers=[] 268 | test_numbers=[] 269 | for number in self.mask.keys(): 270 | if len(self.mask[number])>0: 271 | numbers.append(number) 272 | for i in range(1): 273 | test_numbers.append(numbers[random.randint(0,len(numbers)-1)]) 274 | for number in numbers: 275 | if not number in test_numbers: 276 | train_numbers.append(number) 277 | return train_numbers,test_numbers 278 | 279 | def load_X_Y_train_batch_num(self): 280 | total_num=0 281 | locs=[] 282 | for number in self.train_numbers: 283 | for i in range(len(self.mask[number])): 284 | total_num=total_num+1 285 | locs.append([number,i]) 286 | return int(total_num/self.batch_size),locs 287 | 288 | def load_X_Y_test_batch_num(self): 289 | total_num = 0 290 | locs=[] 291 | for number in self.test_numbers: 292 | for i in range(len(self.mask[number])): 293 | total_num = total_num + 1 294 | locs.append([number,i]) 295 | return int(total_num / self.batch_size),locs 296 | 297 | def shuffle_X_Y_files(self, label='train'): 298 | X_new = []; Y_new = [] 299 | if label == 'train': 300 | X = self.X_train_files; Y = self.Y_train_files 301 | self.train_batch_index = 0 302 | index = range(len(X)) 303 | shuffle(index) 304 | for i in index: 305 | X_new.append(X[i]) 306 | Y_new.append(Y[i]) 307 | self.X_train_files = X_new 308 | self.Y_train_files = Y_new 309 | 310 | elif label == 'test': 311 | X = self.X_test_files; Y = self.Y_test_files 312 | self.test_seq_index = 0 313 | index = range(len(X)) 314 | shuffle(index) 315 | for i in index: 316 | X_new.append(X[i]) 317 | Y_new.append(Y[i]) 318 | self.X_test_files = X_new 319 | self.Y_test_files = Y_new 320 | 321 | else: 322 | print "shuffle_X_Y_files error!\n" 323 | exit() 324 | 325 | def shuffle_X_Y_pairs(self): 326 | train_locs_new=[] 327 | test_locs_new=[] 328 | trains=self.train_locs 329 | tests=self.test_locs 330 | self.train_batch_index = 0 331 | train_index = range(len(trains)) 332 | test_index = range(len(tests)) 333 | shuffle(train_index) 334 | shuffle(test_index) 335 | for i in train_index: 336 | train_locs_new.append(trains[i]) 337 | for j in test_index: 338 | test_locs_new.append(tests[j]) 339 | self.train_locs=train_locs_new 340 | self.test_locs=test_locs_new 341 | 342 | ###################### voxel grids 343 | def load_X_Y_voxel_grids_train_next_batch(self): 344 | X_data_files = self.X_train_files[self.batch_size * self.train_batch_index:self.batch_size * (self.train_batch_index + 1)] 345 | Y_data_files = self.Y_train_files[self.batch_size * self.train_batch_index:self.batch_size * (self.train_batch_index + 1)] 346 | self.train_batch_index += 1 347 | # self.train_batch_index=0 348 | 349 | X_voxel_grids, Y_voxel_grids = self.load_X_Y_voxel_grids(X_data_files, Y_data_files) 350 | return X_voxel_grids, Y_voxel_grids 351 | 352 | def load_X_Y_voxel_train_next_batch(self): 353 | temp_locs=self.train_locs[self.batch_size*self.train_batch_index:self.batch_size*(self.train_batch_index+1)] 354 | X_data_voxels=[] 355 | Y_data_voxels=[] 356 | for pair in temp_locs: 357 | X_data_voxels.append(self.dicom_origin[pair[0]][pair[1]]) 358 | Y_data_voxels.append(self.mask[pair[0]][pair[1]]) 359 | self.train_batch_index += 1 360 | X_data = np.zeros([self.batch_size,self.data_size[0],self.data_size[1],self.data_size[2]],np.float32) 361 | Y_data = np.zeros([self.batch_size,self.data_size[0],self.data_size[1],self.data_size[2]],np.float32) 362 | ''' 363 | X_voxel_grids = np.asarray(X_voxel_grids) 364 | Y_voxel_grids = np.asarray(Y_voxel_grids) 365 | X_data_voxels=np.asarray(X_data_voxels) 366 | Y_data_voxels=np.asarray(Y_data_voxels) 367 | ''' 368 | for i in range(len(X_data_voxels)): 369 | temp_X = X_data_voxels[i][:,:,:] 370 | temp_y = Y_data_voxels[i][:,:,:] 371 | shape_X = np.shape(temp_X) 372 | shape_Y = np.shape(temp_y) 373 | X_data[i,:shape_X[0],:shape_X[1],:shape_X[2]] = X_data_voxels[i][:,:,:] 374 | Y_data[i,:shape_Y[0],:shape_Y[1],:shape_Y[2]] = Y_data_voxels[i][:,:,:] 375 | 376 | return X_data,Y_data 377 | 378 | def load_X_Y_voxel_grids_test_next_batch(self,fix_sample=False): 379 | if fix_sample: 380 | random.seed(45) 381 | idx = random.sample(range(len(self.X_test_files)), self.batch_size) 382 | X_test_files_batch = [] 383 | Y_test_files_batch = [] 384 | for i in idx: 385 | X_test_files_batch.append(self.X_test_files[i]) 386 | Y_test_files_batch.append(self.Y_test_files[i]) 387 | 388 | X_test_batch, Y_test_batch = self.load_X_Y_voxel_grids(X_test_files_batch, Y_test_files_batch) 389 | return X_test_batch, Y_test_batch 390 | 391 | def load_X_Y_voxel_test_next_batch(self,fix_sample=False): 392 | if fix_sample: 393 | random.seed(45) 394 | idx = random.sample(range(len(self.test_locs)), self.batch_size) 395 | X_test_voxels_batch=[] 396 | Y_test_voxels_batch=[] 397 | for i in idx: 398 | temp_pair=self.test_locs[i] 399 | X_test_voxels_batch.append(self.dicom_origin[temp_pair[0]][temp_pair[1]]) 400 | Y_test_voxels_batch.append(self.mask[temp_pair[0]][temp_pair[1]]) 401 | X_data = np.zeros([self.batch_size,self.data_size[0],self.data_size[1],self.data_size[2]],np.float32) 402 | Y_data = np.zeros([self.batch_size,self.data_size[0],self.data_size[1],self.data_size[2]],np.float32) 403 | ''' 404 | X_test_voxels_batch=np.asarray(X_test_voxels_batch) 405 | Y_test_voxels_batch=np.asarray(Y_test_voxels_batch) 406 | ''' 407 | for i in range(len(X_test_voxels_batch)): 408 | temp_X = X_test_voxels_batch[i][:,:,:] 409 | temp_y = Y_test_voxels_batch[i][:,:,:] 410 | shape_X = np.shape(temp_X) 411 | shape_Y = np.shape(temp_y) 412 | X_data[i,:shape_X[0],:shape_X[1],:shape_X[2]] = X_test_voxels_batch[i][:,:,:] 413 | Y_data[i,:shape_Y[0],:shape_Y[1],:shape_Y[2]] = Y_test_voxels_batch[i][:,:,:] 414 | return X_data,Y_data 415 | 416 | ################### check datas 417 | def check_data(self): 418 | fail_list=[] 419 | tag=True 420 | for pair in self.train_locs: 421 | shape1 = np.shape(self.dicom_origin[pair[0]][pair[1]]) 422 | shape2 = np.shape(self.mask[pair[0]][pair[1]]) 423 | if shape1[0]==shape2[0]==self.data_size[0] and shape1[1]==shape2[1]==self.data_size[1] and shape1[2]==shape2[2]==self.data_size[2]: 424 | tag=True 425 | else: 426 | tag=False 427 | fail_list.append(pair) 428 | for pair in self.test_locs: 429 | shape1 = np.shape(self.dicom_origin[pair[0]][pair[1]]) 430 | shape2 = np.shape(self.mask[pair[0]][pair[1]]) 431 | if shape1[0]==shape2[0]==self.data_size[0] and shape1[1]==shape2[1]==self.data_size[1] and shape1[2]==shape2[2]==self.data_size[2]: 432 | tag=True 433 | else: 434 | tag=False 435 | fail_list.append(pair) 436 | print shape1 437 | print shape2 438 | print "==============================================" 439 | if tag: 440 | print "checked!" 441 | else: 442 | print "some are failed" 443 | for item in fail_list: 444 | print item 445 | 446 | class Ops: 447 | 448 | @staticmethod 449 | def lrelu(x, leak=0.2): 450 | f1 = 0.5 * (1 + leak) 451 | f2 = 0.5 * (1 - leak) 452 | return f1 * x + f2 * abs(x) 453 | 454 | @staticmethod 455 | def relu(x): 456 | return tf.nn.relu(x) 457 | 458 | @staticmethod 459 | def xxlu(x,name='relu'): 460 | if name =='relu': 461 | return Ops.relu(x) 462 | if name =='lrelu': 463 | return Ops.lrelu(x,leak=0.2) 464 | 465 | @staticmethod 466 | def variable_sum(var, name): 467 | with tf.name_scope(name): 468 | try: 469 | mean = tf.reduce_mean(var) 470 | tf.summary.scalar('mean', mean) 471 | stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))) 472 | tf.summary.scalar('stddev', stddev) 473 | tf.summary.scalar('max', tf.reduce_max(var)) 474 | tf.summary.scalar('min', tf.reduce_min(var)) 475 | tf.summary.histogram('histogram', var) 476 | except Exception,e: 477 | print e 478 | 479 | @staticmethod 480 | def variable_count(): 481 | total_para = 0 482 | for variable in tf.trainable_variables(): 483 | shape = variable.get_shape() 484 | variable_para = 1 485 | for dim in shape: 486 | variable_para *= dim.value 487 | total_para += variable_para 488 | return total_para 489 | 490 | @staticmethod 491 | def fc(x, out_d, name): 492 | xavier_init = tf.contrib.layers.xavier_initializer() 493 | zero_init = tf.zeros_initializer() 494 | in_d = x.get_shape()[1] 495 | w = tf.get_variable(name + '_w', [in_d, out_d], initializer=xavier_init) 496 | b = tf.get_variable(name + '_b', [out_d], initializer=zero_init) 497 | y = tf.nn.bias_add(tf.matmul(x, w), b) 498 | Ops.variable_sum(w, name) 499 | return y 500 | 501 | @staticmethod 502 | def maxpool3d(x,k,s,pad='SAME'): 503 | ker =[1,k,k,k,1] 504 | str =[1,s,s,s,1] 505 | y = tf.nn.max_pool3d(x,ksize=ker,strides=str,padding=pad) 506 | return y 507 | 508 | @staticmethod 509 | def conv3d(x, k, out_c, str, name,pad='SAME'): 510 | xavier_init = tf.contrib.layers.xavier_initializer() 511 | zero_init = tf.zeros_initializer() 512 | in_c = x.get_shape()[4] 513 | w = tf.get_variable(name + '_w', [k, k, k, in_c, out_c], initializer=xavier_init) 514 | b = tf.get_variable(name + '_b', [out_c], initializer=zero_init) 515 | 516 | stride = [1, str, str, str, 1] 517 | y = tf.nn.bias_add(tf.nn.conv3d(x, w, stride, pad), b) 518 | Ops.variable_sum(w, name) 519 | return y 520 | 521 | @staticmethod 522 | def deconv3d(x, k, out_c, str, name,pad='SAME'): 523 | xavier_init = tf.contrib.layers.xavier_initializer() 524 | zero_init = tf.zeros_initializer() 525 | bat, in_d1, in_d2, in_d3, in_c = [int(d) for d in x.get_shape()] 526 | w = tf.get_variable(name + '_w', [k, k, k, out_c, in_c], initializer=xavier_init) 527 | b = tf.get_variable(name + '_b', [out_c], initializer=zero_init) 528 | out_shape = [bat, in_d1 * str, in_d2 * str, in_d3 * str, out_c] 529 | stride = [1, str, str, str, 1] 530 | y = tf.nn.conv3d_transpose(x, w, output_shape=out_shape, strides=stride, padding=pad) 531 | y = tf.nn.bias_add(y, b) 532 | Ops.variable_sum(w, name) 533 | return y 534 | 535 | @staticmethod 536 | def batch_norm(x, name_scope, training, epsilon=1e-3, decay=0.999): 537 | '''Assume 2d [batch, values] tensor''' 538 | 539 | with tf.variable_scope(name_scope): 540 | size = x.get_shape().as_list()[-1] 541 | x_shape = x.get_shape() 542 | axis = list(range(len(x_shape) - 1)) 543 | scale = tf.get_variable('scale', [size], initializer=tf.constant_initializer(0.1)) 544 | offset = tf.get_variable('offset', [size]) 545 | 546 | pop_mean = tf.get_variable('pop_mean', [size], initializer=tf.zeros_initializer, trainable=False) 547 | pop_var = tf.get_variable('pop_var', [size], initializer=tf.ones_initializer, trainable=False) 548 | batch_mean, batch_var = tf.nn.moments(x, axis) 549 | 550 | train_mean_op = tf.assign(pop_mean, pop_mean * decay + batch_mean * (1 - decay)) 551 | train_var_op = tf.assign(pop_var, pop_var * decay + batch_var * (1 - decay)) 552 | 553 | def batch_statistics(): 554 | with tf.control_dependencies([train_mean_op, train_var_op]): 555 | return tf.nn.batch_normalization(x, batch_mean, batch_var, offset, scale, epsilon) 556 | 557 | def population_statistics(): 558 | return tf.nn.batch_normalization(x, pop_mean, pop_var, offset, scale, epsilon) 559 | 560 | return tf.cond(training, batch_statistics, population_statistics) --------------------------------------------------------------------------------